diff --git a/docs/state-dict.md b/docs/state-dict.md new file mode 100644 index 0000000..7cdc54e --- /dev/null +++ b/docs/state-dict.md @@ -0,0 +1,160 @@ +# Serialization + +Haliax supports serialization of modules (including any [equinox.Module][]) to and from PyTorch-compatible +state dicts using the [safetensors](https://github.com/huggingface/safetensors) library. For details on +how state dicts work in PyTorch, see the [PyTorch documentation](https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-torch-nn-modules). + +A state dict is a Python dictionary that maps string keys to tensors. It is used to store the parameters +of a model (though typically not the model's structure or hyperparameters). The keys are typically the names of the +model's parameters, arranged as `.`-separated paths. For example, a model with a `conv1` layer might have a +state dict with keys like `conv1.weight` and `conv1.bias`. Sequences of modules (e.g., for lists of layers) are +serialize with keys like `layer.0.weight`, `layer.1.weight`, etc. + +Haliax uses the [safetensors](https://github.com/huggingface/safetensors) library to serialize state dicts. This +library is a safer, more portable format developed by Hugging Face. Serializing a native PyTorch state dict requires +PyTorch itself, and we want to avoid that dependency. Also, PyTorch uses pickles, which are in general not +safe to deserialize from untrusted sources. + +This does mean that you can't directly load a Haliax state dict into PyTorch, but safetensors is lightweight and +easy to use. Hugging Face natively supports it in their libraries. + +## Saving a State Dict + +To serialize a module to a Pytorch-compatible state dict, use the [haliax.state_dict.to_torch_compatible_state_dict][] +function. This function takes a module and returns a state dict. To save the state dict to a file, use the +[haliax.state_dict.save_state_dict][] function, which writes the state dict to a file in safetensor format. +`to_torch_compatible_state_dict` flattens [haliax.nn.Linear] module input and output axis specs to a format that +is compatible with PyTorch Linear modules (though `out_first=True` is necessary to match PyTorch's Linear module). + +```python +import haliax +import jax.random as jrandom + +# Create a module +Heads = haliax.Axis("Heads", 8) +Dim = haliax.Axis("Dim", 16) +Out = haliax.Axis("Out", 5) +module = haliax.nn.Linear.init(In=(Heads, Dim), Out=Out, key=jrandom.PRNGKey(0)) + +# Serialize the module to a state dict +state_dict = haliax.state_dict.to_torch_compatible_state_dict(module) + +# Save the state dict to a file +haliax.state_dict.save_state_dict(state_dict, 'state_dict.safetensors') +``` + +Note that the state dict is saved in the [safetensors](https://github.com/huggingface/safetensors) format, which +is a safer, more portable format developed by Hugging Face. To load a model from a state dict in PyTorch, you +can use safetensors directly. + +```python +import torch +from safetensors.torch import load_model + +model = torch.nn.Linear(10, 5) + +# Load the state dict from a file +state_dict = load_model(model, 'state_dict.safetensors') +``` + +## Loading a State Dict + +Similarly, you can load a state dict from a file using the [haliax.state_dict.load_state_dict][] function. This +function reads a state dict from a file in safetensors format and returns a dictionary. To load the state dict +into a module, use the [haliax.state_dict.from_torch_compatible_state_dict][] function. + +```python +import haliax as hax +import jax.random as jrandom + +# Create a module +Heads = hax.Axis("Heads", 8) +Dim = hax.Axis("Dim", 16) +Out = hax.Axis("Out", 5) +module = hax.nn.Linear.init(In=(Heads, Dim), Out=Out, key=jrandom.PRNGKey(0)) + +# Load the state dict from a file +state_dict = hax.state_dict.load_state_dict('state_dict.safetensors') + +# this will unflatten the state dict and load it into the module +module = hax.state_dict.from_torch_compatible_state_dict(module, state_dict) +``` + +The `from_torch_compatible_state_dict` function will unflatten the state dict and load it into the module. Note +that the module must have the same structure as the module that was serialized to the state dict. If the module +structure has changed, you may need to manually update the state dict keys to match the new structure. + + +## Customizing Serialization + +### Changing the State Dict Key Names + +If for some reason you want to use different names in the serialized state dict (e.g. because you +chose to use different names from a Hugging Face implementation), you can extend your class from [haliax.state_dict.ModuleWithStateDictSerialization][] +and use `_state_dict_key_map` to rename keys. For instance, the `Gpt2Transformer` class in Levanter has this method: + +```python +from typing import Optional +from haliax.state_dict import ModuleWithStateDictSerialization + +class Gpt2Transformer(ModuleWithStateDictSerialization): + ... + + def _state_dict_key_map(self) -> dict[str, Optional[str]]: + return {"blocks": "h"} +``` + +This says that the field called `blocks` in this class should be (de)serialized as `h`, +because the Hugging Face GPT-2 implementation uses `h`, which is not very clear. +You can also "flatten" the submodules of a field by using `None`. + +### Custom Serialization Logic + +If your modules need fancier special logic, you'll need to extend your class from `ModuleWithStateDictSerialization` and +override the default functions `to_state_dict()` and `from_state_dict()`. It takes in and returns a modified +[haliax.state_dict.StateDict][]. As of June 2024, we almost never this in Levanter. + +For implementation, there are a few helper methods from `haliax.state_dict` that you can use: +- To join specific prefix to the keys of Hugging Face state_dict, you can use the helper function `with_prefix()`. + The prefix comes from the name of attributes defined at the beginning of your model class. + +For example, below is the implementation of `to_state_dict()` in [levanter.models.backpack.BackpackLMHeadModel][]. +In this class, we want to preserve HF compatibility by saving untied output embeddings. (We chose not to implement +non-weight-tied embeddings.) + +```python +from typing import Optional + +from haliax.state_dict import with_prefix, StateDict + + +class BackpackLMHeadModel(ModuleWithStateDictSerialization): + ... + + def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: + state_dict = super().to_state_dict(prefix=prefix) + # In levanter's implementation, we have a shared embedding matrix for both the word + # embeddings and the sense embeddings + state_dict[with_prefix(prefix, "backpack.word_embeddings.weight")] = state_dict[ + with_prefix(prefix, "backpack.gpt2_model.wte.weight") + ] + state_dict[with_prefix(prefix, "backpack.position_embeddings.weight")] = state_dict[ + with_prefix(prefix, "backpack.gpt2_model.wpe.weight") + ] + return state_dict +``` + +Similarly, to load weights from the state dict, you might need to implement `from_state_dict`. This function +takes in a state dict and the module with the updated weights. You can use the `with_prefix()` helper function +to join the prefix to the keys of the state dict. + +```python + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> T: + ... + + +``` + +## API Reference + +::: haliax.state_dict diff --git a/mkdocs.yml b/mkdocs.yml index d90c971..3e194a9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -88,5 +88,6 @@ nav: - Partitioning: 'partitioning.md' - Higher Order Functions: 'hof.md' - FP8: 'fp8.md' + - Serialization: 'state-dict.md' - API Reference: 'api.md' - FAQ: 'faq.md' diff --git a/pyproject.toml b/pyproject.toml index 4b1d0f7..1d40bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "equinox>=0.10.6", "jaxtyping>=0.2.20", "jmp>=0.0.4", + "safetensors>=0.4.3" ] dynamic =[ "version" ] @@ -68,3 +69,4 @@ src_paths = ["src", "tests"] [project.urls] "Homepage" = "https://github.com/stanford-crfm/haliax" "Bug Tracker" = "https://github.com/stanford-crfm/haliax/issues/" +"Documentation" = "https://haliax.readthedocs.io/en/latest/" diff --git a/src/haliax/__init__.py b/src/haliax/__init__.py index 7aea845..56bd65c 100644 --- a/src/haliax/__init__.py +++ b/src/haliax/__init__.py @@ -14,6 +14,7 @@ import haliax.nn as nn import haliax.quantization as quantization import haliax.random as random +import haliax.state_dict as state_dict import haliax.tree_util as tree_util import haliax.util as util @@ -889,6 +890,7 @@ def true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric, /) -> NamedOrNumeric: "random", "tree_util", "nn", + "state_dict", "Axis", "AxisSpec", "AxisSelection", diff --git a/src/haliax/_src/einsum.py b/src/haliax/_src/einsum.py index 6a65f44..5d08358 100644 --- a/src/haliax/_src/einsum.py +++ b/src/haliax/_src/einsum.py @@ -9,6 +9,7 @@ from ..axis import Axis, AxisSelector, axis_name, eliminate_axes, rearrange_for_partial_order, union_axes from ..core import NamedArray from ..jax_utils import _jittable_dg_einsum +from ..quantization import DotGeneralOp from ..types import DTypeLike, PrecisionLike from ..util import ensure_tuple from .parsing import AliasTable, parse_einsum, raise_parse_error @@ -19,8 +20,8 @@ def einsum( *arrays: NamedArray, precision: PrecisionLike = None, preferred_element_type: Optional[DTypeLike] = None, - _dot_general=jax.lax.dot_general, - **axis_aliases, + _dot_general: DotGeneralOp = jax.lax.dot_general, + **axis_aliases: AxisSelector, ) -> NamedArray: """Compute the tensor contraction of the input arrays according to Haliax's named variant of the Einstein summation convention. diff --git a/src/haliax/_src/state_dict.py b/src/haliax/_src/state_dict.py new file mode 100644 index 0000000..205c08c --- /dev/null +++ b/src/haliax/_src/state_dict.py @@ -0,0 +1,495 @@ +# Module to support torch-style "state dict" serialization via safetensors +import dataclasses +import re +import typing +from typing import Any, Optional, Sequence, TypeVar, cast + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from jax import ShapeDtypeStruct +from jax.experimental.multihost_utils import sync_global_devices +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.tree_util import DictKey, FlattenedIndexKey, GetAttrKey, SequenceKey +from jaxtyping import PyTree + +import haliax.partitioning as partitioning +from haliax._src.util import index_where +from haliax.core import NamedArray, named +from haliax.jax_utils import is_jax_array_like + + +try: + import safetensors +except ImportError: + safetensors = None + + +StateDict = dict[str, Any] +Mod = TypeVar("Mod", bound=eqx.Module) +T = TypeVar("T") + + +def from_torch_compatible_state_dict( + t: T, state_dict: StateDict, *, unflatten_linear: bool = True, prefix: Optional[str] = None +) -> T: + """ + Convert a state dict to a tree that is compatible with the structure of `t`. + + This applies [haliax.state_dict.from_state_dict][] followed by [haliax.state_dict.unflatten_linear_layers][]. + """ + if unflatten_linear: + t = _flatten_to_unflatten(t, state_dict, prefix) + else: + t = from_state_dict(t, state_dict, prefix=prefix) + + return t + + +def _flatten_to_unflatten(t, state_dict, prefix): + """ + Flatten the torch compatible state_dict before loading into t, and then recover the unflattened layers. + """ + # typically, `t` is a bunch of ShapeDtypeStructs, which can't be transposed etc. so we instead have to zeros() + # into real arrays (that aren't actually real b/c this is inside a jit) + def _dt_struct_to_array(struct): + if not isinstance(struct, ShapeDtypeStruct): + return struct + return jnp.zeros(struct.shape, struct.dtype) + + t = jax.tree.map(_dt_struct_to_array, t) + flat_t = flatten_linear_layers(t) + flat_t = from_state_dict(flat_t, state_dict, prefix=prefix) + t = unflatten_linear_layers(t, flat_t) + return t + + +@typing.overload +def with_prefix(prefix: str | None, leaf: str) -> str: + ... + + +@typing.overload +def with_prefix(prefix: str, leaf: None) -> str: + ... + + +@typing.overload +def with_prefix(prefix: Optional[str], leaf: Optional[str]) -> Optional[str]: + ... + + +def with_prefix(prefix: Optional[str], leaf: Optional[str]) -> Optional[str]: + """Joins two optional path strings in a way compatible with pytorch state dict serialization""" + if prefix is None: + return leaf + elif leaf is None: + return prefix + else: + return f"{prefix}.{leaf}" + + +class ModuleWithStateDictSerialization(eqx.Module): + """An eqx.Module that can be serialized to a torch-style state dict.""" + + def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: + return default_eqx_module_to_state_dict(self, prefix) + + def from_state_dict(self: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod: + return default_eqx_module_from_state_dict(self, state_dict, prefix) + + def _state_dict_key_map(self) -> dict[str, Optional[str]]: + """Returns a dict mapping eqx.Module keys to torch keys that need to be renamed for serialization""" + return {} + + +def from_state_dict(tree: T, state_dict: StateDict, prefix: Optional[str] = None) -> T: + """ + Given a (template) tree and a state dict, return a new tree with the same structure as the input tree, but with + the values from the state dict. + + Args: + tree: The template tree + state_dict: The state dict + prefix: The prefix to use when looking up keys in the state dict + + Returns: + A new tree with the same structure as the input tree, but with the values from the state dict. + + """ + # TODO: assert compatibility of old and new values (type, shape, etc.) + if isinstance(tree, eqx.Module): + if hasattr(tree, "from_state_dict"): + return tree.from_state_dict(state_dict, prefix) + else: + return default_eqx_module_from_state_dict(tree, state_dict, prefix) + elif isinstance(tree, list): + return [from_state_dict(item, state_dict, with_prefix(prefix, str(i))) for i, item in enumerate(tree)] # type: ignore + elif isinstance(tree, dict): + return {k: from_state_dict(v, state_dict, prefix=with_prefix(prefix, k)) for k, v in tree.items()} # type: ignore + elif isinstance(tree, NamedArray): + if prefix is None: + raise ValueError("Cannot extract a leaf value from a torch dict without a prefix") + + array = state_dict[prefix] + + if isinstance(array, np.ndarray): + mesh = partitioning._get_mesh() + # TODO: modernize this + if mesh.devices.size > 1: # this happens with the default mesh + pspec = partitioning.pspec_for_axis(tree.axes) + sharding = jax.sharding.NamedSharding(mesh, pspec) + array = jax.make_array_from_callback(tree.array.shape, sharding, lambda indices: array[indices]) + else: + array = jnp.array(array) + array = named(array, tree.axes) + else: + array = named(array, tree.axes) + array = partitioning.auto_sharded(array) + + return array + elif is_jax_array_like(tree): + if prefix is None: + raise ValueError("Cannot extract a leaf value from a state dict without a prefix") + # TODO: add "strict" flag so we can return None in cases where it's just missing + return jnp.array(state_dict[prefix]) + else: + if prefix is None: + return tree + return state_dict.get(prefix, tree) + + +def to_state_dict(tree: PyTree, prefix: Optional[str] = None) -> StateDict: + """ + Convert a PyTree to a state dict. + + Returns: + The state dict representation of the input tree. + """ + if isinstance(tree, eqx.Module): + if hasattr(tree, "to_state_dict"): + state_dict = tree.to_state_dict(prefix) + else: + state_dict = default_eqx_module_to_state_dict(tree, prefix) + elif isinstance(tree, list): + state_dict = {} + for i, item in enumerate(tree): + child = to_state_dict(item, prefix=with_prefix(prefix, str(i))) + # TODO: check for conflicts? + state_dict.update(child) + elif isinstance(tree, dict): + state_dict = {} + for k, v in tree.items(): + child = to_state_dict(v, prefix=with_prefix(prefix, k)) + # TODO: check for conflicts? + state_dict.update(child) + elif isinstance(tree, NamedArray): + if prefix is None: + raise ValueError("Cannot convert a leaf value to a state dict without a prefix") + if tree.array is not None: + state_dict = {prefix: tree.array} + else: + state_dict = {} + elif is_jax_array_like(tree): + if prefix is not None: + if tree is not None: + state_dict = {prefix: tree} + else: + state_dict = {} + else: + raise ValueError("Cannot convert a leaf value to a state dict without a prefix") + else: + raise ValueError(f"Unsupported type {type(tree)}") + + return state_dict + + +def default_eqx_module_from_state_dict(mod: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod: + key_map: Dict[str, Optional[str]] = getattr(mod, "_state_dict_key_map", lambda: {})() # type: ignore + names = [] + values = [] + for field in dataclasses.fields(mod): + if field.metadata.get("static", False): + continue + key = key_map.get(field.name, field.name) + value = getattr(mod, field.name) + # TODO: might want to add a flag that allows missing keys? + new = from_state_dict(value, state_dict, with_prefix(prefix, key)) + # Do not try to update parameters that are never defined + if value is None and new is None: + continue + names.append(field.name) + values.append(new) + return eqx.tree_at(lambda m: [getattr(m, name) for name in names], mod, values) + + +def default_eqx_module_to_state_dict(mod: eqx.Module, prefix: Optional[str] = None) -> StateDict: + """ + Convert an eqx.Module to a state dict. This is the default implementation of the to_state_dict method for + eqx.Modules. It works by iterating over the fields of the module and calling to_state_dict on each field. + Args: + mod: + prefix: + + Returns: + + """ + state_dict: StateDict = {} + key_map: Dict[str, Optional[str]] = getattr(mod, "_state_dict_key_map", lambda: {})() # type: ignore + for field in dataclasses.fields(mod): + if field.metadata.get("static", False): + continue + key = key_map.get(field.name, field.name) + value = getattr(mod, field.name) + child = to_state_dict(value, with_prefix(prefix, key)) + # TODO: should we check for conflicts? + state_dict.update(child) + return state_dict + + +def format_path_for_state_dict(prefix: Optional[str], path: Sequence) -> str: + res = "".join(_format_key_path_element(path_elem) for path_elem in path) + # res will have a . + if prefix is not None: + res = f"{prefix}{res}" + elif res.startswith("."): + res = res[1:] + + return res + + +# Torch compatible KeyPath formatting. Torch just always uses . +def _format_key_path_element(path_elem) -> str: + match path_elem: + case SequenceKey(idx): # type: ignore + return f".{idx}" + case DictKey(key): # type: ignore + return f".{key}" + case GetAttrKey(): # type: ignore + return str(path_elem) + case FlattenedIndexKey(idx): # type: ignore + return f".{idx}" + case _: + # The convention in JAX is to append the separator in the element itself + # so we expect it to have + path_elem = str(path_elem) + if path_elem.startswith("."): + return path_elem + else: + return f".{path_elem}" + + +def to_numpy_state_dict(model, prefix: Optional[str] = None) -> StateDict: + """ + Convert a model to a state dict by first creating desharded copies of all parameters that reside in CPU + memory. + + This method is especially useful for saving models distributed across multiple hosts. + """ + + with jax.default_device(jax.local_devices(backend="cpu")[0]): + + def get_to_cpu(arr): + if not is_jax_array_like(arr): + return arr + elif isinstance(arr, np.ndarray): + return arr + elif arr.is_fully_addressable: + r = np.array(arr) + return r + else: + # unfortunately, jax's allgather seems to replicate to every device rather than every host + # which doesn't work for ~7B parameter models on TPU (assuming we also have optimizer state) + # this approach limits us to <64B parameters, but that's good enough for now + # we're going to do something a bit fancy, where we shard the model into a (process, device) mesh, + # then look for some axis along which we can shard the array, and then we'll do an allgather + # via pjit. If we can't find one, we'll just fully replicate since it probably isn't that big. + # TODO: ensure that this mesh arranges devices correctly + # (jax seems to do this internally itself, so we should be fine?) + process_mesh = Mesh(np.array(jax.devices()).reshape((jax.process_count(), -1)), ("process", "device")) + # now we need to find an axis along which we can shard the array. + # for this, we need to find an axis s.t. size(axis) % local_devices == 0 + + try: + axis_to_shard = index_where( + lambda axis_size: axis_size % process_mesh.devices.size == 0, arr.shape + ) + except ValueError: + return np.array(arr) + + shardings = [None if i != axis_to_shard else "device" for i in range(len(arr.shape))] + sharding = NamedSharding(process_mesh, PartitionSpec(*shardings)) + out = jax.jit(lambda x: x, out_shardings=sharding)(arr) + return np.array(out) + + # need to make sure the model is on *this machine* and *this machine's CPU* before saving + model = jax.tree.map(lambda arr: get_to_cpu(arr), model) + # TODO: it would be nice if safetensors supported an iterator or something so we could do the allgather one at a time + state_dict = to_state_dict(model, prefix=prefix) + return state_dict + + +_GLOBAL_SAVE_COUNT = 0 + + +def save_state_dict(state_dict: StateDict, path): + """ + Save a model's state dict to a file, bringing all tensors to the CPU first and then converting to numpy. + This will save using safetensors format + """ + state_dict = {k: v for k, v in state_dict.items() if v is not None} + # now that we've moved the model to the CPU, we don't need to do this on all processes + if jax.process_index() == 0: + # the "pt" is a lie but it doesn't seem to actually matter and HF demands it + safetensors.numpy.save_file(state_dict, path, metadata={"format": "pt"}) + global _GLOBAL_SAVE_COUNT + sync_global_devices(f"save_state_dict {_GLOBAL_SAVE_COUNT}") + _GLOBAL_SAVE_COUNT += 1 + + +def load_state_dict(path): + """ + Load a model's state dict from a file, bringing all tensors to the CPU first and then converting to numpy. + This will load using safetensors format + """ + state_dict = safetensors.numpy.load_file(path) + return state_dict + + +def stack_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + """ + Stack all keys matching prefix in a new state dict, returning a state dict that has all keys matching + prefix stacked, but otherwise the same. + + Stacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the + keys are of the form ".0.", ".1.", etc. + + Mostly for use with [haliax.nn.Stacked][]. + """ + vectorized_dict: StateDict = {} + + tensors_to_vectorize: dict[str, list[Optional[Any]]] = {} + if prefix is not None: + prefix_for_pat = re.escape(prefix + ".") + else: + prefix_for_pat = "" + pattern = re.compile(rf"{prefix_for_pat}(\d+)\.(.*)") + + for k, v in state_dict.items(): + match = pattern.match(k) + if match: + block_idx = int(match.group(1)) + block_key = match.group(2) + tensors = tensors_to_vectorize.setdefault(block_key, []) + if len(tensors) <= block_idx: + tensors.extend([None] * (block_idx - len(tensors) + 1)) + assert tensors[block_idx] is None, f"Duplicate key {k}" + tensors[block_idx] = v + else: + vectorized_dict[k] = v + + # now we have to vectorize the tensors + for k, tensors in tensors_to_vectorize.items(): + vectorized_dict[cast(str, with_prefix(prefix, k))] = jnp.stack(tensors, axis=0) + + return vectorized_dict + + +def unstack_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + """ + Unstack all keys matching prefix in a new state dict, returning a state dict that has all keys matching + prefix unstacked, but otherwise the same. Mostly for use with [haliax.nn.Stacked][]. + + Unstacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the + keys are of the form ".0.", ".1.", etc. + """ + new_dict: StateDict = {} + prefix = with_prefix(prefix, "") + assert prefix is not None + + for k, v in state_dict.items(): + if k.startswith(prefix) and v is not None: + for i, v_i in enumerate(v): + new_dict[f"{prefix}{i}.{k[len(prefix):]}"] = v_i + else: + new_dict[k] = v + + return new_dict + + +def flatten_linear_layers(tree: T) -> T: + """ + In PyTorch, linear layers are stored as a 2d weight matrix and a 1d bias vector. In Haliax, + linear layers can have arbitrary dimensions, grouped into input and output axes. This function + flattens the linear layers in a tree to be compatible with PyTorch-style state dicts. + + :param tree: + """ + from haliax.nn import Linear + + def _flatten_linear(layer): + if not isinstance(layer, Linear): + return layer + + weight = layer.weight + bias = layer.bias + + if weight.array is not None: + out_first = layer.out_first + weight = weight.flatten_axes(layer.Out, "__OUT__").flatten_axes(layer.In, "__IN__") + + if out_first: + weight = weight.rearrange((..., "__OUT__", "__IN__")) + else: + weight = weight.rearrange((..., "__IN__", "__OUT__")) + + if bias is not None: + bias = bias.flatten_axes(layer.Out, "__OUT__") + + In = weight.resolve_axis("__IN__") + Out = weight.resolve_axis("__OUT__") + + return dataclasses.replace(layer, weight=weight, bias=bias, In=In, Out=Out) # type: ignore + else: + return layer + + return jax.tree.map(_flatten_linear, tree, is_leaf=lambda x: isinstance(x, Linear)) + + +def unflatten_linear_layers(template: T, tree_with_flattened_linears: T) -> T: + """ + Unflattens linear layers in a tree that was flattened with [haliax.state_dict.flatten_linear_layers][]. + Template has the same structure as the tree that was flattened, but with the original (unflattened) + linear layers. + + Returns: + The same tree as `tree_with_flattened_linears`, but with the linear layers unflattened to match + the structure of `template`. + """ + + from haliax.nn import Linear + + def _unflatten_linear(template, flattened): + assert isinstance(template, Linear) == isinstance(flattened, Linear) + + if not isinstance(template, Linear): + return flattened + + weight = flattened.weight + bias = flattened.bias + + if weight.array is not None: + weight = weight.unflatten_axis("__OUT__", template.Out).unflatten_axis("__IN__", template.In) + weight = weight.rearrange(template.weight.axes) + + if bias is not None: + bias = bias.unflatten_axis("__OUT__", template.Out) + assert template.bias is not None, "Flattened bias but template has no bias" + bias = bias.rearrange(template.bias.axes) + + return dataclasses.replace(template, weight=weight, bias=bias) # type: ignore + + return jax.tree.map( + _unflatten_linear, template, tree_with_flattened_linears, is_leaf=lambda x: isinstance(x, Linear) + ) diff --git a/src/haliax/core.py b/src/haliax/core.py index 57550ed..f97a48b 100644 --- a/src/haliax/core.py +++ b/src/haliax/core.py @@ -1205,7 +1205,9 @@ def unflatten_axis(array: NamedArray, axis: AxisSelector, new_axes: AxisSpec) -> raise ValueError("Must specify at least one axis to split") if axis_size != prod(ax.size for ax in new_axes): - raise ValueError(f"Cannot split {axis} into {new_axes}: size mismatch") + raise ValueError( + f"Cannot split {axis} into {new_axes}: size mismatch ({axis_size} != {prod(ax.size for ax in new_axes)})" + ) resolved_new_axes = array.axes[:old_index] + tuple(new_axes) + array.axes[old_index + 1 :] new_array = jnp.reshape(array.array, [ax.size for ax in resolved_new_axes]) diff --git a/src/haliax/nn/linear.py b/src/haliax/nn/linear.py index ffba87c..cbf697e 100644 --- a/src/haliax/nn/linear.py +++ b/src/haliax/nn/linear.py @@ -1,9 +1,8 @@ import math -from typing import Callable, Optional +from typing import Optional import equinox as eqx -import jax.lax -from jaxtyping import PRNGKeyArray +from jax.random import PRNGKey import haliax as hax @@ -29,14 +28,13 @@ def init( In: AxisSpec, Out: AxisSpec, *, - key, - use_bias=True, - out_first: bool = False, - dot_general=None, + key: PRNGKey, + use_bias: bool = True, + out_first: bool = True, + dot_general: Optional[DotGeneralOp] = None, init_scale: float = 1.0, ) -> "Linear": """ - Args: In: AxisSpec: The input axis spec Out: AxisSpec: The output axis spec @@ -57,7 +55,7 @@ def init( return Linear(weight, bias, In, Out, dot_general=dot_general) @named_call - def __call__(self, inputs, *, key: Optional[PRNGKeyArray] = None): + def __call__(self, inputs, *, key: Optional[PRNGKey] = None): """ Args: inputs (NamedArray): Input array diff --git a/src/haliax/nn/pool.py b/src/haliax/nn/pool.py index 9aa5597..8c3c048 100644 --- a/src/haliax/nn/pool.py +++ b/src/haliax/nn/pool.py @@ -48,6 +48,7 @@ def pool( of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension, or an integer to pad all dimensions. use_ceil: if True, will use ceil instead of floor to compute the output shape + Returns: The output of the reduction for each window slice. """ @@ -154,12 +155,12 @@ def max_pool( Window: the size of the window to pool over inputs: input data with dimensions (batch, window dims..., features). stride: a sequence of `n` integers, representing the inter-window - stride (default: `(1, ..., 1)`). + stride (default: `(1, ..., 1)`). padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. + of `n` `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. Returns: - The maximum value in each window slice. + The maximum value in each window slice. """ return pool(Window, inputs, -float("inf"), jax.lax.max, stride, padding, use_ceil=use_ceil) @@ -184,7 +185,7 @@ def min_pool( and after each spatial dimension. use_ceil: if True, will use ceil instead of floor to compute the output shape Returns: - The minimum value in each window slice. + The minimum value in each window slice. """ return pool(Window, inputs, float("inf"), jax.lax.min, stride, padding, use_ceil=use_ceil) @@ -204,11 +205,10 @@ def mean_pool( Args: Window: the size of the window to pool over inputs: input data with dimensions (batch, window dims..., features). - stride: a sequence of `n` integers, representing the inter-window - stride (default: `(1, ..., 1)`). + stride: a sequence of `n` integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. + of `n` `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. Returns: The mean value in each window slice. """ diff --git a/src/haliax/nn/scan.py b/src/haliax/nn/scan.py index 2baf888..dfe25b4 100644 --- a/src/haliax/nn/scan.py +++ b/src/haliax/nn/scan.py @@ -8,6 +8,13 @@ import haliax.util from haliax.jax_utils import filter_checkpoint +from .._src.state_dict import ( + ModuleWithStateDictSerialization, + StateDict, + stack_state_dict, + unstack_state_dict, + with_prefix, +) from ..axis import Axis @@ -49,7 +56,7 @@ def unstacked(self) -> Sequence[M]: ... -class BlockSeq(eqx.Module, Generic[M]): +class BlockSeq(ModuleWithStateDictSerialization, Generic[M]): """ A "BlockSeq" wraps another module and produces a "sequential" version of it, where an input is applied to each instance of the sequential module in sequence. This is useful for e.g. transformers @@ -134,8 +141,25 @@ def _slice_out(Block, i, x): def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"blocks": None} + def from_state_dict(self: M, state_dict: StateDict, prefix: Optional[str] = None) -> M: + out_blocks = [] + for i, block in enumerate(self.blocks): + my_prefix = with_prefix(prefix, str(i)) + block = block.from_state_dict(state_dict, my_prefix) + out_blocks.append(block) + + return eqx.tree_at(lambda m: m.blocks, self, out_blocks) + + def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: + state_dict: StateDict = {} + for i, block in enumerate(self.blocks): + my_prefix = with_prefix(prefix, str(i)) + state_dict.update(block.to_state_dict(my_prefix)) + + return state_dict -class Stacked(eqx.Module, Generic[M]): + +class Stacked(ModuleWithStateDictSerialization, Generic[M]): """ A "Stacked" wraps another module and produces a "stacked" version of it, where an input is applied to each instance of the stacked module in sequence. This is useful for e.g. transformers @@ -254,3 +278,17 @@ def unbatch_leaf(x): # now we need to transpose the leaves unstacked_leaves = tuple(zip(*unstacked_leaves)) return tuple(map(lambda x: jax.tree_util.tree_unflatten(structure, x), unstacked_leaves)) + + def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: + # this method needs to "devectorize" the blocks, so that we have a list of blocks h.0.FOO, h.1.FOO, etc. + # first just do the normal thing with our own dict, which we'll post-process + state_dict: StateDict = super().to_state_dict(prefix) + + return unstack_state_dict(state_dict, prefix) + + def from_state_dict(self: M, state_dict: StateDict, prefix: Optional[str] = None) -> M: + # this method needs to "vectorize" the blocks, so that we have a single block h.FOO + # first just do the normal thing with our own dict, which we'll post-process + stacked = stack_state_dict(state_dict, prefix=prefix) + out = super().from_state_dict(stacked, prefix=prefix) # type: ignore + return out diff --git a/src/haliax/partitioning.py b/src/haliax/partitioning.py index 95c5d86..17908c1 100644 --- a/src/haliax/partitioning.py +++ b/src/haliax/partitioning.py @@ -382,7 +382,7 @@ def named_jit( donate_args: Optional[PyTree] = None, donate_kwargs: Optional[PyTree] = None, **pjit_args, -): +) -> typing.Union[WrappedCallable[Args, R], typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]]]: """ A version of pjit that uses NamedArrays and the provided resource mapping to infer resource partitions for sharded computation for. @@ -395,6 +395,8 @@ def named_jit( Functionally this is very similar to something like: + This function can be used as a decorator or as a function. + ```python def wrapped_fn(arg): result = fn(arg) @@ -425,7 +427,7 @@ def wrapped_fn(arg): if fn is None: return functools.partial( # type: ignore - named_jit, + named_jit, # type: ignore axis_resources=axis_resources, in_axis_resources=in_axis_resources, out_axis_resources=out_axis_resources, diff --git a/src/haliax/state_dict.py b/src/haliax/state_dict.py new file mode 100644 index 0000000..67d7542 --- /dev/null +++ b/src/haliax/state_dict.py @@ -0,0 +1,45 @@ +from typing import Optional, TypeVar + +from ._src.state_dict import ( + ModuleWithStateDictSerialization, + StateDict, + flatten_linear_layers, + from_state_dict, + from_torch_compatible_state_dict, + load_state_dict, + save_state_dict, + to_numpy_state_dict, + to_state_dict, + unflatten_linear_layers, + with_prefix, +) + + +T = TypeVar("T") + + +def to_torch_compatible_state_dict(t: T, *, flatten_linear: bool = True, prefix: Optional[str] = None) -> StateDict: + """ + Convert a tree to a state dict that is compatible with torch-style state dicts. + + This applies [haliax.state_dict.flatten_linear_layers][] followed by [haliax.state_dict.to_state_dict][] + """ + if flatten_linear: + t = flatten_linear_layers(t) + return to_numpy_state_dict(t, prefix=prefix) + + +__all__ = [ + "ModuleWithStateDictSerialization", + "from_torch_compatible_state_dict", + "load_state_dict", + "save_state_dict", + "from_state_dict", + "flatten_linear_layers", + "unflatten_linear_layers", + "with_prefix", + "to_state_dict", + "to_numpy_state_dict", + "StateDict", + "to_torch_compatible_state_dict", +] diff --git a/tests/test_scan.py b/tests/test_scan.py index d01896a..8a21d5e 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -68,3 +68,34 @@ def init(named, array, static): y_seq = m_seq.fold(x, key=jax.random.split(jax.random.PRNGKey(2), Block.size)) assert hax.all(hax.isclose(y, y_seq, atol=1e-5)) + + +def test_stacked_to_state_dict(): + class Module(eqx.Module): + named: hax.NamedArray + array: jax.Array + static: int = eqx.static_field() + + def __call__(self, x, *, key): + return x + self.array + self.static + hax.random.normal(key, x.axes) + + @staticmethod + def init(named, array, static): + return Module(named=named, array=array, static=static) + + Block = hax.Axis("block", 4) + E = hax.Axis("E", 10) + + initial_named = hax.random.uniform(jax.random.PRNGKey(0), (Block, E)) + + m = Stacked.init(Block, Module)(named=initial_named, array=jax.numpy.ones(Block.size), static=1) + + state_dict = m.to_state_dict() + m2 = m.from_state_dict(state_dict) + input = hax.random.uniform(jax.random.PRNGKey(1), (E,)) + key = jax.random.split(jax.random.PRNGKey(2), Block.size) + + y = m.fold(input, key=key) + y2 = m2.fold(input, key=key) + + assert hax.all(hax.equal(y, y2)) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py new file mode 100644 index 0000000..49aceae --- /dev/null +++ b/tests/test_state_dict.py @@ -0,0 +1,129 @@ +from typing import Any + +import equinox as eqx +import jax +import jax.numpy as jnp +import pytest + +import haliax as hax +from haliax._src.state_dict import stack_state_dict, unstack_state_dict +from haliax.nn import Linear +from haliax.state_dict import flatten_linear_layers, from_state_dict, to_state_dict, unflatten_linear_layers + + +@pytest.mark.parametrize("out_dims_first", [True, False]) +def test_flatten_linear_layers(out_dims_first: bool): + H = hax.Axis("H", 10) + W = hax.Axis("W", 20) + D = hax.Axis("D", 30) + B = hax.Axis("B", 40) + linear = hax.nn.Linear.init((H, W), (D, B), key=jax.random.PRNGKey(0), use_bias=True, out_first=out_dims_first) + + if out_dims_first: + assert linear.weight.axes == (D, B, H, W) + else: + assert linear.weight.axes == (H, W, D, B) + + flat_linear = flatten_linear_layers(linear) + + flat_state_dict = to_state_dict(flat_linear) + if out_dims_first: + assert flat_state_dict["weight"].shape == (D.size * B.size, H.size * W.size) + else: + assert flat_state_dict["weight"].shape == (H.size * W.size, D.size * B.size) + assert flat_state_dict["bias"].shape == (D.size * B.size,) + assert flat_state_dict["weight"].dtype == flat_state_dict["bias"].dtype == linear.weight.dtype + + # now unflatten it + linear2 = Linear.init((H, W), (D, B), key=jax.random.PRNGKey(1), use_bias=True, out_first=out_dims_first) + new_linear = unflatten_linear_layers(linear2, flat_linear) + + if out_dims_first: + assert new_linear.weight.axes == (D, B, H, W) + else: + assert new_linear.weight.axes == (H, W, D, B) + assert new_linear.bias.axes == (D, B) # type: ignore + + assert linear == new_linear + + +# Test cases for stack_state_dict +@pytest.mark.parametrize( + "input_dict, prefix, expected_output", + [ + # Single block stacking + ( + { + "block.0.weight": jnp.array([1, 2]), + "block.0.bias": jnp.array([3]), + "block.1.weight": jnp.array([4, 5]), + "block.1.bias": jnp.array([6]), + }, + "block", + { + "block.weight": jnp.array([[1, 2], [4, 5]]), + "block.bias": jnp.array([[3], [6]]), + }, + ), + # Mixed data types and unmatched items remain unchanged + ( + { + "block.0.weight": jnp.array([1, 2]), + "block.0.bias": jnp.array([3]), + "block.1.weight": jnp.array([4, 5]), + "block.1.bias": jnp.array([6.0]), + "unrelated.item": jnp.array([7]), + }, + "block", + { + "block.weight": jnp.array([[1, 2], [4, 5]]), + "block.bias": jnp.array([[3.0], [6.0]]), + "unrelated.item": jnp.array([7]), + }, + ), + # No items match prefix, all items should remain unchanged + ( + { + "module.0.param": jnp.array([1]), + "module.1.param": jnp.array([2]), + }, + "block", + { + "module.0.param": jnp.array([1]), + "module.1.param": jnp.array([2]), + }, + ), + ], +) +def test_stack_state_dict(input_dict, prefix, expected_output): + result = stack_state_dict(input_dict, prefix) + for key in expected_output: + assert jnp.all(jnp.array_equal(result[key], expected_output[key])), f"Failed on key: {key}" + + # now unstack it + unstacked = unstack_state_dict(result, prefix) + for key in input_dict: + assert jnp.all(jnp.array_equal(unstacked[key], input_dict[key])), f"Failed on key: {key}" + + +class M(eqx.Module): + a: Any + b: Any + + def __init__(self, a, b): + self.a = a + self.b = b + + +def test_to_from_state_dict(): + a = jnp.array([1, 2]) + b = jnp.array([3, 4]) + m = M(a, b) + + state_dict = to_state_dict(m) + assert state_dict == {"a": a, "b": b} + + m2 = M(jnp.array([0, 0]), jnp.array([0, 0])) + m2 = from_state_dict(m2, state_dict) + assert jnp.all(m2.a == a) + assert jnp.all(m2.b == b)