Skip to content

Commit

Permalink
Merge pull request #404 from martindevans/switched_to_LLamaToken_struct
Browse files Browse the repository at this point in the history
LLamaToken Struct
  • Loading branch information
martindevans authored Jan 9, 2024
2 parents d9b4e1f + 82727c4 commit 402a110
Show file tree
Hide file tree
Showing 29 changed files with 196 additions and 168 deletions.
3 changes: 1 addition & 2 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Diagnostics;
using System.Text;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;

Expand Down Expand Up @@ -94,7 +93,7 @@ public static async Task Run()
var n_cur = batch.NativeBatch.n_tokens;
var n_decode = 0;

var streams = new List<int>[n_parallel];
var streams = new List<LLamaToken>[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new();

Expand Down
7 changes: 4 additions & 3 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Unittest
{
Expand Down Expand Up @@ -37,23 +38,23 @@ public void Tokenize()
{
var tokens = _context.Tokenize("The quick brown fox", true);

Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
public void TokenizeWithoutBOS()
{
var tokens = _context.Tokenize("The quick brown fox", false);

Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new LLamaToken[] { 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
public void TokenizeEmpty()
{
var tokens = _context.Tokenize("", false);

Assert.Equal(Array.Empty<int>(), tokens);
Assert.Equal(Array.Empty<LLamaToken>(), tokens);
}
}
}
2 changes: 1 addition & 1 deletion LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class InferenceOptions
public int MaxTokens { get; set; } = -1;

/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
Expand Down
2 changes: 1 addition & 1 deletion LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public interface IInferenceParams
/// <summary>
/// logit bias for specific tokens
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }
public Dictionary<LLamaToken, float>? LogitBias { get; set; }

/// <summary>
/// Sequences where the model will stop generating further tokens.
Expand Down
4 changes: 1 addition & 3 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace LLama.Common
{
using llama_token = Int32;

/// <summary>
/// The paramters used for inference.
/// </summary>
Expand All @@ -28,7 +26,7 @@ public record InferenceParams
/// <summary>
/// logit bias for specific tokens
/// </summary>
public Dictionary<llama_token, float>? LogitBias { get; set; } = null;
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
Expand Down
4 changes: 2 additions & 2 deletions LLama/Extensions/IReadOnlyListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal static class IReadOnlyListExtensions
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TTokens : IReadOnlyList<LLamaToken>
where TQueries : IReadOnlyList<string>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
Expand Down Expand Up @@ -79,7 +79,7 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TTokens : IReadOnlyList<LLamaToken>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;
Expand Down
36 changes: 17 additions & 19 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

namespace LLama
{
using llama_token = Int32;

/// <summary>
/// A llama_context, which holds all the context required to interact with a model
/// </summary>
Expand Down Expand Up @@ -93,7 +91,7 @@ public void SetSeed(uint seed)
/// <param name="addBos">Whether to add a bos to the text.</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false)
{
return NativeHandle.Tokenize(text, addBos, special, Encoding);
}
Expand All @@ -104,7 +102,7 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa
/// <param name="tokens"></param>
/// <returns></returns>
[Obsolete("Use a `StreamingTokenDecoder` instead")]
public string DeTokenize(IReadOnlyList<llama_token> tokens)
public string DeTokenize(IReadOnlyList<LLamaToken> tokens)
{
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
// It should be kept around for the entire time you are decoding one stream of tokens.
Expand Down Expand Up @@ -219,7 +217,7 @@ public void LoadState(State state)
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}
Expand All @@ -240,11 +238,11 @@ public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token>
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;
LLamaToken id;

if (grammar != null)
{
Expand Down Expand Up @@ -301,7 +299,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null,
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
Expand All @@ -311,12 +309,12 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
if (logitBias is not null)
{
foreach (var (key, value) in logitBias)
logits[key] += value;
logits[(int)key] += value;
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[(int)nl_token];

// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
Expand Down Expand Up @@ -353,7 +351,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(llama_token[] tokens, int pastTokensCount)
public int Eval(LLamaToken[] tokens, int pastTokensCount)
{
return Eval(tokens.AsSpan(), pastTokensCount);
}
Expand All @@ -366,7 +364,7 @@ public int Eval(llama_token[] tokens, int pastTokensCount)
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(List<llama_token> tokens, int pastTokensCount)
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
Expand All @@ -376,15 +374,15 @@ public int Eval(List<llama_token> tokens, int pastTokensCount)
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count);
var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
}
finally
{
System.Buffers.ArrayPool<llama_token>.Shared.Return(rented);
System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented);
}
#endif
}
Expand All @@ -397,7 +395,7 @@ public int Eval(List<llama_token> tokens, int pastTokensCount)
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}
Expand All @@ -410,7 +408,7 @@ public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += (int)Params.BatchSize)
Expand Down
21 changes: 10 additions & 11 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The base class for stateful LLama executors.
/// </summary>
Expand Down Expand Up @@ -47,19 +46,19 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// <summary>
/// A container of the tokens to be processed and after processed.
/// </summary>
protected List<llama_token> _embeds = new(); // embd
protected List<LLamaToken> _embeds = new(); // embd
/// <summary>
/// A container for the tokens of input.
/// </summary>
protected List<llama_token> _embed_inps = new();
protected List<LLamaToken> _embed_inps = new();
/// <summary>
///
/// </summary>
protected List<llama_token> _session_tokens = new();
protected List<LLamaToken> _session_tokens = new();
/// <summary>
/// The last tokens generated by the model.
/// </summary>
protected FixedSizeQueue<llama_token> _last_n_tokens;
protected FixedSizeQueue<LLamaToken> _last_n_tokens;
/// <summary>
/// The context used by the executor.
/// </summary>
Expand All @@ -84,7 +83,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}

