Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translating the grammar parser #136

Merged
merged 10 commits into from
Aug 30, 2023
27 changes: 27 additions & 0 deletions LLama.Examples/Assets/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf

root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
3 changes: 3 additions & 0 deletions LLama.Examples/LLama.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
<None Update="Assets\dan.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\json.gbnf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\reason-act.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
Expand Down
55 changes: 55 additions & 0 deletions LLama.Examples/NewVersion/GrammarJsonResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using LLama.Common;
using LLama.Grammar;
using LLama.Native;

namespace LLama.Examples.NewVersion
{
public class GrammarJsonResponse
{
public static void Run()
{
var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim();
var parsedGrammar = new GrammarParser();

Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);
ParseState state = parsedGrammar.Parse(grammarBytes);
using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0);
martindevans marked this conversation as resolved.
Show resolved Hide resolved

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\"");
Console.ForegroundColor = ConsoleColor.White;

var inferenceParams = new InferenceParams()
{
Temperature = 0.6f,
AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" },
MaxTokens = 50,
Grammar = grammar
};

while (true)
{
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
{
Console.Write(text);
}
}
}
}
}
5 changes: 5 additions & 0 deletions LLama.Examples/NewVersion/TestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public static async Task Run()
Console.WriteLine("7: Get embeddings from LLama model.");
Console.WriteLine("8: Quantize the model.");
Console.WriteLine("9: Automatic conversation.");
Console.WriteLine("10: Constrain response to json format using grammar.");

while (true)
{
Expand Down Expand Up @@ -63,6 +64,10 @@ public static async Task Run()
{
await TalkToYourself.Run();
}
else if (choice == 10)
{
GrammarJsonResponse.Run();
}
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");
Expand Down
274 changes: 274 additions & 0 deletions LLama.Unittest/GrammarParserTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
using LLama.Common;
using LLama.Native;
using System.Diagnostics;
using LLama.Grammar;
using Newtonsoft.Json.Linq;

namespace LLama.Unittest
{
/// <summary>
/// Source:
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
public sealed class GrammarParserTest
{
[Fact]
public void ParseComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
string grammarBytes = @"root ::= (expr ""="" term ""\n"")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+";

ParseState state = parsedGrammar.Parse(grammarBytes);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_5", 5),
new KeyValuePair<string, uint>("expr_6", 6),
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
new KeyValuePair<string, uint>("root_4", 4),
new KeyValuePair<string, uint>("term", 3),
new KeyValuePair<string, uint>("term_7", 7),
};

uint index = 0;
foreach (var it in state.SymbolIds)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
{
var element = rule[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
{
Console.Error.WriteLine($"index: {index}");
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
Console.Error.WriteLine("expected_element != actual_element");
}
Assert.Equal(expectedElement.Type, element.Type);
Assert.Equal(expectedElement.Value, element.Value);
index++;
}
}
Assert.NotEmpty(state.Rules);
}

[Fact]
public void ParseExtraComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
string grammarBytes = @"
root ::= (expr ""="" ws term ""\n"")+
expr ::= term ([-+*/] term)*
term ::= ident | num | ""("" ws expr "")"" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
";

ParseState state = parsedGrammar.Parse(grammarBytes);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_6", 6),
new KeyValuePair<string, uint>("expr_7", 7),
new KeyValuePair<string, uint>("ident", 8),
new KeyValuePair<string, uint>("ident_10", 10),
new KeyValuePair<string, uint>("num", 9),
new KeyValuePair<string, uint>("num_11", 11),
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
new KeyValuePair<string, uint>("root_5", 5),
new KeyValuePair<string, uint>("term", 4),
new KeyValuePair<string, uint>("ws", 3),
new KeyValuePair<string, uint>("ws_12", 12),
};

uint index = 0;
foreach (var it in state.SymbolIds)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0)
};

index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
{
var element = rule[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
{
Console.Error.WriteLine($"index: {index}");
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
Console.Error.WriteLine("expected_element != actual_element");
}
Assert.Equal(expectedElement.Type, element.Type);
Assert.Equal(expectedElement.Value, element.Value);
index++;
}
}
Assert.NotEmpty(state.Rules);
}
}
}
Loading
Loading