Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Tensor Splits #81

Merged
merged 4 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public class ModelOptions : IModelParams
/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>
public nint TensorSplits { get; set; }
public float[] TensorSplits { get; set; }

/// <summary>
/// Grouped-Query Attention
Expand Down
2 changes: 1 addition & 1 deletion LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public interface IModelParams
/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>
nint TensorSplits { get; set; }
float[]? TensorSplits { get; set; }

/// <summary>
/// Grouped-Query Attention
Expand Down
7 changes: 3 additions & 4 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Common
{
/// <summary>
/// The parameters for initializing a LLama model.
/// </summary>
public class ModelParams : IModelParams
public class ModelParams
: IModelParams
{
/// <summary>
/// Model context size (n_ctx)
Expand Down Expand Up @@ -85,7 +84,7 @@ public class ModelParams : IModelParams
/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>
public nint TensorSplits { get; set; }
public float[]? TensorSplits { get; set; }

/// <summary>
/// Grouped-Query Attention
Expand Down
54 changes: 54 additions & 0 deletions LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.IO;
using System;
using System.Buffers;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Extensions
{
internal static class IModelParamsExtensions
{
/// <summary>
/// Convert the given `IModelParams` into a `LLamaContextParams`
/// </summary>
/// <param name="params"></param>
/// <param name="result"></param>
/// <returns></returns>
/// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception>
public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
{
if (!File.Exists(@params.ModelPath))
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");

if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");

result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
result.n_batch = @params.BatchSize;
result.main_gpu = @params.MainGpu;
result.n_gpu_layers = @params.GpuLayerCount;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.use_mmap = @params.UseMemoryLock;
result.use_mlock = @params.UseMemoryLock;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.low_vram = @params.LowVram;
result.n_gqa = @params.GroupedQueryAttention;
result.rms_norm_eps = @params.RmsNormEpsilon;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;

var pin = @params.TensorSplits.AsMemory().Pin();
unsafe
{
result.tensor_split = (nint)pin.Pointer;
}

return pin;
}
}
}
46 changes: 8 additions & 38 deletions LLama/Utils.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;

namespace LLama
{
Expand All @@ -15,46 +15,16 @@
{
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
{
var lparams = NativeApi.llama_context_default_params();

lparams.n_ctx = @params.ContextSize;
lparams.n_batch = @params.BatchSize;
lparams.main_gpu = @params.MainGpu;
lparams.n_gpu_layers = @params.GpuLayerCount;
lparams.seed = @params.Seed;
lparams.f16_kv = @params.UseFp16Memory;
lparams.use_mmap = @params.UseMemoryLock;
lparams.use_mlock = @params.UseMemoryLock;
lparams.logits_all = @params.Perplexity;
lparams.embedding = @params.EmbeddingMode;
lparams.low_vram = @params.LowVram;
lparams.n_gqa = @params.GroupedQueryAttention;
lparams.rms_norm_eps = @params.RmsNormEpsilon;
lparams.rope_freq_base = @params.RopeFrequencyBase;
lparams.rope_freq_scale = @params.RopeFrequencyScale;
lparams.mul_mat_q = @params.MulMatQ;

/*
if (@params.TensorSplits.Length != 1)
using (@params.ToLlamaContextParams(out var lparams))
{
throw new ArgumentException("Currently multi-gpu support is not supported by " +
"both llama.cpp and LLamaSharp.");
}*/
var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);

lparams.tensor_split = @params.TensorSplits;
if (!string.IsNullOrEmpty(@params.LoraAdapter))
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);

if (!File.Exists(@params.ModelPath))
{
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
return ctx;
}

var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);

if (!string.IsNullOrEmpty(@params.LoraAdapter))
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);

return ctx;
}

public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
Expand Down Expand Up @@ -96,15 +66,15 @@
#if NET6_0_OR_GREATER
if(encoding == Encoding.UTF8)
{
return Marshal.PtrToStringUTF8(ptr);

Check warning on line 69 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (linux-debug)

Possible null reference return.

Check warning on line 69 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference return.

Check warning on line 69 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (windows-debug)

Possible null reference return.

Check warning on line 69 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference return.
}
else if(encoding == Encoding.Unicode)
{
return Marshal.PtrToStringUni(ptr);

Check warning on line 73 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (linux-debug)

Possible null reference return.

Check warning on line 73 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference return.

Check warning on line 73 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (windows-debug)

Possible null reference return.

Check warning on line 73 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference return.
}
else
{
return Marshal.PtrToStringAuto(ptr);

Check warning on line 77 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (linux-debug)

Possible null reference return.

Check warning on line 77 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference return.

Check warning on line 77 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (windows-debug)

Possible null reference return.

Check warning on line 77 in LLama/Utils.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference return.
}
#else
byte* tp = (byte*)ptr.ToPointer();
Expand Down
Loading