From 377ebf3664f0f3f47462bda05171d38d0191808c Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 27 Apr 2024 23:31:07 +0100 Subject: [PATCH] - Added `LoadFromFileAsync` method for `LLavaWeights` - Fixed checking for invalid handles in `clip_model_load` --- .../Examples/LlavaInteractiveModeExecute.cs | 4 ++-- LLama/LLavaWeights.cs | 19 ++++++++++++++++--- LLama/Native/SafeLlavaModelHandle.cs | 7 +++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs index 170bab0c7..89b4ae41a 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs @@ -24,9 +24,9 @@ public static async Task Run() using var context = model.CreateContext(parameters); // Llava Init - using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); + using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj); - var ex = new InteractiveExecutor(context, clipModel ); + var ex = new InteractiveExecutor(context, clipModel); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs index ff8959a50..9594dcdbb 100644 --- a/LLama/LLavaWeights.cs +++ b/LLama/LLavaWeights.cs @@ -1,5 +1,7 @@ using System; +using System.Threading; +using System.Threading.Tasks; using LLama.Native; namespace LLama; @@ -13,9 +15,9 @@ public sealed class LLavaWeights : IDisposable /// The native handle, which is used in the native APIs /// /// Be careful how you use this! - public SafeLlavaModelHandle NativeHandle { get; } - - internal LLavaWeights(SafeLlavaModelHandle weights) + public SafeLlavaModelHandle NativeHandle { get; } + + private LLavaWeights(SafeLlavaModelHandle weights) { NativeHandle = weights; } @@ -31,6 +33,17 @@ public static LLavaWeights LoadFromFile(string mmProject) return new LLavaWeights(weights); } + /// + /// Load weights into memory + /// + /// path to the "mmproj" model file + /// + /// + public static Task LoadFromFileAsync(string mmProject, CancellationToken token = default) + { + return Task.Run(() => LoadFromFile(mmProject), token); + } + /// /// Create the Image Embeddings from the bytes of an image. /// diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs index 2edb7aee7..fd898b536 100644 --- a/LLama/Native/SafeLlavaModelHandle.cs +++ b/LLama/Native/SafeLlavaModelHandle.cs @@ -39,8 +39,11 @@ public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity if (!fs.CanRead) throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable"); - return clip_model_load(modelPath, verbosity) - ?? throw new LoadWeightsFailedException(modelPath); + var handle = clip_model_load(modelPath, verbosity); + if (handle.IsInvalid) + throw new LoadWeightsFailedException(modelPath); + + return handle; } ///