diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 38ba1bc6d..2a34820d6 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -71,15 +71,13 @@ public static partial class NativeApi
public static extern bool llama_mlock_supported();
///
- /// Various functions for loading a ggml llama model.
- /// Allocate (almost) all memory needed for the model.
- /// Return NULL on failure
+ /// Load all of the weights of a model into memory.
///
///
///
- ///
+ /// The loaded model, or null on failure.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params);
+ public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);
///
/// Create a new llama_context with the given model.
@@ -92,12 +90,11 @@ public static partial class NativeApi
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);
///
- /// not great API - very likely to change.
/// Initialize the llama + ggml backend
/// Call once at the start of the program
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_backend_init(bool numa);
+ private static extern void llama_backend_init(bool numa);
///
/// Frees all allocated memory in the given llama_context
@@ -510,10 +507,20 @@ public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
///
///
/// buffer to write string into
- /// size of the buffer
/// The length written, or if the buffer is too small a negative that indicates the length required
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
+ public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span buffer)
+ {
+ unsafe
+ {
+ fixed (byte* bufferPtr = buffer)
+ {
+ return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length);
+ }
+ }
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
+ static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length);
+ }
///
/// Convert text into tokens
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index fa71af3f1..291cfbc20 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -16,44 +16,33 @@ public sealed class SafeLlamaModelHandle
///
/// Total number of tokens in vocabulary of this model
///
- public int VocabCount { get; }
+ public int VocabCount => NativeApi.llama_n_vocab(this);
///
/// Total number of tokens in the context
///
- public int ContextSize { get; }
+ public int ContextSize => NativeApi.llama_n_ctx_train(this);
///
/// Dimension of embedding vectors
///
- public int EmbeddingSize { get; }
+ public int EmbeddingSize => NativeApi.llama_n_embd(this);
///
/// Get the size of this model in bytes
///
- public ulong SizeInBytes { get; }
+ public ulong SizeInBytes => NativeApi.llama_model_size(this);
///
/// Get the number of parameters in this model
///
- public ulong ParameterCount { get; }
+ public ulong ParameterCount => NativeApi.llama_model_n_params(this);
///
/// Get the number of metadata key/value pairs
///
///
- public int MetadataCount { get; }
-
- internal SafeLlamaModelHandle(IntPtr handle)
- : base(handle)
- {
- VocabCount = NativeApi.llama_n_vocab(this);
- ContextSize = NativeApi.llama_n_ctx_train(this);
- EmbeddingSize = NativeApi.llama_n_embd(this);
- SizeInBytes = NativeApi.llama_model_size(this);
- ParameterCount = NativeApi.llama_model_n_params(this);
- MetadataCount = NativeApi.llama_model_meta_count(this);
- }
+ public int MetadataCount => NativeApi.llama_model_meta_count(this);
///
protected override bool ReleaseHandle()
@@ -73,10 +62,10 @@ protected override bool ReleaseHandle()
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
- if (model_ptr == IntPtr.Zero)
+ if (model_ptr == null)
throw new RuntimeError($"Failed to load model {modelPath}.");
- return new SafeLlamaModelHandle(model_ptr);
+ return model_ptr;
}
#region LoRA
@@ -114,14 +103,8 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
/// The size of this token. **nothing will be written** if this is larger than `dest`
public int TokenToSpan(int llama_token, Span dest)
{
- unsafe
- {
- fixed (byte* destPtr = dest)
- {
- var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
- return Math.Abs(length);
- }
- }
+ var length = NativeApi.llama_token_to_piece(this, llama_token, dest);
+ return Math.Abs(length);
}
///