diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs new file mode 100644 index 000000000..0cd24cff9 --- /dev/null +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -0,0 +1,54 @@ +using System.IO; +using System; +using System.Buffers; +using LLama.Abstractions; +using LLama.Native; + +namespace LLama.Extensions +{ + internal static class IModelParamsExtensions + { + /// + /// Convert the given `IModelParams` into a `LLamaContextParams` + /// + /// + /// + /// + /// + /// + public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result) + { + if (!File.Exists(@params.ModelPath)) + throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); + + if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) + throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); + + result = NativeApi.llama_context_default_params(); + result.n_ctx = @params.ContextSize; + result.n_batch = @params.BatchSize; + result.main_gpu = @params.MainGpu; + result.n_gpu_layers = @params.GpuLayerCount; + result.seed = @params.Seed; + result.f16_kv = @params.UseFp16Memory; + result.use_mmap = @params.UseMemoryLock; + result.use_mlock = @params.UseMemoryLock; + result.logits_all = @params.Perplexity; + result.embedding = @params.EmbeddingMode; + result.low_vram = @params.LowVram; + result.n_gqa = @params.GroupedQueryAttention; + result.rms_norm_eps = @params.RmsNormEpsilon; + result.rope_freq_base = @params.RopeFrequencyBase; + result.rope_freq_scale = @params.RopeFrequencyScale; + result.mul_mat_q = @params.MulMatQ; + + var pin = @params.TensorSplits.AsMemory().Pin(); + unsafe + { + result.tensor_split = (nint)pin.Pointer; + } + + return pin; + } + } +} diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 8ab5e711f..391a5cc14 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -1,12 +1,12 @@ using LLama.Abstractions; -using LLama.Exceptions; using LLama.Native; using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; +using LLama.Exceptions; +using LLama.Extensions; namespace LLama { @@ -15,43 +15,16 @@ public static class Utils { public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) { - if (!File.Exists(@params.ModelPath)) - throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); - - if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) - throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); - - var lparams = NativeApi.llama_context_default_params(); - lparams.n_ctx = @params.ContextSize; - lparams.n_batch = @params.BatchSize; - lparams.main_gpu = @params.MainGpu; - lparams.n_gpu_layers = @params.GpuLayerCount; - lparams.seed = @params.Seed; - lparams.f16_kv = @params.UseFp16Memory; - lparams.use_mmap = @params.UseMemoryLock; - lparams.use_mlock = @params.UseMemoryLock; - lparams.logits_all = @params.Perplexity; - lparams.embedding = @params.EmbeddingMode; - lparams.low_vram = @params.LowVram; - lparams.n_gqa = @params.GroupedQueryAttention; - lparams.rms_norm_eps = @params.RmsNormEpsilon; - lparams.rope_freq_base = @params.RopeFrequencyBase; - lparams.rope_freq_scale = @params.RopeFrequencyScale; - lparams.mul_mat_q = @params.MulMatQ; - - using var pin = @params.TensorSplits.AsMemory().Pin(); - unsafe + using (@params.ToLlamaContextParams(out var lparams)) { - lparams.tensor_split = (nint)pin.Pointer; - } + var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + var ctx = SafeLLamaContextHandle.Create(model, lparams); - var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - var ctx = SafeLLamaContextHandle.Create(model, lparams); + if (!string.IsNullOrEmpty(@params.LoraAdapter)) + model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); - - return ctx; + return ctx; + } } public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)