Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal working example of safetensors support for hezar #157

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Adversarian
Copy link

Pull Request

Description

The save and load methods for the hezar.models.Model class have been altered to accommodate for safetensors support.

Changes

  • Loading safetensors exposed through the keyword argument load_safetensors under Model.load().
  • Saving safetensors exposed through the keyword argument safe_serialization under Model.save().
  • SAFETENSORS added to the enumlist of Backends in constants.py.
  • A minimal working example unittest is available at test_safetensors.py. This serves just to demonstrate the changes and is by no means a comprehensive test ready for production.

Related Issues

Resolves #153

Additional Comments

Please keep in mind that this is meant to serve as a draft PR and is by no means production-ready. The changes made to the code base are very crude and the aim was only to show how it may be possible to incorporate safetensors support in hezar. After the architectural details and design decisions regarding how and where this change should be introduced within the codebase are approved, I can happily edit this PR (or submit a new one) with a cleaner code that adheres to this library's standards.

Notes

  • Model.push_to_hub() also requires changes but for the purposes of this quick prototype it was left unchanged for the time being.

@Adversarian
Copy link
Author

@arxyzan I think this is what you meant I should be doing 😅. If there's anything further wrong with the PR, do let me know please.

@arxyzan
Copy link
Member

arxyzan commented Apr 24, 2024

@Adversarian Thanks Arian jan, I tested your code and it works perfectly. As a next step, I will go through the code and apply necessary changes if needed. In the meanwhile, can you please test this on other models too? (Note that you can create models from configs on the hub instead of downloading the weights from there OR you can use google colab which downloads models in seconds).
Overall, I think your changes adheres perfectly with the standards of this library. The most important thing here is to design and implement a flawless migration pipeline.

As a bonus, this change will also enable model downloads count on the Hub! :))) (see #56)

@Adversarian
Copy link
Author

@arxyzan My pleasure! Sure, I can try it on other models, I'll add them to test_safetensors.py and update the PR. I'm also psyched about seeing counters go up on HF now! Let me know if there's anything else I can do to help with this.

@Adversarian
Copy link
Author

@arxyzan I added a more comprehensive test suite for safetensors and tried to test it out on Colab since I don't have access to premium internet at the moment. It looks like we're failing on mask-filling and text-generation tasks on RoBERTa and GPT-2 respectively while the remaining two tasks are passing with their respective models.

I'll try to investigate myself soon but in the meantime, here's the notebook with the tests performed if you'd like to take a look yourself.

@arxyzan
Copy link
Member

arxyzan commented Apr 28, 2024

@Adversarian Thanks Arian, I think that's exactly the case that failed for me too. Back then I didn't test other models since I thought my conversion code was buggy.
Can you please open an issue on this if you can?

@Adversarian
Copy link
Author

@arxyzan No problem at all Aryan jan. I don't think this is an issue regarding safetensors or huggingface unless you meant for me to open an issue on this here, on hezar. I've been swamped again so I haven't had time to investigate this fully but looking at HF's save_pretrained method, it looks like saving a model in safetensors format shouldn't be as trivial as I have made it out to be in my draft PR. I may have skipped a ton of preprocessing which we would have to add back into the save method.

I will try to take a look at this again as soon as I'm able to but is there really no way to subclass HF's models for hezar.models.Model so that we can just directly use save_pretrained? Is there something that's holding you back from doing this? I'm only asking because that seems to be the cleanest solution if at all possible and the burden of any further development on this specific matter would be offloaded to HF with their multitude of smart (and paid!) developers.

Again, sorry for the late response and looking forward to hearing your thoughts on the subject.

@arxyzan
Copy link
Member

arxyzan commented Apr 30, 2024

@Adversarian Thanks for putting the time into it man.
The fact is, subclassing HF models (transformers.PretrainedModel) for the Model class in Hezar is a big overkill and adds a lot of redundant and unnecessary add-ins to the hezar.models.Model class. Our primary is to keep Hezar simple meaning that for example subclassing hezar.model.Model would be really similar to subclassing torch.nn.Module.
Plus, HF's PretrainedModel has some additional behavior injected into the models which are specific to Transformers architecture like tied layers, precision control, etc. which makes it unreasonable to use it as the base class for non-Transformer models.
I accept the fact that most of Hezar's current models rely on Transformers and directly use them inside the models modules but that does not add boilerplate code to the base model module class.
Idk if I explained this well but I hope you see what I mean.

@Adversarian
Copy link
Author

@arxyzan I understand, thanks for your explanation. We can make it happen. As I mentioned before I'm a tiny bit swamped at the moment but I'll get back on this first thing after.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Weight conversion to safetensors format
2 participants