Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move StateDict to Haliax, clean it up a lot. #90

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/state_dict.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Serialization
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The main "draft-y" part of this PR is just the docs.


Haliax supports serialization of modules (including any [equinox.Module][]) to and from PyTorch-compatible
state dicts using the [safetensors](https://github.com/huggingface/safetensors) library. For details on
how state dicts work in PyTorch, see the [PyTorch documentation](https://pytorch.org/docs/stable/notes/serialization.html).

A state dict is a Python dictionary that maps string keys to tensors. It is used to store the parameters
of a model (though typically not the model's structure or hyperparameters). The keys are typically the names of the
model's parameters, arranged as `.`-separated paths. For example, a model with a `conv1` layer might have a
state dict with keys like `conv1.weight` and `conv1.bias`. Sequences of modules (e.g., for lists of layers) are
serialize with keys like `layer.0.weight`, `layer.1.weight`, etc.


## Saving a state dict

To serialize a module to a Pytorch-compatible state dict, use the [haliax.state_dict.to_torch_compatible_state_dict][]
function. This function takes a module and returns a state dict. To save the state dict to a file, use the
[haliax.state_dict.save_state_dict][] function, which writes the state dict to a file in safetensor format.
`to_torch_compatible_state_dict` flattens [haliax.nn.Linear] module Input and Output axis specs to a format that
is compatible with PyTorch Linear modules (though `out_first=True` is necessary to match PyTorch's default).
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"equinox>=0.10.6",
"jaxtyping>=0.2.20",
"jmp>=0.0.4",
"safetensors>=0.4.3"
]
dynamic =[ "version" ]

Expand Down Expand Up @@ -66,3 +67,4 @@ src_paths = ["src", "tests"]
[project.urls]
"Homepage" = "https://github.com/stanford-crfm/haliax"
"Bug Tracker" = "https://github.com/stanford-crfm/haliax/issues/"
"Documentation" = "https://haliax.readthedocs.io/en/latest/"
Loading
Loading