From 91a7967869ec1d3441e9b819bc164d178d772b96 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 25 Feb 2024 02:12:00 +0000 Subject: [PATCH] `ReadOnlySpan` in ISamplingPipeline (#538) * - Modified ISamplingPipeline to accept `ReadOnlySpan` of logits directly. This moves responsibility to copy the logits into the pipeline. - Added a flag to `BaseSamplingPipeline` indicating if a logit copy is necessary. Skipping it in most cases. * Fixed `RestoreProtectedTokens` not working if logit processing is skipped * - Implemented a new greedy sampling pipeline (always sample most likely token) - Moved `Grammar` into `BaseSamplingPipeline` - Removed "protected tokens" concept from `BaseSamplingPipeline`. Was introducing a lot of incidental complexity. - Implemented newline logit save/restore in `DefaultSamplingPipeline` (only place protected tokens was used) * Implemented pipelines for mirostat v1 and v2 --- .../Examples/BatchedExecutorFork.cs | 3 +- .../Examples/BatchedExecutorRewind.cs | 8 +- LLama/LLamaContext.cs | 3 +- LLama/Sampling/BaseSamplingPipeline.cs | 86 +++------------ LLama/Sampling/DefaultSamplingPipeline.cs | 100 +++++++++++++----- LLama/Sampling/GreedySamplingPipeline.cs | 32 ++++++ LLama/Sampling/ISamplingPipeline.cs | 4 +- LLama/Sampling/Mirostat2SamplingPipeline.cs | 71 +++++++++++++ LLama/Sampling/MirostatSamplingPipeline.cs | 72 +++++++++++++ 9 files changed, 271 insertions(+), 108 deletions(-) create mode 100644 LLama/Sampling/GreedySamplingPipeline.cs create mode 100644 LLama/Sampling/Mirostat2SamplingPipeline.cs create mode 100644 LLama/Sampling/MirostatSamplingPipeline.cs diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs index b42f436bb..861eecc75 100644 --- a/LLama.Examples/Examples/BatchedExecutorFork.cs +++ b/LLama.Examples/Examples/BatchedExecutorFork.cs @@ -91,8 +91,7 @@ public void Sample() // Sample one token var ctx = _conversation.Executor.Context.NativeHandle; - var logitsCopy = _conversation.Sample().ToArray(); - var token = _sampler.Sample(ctx, logitsCopy, Array.Empty()); + var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty()); _sampler.Accept(ctx, token); _decoder.Add(token); diff --git a/LLama.Examples/Examples/BatchedExecutorRewind.cs b/LLama.Examples/Examples/BatchedExecutorRewind.cs index 4a8c3ab29..54c6e5f90 100644 --- a/LLama.Examples/Examples/BatchedExecutorRewind.cs +++ b/LLama.Examples/Examples/BatchedExecutorRewind.cs @@ -88,7 +88,7 @@ public Node(LLamaContext context) public LLamaToken Sample(Conversation conversation) { - var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty()); + var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty()); _tokens.Add(token); return token; } @@ -100,14 +100,12 @@ public void Write(int n_rewind, int depth) for (var i = 0; i < _tokens.Count - n_rewind; i++) decoder.Add(_tokens[i]); - Console.ForegroundColor = ConsoleColor.Green; - Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")); + AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]"); for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++) decoder.Add(_tokens[i]); - Console.ForegroundColor = ConsoleColor.DarkRed; - Console.WriteLine(decoder.Read().ReplaceLineEndings(" ")); + AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]"); } public LLamaToken GetToken(int index) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9075c89fc..44531a9f7 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -147,8 +147,9 @@ public void SaveState(string filename) } /// - /// Get the state data as an opaque handle + /// Get the state data as an opaque handle, which can be loaded later using /// + /// Use if you intend to save this state to disk. /// public State GetState() { diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index b86001aa0..aafb59329 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -1,6 +1,4 @@ using System; -using System.Buffers; -using System.Collections.Generic; using LLama.Native; namespace LLama.Sampling; @@ -11,84 +9,28 @@ namespace LLama.Sampling; public abstract class BaseSamplingPipeline : ISamplingPipeline { - private int _savedLogitsCount; - private (LLamaToken index, float logit)[]? _savedLogits; - - /// - public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - var protectedLogits = GetProtectedTokens(ctx); - _savedLogitsCount = protectedLogits.Count; - _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); - try - { - // Save the values of protected logits - for (var i = 0; i < protectedLogits.Count; i++) - { - var index = protectedLogits[i]; - var value = logits[(int)index]; - _savedLogits[i] = (index, value); - } - - // Process raw logits - ProcessLogits(ctx, logits, lastTokens); - - // Automatically restore saved logit values after processing - RestoreProtectedTokens(logits); - - // Convert logits into token candidates - var candidates = LLamaTokenDataArray.Create(logits); - - // Process token data array - return ProcessTokenDataArray(ctx, candidates, lastTokens); - } - finally - { - ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); - _savedLogits = null; - _savedLogitsCount = 0; - } - } - - /// - public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token); - - #region protected tokens /// - /// Get all of the "protected" tokens that cannot be changed by ProcessLogits + /// Grammar to constrain valid tokens /// - /// - protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); + public SafeLLamaGrammarHandle? Grammar { get; set; } - /// - /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits - /// - /// - protected void RestoreProtectedTokens(Span logits) + /// + public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) { - if (_savedLogits == null) - return; + // Apply processing to raw logit values + logits = ProcessLogits(ctx, logits, lastTokens); - // The array may be bigger than necessary, get a span of the valid bit - var saved = _savedLogits.AsSpan(0, _savedLogitsCount); - - // Restore the values of protected logits - for (var i = 0; i < saved.Length; i++) - logits[(int)saved[i].index] = saved[i].logit; + // Process token data array to select a final token + var candidates = LLamaTokenDataArray.Create(logits); + candidates.ApplyGrammar(ctx, Grammar); + return ProcessTokenDataArray(ctx, candidates, lastTokens); } - /// - /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits - /// - /// - protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) + /// + public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token) { - if (_savedLogits == null || _savedLogits.Length == 0) - return; - - candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); + Grammar?.AcceptToken(ctx, token); } - #endregion /// /// Process the raw logit values @@ -96,7 +38,7 @@ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) /// The context being sampled from /// The logits produced by the model /// A list of tokens recently returned by the model - protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + protected abstract ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens); /// /// Process the LLamaTokenDataArray and select a single token diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 531f34faa..071b5c19d 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -16,15 +16,10 @@ public sealed class DefaultSamplingPipeline /// public Dictionary LogitBias { get; } = new(); - /// - /// Grammar to constrain valid tokens - /// - public SafeLLamaGrammarHandle? Grammar { get; set; } - /// /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 /// - public float RepeatPenalty { get; set; } = 1.1f; + public float RepeatPenalty { get; set; } /// /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
@@ -43,7 +38,7 @@ public float AlphaFrequency _alphaFreq = value; } } - private float _alphaFreq = 0.1f; + private float _alphaFreq; /// /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
@@ -62,7 +57,7 @@ public float AlphaPresence _alphaPresence = value; } } - private float _alphaPresence = 0.1f; + private float _alphaPresence; /// /// Temperature to apply (higher temperature is more "creative") @@ -99,33 +94,46 @@ public float AlphaPresence /// public bool PenalizeNewline { get; set; } = false; - private readonly LLamaToken[] _newlineToken = new LLamaToken[1]; + private float[]? _logits; /// - protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) { - if (PenalizeNewline) - return Array.Empty(); + // Skip work if possible + if (LogitBias.Count == 0) + return logits; - _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); - return _newlineToken; - } + // Create a temporary array to hold logits + if (_logits == null || _logits.Length < logits.Length) + _logits = new float[logits.Length]; - /// - protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { + // Copy logits + logits.CopyTo(_logits); + var mutable = _logits.AsSpan(0, logits.Length); + + // Apply logit bias foreach (var (key, value) in LogitBias) - logits[key] += value; + mutable[key] += value; + + return mutable; } /// protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) { - // Apply penalties to candidates - candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + // Only apply repetition penalty if we really must. Otherwise avoid all this work + if (lastTokens.Length > 0 && (RepeatPenalty != 0 || AlphaFrequency != 0 || AlphaPresence != 0)) + { + // Save the logit value for the newline token + var (nlIndex, nlLogit) = PenalizeNewline ? GetNewlineLogit(ctx, candidates) : (-1, 0); + + // Apply penalties to candidates + candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); - // Restore protected tokens, so they are not affected by repetition penalties - RestoreProtectedTokens(candidates); + // Restore newline token + if (!PenalizeNewline) + SetNewlineLogit(ctx, candidates, nlIndex, nlLogit); + } // Apply the normal llama.cpp pipeline candidates.ApplyGrammar(ctx, Grammar); @@ -135,12 +143,52 @@ protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, candidates.TopP(ctx, TopP); candidates.MinP(ctx, MinP); candidates.Temperature(ctx, Temperature); - var id = candidates.SampleToken(ctx); + return candidates.SampleToken(ctx); + } + + private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + { + var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); + + // Try using the ID as an index + if (candidates.data.Span[(int)nlToken].id == nlToken) + return ((int)nlToken, candidates.data.Span[(int)nlToken].logit); + + // Exhaustive search + var span = candidates.data.Span; + for (var i = 0; i < span.Length; i++) + { + if (span[i].id == nlToken) + return (i, span[i].logit); + } - Grammar?.AcceptToken(ctx, id); - return id; + return (-1, 0); } + private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit) + { + var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); + + // Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order + if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken) + { + candidates.data.Span[indexHint].logit = logit; + return; + } + + // Didn't find it, do an exhaustive search for it + var span = candidates.data.Span; + for (var i = 0; i < candidates.data.Length; i++) + { + if (span[i].id == nlToken) + { + span[i].logit = logit; + return; + } + } + } + + /// public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token) { Grammar?.AcceptToken(ctx, token); diff --git a/LLama/Sampling/GreedySamplingPipeline.cs b/LLama/Sampling/GreedySamplingPipeline.cs new file mode 100644 index 000000000..81b2d3cdd --- /dev/null +++ b/LLama/Sampling/GreedySamplingPipeline.cs @@ -0,0 +1,32 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// A sampling pipeline which always selects the most likely token +/// +public class GreedySamplingPipeline + : BaseSamplingPipeline +{ + /// + protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + { + return logits; + } + + /// + protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenGreedy(ctx); + } + + /// + public override ISamplingPipeline Clone() + { + return new GreedySamplingPipeline + { + Grammar = Grammar?.Clone() + }; + } +} \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index b538d1feb..53c8c7c66 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -19,7 +19,7 @@ public interface ISamplingPipeline /// The logits produced by the model /// A span of tokens recently returned by the model /// - LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens); /// /// Update the pipeline, with knowledge that a particular token was just accepted @@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions /// The logits produced by the model /// A list of tokens recently returned by the model /// - public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan logits, List lastTokens) { #if NET5_0_OR_GREATER var span = CollectionsMarshal.AsSpan(lastTokens); diff --git a/LLama/Sampling/Mirostat2SamplingPipeline.cs b/LLama/Sampling/Mirostat2SamplingPipeline.cs new file mode 100644 index 000000000..dcdc4197b --- /dev/null +++ b/LLama/Sampling/Mirostat2SamplingPipeline.cs @@ -0,0 +1,71 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// A sampling pipeline which uses mirostat (v2) to select tokens +/// +public class Mirostate2SamplingPipeline + : BaseSamplingPipeline +{ + private const float DEFAULT_TAU = 5; + + private float _mu = DEFAULT_TAU * 2; + /// + /// Currently learned mu value + /// + public float Mu => _mu; + + private float _tau = DEFAULT_TAU; + /// + /// target entropy + /// + public float Tau + { + get => _tau; + set + { + _tau = value; + _mu = value * 2; + } + } + + /// + /// learning rate + /// + public float Eta { get; set; } = 0.1f; + + /// + protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + { + return logits; + } + + /// + protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); + } + + /// + public override void Reset() + { + base.Reset(); + + _mu = Tau * 2; + } + + /// + public override ISamplingPipeline Clone() + { + return new Mirostate2SamplingPipeline + { + Grammar = Grammar?.Clone(), + + _mu = _mu, + _tau = _tau, + Eta = Eta + }; + } +} \ No newline at end of file diff --git a/LLama/Sampling/MirostatSamplingPipeline.cs b/LLama/Sampling/MirostatSamplingPipeline.cs new file mode 100644 index 000000000..65d360073 --- /dev/null +++ b/LLama/Sampling/MirostatSamplingPipeline.cs @@ -0,0 +1,72 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// A sampling pipeline which uses mirostat (v1) to select tokens +/// +public class MirostateSamplingPipeline + : BaseSamplingPipeline +{ + private const int MIROSTAT_M = 100; + private const float DEFAULT_TAU = 5; + + private float _mu = DEFAULT_TAU * 2; + /// + /// Currently learned mu value + /// + public float Mu => _mu; + + private float _tau = DEFAULT_TAU; + /// + /// target entropy + /// + public float Tau + { + get => _tau; + set + { + _tau = value; + _mu = value * 2; + } + } + + /// + /// learning rate + /// + public float Eta { get; set; } = 0.1f; + + /// + protected override ReadOnlySpan ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan logits, ReadOnlySpan lastTokens) + { + return logits; + } + + /// + protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu); + } + + /// + public override void Reset() + { + base.Reset(); + + _mu = Tau * 2; + } + + /// + public override ISamplingPipeline Clone() + { + return new MirostateSamplingPipeline + { + Grammar = Grammar?.Clone(), + + _mu = _mu, + _tau = _tau, + Eta = Eta + }; + } +} \ No newline at end of file