Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Kaggle Model Hub Integration for Torchtune #1852

Open
KeijiBranshi opened this issue Oct 16, 2024 · 8 comments
Open

[RFC] Kaggle Model Hub Integration for Torchtune #1852

KeijiBranshi opened this issue Oct 16, 2024 · 8 comments
Labels
rfc Request for comments

Comments

@KeijiBranshi
Copy link

Authors: @KeijiBranshi @rosbo @mrisdal @neshdev at.bflynn

Summary

This RFC proposes extending torchtune to support loading pre-trained and fine-tuned model weights directly from the Kaggle Model Hub. This integration aims to expand the accessibility of models within torchtune and contribute to the adoption of both PyTorch/torchtune and the Kaggle Model Hub by streamlining the experience for Kaggle users.

Motivation

This proposal aligns with PyTorch's objective of integrating with partner platforms to increase Torchtune adoption (KR 3.2). By adding support for Kaggle, we can:

  • Increase Model Accessibility: Provide Torchtune users with a wider selection of pre-trained models and community-shared fine-tuned weights, fostering a more diverse model ecosystem.

  • Community Engagement: This integration can foster collaboration between the PyTorch and Kaggle communities, leading to increased contributions and a more diverse model ecosystem.

Other Potential Benefits

  • Streamline Kaggle Competition Workflow: Enable seamless torchtune model loading within Kaggle competition notebooks, which have internet access restrictions. This eliminates the need for workarounds currently required to use PyTorch models in competitions.
  • Deeper Kaggle Notebook Integration: Using kagglehub allows for better integration with Kaggle Notebooks, enabling features like automatic model detection and UI enhancements within the notebook environment.

Prior Art

Similar functionality for loading models from various hubs exists in other deep learning libraries. Keras, for example, provides a unified mechanism for loading models from both Hugging Face and Kaggle using URI schemes (see the Keras documentation).

Proposed Implementation

We propose extending torchtune's model loading mechanism to recognize and handle Kaggle Model Hub URIs. This will involve the following:

  1. URI Scheme Recognition: Torchtune can be updated to recognize model URIs using the kaggle:// scheme. While Hugging Face will remain the default source for models, we could also add support for explicit Hugging Face URIs using the hf:// scheme for increased clarity.

  2. Kaggle Hub Integration: Leverage the kagglehub Python library to handle the download and upload of model weights to and from the Kaggle Model Hub.

