Skip to content

Commit

Permalink
Merge branch 'SciSharp:master' into Development
Browse files Browse the repository at this point in the history
  • Loading branch information
SignalRT authored Sep 8, 2023
2 parents c272b18 + d3b8ee9 commit 50a8d3d
Show file tree
Hide file tree
Showing 14 changed files with 467 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
- name: Build
run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore
- name: Test
run: dotnet test LLamaSharp.sln -c ${{ matrix.config }}
run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} -l "console;verbosity=detailed"
63 changes: 63 additions & 0 deletions LLama.Unittest/BeamTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class BeamTests
: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly ModelParams _params;
private readonly LLamaWeights _model;

public BeamTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048
};
_model = LLamaWeights.LoadFromFile(_params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact(Skip = "Very very slow in CI")]
public void BasicBeam()
{
const int num_beams = 2;
const int n_predict = 3;

var context = _model.CreateContext(_params);

var result = new StringBuilder();

var initial_tokens = context.Tokenize("The cat sat on");
result.Append(context.DeTokenize(initial_tokens.ToArray()));
context.Eval(initial_tokens, 0);

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{
for (var i = 0; i < state.Beams.Length; i++)
{
ref var view = ref state.Beams[i];
var tokens = context.DeTokenize(view.Tokens.ToArray());
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
}
if (state.CommonPrefixLength > 0)
{
var view = state.Beams[0];
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));
}
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));

_testOutputHelper.WriteLine($"Final: {result}");
}
}
2 changes: 1 addition & 1 deletion LLama.Unittest/GrammarTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void SampleWithTrivialGrammar()
Grammar = grammar,
};

var result = executor.Infer("Question: What is your favourite number?\nAnswer: ", inferenceParams).ToList();
var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();

Assert.Equal("cat", result[0]);
}
Expand Down
75 changes: 75 additions & 0 deletions LLama.Unittest/TokenTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System.Text;
using LLama.Common;
using LLama.Extensions;

namespace LLama.Unittest;

public sealed class TokenTests
: IDisposable
{
private readonly ModelParams _params;
private readonly LLamaWeights _model;

public TokenTests()
{
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048
};
_model = LLamaWeights.LoadFromFile(_params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void TokensEndWith()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);

var result = tokens.TokensEndsWithAnyString(new[]
{
"a fish",
"the mat",
"this is an improbably long query to be using for this method"
}, _model.NativeHandle, Encoding.UTF8);
Assert.True(result);
}

[Fact]
public void TokensEndSubstring()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);

var result = tokens.TokensEndsWithAnyString(new[]
{
"at",
}, _model.NativeHandle, Encoding.UTF8);
Assert.True(result);
}

[Fact]
public void TokensNotEndWith()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);

var result = tokens.TokensEndsWithAnyString(new[]
{
"a fish",
"The cat sat on the edge of the ma",
"this is an improbably long query to be using for this method"
}, _model.NativeHandle, Encoding.UTF8);
Assert.False(result);
}

[Fact]
public void TokensNotEndWithNothing()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);

var result = tokens.TokensEndsWithAnyString(Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
Assert.False(result);
}
}
3 changes: 3 additions & 0 deletions LLama/Common/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public class FixedSizeQueue<T>
private readonly int _maxSize;
private readonly List<T> _storage;

internal IReadOnlyList<T> Items => _storage;

/// <summary>
/// Number of items in this queue
/// </summary>
Expand Down Expand Up @@ -57,6 +59,7 @@ public FixedSizeQueue(int size, IEnumerable<T> data)
if (_storage.Count > _maxSize)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}

/// <summary>
/// Replace every item in the queue with the given value
/// </summary>
Expand Down
33 changes: 33 additions & 0 deletions LLama/Extensions/EncodingExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Extensions;

internal static class EncodingExtensions
{
#if NETSTANDARD2_0
public static int GetChars(this Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output)
{
unsafe
{
fixed (byte* bytePtr = bytes)
fixed (char* charPtr = output)
{
return encoding.GetChars(bytePtr, bytes.Length, charPtr, output.Length);
}
}
}

public static int GetCharCount(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
unsafe
{
fixed (byte* bytePtr = bytes)
{
return encoding.GetCharCount(bytePtr, bytes.Length);
}
}
}
#endif
}
78 changes: 78 additions & 0 deletions LLama/Extensions/IReadOnlyListExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using LLama.Native;

namespace LLama.Extensions
{
Expand All @@ -16,5 +20,79 @@ internal static class IReadOnlyListExtensions

return null;
}

/// <summary>
/// Check if the given set of tokens ends with any of the given strings
/// </summary>
/// <param name="tokens">Tokens to check</param>
/// <param name="queries">Strings to search for</param>
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;

// Find the length of the longest query
var longest = 0;
foreach (var candidate in queries)
longest = Math.Max(longest, candidate.Length);

// Rent an array to detokenize into
var builderArray = ArrayPool<char>.Shared.Rent(longest);
try
{
// Convert as many tokens as possible into the builderArray
var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding);

// Check every query to see if it's present
foreach (var query in queries)
if (characters.EndsWith(query.AsSpan()))
return true;

return false;
}
finally
{
ArrayPool<char>.Shared.Return(builderArray);
}
}

internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;

return tokens.TokensEndsWithAnyString(new ReadonlyWrapper<string>(queries), model, encoding);
}

private readonly struct ReadonlyWrapper<T>
: IReadOnlyList<T>
{
private readonly IList<T> _list;

public int Count => _list.Count;

public T this[int index] => _list[index];

public ReadonlyWrapper(IList<T> list)
{
_list = list;
}

public IEnumerator<T> GetEnumerator()
{
return _list.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_list).GetEnumerator();
}
}
}
}
2 changes: 1 addition & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public override void LoadState(ExecutorBaseState data)
public override void SaveState(string filename)
{
var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write))
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
JsonSerializer.Serialize(fs, state);
}
Expand Down
23 changes: 4 additions & 19 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text;
using LLama.Extensions;

namespace LLama
{
Expand Down Expand Up @@ -72,7 +73,7 @@ public override void LoadState(ExecutorBaseState data)
public override void SaveState(string filename)
{
InteractiveExecutorState state = (InteractiveExecutorState)GetStateData();
using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write))
using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
JsonSerializer.Serialize(fs, state);
}
Expand Down Expand Up @@ -128,27 +129,11 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
var last_output_builder = new StringBuilder();
foreach (var token in _last_n_tokens)
Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
var last_output = last_output_builder.ToString();

foreach (var antiprompt in args.Antiprompts)
{
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
break;
}
}
}
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
{
return true;
}
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
Expand Down
19 changes: 2 additions & 17 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using LLama.Extensions;

namespace LLama
{
Expand Down Expand Up @@ -138,22 +138,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams
/// <returns></returns>
private bool EndsWithAntiprompt(IReadOnlyList<llama_token> tokens, IReadOnlyList<string> antiprompts)
{
if (antiprompts.Count == 0 || tokens.Count == 0)
return false;

var builder = new StringBuilder();
foreach (var token in tokens)
builder.Append(Context.TokenToString(token));

var last_output = builder.ToString();

foreach (var antiprompt in antiprompts)
{
if (last_output.EndsWith(antiprompt))
return true;
}

return false;
return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
}

/// <inheritdoc />
Expand Down
Loading

0 comments on commit 50a8d3d

Please sign in to comment.