Skip to content

Commit

Permalink
test: compression and cache limits
Browse files Browse the repository at this point in the history
  • Loading branch information
meakbiyik committed Sep 1, 2023
1 parent bc17c1a commit 21662e1
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
101 changes: 101 additions & 0 deletions tests/test_torchcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ class CachedModule(SimpleModule):
output_cached = model(input_tensor)
assert torch.equal(output, output_cached)

# Third time is the charm, but let's use a bigger batch size
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32)
output_cached = model(input_tensor)
assert torch.equal(output, output_cached[:2])

# Argument checks
with pytest.raises(ValueError):

@torchcache(persistent=True, zstd_compression=True, use_mmap_on_load=True)
class CachedModule(SimpleModule):
pass

CachedModule()

with pytest.raises(ValueError):

@torchcache(persistent=False, zstd_compression=True)
class CachedModule(SimpleModule):
pass

CachedModule()


# Test caching mechanism with persistent storage.
def test_persistent_caching(tmp_path):
Expand Down Expand Up @@ -121,6 +143,85 @@ def test_hashing():
assert hashes.shape[0] == input_tensor.shape[0]


def test_compression(tmp_path):
@torchcache(persistent=True, persistent_cache_dir=tmp_path, zstd_compression=True)
class CachedModule(SimpleModule):
pass

model = CachedModule()
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)

# First pass, caching should occur and save to file
output = model(input_tensor)
assert torch.equal(output, input_tensor * 2)

# Check if cache files were created
assert len(list((tmp_path / model.cache_instance.module_hash).iterdir())) == 2

# Second pass, should retrieve from cache from memory
output_cached = model(input_tensor)
assert torch.equal(output, output_cached)

# Now create a new instance of the model and check if the cache is loaded from disk
# We re-define the class to flush the cache in memory
@torchcache(persistent=True, persistent_cache_dir=tmp_path, zstd_compression=True)
class CachedModule(SimpleModule):
pass

model2 = CachedModule()
original_load_from_file = model2.cache_instance._load_from_file
model2.cache_instance.original_load_from_file = original_load_from_file
load_from_file_called = False

def _load_from_file(*args, **kwargs):
nonlocal load_from_file_called
load_from_file_called = True
original_load_from_file(*args, **kwargs)

model2.cache_instance._load_from_file = _load_from_file
output_cached = model2(input_tensor)
assert torch.equal(output, output_cached)
assert load_from_file_called


# Test cache size limits
def test_cache_size(tmp_path):
# Overhead of saving a tensor in disk is around 700 bytes
@torchcache(
persistent=True,
persistent_cache_dir=tmp_path,
max_persistent_cache_size=1500,
max_memory_cache_size=20,
)
class CachedModule(SimpleModule):
pass

model = CachedModule()
input_tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
input_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)

# First pass, caching should occur and save to file
output = model(input_tensor1)
assert torch.equal(output, input_tensor1 * 2)

# Check if cache files were created
assert len(list((tmp_path / model.cache_instance.module_hash).iterdir())) == 2

# Check that the persistent flag is not set, but the memory flag is
assert not model.cache_instance.is_persistent_cache_full
assert model.cache_instance.is_memory_cache_full

# Now pass a tensor that is bigger than the cache size
output = model(input_tensor2)
assert torch.equal(output, input_tensor2 * 2)

# Check if cache files were not created
assert len(list((tmp_path / model.cache_instance.module_hash).iterdir())) == 2

# Check that the flag is set
assert model.cache_instance.is_persistent_cache_full


# Test for mixed cache hits
def test_mixed_cache_hits():
@torchcache(persistent=False)
Expand Down
23 changes: 17 additions & 6 deletions torchcache/torchcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def __init__(
improve the performance for large files, but it cannot be used together
with compression.
"""
if not persistent and zstd_compression:
raise ValueError("Cannot use zstd compression without persistent cache")

if zstd_compression and use_mmap_on_load:
raise ValueError(
"Cannot use zstd compression and mmap on load at the same time"
)

# Rolling powers of the hash base, up until 2**15 to fit in float16
roll_powers = torch.arange(0, subsample_count * 2) % 15
self.subsample_count = subsample_count
Expand All @@ -171,11 +179,6 @@ def __init__(
self.is_memory_cache_full = False
self.cache_dtype = cache_dtype

if self.zstd_compression and self.use_mmap_on_load:
raise ValueError(
"Cannot use zstd compression and mmap on load at the same time"
)

# We allow explicit overloading of mmap option despite version
# check so that people can use it with nightly versions
torch_version = torch.__version__.split(".")
Expand Down Expand Up @@ -598,7 +601,15 @@ def _load_from_file(self, hash_val: int) -> Union[Tensor, None]:
else:
if self.use_mmap_on_load:
load_kwargs["mmap"] = True
embedding = torch.load(str(file_path), **load_kwargs)
try:
embedding = torch.load(str(file_path), **load_kwargs)
except Exception as e:
logger.error(
f"Could not read file {file_path}, skipping loading from file. "
f"Error: {e}\nRemoving the file to avoid future errors."
)
file_path.unlink(missing_ok=True)
return None

logger.debug("Caching to memory before returning")
self._cache_to_memory(embedding, hash_val)
Expand Down

0 comments on commit 21662e1

Please sign in to comment.