Expand All @@ -105,7 +104,7 @@ public StatefulExecutorBase WithSessionFile(string filename)
if (File.Exists(filename))
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
var session_tokens = new llama_token[Context.ContextSize];
var session_tokens = new LLamaToken[Context.ContextSize];
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
Expand Down Expand Up @@ -361,16 +360,16 @@ public class ExecutorBaseState
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }
public List<LLamaToken> Embeds { get; set; }

Check warning on line 363 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'Embeds' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }
public List<LLamaToken> EmbedInps { get; set; }

Check warning on line 366 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'EmbedInps' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }
public List<LLamaToken> SessionTokens { get; set; }

Check warning on line 369 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'SessionTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
public LLamaToken[] LastTokens { get; set; }

Check warning on line 372 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'LastTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
Expand Down
13 changes: 6 additions & 7 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The LLama executor for instruct mode.
/// </summary>
Expand All @@ -22,8 +21,8 @@ public class InstructExecutor
{
private bool _is_prompt_run = true;
private readonly string _instructionPrefix;
private llama_token[] _inp_pfx;
private llama_token[] _inp_sfx;
private LLamaToken[] _inp_pfx;
private LLamaToken[] _inp_sfx;

/// <summary>
///
Expand Down Expand Up @@ -75,7 +74,7 @@ public override Task LoadState(ExecutorBaseState data)
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_inp_pfx = state.InputPrefixTokens;
_inp_sfx = state.InputSuffixTokens;
_n_matching_session_tokens = state.MatchingSessionTokensCount;
Expand Down Expand Up @@ -210,7 +209,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
SaveSessionFile(_pathSession);
}

llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
Expand Down Expand Up @@ -266,12 +265,12 @@ public class InstructExecutorState : ExecutorBaseState
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public llama_token[] InputPrefixTokens { get; set; }
public LLamaToken[] InputPrefixTokens { get; set; }
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public llama_token[] InputSuffixTokens { get; set; }
public LLamaToken[] InputSuffixTokens { get; set; }
}
}
}
7 changes: 3 additions & 4 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The LLama executor for interactive mode.
/// </summary>
public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;
private readonly llama_token _llama_token_newline;
private readonly LLamaToken _llama_token_newline;

/// <summary>
///
Expand Down Expand Up @@ -63,7 +62,7 @@ public override Task LoadState(ExecutorBaseState data)
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_n_matching_session_tokens = state.MatchingSessionTokensCount;
_pastTokensCount = state.PastTokensCount;
_pathSession = state.SessionFilePath;
Expand Down Expand Up @@ -189,7 +188,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
SaveSessionFile(_pathSession);

Check warning on line 188 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.

Check warning on line 188 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
Expand Down
Loading

0 comments on commit 402a110

Please sign in to comment.