Skip to content

Commit

Permalink
Replaced nint with float[]? in Model params, which is much more u…
Browse files Browse the repository at this point in the history
…ser friendly!
  • Loading branch information
martindevans committed Aug 6, 2023
1 parent 18b1df6 commit 685eb3b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
2 changes: 1 addition & 1 deletion LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public interface IModelParams
/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>
nint TensorSplits { get; set; }
float[]? TensorSplits { get; set; }

/// <summary>
/// Grouped-Query Attention
Expand Down
7 changes: 3 additions & 4 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Common
{
/// <summary>
/// The parameters for initializing a LLama model.
/// </summary>
public class ModelParams : IModelParams
public class ModelParams
: IModelParams
{
/// <summary>
/// Model context size (n_ctx)
Expand Down Expand Up @@ -85,7 +84,7 @@ public class ModelParams : IModelParams
/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>
public nint TensorSplits { get; set; }
public float[]? TensorSplits { get; set; }

/// <summary>
/// Grouped-Query Attention
Expand Down
35 changes: 17 additions & 18 deletions LLama/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ public static class Utils
{
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
{
var lparams = NativeApi.llama_context_default_params();
if (!File.Exists(@params.ModelPath))
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");

if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");

var lparams = NativeApi.llama_context_default_params();
lparams.n_ctx = @params.ContextSize;
lparams.n_batch = @params.BatchSize;
lparams.main_gpu = @params.MainGpu;
Expand All @@ -34,27 +39,21 @@ public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParam
lparams.rope_freq_scale = @params.RopeFrequencyScale;
lparams.mul_mat_q = @params.MulMatQ;

/*
if (@params.TensorSplits.Length != 1)
{
throw new ArgumentException("Currently multi-gpu support is not supported by " +
"both llama.cpp and LLamaSharp.");
}*/

lparams.tensor_split = @params.TensorSplits;

if (!File.Exists(@params.ModelPath))
unsafe
{
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
}
fixed (float* splits = @params.TensorSplits)
{
lparams.tensor_split = (nint)splits;

var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);
var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);

if (!string.IsNullOrEmpty(@params.LoraAdapter))
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
if (!string.IsNullOrEmpty(@params.LoraAdapter))
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);

return ctx;
return ctx;
}
}
}

public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
Expand Down

0 comments on commit 685eb3b

Please sign in to comment.