Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move StateDict to Haliax, clean it up a lot. #90

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

dlwh
Copy link
Member

@dlwh dlwh commented Jun 1, 2024

No description provided.

@@ -0,0 +1,20 @@
# Serialization
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The main "draft-y" part of this PR is just the docs.

*,
key,
use_bias=True,
out_first: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this seems like it only affects the internals... for the purpose of simplifying the torch integration, would it be harmful to make out_first default to True (or remove/ignore it entirely and always output in torch compatible mode?).

If it's a performance issue, maybe transform to the out_first=False mode during loading?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more that GPT2's linear was out_first=False and I modeled Haliax's linear on that. I need to go through and make the default True...

@dlwh dlwh requested a review from rjpower June 10, 2024 04:46
@dlwh dlwh marked this pull request as ready for review June 10, 2024 04:46
Copy link
Contributor

@rjpower rjpower left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some nits & thoughts!

```

Similarly, to load weights from the state dict, you might need to implement `from_state_dict`. For the case of
`BackpackLMHeadModel`, we didn't because we can just not call the parent class's `from_state_dict` method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't because we can just not call the parent class's

For BackpackLMHeadModel, we can use the default implementation of from_state_dict.

(It's not obvious to me why you don't need from_state_dict here, given the remapping that happens on the update_state_dict side).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops that what I get for writing docs late at night.

The issue is that the backpacklm implementation in levanter enforces tied weights but they're not required to be so in HF. So we have to write it twice in the state dict, but we only bother to read it in once. Not ideal but honestly backpacks are not a high priority and might get chopped.

It's also not the most ideal example, but it's about the only place left where we need update_state_dict

return t


def _flatten_to_unflatten(t, state_dict, prefix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a quick comment:

"""Flatten the torch compatiblestate_dict before loading into t, and then recover the unflattened layers."""

This applies [haliax.state_dict.from_state_dict][] followed by [haliax.state_dict.unflatten_linear_layers][].
"""
if unflatten_linear:
t = _flatten_to_unflatten(t, state_dict, prefix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're probably doing this for a good reason, but it seems epsilon-cleaner to unflatten state_dict and then restore into t.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the modules themselves know what has to be flattened/unflattened (in terms of their specific memvars), and so I'd need to implement parallel logic to do it in the state dict. I did think this solution was a bit too clever, but it works and it lets me be lazy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unless i'm missing something!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no that makes sense, if the flattening is module-specific, then it does seem kind of a pain to try to hoist something separate alongside. This is fine, it just stuck out to me.

src/haliax/_src/state_dict.py Outdated Show resolved Hide resolved
src/haliax/_src/state_dict.py Outdated Show resolved Hide resolved
src/haliax/_src/state_dict.py Outdated Show resolved Hide resolved


def default_eqx_module_from_state_dict(mod: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod:
key_map: Dict[str, Optional[str]] = getattr(mod, "_state_dict_key_map", lambda: {})() # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it feels a little magical to have this (e.g. "block": None) vs having the user override to_state_dict entirely and explicitly specifying the names they want. Just a thought, feel free to ignore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hrm yeah, i guess the parallel in my head was that apply_prefix(None, suffix) == suffix and the None here means the same thing

src/haliax/_src/state_dict.py Show resolved Hide resolved
src/haliax/_src/state_dict.py Show resolved Hide resolved
Stack all keys matching prefix in a new state dict, returning a state dict that has all keys matching
prefix stacked, but otherwise the same.

Stacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is the same for stacked and unstacked below. Maybe:

"The unstacked format is compatible with torch.nn.Sequential, with keys of the form (...). The stacked/vectorized format is required for haliax.nn.Stacked and vectorizes all such tensors into a single shared key.".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants