diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0cafa6d8e..a81871369 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -52,4 +52,4 @@ jobs: - name: Build run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore - name: Test - run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} + run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed" diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs new file mode 100644 index 000000000..f8d5cf01c --- /dev/null +++ b/LLama.Unittest/BeamTests.cs @@ -0,0 +1,63 @@ +using System.Text; +using LLama.Common; +using LLama.Native; +using Xunit.Abstractions; + +namespace LLama.Unittest; + +public sealed class BeamTests + : IDisposable +{ + private readonly ITestOutputHelper _testOutputHelper; + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public BeamTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.ModelPath) + { + ContextSize = 2048 + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact(Skip = "Very very slow in CI")] + public void BasicBeam() + { + const int num_beams = 2; + const int n_predict = 3; + + var context = _model.CreateContext(_params); + + var result = new StringBuilder(); + + var initial_tokens = context.Tokenize("The cat sat on"); + result.Append(context.DeTokenize(initial_tokens.ToArray())); + context.Eval(initial_tokens, 0); + + NativeApi.llama_beam_search(context.NativeHandle, (data, state) => + { + for (var i = 0; i < state.Beams.Length; i++) + { + ref var view = ref state.Beams[i]; + var tokens = context.DeTokenize(view.Tokens.ToArray()); + _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); + } + + if (state.CommonPrefixLength > 0) + { + var view = state.Beams[0]; + result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); + } + + }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); + + _testOutputHelper.WriteLine($"Final: {result}"); + } +} \ No newline at end of file diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index 152ede935..b86a0f40d 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -66,7 +66,7 @@ public void SampleWithTrivialGrammar() Grammar = grammar, }; - var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList(); + var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); Assert.Equal("cat", result[0]); } diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs new file mode 100644 index 000000000..fd17727b3 --- /dev/null +++ b/LLama.Unittest/TokenTests.cs @@ -0,0 +1,75 @@ +using System.Text; +using LLama.Common; +using LLama.Extensions; + +namespace LLama.Unittest; + +public sealed class TokenTests + : IDisposable +{ + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public TokenTests() + { + _params = new ModelParams(Constants.ModelPath) + { + ContextSize = 2048 + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact] + public void TokensEndWith() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + + var result = tokens.TokensEndsWithAnyString(new[] + { + "a fish", + "the mat", + "this is an improbably long query to be using for this method" + }, _model.NativeHandle, Encoding.UTF8); + Assert.True(result); + } + + [Fact] + public void TokensEndSubstring() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + + var result = tokens.TokensEndsWithAnyString(new[] + { + "at", + }, _model.NativeHandle, Encoding.UTF8); + Assert.True(result); + } + + [Fact] + public void TokensNotEndWith() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + + var result = tokens.TokensEndsWithAnyString(new[] + { + "a fish", + "The cat sat on the edge of the ma", + "this is an improbably long query to be using for this method" + }, _model.NativeHandle, Encoding.UTF8); + Assert.False(result); + } + + [Fact] + public void TokensNotEndWithNothing() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + + var result = tokens.TokensEndsWithAnyString(Array.Empty(), _model.NativeHandle, Encoding.UTF8); + Assert.False(result); + } +} \ No newline at end of file diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 2c331e5a0..97a4d6ee6 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -15,6 +15,8 @@ public class FixedSizeQueue private readonly int _maxSize; private readonly List _storage; + internal IReadOnlyList Items => _storage; + /// /// Number of items in this queue /// @@ -57,6 +59,7 @@ public FixedSizeQueue(int size, IEnumerable data) if (_storage.Count > _maxSize) throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); } + /// /// Replace every item in the queue with the given value /// diff --git a/LLama/Extensions/EncodingExtensions.cs b/LLama/Extensions/EncodingExtensions.cs new file mode 100644 index 000000000..29073fea5 --- /dev/null +++ b/LLama/Extensions/EncodingExtensions.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Extensions; + +internal static class EncodingExtensions +{ +#if NETSTANDARD2_0 + public static int GetChars(this Encoding encoding, ReadOnlySpan bytes, Span output) + { + unsafe + { + fixed (byte* bytePtr = bytes) + fixed (char* charPtr = output) + { + return encoding.GetChars(bytePtr, bytes.Length, charPtr, output.Length); + } + } + } + + public static int GetCharCount(this Encoding encoding, ReadOnlySpan bytes) + { + unsafe + { + fixed (byte* bytePtr = bytes) + { + return encoding.GetCharCount(bytePtr, bytes.Length); + } + } + } +#endif +} \ No newline at end of file diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index 51b365be0..b07d90cfa 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -1,5 +1,9 @@ using System; +using System.Buffers; +using System.Collections; using System.Collections.Generic; +using System.Text; +using LLama.Native; namespace LLama.Extensions { @@ -16,5 +20,79 @@ internal static class IReadOnlyListExtensions return null; } + + /// + /// Check if the given set of tokens ends with any of the given strings + /// + /// Tokens to check + /// Strings to search for + /// Model to use to convert tokens into bytes + /// Encoding to use to convert bytes into characters + /// + internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) + where TTokens : IReadOnlyList + where TQueries : IReadOnlyList + { + if (queries == null || queries.Count == 0 || tokens.Count == 0) + return false; + + // Find the length of the longest query + var longest = 0; + foreach (var candidate in queries) + longest = Math.Max(longest, candidate.Length); + + // Rent an array to detokenize into + var builderArray = ArrayPool.Shared.Rent(longest); + try + { + // Convert as many tokens as possible into the builderArray + var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding); + + // Check every query to see if it's present + foreach (var query in queries) + if (characters.EndsWith(query.AsSpan())) + return true; + + return false; + } + finally + { + ArrayPool.Shared.Return(builderArray); + } + } + + internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) + where TTokens : IReadOnlyList + { + if (queries == null || queries.Count == 0 || tokens.Count == 0) + return false; + + return tokens.TokensEndsWithAnyString(new ReadonlyWrapper(queries), model, encoding); + } + + private readonly struct ReadonlyWrapper + : IReadOnlyList + { + private readonly IList _list; + + public int Count => _list.Count; + + public T this[int index] => _list[index]; + + public ReadonlyWrapper(IList list) + { + _list = list; + } + + public IEnumerator GetEnumerator() + { + return _list.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_list).GetEnumerator(); + } + } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index a7d53cc81..712c2c239 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -86,7 +86,7 @@ public override void LoadState(ExecutorBaseState data) public override void SaveState(string filename) { var state = (InstructExecutorState)GetStateData(); - using (var fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) + using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { JsonSerializer.Serialize(fs, state); } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4f29b9984..6b4c21047 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text; +using LLama.Extensions; namespace LLama { @@ -72,7 +73,7 @@ public override void LoadState(ExecutorBaseState data) public override void SaveState(string filename) { InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); - using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write)) + using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { JsonSerializer.Serialize(fs, state); } @@ -128,27 +129,11 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) { - if (args.Antiprompts is not null && args.Antiprompts.Count > 0) - { - var last_output_builder = new StringBuilder(); - foreach (var token in _last_n_tokens) - Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); - var last_output = last_output_builder.ToString(); - - foreach (var antiprompt in args.Antiprompts) - { - if (last_output.EndsWith(antiprompt)) - { - args.WaitForInput = true; - break; - } - } - } + if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + args.WaitForInput = true; if (_pastTokensCount > 0 && args.WaitForInput) - { return true; - } } if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index d3f0c0e21..5c4960378 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -4,8 +4,8 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; +using LLama.Extensions; namespace LLama { @@ -138,22 +138,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams /// private bool EndsWithAntiprompt(IReadOnlyList tokens, IReadOnlyList antiprompts) { - if (antiprompts.Count == 0 || tokens.Count == 0) - return false; - - var builder = new StringBuilder(); - foreach (var token in tokens) - builder.Append(Context.TokenToString(token)); - - var last_output = builder.ToString(); - - foreach (var antiprompt in antiprompts) - { - if (last_output.EndsWith(antiprompt)) - return true; - } - - return false; + return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); } /// diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs new file mode 100644 index 000000000..e6a6c39f5 --- /dev/null +++ b/LLama/Native/LLamaBeamView.cs @@ -0,0 +1,42 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +using llama_token = Int32; + +/// +/// Information about a single beam in a beam search +/// +[StructLayout(LayoutKind.Sequential)] +public struct LLamaBeamView +{ + private readonly unsafe llama_token* tokens; + private readonly nint n_tokens; + + /// + /// Cumulative beam probability (renormalized relative to all beams) + /// + public readonly float CumulativeProbability; + + /// + /// Callback should set this to true when a beam is at end-of-beam. + /// + public bool EndOfBeam; + + /// + /// Tokens in this beam + /// + public readonly Span Tokens + { + get + { + unsafe + { + if (n_tokens > int.MaxValue) + throw new InvalidOperationException("More than 2147483647 tokens is not supported"); + return new Span(tokens, (int)n_tokens); + } + } + } +} \ No newline at end of file diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs new file mode 100644 index 000000000..6f0a447d7 --- /dev/null +++ b/LLama/Native/LLamaBeamsState.cs @@ -0,0 +1,49 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Passed to beam_search_callback function. +/// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams +/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. +/// +[StructLayout(LayoutKind.Sequential)] +public readonly struct LLamaBeamsState +{ + /// + /// The state of each individual beam + /// + private readonly unsafe LLamaBeamView* beam_views; + + /// + /// Number of elements in beam_views + /// + private readonly nint n_beams; + + /// + /// Current max length of prefix tokens shared by all beams. + /// + public readonly ulong CommonPrefixLength; + + /// + /// True iff this is the last callback invocation. + /// + public readonly bool LastCall; + + /// + /// The current state of each beam + /// + public Span Beams + { + get + { + unsafe + { + if (n_beams > int.MaxValue) + throw new InvalidOperationException("More than 2147483647 beams is not supported"); + return new Span(beam_views, (int)n_beams); + } + } + } +} \ No newline at end of file diff --git a/LLama/Native/NativeApi.BeamSearch.cs b/LLama/Native/NativeApi.BeamSearch.cs new file mode 100644 index 000000000..1049dbe3a --- /dev/null +++ b/LLama/Native/NativeApi.BeamSearch.cs @@ -0,0 +1,25 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +public partial class NativeApi +{ + /// + /// Type of pointer to the beam_search_callback function. + /// + /// callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac + /// + public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state); + + /// Deterministically returns entire sentence constructed by a beam search. + /// Pointer to the llama_context. + /// Invoked for each iteration of the beam_search loop, passing in beams_state. + /// A pointer that is simply passed back to callback. + /// Number of beams to use. + /// Number of tokens already evaluated. + /// Maximum number of tokens to predict. EOS may occur earlier. + /// Number of threads. + [DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads); +} \ No newline at end of file diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 7074fddbc..615889d55 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,7 +1,10 @@ using System; +using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.Text; using LLama.Exceptions; +using LLama.Extensions; namespace LLama.Native { @@ -160,6 +163,93 @@ public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest } } + /// + /// Convert a sequence of tokens into characters. If there + /// + /// + /// + /// + /// The section of the span which has valid data in it. + /// If there was insufficient space in the output span this will be + /// filled with as many characters as possible, starting from the _last_ token. + /// + internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Encoding encoding) + { + // Rent an array to detokenize into + var tokenBytesArr = ArrayPool.Shared.Rent(16); + var tokenCharsArr = ArrayPool.Shared.Rent(16); + try + { + var totalCharacters = 0; + var unused = dest; + + for (var i = tokens.Count - 1; i >= 0; i--) + { + var token = tokens[i]; + + // Get bytes for this token + var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); + + // Get chars for this token + var tokenChars = BytesToChars(ref tokenCharsArr, tokenBytes, encoding); + + // Trim down number of characters if there are too many + if (tokenChars.Length > unused.Length) + tokenChars = tokenChars.Slice(tokenChars.Length - unused.Length, unused.Length); + + // Copy characters + tokenChars.CopyTo(unused.Slice(unused.Length - tokenChars.Length, tokenChars.Length)); + unused = unused.Slice(0, unused.Length - tokenChars.Length); + totalCharacters += tokenChars.Length; + + // Break out if we've run out of space + if (unused.Length == 0) + break; + } + + return dest.Slice(dest.Length - totalCharacters, totalCharacters); + } + finally + { + ArrayPool.Shared.Return(tokenBytesArr); + ArrayPool.Shared.Return(tokenCharsArr); + } + + // vvv Local Functions vvv + + static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) + { + // Try to get bytes, if that fails we known the length + var l = model.TokenToSpan(token, bytes); + + // Array was too small, get a bigger one + if (l < 0) + { + ArrayPool.Shared.Return(bytes); + bytes = ArrayPool.Shared.Rent(-l * 2); + + // Get bytes, this time it can't fail + l = model.TokenToSpan(token, bytes); + } + + Debug.Assert(l >= 0); + return new Span(bytes, 0, l); + } + + static Span BytesToChars(ref char[] chars, ReadOnlySpan bytes, Encoding encoding) + { + var count = encoding.GetCharCount(bytes); + if (count > chars.Length) + { + ArrayPool.Shared.Return(chars); + chars = ArrayPool.Shared.Rent(count * 2); + } + + encoding.GetChars(bytes, chars); + return chars.AsSpan(0, count); + } + } + /// /// Convert a string of text into tokens ///