Using the above, we would modify torchtune's model loading logic to:

  1. Detect the URI scheme (kaggle:// or hf://).
  2. Utilize kagglehub for downloading weights from Kaggle when a kaggle:// URI is provided.
  3. Maintain the existing Hugging Face integration for models without a URI scheme or those explicitly using hf://.

Example Usage:

Users will be able to download a model from Kaggle using a command like:

tune download kaggle://metaresearch/llama-3.2/pyTorch/3b \
--output-dir /tmp/llama-3.2-3b \
--kaggle-username <KAGGLE_USERNAME> \
--kaggle-api-key <KAGGLE_API_KEY>

Considerations

  • Backward Compatibility: This change should not affect existing functionality with Hugging Face models.

  • Dependencies: Torchtune will need to add kagglehub as a dependency. Would the introduction of fsspec for more general URI scheme handling be a desirable enhancement, even if it adds complexity?

  • --output-dir Argument: Since kagglehub utilizes a default cache folder, should the --output-dir argument be optional or required for Kaggle models? What are the preferred behaviors and potential implications of each approach?

  • Documentation: We are willing to contribute to the torchtune documentation to include instructions and examples for using Kaggle Model Hub URIs. Guidance on the documentation update process and procedures would be greatly appreciated.

  • Testing: Develop comprehensive tests to ensure the correct functionality of Kaggle model loading and compatibility with existing features.

Call for Feedback

We’d love feedback from the PyTorch community on this proposal. Please share your thoughts, suggestions, and any potential concerns you may have.

Happy modeling,
The Kaggle Team

@joecummings joecummings added the rfc Request for comments label Oct 16, 2024
@joecummings
Copy link
Contributor

Thanks for the RFC @KeijiBranshi - we're really excited about providing another source for model integration. I'm still closely considering some of the implementation details, so I'll respond to those individually.

@joecummings
Copy link
Contributor

Weight formats:

Kaggle supports loading models in checkpointing formats that are able to be loaded into different modeling libraries. These are often differentiated by the names "PyTorch" or "transformers" (Hugging Face). Until now, our only model integration source has been the Hugging Face Hub or from Meta directly, which means that we have guarantees that our models can load checkpoints in the "transformers" format or the Meta Llama format. We make no guarantees about any other checkpoint formats.

At first glance, what this practically means is that if a user tries to load in e.g. Gemma 2 using the following path: google/gemma-2-2b-jpn-it/flax, we should throw an error that "flax" is not supported. Where this gets potentially a little trickier is with respect to the "PyTorch" format. For Llama models, this format designates the aforementioned Meta Llama format that we already support. But for e.g. Gemma models, this is the native format released by Google, which means we would have to write additional logic to convert into our torchtune model format.

On a longer time scale, I'd like to be able to support loading in checkpoints in both the format they were released in + the transformers format, but we don't have the bandwidth to do that right now. So practically, we'd want to throw an error if someone tries to load a model in any format other than "transformers" unless the organization is "metaresearch", in which case we would also support "PyTorch".

Please let me know if I'm missing some details here or if this is too restrictive.

@joecummings
Copy link
Contributor

joecummings commented Oct 16, 2024

URI Scheme Recognition:

I think I have a slight preference for a UX similar to llama-stack wherein the source is specified as a param e.g.:

tune download metaresearch/llama-3.2/pyTorch/3b --source kaggle

The reason for this is that the "path" for models on the Hugging Face Hub is very recognizable as the entry point to downloading any of their models e.g.

AutoModel.from_pretrained("openai/whisper-large-v3-turbo")

or

huggingface-cli download openai/whisper-large-v3-turbo

Attaching a URI of "hf" would slightly obfuscate that recognition. In addition, without a prefix, there would be no need for more complex URI scheme handling.

Open to thoughts.

@joecummings
Copy link
Contributor

--output-dir Argument:

We're considering changes to our current download process + checkpointing API that would default to downloading the model to the source's cache as a default. So it should be no problem to make --output-dir optional.

@KeijiBranshi
Copy link
Author

Thanks for the thoughtful feedback! We appreciate you taking the time to review our proposal and provide your insights.

@KeijiBranshi
Copy link
Author

Weight formats:

Kaggle supports loading models in checkpointing formats that are able to be loaded into different modeling libraries. These are often differentiated by the names "PyTorch" or "transformers" (Hugging Face). Until now, our only model integration source has been the Hugging Face Hub or from Meta directly, which means that we have guarantees that our models can load checkpoints in the "transformers" format or the Meta Llama format. We make no guarantees about any other checkpoint formats.

At first glance, what this practically means is that if a user tries to load in e.g. Gemma 2 using the following path: google/gemma-2-2b-jpn-it/flax, we should throw an error that "flax" is not supported. Where this gets potentially a little trickier is with respect to the "PyTorch" format. For Llama models, this format designates the aforementioned Meta Llama format that we already support. But for e.g. Gemma models, this is the native format released by Google, which means we would have to write additional logic to convert into our torchtune model format.

On a longer time scale, I'd like to be able to support loading in checkpoints in both the format they were released in + the transformers format, but we don't have the bandwidth to do that right now. So practically, we'd want to throw an error if someone tries to load a model in any format other than "transformers" unless the organization is "metaresearch", in which case we would also support "PyTorch".

Please let me know if I'm missing some details here or if this is too restrictive.

RE: Weight Formats #1852 (comment)

Agreed that focusing on valid PyTorch and Transformers formats is a good first step. Filtering out incompatible frameworks (e.g. flax) with string manipulation seems straightforward. But excluding PyTorch downloads to just the Metaresearch models might yield an awkward experience. Namely, when a user publishes a torchtune fine-tuned model to Kaggle, they would not be able to download their own model later using torchtune.

Should we instead consider doing some post-download validation? In other words, download the model payload, but have torchtune check that the files are properly formatted before proceeding?

Open to discussing these options and finding the best approach that balances initial simplicity with long-term flexibility.

@KeijiBranshi
Copy link
Author

KeijiBranshi commented Oct 17, 2024

URI Scheme Recognition:

I think I have a slight preference for a UX similar to llama-stack wherein the source is specified as a param e.g.:

tune download metaresearch/llama-3.2/pyTorch/3b --source kaggle

The reason for this is that the "path" for models on the Hugging Face Hub is very recognizable as the entry point to downloading any of their models e.g.

AutoModel.from_pretrained("openai/whisper-large-v3-turbo")

or

huggingface-cli download openai/whisper-large-v3-turbo

Attaching a URI of "hf" would slightly obfuscate that recognition. In addition, without a prefix, there would be no need for more complex URI scheme handling.

Open to thoughts.

RE: URI Scheme Recognition #1852 (comment)

We're happy to defer to your expertise on the use of the --source parameter, especially if it helps make the UX more consistent across similar libraries.

FWIW, HuggingFace supports the hf:// scheme in some of its own tooling (see documentation). But I understand that it’s not a universal concept for model URIs.

@KeijiBranshi
Copy link
Author

--output-dir Argument:

We're considering changes to our current download process + checkpointing API that would default to downloading the model to the source's cache as a default. So it should be no problem to make --output-dir optional.

RE: --output-dir Argument #1852 (comment)

Thanks for sharing those considerations around --output-dir. Currently, kagglehub doesn’t allow users to change the download directory, but we have some related requests on our side to support it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Request for comments
Projects
None yet
Development

No branches or pull requests

2 participants