diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 93f1e50825..7a555b00af 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -35,6 +35,10 @@ enum Which { V31, V3Instruct, V31Instruct, + V32_1b, + V32_1bInstruct, + V32_3b, + V32_3bInstruct, #[value(name = "solar-10.7b")] Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] @@ -137,6 +141,10 @@ fn main() -> Result<()> { Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(), Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(), + Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(), + Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(), + Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(), + Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); @@ -156,10 +164,14 @@ fn main() -> Result<()> { | Which::V3Instruct | Which::V31 | Which::V31Instruct + | Which::V32_3b + | Which::V32_3bInstruct | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } - Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], + Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => { + vec![api.get("model.safetensors")?] + } }; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;