forked from SciSharp/LLamaSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'SciSharp:master' into Development
- Loading branch information
Showing
14 changed files
with
467 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
using System.Text; | ||
using LLama.Common; | ||
using LLama.Native; | ||
using Xunit.Abstractions; | ||
|
||
namespace LLama.Unittest; | ||
|
||
public sealed class BeamTests | ||
: IDisposable | ||
{ | ||
private readonly ITestOutputHelper _testOutputHelper; | ||
private readonly ModelParams _params; | ||
private readonly LLamaWeights _model; | ||
|
||
public BeamTests(ITestOutputHelper testOutputHelper) | ||
{ | ||
_testOutputHelper = testOutputHelper; | ||
_params = new ModelParams(Constants.ModelPath) | ||
{ | ||
ContextSize = 2048 | ||
}; | ||
_model = LLamaWeights.LoadFromFile(_params); | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
_model.Dispose(); | ||
} | ||
|
||
[Fact(Skip = "Very very slow in CI")] | ||
public void BasicBeam() | ||
{ | ||
const int num_beams = 2; | ||
const int n_predict = 3; | ||
|
||
var context = _model.CreateContext(_params); | ||
|
||
var result = new StringBuilder(); | ||
|
||
var initial_tokens = context.Tokenize("The cat sat on"); | ||
result.Append(context.DeTokenize(initial_tokens.ToArray())); | ||
context.Eval(initial_tokens, 0); | ||
|
||
NativeApi.llama_beam_search(context.NativeHandle, (data, state) => | ||
{ | ||
for (var i = 0; i < state.Beams.Length; i++) | ||
{ | ||
ref var view = ref state.Beams[i]; | ||
var tokens = context.DeTokenize(view.Tokens.ToArray()); | ||
_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | ||
} | ||
if (state.CommonPrefixLength > 0) | ||
{ | ||
var view = state.Beams[0]; | ||
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); | ||
} | ||
}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | ||
|
||
_testOutputHelper.WriteLine($"Final: {result}"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
using System.Text; | ||
using LLama.Common; | ||
using LLama.Extensions; | ||
|
||
namespace LLama.Unittest; | ||
|
||
public sealed class TokenTests | ||
: IDisposable | ||
{ | ||
private readonly ModelParams _params; | ||
private readonly LLamaWeights _model; | ||
|
||
public TokenTests() | ||
{ | ||
_params = new ModelParams(Constants.ModelPath) | ||
{ | ||
ContextSize = 2048 | ||
}; | ||
_model = LLamaWeights.LoadFromFile(_params); | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
_model.Dispose(); | ||
} | ||
|
||
[Fact] | ||
public void TokensEndWith() | ||
{ | ||
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | ||
|
||
var result = tokens.TokensEndsWithAnyString(new[] | ||
{ | ||
"a fish", | ||
"the mat", | ||
"this is an improbably long query to be using for this method" | ||
}, _model.NativeHandle, Encoding.UTF8); | ||
Assert.True(result); | ||
} | ||
|
||
[Fact] | ||
public void TokensEndSubstring() | ||
{ | ||
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | ||
|
||
var result = tokens.TokensEndsWithAnyString(new[] | ||
{ | ||
"at", | ||
}, _model.NativeHandle, Encoding.UTF8); | ||
Assert.True(result); | ||
} | ||
|
||
[Fact] | ||
public void TokensNotEndWith() | ||
{ | ||
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | ||
|
||
var result = tokens.TokensEndsWithAnyString(new[] | ||
{ | ||
"a fish", | ||
"The cat sat on the edge of the ma", | ||
"this is an improbably long query to be using for this method" | ||
}, _model.NativeHandle, Encoding.UTF8); | ||
Assert.False(result); | ||
} | ||
|
||
[Fact] | ||
public void TokensNotEndWithNothing() | ||
{ | ||
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); | ||
|
||
var result = tokens.TokensEndsWithAnyString(Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8); | ||
Assert.False(result); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Extensions; | ||
|
||
internal static class EncodingExtensions | ||
{ | ||
#if NETSTANDARD2_0 | ||
public static int GetChars(this Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output) | ||
{ | ||
unsafe | ||
{ | ||
fixed (byte* bytePtr = bytes) | ||
fixed (char* charPtr = output) | ||
{ | ||
return encoding.GetChars(bytePtr, bytes.Length, charPtr, output.Length); | ||
} | ||
} | ||
} | ||
|
||
public static int GetCharCount(this Encoding encoding, ReadOnlySpan<byte> bytes) | ||
{ | ||
unsafe | ||
{ | ||
fixed (byte* bytePtr = bytes) | ||
{ | ||
return encoding.GetCharCount(bytePtr, bytes.Length); | ||
} | ||
} | ||
} | ||
#endif | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.