Effortlessly cache PyTorch module outputs on-the-fly with torchcache
.
Particularly useful for caching and serving the outputs of computationally expensive large, pre-trained PyTorch modules, such as vision transformers. Note that gradients will not flow through the cached outputs.
- Cache PyTorch module outputs either in-memory or persistently to disk.
- Simple decorator-based interface for easy usage.
- Uses an MRU (most-recently-used) cache to limit memory/disk usage
pip install torchcache
Quickly cache the output of your PyTorch module with a single decorator:
from torchcache import torchcache
@torchcache()
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# This output will be cached
return self.linear(x)
input_tensor = torch.ones(10, dtype=torch.float32)
# Output is cached during the first call...
output = model(input_tensor)
# ...and is retrieved from the cache for the next one
output_cached = model(input_tensor)
See documentation at torchcache.readthedocs.io for more examples.
To ensure seamless operation, torchcache
assumes the following:
- Your module is a subclass of
nn.Module
. - The module's forward method accepts any number of positional arguments with shapes
(B, *)
, whereB
is the batch size and*
represents any number of dimensions. All tensors should be on the same device and have the same dtype. - The forward method returns a single tensor of shape
(B, *)
.
- Ensure you have Python installed.
- Install
poetry
. - Run
poetry install
to set up dependencies. - Run
poetry run pre-commit install
to install pre-commit hooks. - Create a branch, make your changes, and open a pull request.