Skip to content

Commit

Permalink
Check that HF repo exists
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 28, 2024
1 parent 78da0e5 commit bea56e2
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 26 deletions.
16 changes: 2 additions & 14 deletions candle-holder-models/src/from_pretrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use hf_hub::{
use serde::{Deserialize, Serialize};

use candle_holder::{
get_repo_api,
utils::from_pretrained::{load_model_config, FromPretrainedParameters, MODEL_CONFIG_FILE},
Error, Result,
};
Expand Down Expand Up @@ -193,20 +194,7 @@ pub fn from_pretrained<I: AsRef<str>>(
repo_id: I,
params: Option<FromPretrainedParameters>,
) -> Result<ModelInfo> {
let params = params.unwrap_or_default();

let repo = Repo::with_revision(
repo_id.as_ref().to_string(),
RepoType::Model,
params.revision,
);

let mut builder = ApiBuilder::new();
if let Some(token) = params.auth_token {
builder = builder.with_token(Some(token));
}
let api = builder.build()?;
let api = api.repo(repo);
let api = get_repo_api(repo_id.as_ref(), params)?;

// Get the model configuration from `config.json`
let config = match api.get(MODEL_CONFIG_FILE) {
Expand Down
12 changes: 2 additions & 10 deletions candle-holder-tokenizers/src/from_pretrained.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, fs};

use candle_holder::{
get_repo_api,
utils::from_pretrained::{load_model_config, FromPretrainedParameters, MODEL_CONFIG_FILE},
Result,
};
Expand Down Expand Up @@ -414,16 +415,7 @@ pub fn from_pretrained<I: AsRef<str>>(
repo_id: I,
params: Option<FromPretrainedParameters>,
) -> Result<TokenizerInfo> {
let params = params.unwrap_or_default();

let repo = Repo::with_revision(
repo_id.as_ref().to_string(),
RepoType::Model,
params.revision,
);

let api = Api::new()?;
let api = api.repo(repo);
let api = get_repo_api(repo_id.as_ref(), params)?;

let tokenizer_config = match api.get(TOKENIZER_CONFIG_FILE) {
Ok(tokenizer_config_file) => {
Expand Down
2 changes: 2 additions & 0 deletions candle-holder/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub enum Error {
// -----------------------------------
// From pretrained errors
// -----------------------------------
#[error("Repository '{0}' not found.")]
RepositoryNotFound(String),
#[error("Model '{0}' is not implemented. Create a new issue in 'https://github.com/gabrielmbmb/candle-holder' to request the implementation.")]
ModelNotImplemented(String),
#[error("Tokenizer '{0}' is not implemented. Create a new issue in 'https://github.com/gabrielmbmb/candle-holder' to request the implementation.")]
Expand Down
2 changes: 1 addition & 1 deletion candle-holder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod error;
pub mod utils;

pub use error::Error;
pub use utils::from_pretrained::FromPretrainedParameters;
pub use utils::from_pretrained::{get_repo_api, FromPretrainedParameters};

/// A type alias for `Result<T, Error>` for the `candle-holder` crate.
pub type Result<T> = std::result::Result<T, Error>;
31 changes: 30 additions & 1 deletion candle-holder/src/utils/from_pretrained.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use crate::Result;
use hf_hub::{
api::sync::{Api, ApiRepo},
Repo, RepoType,
};

use crate::{Error, Result};
use std::{collections::HashMap, fs};

pub const MODEL_CONFIG_FILE: &str = "config.json";
Expand All @@ -20,6 +25,30 @@ impl Default for FromPretrainedParameters {
}
}

/// Gets a [`ApiRepo`] instance from the provided repository ID using the provided parameters. It
/// will check if the repository exists.
///
/// # Arguments
///
/// * `repo_id` - The repository ID.
/// * `params` - The parameters to use when creating the API instance.
///
/// # Returns
///
/// The API instance.
pub fn get_repo_api(repo_id: &str, params: Option<FromPretrainedParameters>) -> Result<ApiRepo> {
let params = params.unwrap_or_default();
let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, params.revision);
let api = Api::new()?.repo(repo);

// Check if the repository exists
if api.info().is_err() {
return Err(Error::RepositoryNotFound(repo_id.to_string()));
}

Ok(api)
}

/// Loads the model configuration from the provided file path.
///
/// # Arguments
Expand Down

0 comments on commit bea56e2

Please sign in to comment.