-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move state_dict
serialization to Haliax
#35
Comments
Also think this would be nice, b/c, e.g., in my nanoGPT implementation I have to install all of levanter just to use its serialization features. Seems like this may be a large-ish task, are there any areas I could potentially help with? |
I think so! I think we could start by:
I don't think it's a ton of work, but it's not trivial either |
@rohan-mehta-1024 i merged the change I was making into Levanter, so if you're excited about this, happy to let you take it! Happy to discuss more if you want |
Sorry for the delay on this! I've just been a busy with starting school. I plan to work on it over the weekend. |
no problem at all. Happy you're looking at it at all! |
Sorry if this is obvious, but what did you mean by |
sorry that was very unclear. By that I mean, you could write something like:
and have the mapping work. It's not a necessary step but I think it's a nicer api. And re: the hf_checkpoints stuff, right. I was thinking of basically just a function to do this https://github.com/stanford-crfm/levanter/blob/16112cc003680e79e15fc7c62b63f917679e7e32/src/levanter/compat/hf_checkpoints.py#L403-L407 + the safetensors variant |
Ok, thanks for the clarification! |
That's a good point. I think hax.field is probably a good thing to do—it's
not like it's mandatory that we use it everywhere—but we don't have to
decide just yet
…On Wed, Dec 20, 2023 at 7:29 AM rohan-mehta-1024 ***@***.***> wrote:
Ok, thanks for the clarification! eqx.field passes all kwargs that are
not converter or static directly to dataclasses.field, which does not
accept arbitrary kwargs. So we would either have to do eqx.field(metadata={'state_dict_key'
: 'foo'}) (which isn't as clean), or find somewhat of a hacky way around
this. Would it be worth creating a hax.field function which is basically
the same as eqx.field except that if you pass it a kwarg that is not
accepted by dataclasses.field it automatically shoves it in metadata? The
only downside to this would be replacing a lot of instances of eqx.field...unless
maybe this is general enough that we could make a pull request for this
change directly in Equinox? I'm not sure how else to get it to work besides
these two ways though...
—
Reply to this email directly, view it on GitHub
<#35 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACLIN6GRATZPVS2Q56AEDYKL75HAVCNFSM6AAAAAA4MEPMOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRUGY3TKNJRHA>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Ok, so for now should I just go ahead with a "soft" version of |
Not sure what that means? But my guess is probably :-)
…On Wed, Dec 20, 2023 at 10:49 AM rohan-mehta-1024 ***@***.***> wrote:
Ok, so for now should I just go ahead with a "soft" version of hax.field
and then we can modify this approach as we see fit later?
—
Reply to this email directly, view it on GitHub
<#35 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACLIIRL4G7E524ZYZPNRDYKMXKVAVCNFSM6AAAAAA4MEPMOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRUHE3TKNJXGU>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Yeah, that was poor wording, sorry (basically I think I'm just going to go the route of explicitly using metadata and then once everything else has been ported over successfully, I guess we can decide from there, since it should just be a pretty minor stylistic change)! I was also wondering if we should make the following stylistic change. When flattening/unflattening layers, the type signature is
Then you only need to do:
Instead of:
Since the function can infer this. However, for stacking/unstacking layers, the function does not explicitly take in the block to be stacked/unstacked. So even if you have:
You still have do:
Would it make sense to rewrite the function to have it take in the block to stack/unstack, so that you could do this instead:
It would probably complicate the underlying function, but it would also allow for automatically inferring prefixes from state_dict_key_map. Also, it seems like it would be a little more explicit, since you don't have to reason through the prefixes to find out what is being stacked/unstacked. I was just wondering if this line of thought makes sense to you, or if there are some other reasons I haven't identified why it would be better to keep things as they are? |
I'm confused. How does it know what layer it is?
…On Thu, Dec 21, 2023 at 8:30 AM rohan-mehta-1024 ***@***.***> wrote:
Yeah, that was poor wording, sorry (basically I think I'm just going to go
the route of explicitly using metadata and then once everything else has
been ported over successfully, I guess we can decide from there, since it
should just be a pretty minor stylistic change)! I was also wondering if we
should make the following stylistic change. When flattening/unflattening
layers, the type signature is (prefix, statedict: StateDict, layer:
hnn.Linear, out_dims_first_in_dict: Optional[bool]). Since we pass the
layer itself in, the function can use the state_dict_key_map to lookup any
necessary prefixes. So, if you have the following:
def _state_dict_key_map(self):
return {'attn' : 'c_attn', 'proj' : 'c_proj'}
Then you only need to do:
unflatten_linear_layers(prefix, state_dict, self.attn, None)
Instead of:
unflatten_linear_layers(apply_prefix(prefix, 'c_attn'), state_dict,
self.attn, None)
Since the function can infer this. However, for stacking/unstacking
layers, the function does not explicitly take in the block to be
stacked/unstacked. So even if you have:
def _state_dict_key_map(self):
return {
"blocks" : "h",
"tok_embedding_table" : "wte",
"pos_embedding_table" : "wpe"
}
You still have do:
stacked_params = stack_state_dict(state_dict, prefix=apply_prefix(prefix,
"h"))
Would it make sense to rewrite the function to have it take in the block
to stack/unstack, so that you could do this instead:
stacked_params = stack_state_dict(state_dict, self.blocks, prefix=prefix)
It would probably complicate the underlying function, but it would also
allow for automatically inferring prefixes from state_dict_key_map. Also,
it seems like it would be a little more explicit, since you don't have to
reason through the prefixes to find out what is being stacked/unstacked. I
was just wondering if this line of thought makes sense to you, or if there
are some other reasons I haven't identified why it would be better to keep
things as they are?
—
Reply to this email directly, view it on GitHub
<#35 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACLIMSZN2AODVAJCJYWYLYKRP4BAVCNFSM6AAAAAA4MEPMOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRWGYYDSMZRGM>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Maybe I'm misunderstanding, but the following code, e.g., in unflatten_linear_layers:
has a use_state_dict_keys argument which allows it to access the state_dict_key_map. And here is the relevant code from leaf_key_paths where it does access this mapping:
So I don't understand why you would explicitly have to do: |
oh well, the convention in all these methods is that they receive the path they're supposed to use. self.c_attn doesn't know that it's supposed to be named self.c_attn. It would need information from the parent ( |
Oh ok, I see, that makes sense now (sorry for the confusion). I think I've basically made all the necessary changes at this point (the only other thing I'm wondering about is what you meant by |
Oh wow, thank you! I think buffers thing can be a separate PR/future issue. I agree lighter weight tests are better here and I agree round-tripping to/from torch (and Haliax itself) is probably the best route. Maybe like, a few tests that test only-Haliax roundtrip functionality (dicts, renames, reorders, whatever), and then a few tests that only fire with torch installed that make sure roundtrips work there? |
Currently Levanter has a bunch of machinery to support serializing to/from state_dicts, as well as writing to safetensors. I think it would be move that functionality to Haliax.
While doing it, I would like to revamp a few things:
eqx.field
The text was updated successfully, but these errors were encountered: