Skip to content

Commit

Permalink
ReadOnlySpan<float> in ISamplingPipeline (#538)
Browse files Browse the repository at this point in the history
* - Modified ISamplingPipeline to accept `ReadOnlySpan<float>` of logits directly. This moves responsibility to copy the logits into the pipeline.
 - Added a flag to `BaseSamplingPipeline` indicating if a logit copy is necessary. Skipping it in most cases.

* Fixed `RestoreProtectedTokens` not working if logit processing is skipped

* - Implemented a new greedy sampling pipeline (always sample most likely token)
 - Moved `Grammar` into `BaseSamplingPipeline`
 - Removed "protected tokens" concept from `BaseSamplingPipeline`. Was introducing a lot of incidental complexity.
 - Implemented newline logit save/restore in `DefaultSamplingPipeline` (only place protected tokens was used)

* Implemented pipelines for mirostat v1 and v2
  • Loading branch information
martindevans authored Feb 25, 2024
1 parent 74a3918 commit 91a7967
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 108 deletions.
3 changes: 1 addition & 2 deletions LLama.Examples/Examples/BatchedExecutorFork.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ public void Sample()

// Sample one token
var ctx = _conversation.Executor.Context.NativeHandle;
var logitsCopy = _conversation.Sample().ToArray();
var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
_sampler.Accept(ctx, token);
_decoder.Add(token);

Expand Down
8 changes: 3 additions & 5 deletions LLama.Examples/Examples/BatchedExecutorRewind.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public Node(LLamaContext context)

public LLamaToken Sample(Conversation conversation)
{
var token = Sampler.Sample(_context.NativeHandle, conversation.Sample().ToArray(), Array.Empty<LLamaToken>());
var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
_tokens.Add(token);
return token;
}
Expand All @@ -100,14 +100,12 @@ public void Write(int n_rewind, int depth)
for (var i = 0; i < _tokens.Count - n_rewind; i++)
decoder.Add(_tokens[i]);

Console.ForegroundColor = ConsoleColor.Green;
Console.Write(new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" "));
AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");

for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
decoder.Add(_tokens[i]);

Console.ForegroundColor = ConsoleColor.DarkRed;
Console.WriteLine(decoder.Read().ReplaceLineEndings(" "));
AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");
}

public LLamaToken GetToken(int index)
Expand Down
3 changes: 2 additions & 1 deletion LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ public void SaveState(string filename)
}

/// <summary>
/// Get the state data as an opaque handle
/// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/>
/// </summary>
/// <remarks>Use <see cref="SaveState"/> if you intend to save this state to disk.</remarks>
/// <returns></returns>
public State GetState()
{
Expand Down
86 changes: 14 additions & 72 deletions LLama/Sampling/BaseSamplingPipeline.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Sampling;
Expand All @@ -11,92 +9,36 @@ namespace LLama.Sampling;
public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (LLamaToken index, float logit)[]? _savedLogits;

/// <inheritdoc/>
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[(int)index];
_savedLogits[i] = (index, value);
}

// Process raw logits
ProcessLogits(ctx, logits, lastTokens);

// Automatically restore saved logit values after processing
RestoreProtectedTokens(logits);

// Convert logits into token candidates
var candidates = LLamaTokenDataArray.Create(logits);

// Process token data array
return ProcessTokenDataArray(ctx, candidates, lastTokens);
}
finally
{
ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
}

/// <inheritdoc />
public abstract void Accept(SafeLLamaContextHandle ctx, LLamaToken token);

#region protected tokens
/// <summary>
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
/// Grammar to constrain valid tokens
/// </summary>
/// <returns></returns>
protected abstract IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx);
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="logits"></param>
protected void RestoreProtectedTokens(Span<float> logits)
/// <inheritdoc/>
public LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
if (_savedLogits == null)
return;
// Apply processing to raw logit values
logits = ProcessLogits(ctx, logits, lastTokens);

// The array may be bigger than necessary, get a span of the valid bit
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);

// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[(int)saved[i].index] = saved[i].logit;
// Process token data array to select a final token
var candidates = LLamaTokenDataArray.Create(logits);
candidates.ApplyGrammar(ctx, Grammar);
return ProcessTokenDataArray(ctx, candidates, lastTokens);
}

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="candidates"></param>
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
/// <inheritdoc />
public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{
if (_savedLogits == null || _savedLogits.Length == 0)
return;

candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
Grammar?.AcceptToken(ctx, token);
}
#endregion

/// <summary>
/// Process the raw logit values
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
protected abstract ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Process the LLamaTokenDataArray and select a single token
Expand Down
100 changes: 74 additions & 26 deletions LLama/Sampling/DefaultSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@ public sealed class DefaultSamplingPipeline
/// </summary>
public Dictionary<int, float> LogitBias { get; } = new();

/// <summary>
/// Grammar to constrain valid tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
/// </summary>
public float RepeatPenalty { get; set; } = 1.1f;
public float RepeatPenalty { get; set; }

/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
Expand All @@ -43,7 +38,7 @@ public float AlphaFrequency
_alphaFreq = value;
}
}
private float _alphaFreq = 0.1f;
private float _alphaFreq;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
Expand All @@ -62,7 +57,7 @@ public float AlphaPresence
_alphaPresence = value;
}
}
private float _alphaPresence = 0.1f;
private float _alphaPresence;

/// <summary>
/// Temperature to apply (higher temperature is more "creative")
Expand Down Expand Up @@ -99,33 +94,46 @@ public float AlphaPresence
/// </summary>
public bool PenalizeNewline { get; set; } = false;

private readonly LLamaToken[] _newlineToken = new LLamaToken[1];
private float[]? _logits;

/// <inheritdoc />
protected override IReadOnlyList<LLamaToken> GetProtectedTokens(SafeLLamaContextHandle ctx)
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
if (PenalizeNewline)
return Array.Empty<LLamaToken>();
// Skip work if possible
if (LogitBias.Count == 0)
return logits;

_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}
// Create a temporary array to hold logits
if (_logits == null || _logits.Length < logits.Length)
_logits = new float[logits.Length];

/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
// Copy logits
logits.CopyTo(_logits);
var mutable = _logits.AsSpan(0, logits.Length);

// Apply logit bias
foreach (var (key, value) in LogitBias)
logits[key] += value;
mutable[key] += value;

return mutable;
}

/// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
// Only apply repetition penalty if we really must. Otherwise avoid all this work
if (lastTokens.Length > 0 && (RepeatPenalty != 0 || AlphaFrequency != 0 || AlphaPresence != 0))
{
// Save the logit value for the newline token
var (nlIndex, nlLogit) = PenalizeNewline ? GetNewlineLogit(ctx, candidates) : (-1, 0);

// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);

// Restore protected tokens, so they are not affected by repetition penalties
RestoreProtectedTokens(candidates);
// Restore newline token
if (!PenalizeNewline)
SetNewlineLogit(ctx, candidates, nlIndex, nlLogit);
}

// Apply the normal llama.cpp pipeline
candidates.ApplyGrammar(ctx, Grammar);
Expand All @@ -135,12 +143,52 @@ protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx,
candidates.TopP(ctx, TopP);
candidates.MinP(ctx, MinP);
candidates.Temperature(ctx, Temperature);
var id = candidates.SampleToken(ctx);
return candidates.SampleToken(ctx);
}

private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle);

// Try using the ID as an index
if (candidates.data.Span[(int)nlToken].id == nlToken)
return ((int)nlToken, candidates.data.Span[(int)nlToken].logit);

// Exhaustive search
var span = candidates.data.Span;
for (var i = 0; i < span.Length; i++)
{
if (span[i].id == nlToken)
return (i, span[i].logit);
}

Grammar?.AcceptToken(ctx, id);
return id;
return (-1, 0);
}

private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit)
{
var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle);

// Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order
if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken)
{
candidates.data.Span[indexHint].logit = logit;
return;
}

// Didn't find it, do an exhaustive search for it
var span = candidates.data.Span;
for (var i = 0; i < candidates.data.Length; i++)
{
if (span[i].id == nlToken)
{
span[i].logit = logit;
return;
}
}
}

/// <inheritdoc />
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{
Grammar?.AcceptToken(ctx, token);
Expand Down
32 changes: 32 additions & 0 deletions LLama/Sampling/GreedySamplingPipeline.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// A sampling pipeline which always selects the most likely token
/// </summary>
public class GreedySamplingPipeline
: BaseSamplingPipeline
{
/// <inheritdoc />
protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
return logits;
}

/// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
return candidates.SampleTokenGreedy(ctx);
}

/// <inheritdoc />
public override ISamplingPipeline Clone()
{
return new GreedySamplingPipeline
{
Grammar = Grammar?.Clone()
};
}
}
4 changes: 2 additions & 2 deletions LLama/Sampling/ISamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public interface ISamplingPipeline
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns>
LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
LLamaToken Sample(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens);

/// <summary>
/// Update the pipeline, with knowledge that a particular token was just accepted
Expand Down Expand Up @@ -53,7 +53,7 @@ public static class ISamplingPipelineExtensions
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<LLamaToken> lastTokens)
public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, List<LLamaToken> lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
Expand Down
Loading

0 comments on commit 91a7967

Please sign in to comment.