-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move StateDict to Haliax, clean it up a lot. #90
base: main
Are you sure you want to change the base?
Conversation
docs/state_dict.md
Outdated
@@ -0,0 +1,20 @@ | |||
# Serialization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(The main "draft-y" part of this PR is just the docs.
src/haliax/nn/linear.py
Outdated
*, | ||
key, | ||
use_bias=True, | ||
out_first: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this seems like it only affects the internals... for the purpose of simplifying the torch integration, would it be harmful to make out_first
default to True (or remove/ignore it entirely and always output in torch compatible mode?).
If it's a performance issue, maybe transform to the out_first=False mode during loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more that GPT2's linear was out_first=False and I modeled Haliax's linear on that. I need to go through and make the default True...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some nits & thoughts!
docs/state-dict.md
Outdated
``` | ||
|
||
Similarly, to load weights from the state dict, you might need to implement `from_state_dict`. For the case of | ||
`BackpackLMHeadModel`, we didn't because we can just not call the parent class's `from_state_dict` method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we didn't because we can just not call the parent class's
For BackpackLMHeadModel
, we can use the default implementation of from_state_dict
.
(It's not obvious to me why you don't need from_state_dict
here, given the remapping that happens on the update_state_dict
side).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops that what I get for writing docs late at night.
The issue is that the backpacklm implementation in levanter enforces tied weights but they're not required to be so in HF. So we have to write it twice in the state dict, but we only bother to read it in once. Not ideal but honestly backpacks are not a high priority and might get chopped.
It's also not the most ideal example, but it's about the only place left where we need update_state_dict
return t | ||
|
||
|
||
def _flatten_to_unflatten(t, state_dict, prefix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a quick comment:
"""Flatten the torch compatiblestate_dict
before loading into t
, and then recover the unflattened layers."""
This applies [haliax.state_dict.from_state_dict][] followed by [haliax.state_dict.unflatten_linear_layers][]. | ||
""" | ||
if unflatten_linear: | ||
t = _flatten_to_unflatten(t, state_dict, prefix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're probably doing this for a good reason, but it seems epsilon-cleaner to unflatten state_dict
and then restore into t
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because the modules themselves know what has to be flattened/unflattened (in terms of their specific memvars), and so I'd need to implement parallel logic to do it in the state dict. I did think this solution was a bit too clever, but it works and it lets me be lazy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(unless i'm missing something!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, no that makes sense, if the flattening is module-specific, then it does seem kind of a pain to try to hoist something separate alongside. This is fine, it just stuck out to me.
|
||
|
||
def default_eqx_module_from_state_dict(mod: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod: | ||
key_map: Dict[str, Optional[str]] = getattr(mod, "_state_dict_key_map", lambda: {})() # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: it feels a little magical to have this (e.g. "block": None
) vs having the user override to_state_dict
entirely and explicitly specifying the names they want. Just a thought, feel free to ignore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hrm yeah, i guess the parallel in my head was that apply_prefix(None, suffix) == suffix
and the None here means the same thing
Stack all keys matching prefix in a new state dict, returning a state dict that has all keys matching | ||
prefix stacked, but otherwise the same. | ||
|
||
Stacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is the same for stacked
and unstacked
below. Maybe:
"The unstacked format is compatible with torch.nn.Sequential
, with keys of the form (...). The stacked/vectorized format is required for haliax.nn.Stacked
and vectorizes all such tensors into a single shared key.".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, thanks
No description provided.