Skip to content

Commit

Permalink
- Added LoadFromFileAsync method for LLavaWeights
Browse files Browse the repository at this point in the history
 - Fixed checking for invalid handles in `clip_model_load`
  • Loading branch information
martindevans committed Apr 27, 2024
1 parent 84bb5a3 commit 377ebf3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public static async Task Run()
using var context = model.CreateContext(parameters);

// Llava Init
using var clipModel = LLavaWeights.LoadFromFile(multiModalProj);
using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj);

var ex = new InteractiveExecutor(context, clipModel );
var ex = new InteractiveExecutor(context, clipModel);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
Expand Down
19 changes: 16 additions & 3 deletions LLama/LLavaWeights.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

using System;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;

namespace LLama;
Expand All @@ -13,9 +15,9 @@ public sealed class LLavaWeights : IDisposable
/// The native handle, which is used in the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLlavaModelHandle NativeHandle { get; }
internal LLavaWeights(SafeLlavaModelHandle weights)
public SafeLlavaModelHandle NativeHandle { get; }

private LLavaWeights(SafeLlavaModelHandle weights)
{
NativeHandle = weights;
}
Expand All @@ -31,6 +33,17 @@ public static LLavaWeights LoadFromFile(string mmProject)
return new LLavaWeights(weights);
}

/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="mmProject">path to the "mmproj" model file</param>
/// <param name="token"></param>
/// <returns></returns>
public static Task<LLavaWeights> LoadFromFileAsync(string mmProject, CancellationToken token = default)
{
return Task.Run(() => LoadFromFile(mmProject), token);
}

/// <summary>
/// Create the Image Embeddings from the bytes of an image.
/// </summary>
Expand Down
7 changes: 5 additions & 2 deletions LLama/Native/SafeLlavaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity
if (!fs.CanRead)
throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable");

return clip_model_load(modelPath, verbosity)
?? throw new LoadWeightsFailedException(modelPath);
var handle = clip_model_load(modelPath, verbosity);
if (handle.IsInvalid)
throw new LoadWeightsFailedException(modelPath);

return handle;
}

/// <summary>
Expand Down

0 comments on commit 377ebf3

Please sign in to comment.