Skip to content

Commit

Permalink
Merge pull request SciSharp#136 from Mihaiii/grammar_parser
Browse files Browse the repository at this point in the history
Translating the grammar parser
  • Loading branch information
martindevans authored Aug 30, 2023
2 parents e344918 + 24d3e1b commit 40e76a7
Show file tree
Hide file tree
Showing 9 changed files with 957 additions and 1 deletion.
27 changes: 27 additions & 0 deletions LLama.Examples/Assets/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf

root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
3 changes: 3 additions & 0 deletions LLama.Examples/LLama.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
<None Update="Assets\dan.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\json.gbnf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\reason-act.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
Expand Down
55 changes: 55 additions & 0 deletions LLama.Examples/NewVersion/GrammarJsonResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using LLama.Common;
using LLama.Grammar;
using LLama.Native;

namespace LLama.Examples.NewVersion
{
public class GrammarJsonResponse
{
public static void Run()
{
var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim();
var parsedGrammar = new GrammarParser();

Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);
ParseState state = parsedGrammar.Parse(grammarBytes);
using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\"");
Console.ForegroundColor = ConsoleColor.White;

var inferenceParams = new InferenceParams()
{
Temperature = 0.6f,
AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" },
MaxTokens = 50,
Grammar = grammar
};

while (true)
{
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
{
Console.Write(text);
}
}
}
}
}
5 changes: 5 additions & 0 deletions LLama.Examples/NewVersion/TestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public static async Task Run()
Console.WriteLine("7: Get embeddings from LLama model.");
Console.WriteLine("8: Quantize the model.");
Console.WriteLine("9: Automatic conversation.");
Console.WriteLine("10: Constrain response to json format using grammar.");

while (true)
{
Expand Down Expand Up @@ -63,6 +64,10 @@ public static async Task Run()
{
await TalkToYourself.Run();
}
else if (choice == 10)
{
GrammarJsonResponse.Run();
}
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");
Expand Down
274 changes: 274 additions & 0 deletions LLama.Unittest/GrammarParserTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
using LLama.Common;
using LLama.Native;
using System.Diagnostics;
using LLama.Grammar;
using Newtonsoft.Json.Linq;

namespace LLama.Unittest
{
/// <summary>
/// Source:
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
public sealed class GrammarParserTest
{
[Fact]
public void ParseComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
string grammarBytes = @"root ::= (expr ""="" term ""\n"")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+";

ParseState state = parsedGrammar.Parse(grammarBytes);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_5", 5),
new KeyValuePair<string, uint>("expr_6", 6),
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
new KeyValuePair<string, uint>("root_4", 4),
new KeyValuePair<string, uint>("term", 3),
new KeyValuePair<string, uint>("term_7", 7),
};

uint index = 0;
foreach (var it in state.SymbolIds)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
{
var element = rule[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
{
Console.Error.WriteLine($"index: {index}");
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
Console.Error.WriteLine("expected_element != actual_element");
}
Assert.Equal(expectedElement.Type, element.Type);
Assert.Equal(expectedElement.Value, element.Value);
index++;
}
}
Assert.NotEmpty(state.Rules);
}

[Fact]
public void ParseExtraComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
string grammarBytes = @"
root ::= (expr ""="" ws term ""\n"")+
expr ::= term ([-+*/] term)*
term ::= ident | num | ""("" ws expr "")"" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
";

ParseState state = parsedGrammar.Parse(grammarBytes);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_6", 6),
new KeyValuePair<string, uint>("expr_7", 7),
new KeyValuePair<string, uint>("ident", 8),
new KeyValuePair<string, uint>("ident_10", 10),
new KeyValuePair<string, uint>("num", 9),
new KeyValuePair<string, uint>("num_11", 11),
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
new KeyValuePair<string, uint>("root_5", 5),
new KeyValuePair<string, uint>("term", 4),
new KeyValuePair<string, uint>("ws", 3),
new KeyValuePair<string, uint>("ws_12", 12),
};

uint index = 0;
foreach (var it in state.SymbolIds)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0)
};

index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
{
var element = rule[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
{
Console.Error.WriteLine($"index: {index}");
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
Console.Error.WriteLine("expected_element != actual_element");
}
Assert.Equal(expectedElement.Type, element.Type);
Assert.Equal(expectedElement.Value, element.Value);
index++;
}
}
Assert.NotEmpty(state.Rules);
}
}
}
Loading

0 comments on commit 40e76a7

Please sign in to comment.