Skip to content

Commit

Permalink
Pulled conversion of a IModelParams into a LLamaContextParams out…
Browse files Browse the repository at this point in the history
… into an extension method which can be used in other places.
  • Loading branch information
martindevans committed Aug 7, 2023
1 parent f1111a9 commit f249937
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 36 deletions.
54 changes: 54 additions & 0 deletions LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.IO;
using System;
using System.Buffers;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Extensions
{
internal static class IModelParamsExtensions
{
/// <summary>
/// Convert the given `IModelParams` into a `LLamaContextParams`
/// </summary>
/// <param name="params"></param>
/// <param name="result"></param>
/// <returns></returns>
/// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception>
public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
{
if (!File.Exists(@params.ModelPath))
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");

if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");

result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
result.n_batch = @params.BatchSize;
result.main_gpu = @params.MainGpu;
result.n_gpu_layers = @params.GpuLayerCount;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.use_mmap = @params.UseMemoryLock;
result.use_mlock = @params.UseMemoryLock;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.low_vram = @params.LowVram;
result.n_gqa = @params.GroupedQueryAttention;
result.rms_norm_eps = @params.RmsNormEpsilon;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;

var pin = @params.TensorSplits.AsMemory().Pin();
unsafe
{
result.tensor_split = (nint)pin.Pointer;
}

return pin;
}
}
}
45 changes: 9 additions & 36 deletions LLama/Utils.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;

namespace LLama
{
Expand All @@ -15,43 +15,16 @@ public static class Utils
{
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
{
if (!File.Exists(@params.ModelPath))
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");

if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");

var lparams = NativeApi.llama_context_default_params();
lparams.n_ctx = @params.ContextSize;
lparams.n_batch = @params.BatchSize;
lparams.main_gpu = @params.MainGpu;
lparams.n_gpu_layers = @params.GpuLayerCount;
lparams.seed = @params.Seed;
lparams.f16_kv = @params.UseFp16Memory;
lparams.use_mmap = @params.UseMemoryLock;
lparams.use_mlock = @params.UseMemoryLock;
lparams.logits_all = @params.Perplexity;
lparams.embedding = @params.EmbeddingMode;
lparams.low_vram = @params.LowVram;
lparams.n_gqa = @params.GroupedQueryAttention;
lparams.rms_norm_eps = @params.RmsNormEpsilon;
lparams.rope_freq_base = @params.RopeFrequencyBase;
lparams.rope_freq_scale = @params.RopeFrequencyScale;
lparams.mul_mat_q = @params.MulMatQ;

using var pin = @params.TensorSplits.AsMemory().Pin();
unsafe
using (@params.ToLlamaContextParams(out var lparams))
{
lparams.tensor_split = (nint)pin.Pointer;
}
var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);

var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);
if (!string.IsNullOrEmpty(@params.LoraAdapter))
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);

if (!string.IsNullOrEmpty(@params.LoraAdapter))
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);

return ctx;
return ctx;
}
}

public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
Expand Down

0 comments on commit f249937

Please sign in to comment.