forked from SciSharp/LLamaSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request SciSharp#136 from Mihaiii/grammar_parser
Translating the grammar parser
- Loading branch information
Showing
9 changed files
with
957 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf | ||
|
||
root ::= object | ||
value ::= object | array | string | number | ("true" | "false" | "null") ws | ||
|
||
object ::= | ||
"{" ws ( | ||
string ":" ws value | ||
("," ws string ":" ws value)* | ||
)? "}" ws | ||
|
||
array ::= | ||
"[" ws ( | ||
value | ||
("," ws value)* | ||
)? "]" ws | ||
|
||
string ::= | ||
"\"" ( | ||
[^"\\] | | ||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | ||
)* "\"" ws | ||
|
||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | ||
|
||
# Optional space: by convention, applied in this grammar after literal chars when allowed | ||
ws ::= ([ \t\n] ws)? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
using LLama.Common; | ||
using LLama.Grammar; | ||
using LLama.Native; | ||
|
||
namespace LLama.Examples.NewVersion | ||
{ | ||
public class GrammarJsonResponse | ||
{ | ||
public static void Run() | ||
{ | ||
var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim(); | ||
var parsedGrammar = new GrammarParser(); | ||
|
||
Console.Write("Please input your model path: "); | ||
var modelPath = Console.ReadLine(); | ||
|
||
var parameters = new ModelParams(modelPath) | ||
{ | ||
ContextSize = 1024, | ||
Seed = 1337, | ||
GpuLayerCount = 5 | ||
}; | ||
using var model = LLamaWeights.LoadFromFile(parameters); | ||
var ex = new StatelessExecutor(model, parameters); | ||
ParseState state = parsedGrammar.Parse(grammarBytes); | ||
using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0); | ||
|
||
Console.ForegroundColor = ConsoleColor.Yellow; | ||
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); | ||
Console.ForegroundColor = ConsoleColor.White; | ||
|
||
var inferenceParams = new InferenceParams() | ||
{ | ||
Temperature = 0.6f, | ||
AntiPrompts = new List<string> { "Question:", "#", "Question: ", ".\n" }, | ||
MaxTokens = 50, | ||
Grammar = grammar | ||
}; | ||
|
||
while (true) | ||
{ | ||
Console.Write("\nQuestion: "); | ||
Console.ForegroundColor = ConsoleColor.Green; | ||
var prompt = Console.ReadLine(); | ||
Console.ForegroundColor = ConsoleColor.White; | ||
Console.Write("Answer: "); | ||
prompt = $"Question: {prompt?.Trim()} Answer: "; | ||
foreach (var text in ex.Infer(prompt, inferenceParams)) | ||
{ | ||
Console.Write(text); | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
using LLama.Common; | ||
using LLama.Native; | ||
using System.Diagnostics; | ||
using LLama.Grammar; | ||
using Newtonsoft.Json.Linq; | ||
|
||
namespace LLama.Unittest | ||
{ | ||
/// <summary> | ||
/// Source: | ||
/// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp | ||
/// | ||
/// The commit hash from URL is the actual commit hash that reflects current C# code. | ||
/// </summary> | ||
public sealed class GrammarParserTest | ||
{ | ||
[Fact] | ||
public void ParseComplexGrammar() | ||
{ | ||
GrammarParser parsedGrammar = new GrammarParser(); | ||
string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ | ||
expr ::= term ([-+*/] term)* | ||
term ::= [0-9]+"; | ||
|
||
ParseState state = parsedGrammar.Parse(grammarBytes); | ||
|
||
List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>> | ||
{ | ||
new KeyValuePair<string, uint>("expr", 2), | ||
new KeyValuePair<string, uint>("expr_5", 5), | ||
new KeyValuePair<string, uint>("expr_6", 6), | ||
new KeyValuePair<string, uint>("root", 0), | ||
new KeyValuePair<string, uint>("root_1", 1), | ||
new KeyValuePair<string, uint>("root_4", 4), | ||
new KeyValuePair<string, uint>("term", 3), | ||
new KeyValuePair<string, uint>("term_7", 7), | ||
}; | ||
|
||
uint index = 0; | ||
foreach (var it in state.SymbolIds) | ||
{ | ||
string key = it.Key; | ||
uint value = it.Value; | ||
var expectedPair = expected[(int)index]; | ||
|
||
// pretty print error message before asserting | ||
if (expectedPair.Key != key || expectedPair.Value != value) | ||
{ | ||
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); | ||
Console.Error.WriteLine($"actualPair: {key}, {value}"); | ||
Console.Error.WriteLine("expectedPair != actualPair"); | ||
} | ||
Assert.Equal(expectedPair.Key, key); | ||
Assert.Equal(expectedPair.Value, value); | ||
|
||
index++; | ||
} | ||
Assert.NotEmpty(state.SymbolIds); | ||
|
||
|
||
var expectedRules = new List<LLamaGrammarElement> | ||
{ | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
}; | ||
|
||
index = 0; | ||
foreach (var rule in state.Rules) | ||
{ | ||
// compare rule to expected rule | ||
for (uint i = 0; i < rule.Count; i++) | ||
{ | ||
var element = rule[(int)i]; | ||
var expectedElement = expectedRules[(int)index]; | ||
|
||
// Pretty print error message before asserting | ||
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) | ||
{ | ||
Console.Error.WriteLine($"index: {index}"); | ||
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); | ||
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); | ||
Console.Error.WriteLine("expected_element != actual_element"); | ||
} | ||
Assert.Equal(expectedElement.Type, element.Type); | ||
Assert.Equal(expectedElement.Value, element.Value); | ||
index++; | ||
} | ||
} | ||
Assert.NotEmpty(state.Rules); | ||
} | ||
|
||
[Fact] | ||
public void ParseExtraComplexGrammar() | ||
{ | ||
GrammarParser parsedGrammar = new GrammarParser(); | ||
string grammarBytes = @" | ||
root ::= (expr ""="" ws term ""\n"")+ | ||
expr ::= term ([-+*/] term)* | ||
term ::= ident | num | ""("" ws expr "")"" ws | ||
ident ::= [a-z] [a-z0-9_]* ws | ||
num ::= [0-9]+ ws | ||
ws ::= [ \t\n]* | ||
"; | ||
|
||
ParseState state = parsedGrammar.Parse(grammarBytes); | ||
|
||
List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>> | ||
{ | ||
new KeyValuePair<string, uint>("expr", 2), | ||
new KeyValuePair<string, uint>("expr_6", 6), | ||
new KeyValuePair<string, uint>("expr_7", 7), | ||
new KeyValuePair<string, uint>("ident", 8), | ||
new KeyValuePair<string, uint>("ident_10", 10), | ||
new KeyValuePair<string, uint>("num", 9), | ||
new KeyValuePair<string, uint>("num_11", 11), | ||
new KeyValuePair<string, uint>("root", 0), | ||
new KeyValuePair<string, uint>("root_1", 1), | ||
new KeyValuePair<string, uint>("root_5", 5), | ||
new KeyValuePair<string, uint>("term", 4), | ||
new KeyValuePair<string, uint>("ws", 3), | ||
new KeyValuePair<string, uint>("ws_12", 12), | ||
}; | ||
|
||
uint index = 0; | ||
foreach (var it in state.SymbolIds) | ||
{ | ||
string key = it.Key; | ||
uint value = it.Value; | ||
var expectedPair = expected[(int)index]; | ||
|
||
// pretty print error message before asserting | ||
if (expectedPair.Key != key || expectedPair.Value != value) | ||
{ | ||
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); | ||
Console.Error.WriteLine($"actualPair: {key}, {value}"); | ||
Console.Error.WriteLine("expectedPair != actualPair"); | ||
} | ||
Assert.Equal(expectedPair.Key, key); | ||
Assert.Equal(expectedPair.Value, value); | ||
|
||
index++; | ||
} | ||
Assert.NotEmpty(state.SymbolIds); | ||
|
||
|
||
var expectedRules = new List<LLamaGrammarElement> | ||
{ | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9), | ||
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10), | ||
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), | ||
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), | ||
new LLamaGrammarElement(LLamaGrammarElementType.END, 0) | ||
}; | ||
|
||
index = 0; | ||
foreach (var rule in state.Rules) | ||
{ | ||
// compare rule to expected rule | ||
for (uint i = 0; i < rule.Count; i++) | ||
{ | ||
var element = rule[(int)i]; | ||
var expectedElement = expectedRules[(int)index]; | ||
|
||
// Pretty print error message before asserting | ||
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) | ||
{ | ||
Console.Error.WriteLine($"index: {index}"); | ||
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); | ||
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); | ||
Console.Error.WriteLine("expected_element != actual_element"); | ||
} | ||
Assert.Equal(expectedElement.Type, element.Type); | ||
Assert.Equal(expectedElement.Value, element.Value); | ||
index++; | ||
} | ||
} | ||
Assert.NotEmpty(state.Rules); | ||
} | ||
} | ||
} |
Oops, something went wrong.