From 4e5e994ddac2a838c6605e896791ead6261b4402 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 18:18:48 +0000 Subject: [PATCH] - directly returning a SafeLlamaModelHandle, instead of an IntPtr which is wrapped in a handle. - made `llama_backend_init` private. This is automatically called, there is no way it can correctly be used externally. - made `llama_token_to_piece` safe (Span instead of pointer) --- LLama/Native/NativeApi.cs | 27 ++++++++++++-------- LLama/Native/SafeLlamaModelHandle.cs | 37 ++++++++-------------------- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 38ba1bc6d..2a34820d6 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -71,15 +71,13 @@ public static partial class NativeApi public static extern bool llama_mlock_supported(); /// - /// Various functions for loading a ggml llama model. - /// Allocate (almost) all memory needed for the model. - /// Return NULL on failure + /// Load all of the weights of a model into memory. /// /// /// - /// + /// The loaded model, or null on failure. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params); + public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params); /// /// Create a new llama_context with the given model. @@ -92,12 +90,11 @@ public static partial class NativeApi public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); /// - /// not great API - very likely to change. /// Initialize the llama + ggml backend /// Call once at the start of the program /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_backend_init(bool numa); + private static extern void llama_backend_init(bool numa); /// /// Frees all allocated memory in the given llama_context @@ -510,10 +507,20 @@ public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, /// /// /// buffer to write string into - /// size of the buffer /// The length written, or if the buffer is too small a negative that indicates the length required - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); + public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span buffer) + { + unsafe + { + fixed (byte* bufferPtr = buffer) + { + return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length); + } + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")] + static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length); + } /// /// Convert text into tokens diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index fa71af3f1..291cfbc20 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -16,44 +16,33 @@ public sealed class SafeLlamaModelHandle /// /// Total number of tokens in vocabulary of this model /// - public int VocabCount { get; } + public int VocabCount => NativeApi.llama_n_vocab(this); /// /// Total number of tokens in the context /// - public int ContextSize { get; } + public int ContextSize => NativeApi.llama_n_ctx_train(this); /// /// Dimension of embedding vectors /// - public int EmbeddingSize { get; } + public int EmbeddingSize => NativeApi.llama_n_embd(this); /// /// Get the size of this model in bytes /// - public ulong SizeInBytes { get; } + public ulong SizeInBytes => NativeApi.llama_model_size(this); /// /// Get the number of parameters in this model /// - public ulong ParameterCount { get; } + public ulong ParameterCount => NativeApi.llama_model_n_params(this); /// /// Get the number of metadata key/value pairs /// /// - public int MetadataCount { get; } - - internal SafeLlamaModelHandle(IntPtr handle) - : base(handle) - { - VocabCount = NativeApi.llama_n_vocab(this); - ContextSize = NativeApi.llama_n_ctx_train(this); - EmbeddingSize = NativeApi.llama_n_embd(this); - SizeInBytes = NativeApi.llama_model_size(this); - ParameterCount = NativeApi.llama_model_n_params(this); - MetadataCount = NativeApi.llama_model_meta_count(this); - } + public int MetadataCount => NativeApi.llama_model_meta_count(this); /// protected override bool ReleaseHandle() @@ -73,10 +62,10 @@ protected override bool ReleaseHandle() public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) { var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); - if (model_ptr == IntPtr.Zero) + if (model_ptr == null) throw new RuntimeError($"Failed to load model {modelPath}."); - return new SafeLlamaModelHandle(model_ptr); + return model_ptr; } #region LoRA @@ -114,14 +103,8 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null /// The size of this token. **nothing will be written** if this is larger than `dest` public int TokenToSpan(int llama_token, Span dest) { - unsafe - { - fixed (byte* destPtr = dest) - { - var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length); - return Math.Abs(length); - } - } + var length = NativeApi.llama_token_to_piece(this, llama_token, dest); + return Math.Abs(length); } ///