Skip to content

Commit

Permalink
March Binary Update (#565)
Browse files Browse the repository at this point in the history
* Updated binaries to llama.cpp `3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6` (build run: https://github.com/SciSharp/LLamaSharp/actions/runs/8118890586)

* Added abort callback

* Added properties to get/set thread count on `LLamaContext`

* Fixed LLamaLogLevel numbering
  • Loading branch information
martindevans authored Mar 6, 2024
1 parent 6f03d5a commit a8ba9f0
Show file tree
Hide file tree
Showing 49 changed files with 2,773 additions and 915 deletions.
6 changes: 6 additions & 0 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,11 @@ public class ModelOptions

/// <inheritdoc />
public bool VocabOnly { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }

/// <inheritdoc />
public bool DoPooling { get; set; }
}
}
10 changes: 10 additions & 0 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,14 @@ public interface IContextParams
/// Whether to disable offloading the KQV cache to the GPU
/// </summary>
bool NoKqvOffload { get; }

/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
/// </summary>
float DefragThreshold { get; }

/// <summary>
/// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
/// </summary>
bool DoPooling { get; }
}
24 changes: 12 additions & 12 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ public MetadataOverride(string key, int value)
{
Key = key;
_valueInt = value;
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
Type = LLamaModelKvOverrideType.Int;
}

/// <summary>
Expand All @@ -263,7 +263,7 @@ public MetadataOverride(string key, float value)
{
Key = key;
_valueFloat = value;
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
Type = LLamaModelKvOverrideType.Float;
}

/// <summary>
Expand All @@ -275,20 +275,20 @@ public MetadataOverride(string key, bool value)
{
Key = key;
_valueBool = value;
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
Type = LLamaModelKvOverrideType.Bool;
}

internal void WriteValue(ref LLamaModelMetadataOverride dest)
{
switch (Type)
{
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
case LLamaModelKvOverrideType.Int:
dest.IntValue = _valueInt;
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
case LLamaModelKvOverrideType.Float:
dest.FloatValue = _valueFloat;
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
case LLamaModelKvOverrideType.Bool:
dest.BoolValue = _valueBool ? -1L : 0;
break;
default:
Expand All @@ -300,13 +300,13 @@ internal void WriteValue(Utf8JsonWriter writer)
{
switch (Type)
{
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
case LLamaModelKvOverrideType.Int:
writer.WriteNumberValue(_valueInt);
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
case LLamaModelKvOverrideType.Float:
writer.WriteNumberValue(_valueFloat);
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
case LLamaModelKvOverrideType.Bool:
writer.WriteBooleanValue(_valueBool);
break;
default:
Expand All @@ -328,9 +328,9 @@ public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConv

return ((LLamaModelKvOverrideType)ktv.Type) switch
{
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
_ => throw new JsonException(),
};
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ public void Remove(LLamaPos start, int count)
/// <param name="start">Start position (inclusive)</param>
/// <param name="end">End position (exclusive)</param>
/// <param name="delta">Amount to add on to each token position</param>
public void Shift(LLamaPos start, LLamaPos end, int delta)
public void Add(LLamaPos start, LLamaPos end, int delta)
{
_conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
_conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta);
}
#endregion

Expand Down
2 changes: 1 addition & 1 deletion LLama/Batched/ConversationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep
kv.Remove(keep, count);
// Shift the C's
kv.Shift(keep + count, end, -count);
kv.Add(keep + count, end, -count);
// Update total count
return end.Value - count;
Expand Down
6 changes: 6 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ public record ModelParams
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }

/// <inheritdoc />
public bool DoPooling { get; set; }

/// <inheritdoc />
public bool VocabOnly { get; set; }

Expand Down
5 changes: 4 additions & 1 deletion LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.yarn_beta_fast = @params.YarnBetaFast ?? 32f;
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;

result.defrag_threshold = @params.DefragThreshold;

result.cb_eval = IntPtr.Zero;
result.cb_eval_user_data = IntPtr.Zero;

result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = !@params.NoKqvOffload;
result.do_pooling = @params.DoPooling;

result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
Expand Down
33 changes: 33 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,35 @@ public sealed class LLamaContext
/// </summary>
public Encoding Encoding { get; }

private uint _generationThreads;
private uint _batchThreads;

/// <summary>
/// Get or set the number of threads to use for generation
/// </summary>
public uint GenerationThreads
{
get => _generationThreads;
set
{
_generationThreads = value;
NativeHandle.SetThreads(_generationThreads, _batchThreads);
}
}

/// <summary>
/// Get or set the number of threads to use for batch processing
/// </summary>
public uint BatchThreads
{
get => _batchThreads;
set
{
_batchThreads = value;
NativeHandle.SetThreads(_generationThreads, _batchThreads);
}
}

/// <summary>
/// Create a new LLamaContext for the given LLamaWeights
/// </summary>
Expand All @@ -75,6 +104,10 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger

@params.ToLlamaContextParams(out var lparams);
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);

// It's not possible to get these values from llama.cpp, store a copy of them here.
_generationThreads = lparams.n_threads;
_batchThreads = lparams.n_threads_batch;
}

/// <summary>
Expand Down
14 changes: 12 additions & 2 deletions LLama/LLamaQuantizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static bool Quantize(string srcFileName, string dstFilename, string ftype
private static bool ValidateFtype(LLamaFtype ftype)
{
// Validation copies from here:
// https://github.com/ggerganov/llama.cpp/blob/d71ac90985854b0905e1abba778e407e17f9f887/llama.cpp#L9613
// https://github.com/ggerganov/llama.cpp/blob/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6/llama.cpp#L10965

switch (ftype)
{
Expand All @@ -74,7 +74,7 @@ private static bool ValidateFtype(LLamaFtype ftype)
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_K_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L:
Expand All @@ -89,8 +89,18 @@ private static bool ValidateFtype(LLamaFtype ftype)

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XXS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_M:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_M:
return true;

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
var n_discard = n_left / 2;

NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;
}
Expand Down
11 changes: 11 additions & 0 deletions LLama/Native/LLamaChatMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_chat_message</remarks>
public unsafe struct LLamaChatMessage
{
public byte* role;
public byte* content;
}
35 changes: 25 additions & 10 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,42 @@ public struct LLamaContextParams
/// RoPE base frequency, 0 = from model
/// </summary>
public float rope_freq_base;

/// <summary>
/// RoPE frequency scaling factor, 0 = from model
/// </summary>
public float rope_freq_scale;
public float rope_freq_scale;

/// <summary>
/// YaRN extrapolation mix factor, negative = from model
/// </summary>
public float yarn_ext_factor;
public float yarn_ext_factor;

/// <summary>
/// YaRN magnitude scaling factor
/// </summary>
public float yarn_attn_factor;
public float yarn_attn_factor;

/// <summary>
/// YaRN low correction dim
/// </summary>
public float yarn_beta_fast;
public float yarn_beta_fast;

/// <summary>
/// YaRN high correction dim
/// </summary>
public float yarn_beta_slow;
public float yarn_beta_slow;

/// <summary>
/// YaRN original context size
/// </summary>
public uint yarn_orig_ctx;

/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
/// </summary>
public float defrag_threshold;

/// <summary>
/// ggml_backend_sched_eval_callback
/// </summary>
Expand All @@ -97,11 +107,6 @@ public struct LLamaContextParams
/// </summary>
public GGMLType type_v;

/// <summary>
/// Deprecated!
/// </summary>
private sbyte _mul_mat_q;

/// <summary>
/// Deprecated!
/// </summary>
Expand All @@ -126,6 +131,16 @@ public bool offload_kqv
set => _offload_kqv = Convert.ToSByte(value);
}
private sbyte _offload_kqv;

/// <summary>
/// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
/// </summary>
public bool do_pooling
{
readonly get => Convert.ToBoolean(_do_pooling);
set => _do_pooling = Convert.ToSByte(value);
}
private sbyte _do_pooling;
}
}

37 changes: 36 additions & 1 deletion LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,48 @@ public enum LLamaFtype
/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22,
LLAMA_FTYPE_MOSTLY_IQ3_K_XS = 22,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ1_S = 24,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ4_NL = 25,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ3_S = 26,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ3_M = 27,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ2_S = 28,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ2_M = 29,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
Loading

0 comments on commit a8ba9f0

Please sign in to comment.