Skip to content

Commit

Permalink
Merge pull request #402 from martindevans/safer_model_handle_creation
Browse files Browse the repository at this point in the history
Safer Model Handle Creation
  • Loading branch information
martindevans authored Jan 2, 2024
2 parents a1a8461 + 4e5e994 commit 9b9bcc0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 37 deletions.
27 changes: 17 additions & 10 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,13 @@ public static partial class NativeApi
public static extern bool llama_mlock_supported();

/// <summary>
/// Various functions for loading a ggml llama model.
/// Allocate (almost) all memory needed for the model.
/// Return NULL on failure
/// Load all of the weights of a model into memory.
/// </summary>
/// <param name="path_model"></param>
/// <param name="params"></param>
/// <returns></returns>
/// <returns>The loaded model, or null on failure.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params);
public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Create a new llama_context with the given model.
Expand All @@ -92,12 +90,11 @@ public static partial class NativeApi
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);

/// <summary>
/// not great API - very likely to change.
/// Initialize the llama + ggml backend
/// Call once at the start of the program
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_backend_init(bool numa);
private static extern void llama_backend_init(bool numa);

/// <summary>
/// Frees all allocated memory in the given llama_context
Expand Down Expand Up @@ -510,10 +507,20 @@ public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
/// <param name="model"></param>
/// <param name="llamaToken"></param>
/// <param name="buffer">buffer to write string into</param>
/// <param name="length">size of the buffer</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span<byte> buffer)
{
unsafe
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length);
}

/// <summary>
/// Convert text into tokens
Expand Down
37 changes: 10 additions & 27 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,33 @@ public sealed class SafeLlamaModelHandle
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount { get; }
public int VocabCount => NativeApi.llama_n_vocab(this);

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize { get; }
public int ContextSize => NativeApi.llama_n_ctx_train(this);

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize { get; }
public int EmbeddingSize => NativeApi.llama_n_embd(this);

/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes { get; }
public ulong SizeInBytes => NativeApi.llama_model_size(this);

/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount { get; }
public ulong ParameterCount => NativeApi.llama_model_n_params(this);

/// <summary>
/// Get the number of metadata key/value pairs
/// </summary>
/// <returns></returns>
public int MetadataCount { get; }

internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
VocabCount = NativeApi.llama_n_vocab(this);
ContextSize = NativeApi.llama_n_ctx_train(this);
EmbeddingSize = NativeApi.llama_n_embd(this);
SizeInBytes = NativeApi.llama_model_size(this);
ParameterCount = NativeApi.llama_model_n_params(this);
MetadataCount = NativeApi.llama_model_meta_count(this);
}
public int MetadataCount => NativeApi.llama_model_meta_count(this);

/// <inheritdoc />
protected override bool ReleaseHandle()
Expand All @@ -73,10 +62,10 @@ protected override bool ReleaseHandle()
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero)
if (model_ptr == null)
throw new RuntimeError($"Failed to load model {modelPath}.");

return new SafeLlamaModelHandle(model_ptr);
return model_ptr;
}

#region LoRA
Expand Down Expand Up @@ -114,14 +103,8 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int llama_token, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
return Math.Abs(length);
}
}
var length = NativeApi.llama_token_to_piece(this, llama_token, dest);
return Math.Abs(length);
}

/// <summary>
Expand Down

0 comments on commit 9b9bcc0

Please sign in to comment.