diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
new file mode 100644
index 000000000..0cd24cff9
--- /dev/null
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -0,0 +1,54 @@
+using System.IO;
+using System;
+using System.Buffers;
+using LLama.Abstractions;
+using LLama.Native;
+
+namespace LLama.Extensions
+{
+ internal static class IModelParamsExtensions
+ {
+ ///
+ /// Convert the given `IModelParams` into a `LLamaContextParams`
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
+ {
+ if (!File.Exists(@params.ModelPath))
+ throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
+
+ if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
+ throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");
+
+ result = NativeApi.llama_context_default_params();
+ result.n_ctx = @params.ContextSize;
+ result.n_batch = @params.BatchSize;
+ result.main_gpu = @params.MainGpu;
+ result.n_gpu_layers = @params.GpuLayerCount;
+ result.seed = @params.Seed;
+ result.f16_kv = @params.UseFp16Memory;
+ result.use_mmap = @params.UseMemoryLock;
+ result.use_mlock = @params.UseMemoryLock;
+ result.logits_all = @params.Perplexity;
+ result.embedding = @params.EmbeddingMode;
+ result.low_vram = @params.LowVram;
+ result.n_gqa = @params.GroupedQueryAttention;
+ result.rms_norm_eps = @params.RmsNormEpsilon;
+ result.rope_freq_base = @params.RopeFrequencyBase;
+ result.rope_freq_scale = @params.RopeFrequencyScale;
+ result.mul_mat_q = @params.MulMatQ;
+
+ var pin = @params.TensorSplits.AsMemory().Pin();
+ unsafe
+ {
+ result.tensor_split = (nint)pin.Pointer;
+ }
+
+ return pin;
+ }
+ }
+}
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index 8ab5e711f..391a5cc14 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -1,12 +1,12 @@
using LLama.Abstractions;
-using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
-using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
+using LLama.Exceptions;
+using LLama.Extensions;
namespace LLama
{
@@ -15,43 +15,16 @@ public static class Utils
{
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
{
- if (!File.Exists(@params.ModelPath))
- throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
-
- if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
- throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");
-
- var lparams = NativeApi.llama_context_default_params();
- lparams.n_ctx = @params.ContextSize;
- lparams.n_batch = @params.BatchSize;
- lparams.main_gpu = @params.MainGpu;
- lparams.n_gpu_layers = @params.GpuLayerCount;
- lparams.seed = @params.Seed;
- lparams.f16_kv = @params.UseFp16Memory;
- lparams.use_mmap = @params.UseMemoryLock;
- lparams.use_mlock = @params.UseMemoryLock;
- lparams.logits_all = @params.Perplexity;
- lparams.embedding = @params.EmbeddingMode;
- lparams.low_vram = @params.LowVram;
- lparams.n_gqa = @params.GroupedQueryAttention;
- lparams.rms_norm_eps = @params.RmsNormEpsilon;
- lparams.rope_freq_base = @params.RopeFrequencyBase;
- lparams.rope_freq_scale = @params.RopeFrequencyScale;
- lparams.mul_mat_q = @params.MulMatQ;
-
- using var pin = @params.TensorSplits.AsMemory().Pin();
- unsafe
+ using (@params.ToLlamaContextParams(out var lparams))
{
- lparams.tensor_split = (nint)pin.Pointer;
- }
+ var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
+ var ctx = SafeLLamaContextHandle.Create(model, lparams);
- var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
- var ctx = SafeLLamaContextHandle.Create(model, lparams);
+ if (!string.IsNullOrEmpty(@params.LoraAdapter))
+ model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
- if (!string.IsNullOrEmpty(@params.LoraAdapter))
- model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
-
- return ctx;
+ return ctx;
+ }
}
public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)