diff --git a/LLama.Examples/Assets/json.gbnf b/LLama.Examples/Assets/json.gbnf
new file mode 100644
index 000000000..a01c4efd7
--- /dev/null
+++ b/LLama.Examples/Assets/json.gbnf
@@ -0,0 +1,27 @@
+# https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf
+
+root ::= object
+value ::= object | array | string | number | ("true" | "false" | "null") ws
+
+object ::=
+ "{" ws (
+ string ":" ws value
+ ("," ws string ":" ws value)*
+ )? "}" ws
+
+array ::=
+ "[" ws (
+ value
+ ("," ws value)*
+ )? "]" ws
+
+string ::=
+ "\"" (
+ [^"\\] |
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
+ )* "\"" ws
+
+number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
+
+# Optional space: by convention, applied in this grammar after literal chars when allowed
+ws ::= ([ \t\n] ws)?
\ No newline at end of file
diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj
index ef7ac437e..6a1685ed0 100644
--- a/LLama.Examples/LLama.Examples.csproj
+++ b/LLama.Examples/LLama.Examples.csproj
@@ -49,6 +49,9 @@
PreserveNewest
+
+ PreserveNewest
+
PreserveNewest
diff --git a/LLama.Examples/NewVersion/GrammarJsonResponse.cs b/LLama.Examples/NewVersion/GrammarJsonResponse.cs
new file mode 100644
index 000000000..926aa82b1
--- /dev/null
+++ b/LLama.Examples/NewVersion/GrammarJsonResponse.cs
@@ -0,0 +1,55 @@
+using LLama.Common;
+using LLama.Grammar;
+using LLama.Native;
+
+namespace LLama.Examples.NewVersion
+{
+ public class GrammarJsonResponse
+ {
+ public static void Run()
+ {
+ var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim();
+ var parsedGrammar = new GrammarParser();
+
+ Console.Write("Please input your model path: ");
+ var modelPath = Console.ReadLine();
+
+ var parameters = new ModelParams(modelPath)
+ {
+ ContextSize = 1024,
+ Seed = 1337,
+ GpuLayerCount = 5
+ };
+ using var model = LLamaWeights.LoadFromFile(parameters);
+ var ex = new StatelessExecutor(model, parameters);
+ ParseState state = parsedGrammar.Parse(grammarBytes);
+ using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0);
+
+ Console.ForegroundColor = ConsoleColor.Yellow;
+ Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\"");
+ Console.ForegroundColor = ConsoleColor.White;
+
+ var inferenceParams = new InferenceParams()
+ {
+ Temperature = 0.6f,
+ AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" },
+ MaxTokens = 50,
+ Grammar = grammar
+ };
+
+ while (true)
+ {
+ Console.Write("\nQuestion: ");
+ Console.ForegroundColor = ConsoleColor.Green;
+ var prompt = Console.ReadLine();
+ Console.ForegroundColor = ConsoleColor.White;
+ Console.Write("Answer: ");
+ prompt = $"Question: {prompt?.Trim()} Answer: ";
+ foreach (var text in ex.Infer(prompt, inferenceParams))
+ {
+ Console.Write(text);
+ }
+ }
+ }
+ }
+}
diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs
index 6cc3f3dac..f5a10ef4f 100644
--- a/LLama.Examples/NewVersion/TestRunner.cs
+++ b/LLama.Examples/NewVersion/TestRunner.cs
@@ -17,6 +17,7 @@ public static async Task Run()
Console.WriteLine("7: Get embeddings from LLama model.");
Console.WriteLine("8: Quantize the model.");
Console.WriteLine("9: Automatic conversation.");
+ Console.WriteLine("10: Constrain response to json format using grammar.");
while (true)
{
@@ -63,6 +64,10 @@ public static async Task Run()
{
await TalkToYourself.Run();
}
+ else if (choice == 10)
+ {
+ GrammarJsonResponse.Run();
+ }
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");
diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs
new file mode 100644
index 000000000..496f502ce
--- /dev/null
+++ b/LLama.Unittest/GrammarParserTest.cs
@@ -0,0 +1,274 @@
+using LLama.Common;
+using LLama.Native;
+using System.Diagnostics;
+using LLama.Grammar;
+using Newtonsoft.Json.Linq;
+
+namespace LLama.Unittest
+{
+ ///
+ /// Source:
+ /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp
+ ///
+ /// The commit hash from URL is the actual commit hash that reflects current C# code.
+ ///
+ public sealed class GrammarParserTest
+ {
+ [Fact]
+ public void ParseComplexGrammar()
+ {
+ GrammarParser parsedGrammar = new GrammarParser();
+ string grammarBytes = @"root ::= (expr ""="" term ""\n"")+
+ expr ::= term ([-+*/] term)*
+ term ::= [0-9]+";
+
+ ParseState state = parsedGrammar.Parse(grammarBytes);
+
+ List> expected = new List>
+ {
+ new KeyValuePair("expr", 2),
+ new KeyValuePair("expr_5", 5),
+ new KeyValuePair("expr_6", 6),
+ new KeyValuePair("root", 0),
+ new KeyValuePair("root_1", 1),
+ new KeyValuePair("root_4", 4),
+ new KeyValuePair("term", 3),
+ new KeyValuePair("term_7", 7),
+ };
+
+ uint index = 0;
+ foreach (var it in state.SymbolIds)
+ {
+ string key = it.Key;
+ uint value = it.Value;
+ var expectedPair = expected[(int)index];
+
+ // pretty print error message before asserting
+ if (expectedPair.Key != key || expectedPair.Value != value)
+ {
+ Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
+ Console.Error.WriteLine($"actualPair: {key}, {value}");
+ Console.Error.WriteLine("expectedPair != actualPair");
+ }
+ Assert.Equal(expectedPair.Key, key);
+ Assert.Equal(expectedPair.Value, value);
+
+ index++;
+ }
+ Assert.NotEmpty(state.SymbolIds);
+
+
+ var expectedRules = new List
+ {
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ };
+
+ index = 0;
+ foreach (var rule in state.Rules)
+ {
+ // compare rule to expected rule
+ for (uint i = 0; i < rule.Count; i++)
+ {
+ var element = rule[(int)i];
+ var expectedElement = expectedRules[(int)index];
+
+ // Pretty print error message before asserting
+ if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
+ {
+ Console.Error.WriteLine($"index: {index}");
+ Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
+ Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
+ Console.Error.WriteLine("expected_element != actual_element");
+ }
+ Assert.Equal(expectedElement.Type, element.Type);
+ Assert.Equal(expectedElement.Value, element.Value);
+ index++;
+ }
+ }
+ Assert.NotEmpty(state.Rules);
+ }
+
+ [Fact]
+ public void ParseExtraComplexGrammar()
+ {
+ GrammarParser parsedGrammar = new GrammarParser();
+ string grammarBytes = @"
+ root ::= (expr ""="" ws term ""\n"")+
+ expr ::= term ([-+*/] term)*
+ term ::= ident | num | ""("" ws expr "")"" ws
+ ident ::= [a-z] [a-z0-9_]* ws
+ num ::= [0-9]+ ws
+ ws ::= [ \t\n]*
+ ";
+
+ ParseState state = parsedGrammar.Parse(grammarBytes);
+
+ List> expected = new List>
+ {
+ new KeyValuePair("expr", 2),
+ new KeyValuePair("expr_6", 6),
+ new KeyValuePair("expr_7", 7),
+ new KeyValuePair("ident", 8),
+ new KeyValuePair("ident_10", 10),
+ new KeyValuePair("num", 9),
+ new KeyValuePair("num_11", 11),
+ new KeyValuePair("root", 0),
+ new KeyValuePair("root_1", 1),
+ new KeyValuePair("root_5", 5),
+ new KeyValuePair("term", 4),
+ new KeyValuePair("ws", 3),
+ new KeyValuePair("ws_12", 12),
+ };
+
+ uint index = 0;
+ foreach (var it in state.SymbolIds)
+ {
+ string key = it.Key;
+ uint value = it.Value;
+ var expectedPair = expected[(int)index];
+
+ // pretty print error message before asserting
+ if (expectedPair.Key != key || expectedPair.Value != value)
+ {
+ Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
+ Console.Error.WriteLine($"actualPair: {key}, {value}");
+ Console.Error.WriteLine("expectedPair != actualPair");
+ }
+ Assert.Equal(expectedPair.Key, key);
+ Assert.Equal(expectedPair.Value, value);
+
+ index++;
+ }
+ Assert.NotEmpty(state.SymbolIds);
+
+
+ var expectedRules = new List
+ {
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9),
+ new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10),
+ new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12),
+ new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
+ new LLamaGrammarElement(LLamaGrammarElementType.END, 0)
+ };
+
+ index = 0;
+ foreach (var rule in state.Rules)
+ {
+ // compare rule to expected rule
+ for (uint i = 0; i < rule.Count; i++)
+ {
+ var element = rule[(int)i];
+ var expectedElement = expectedRules[(int)index];
+
+ // Pretty print error message before asserting
+ if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
+ {
+ Console.Error.WriteLine($"index: {index}");
+ Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
+ Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
+ Console.Error.WriteLine("expected_element != actual_element");
+ }
+ Assert.Equal(expectedElement.Type, element.Type);
+ Assert.Equal(expectedElement.Value, element.Value);
+ index++;
+ }
+ }
+ Assert.NotEmpty(state.Rules);
+ }
+ }
+}
diff --git a/LLama/Exceptions/GrammarFormatException.cs b/LLama/Exceptions/GrammarFormatException.cs
new file mode 100644
index 000000000..5b1299dd4
--- /dev/null
+++ b/LLama/Exceptions/GrammarFormatException.cs
@@ -0,0 +1,21 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Exceptions
+{
+ public class GrammarFormatException
+ : Exception
+ {
+ public GrammarFormatException()
+ {
+
+ }
+
+ public GrammarFormatException(string message)
+ : base(message)
+ {
+
+ }
+ }
+}
diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs
new file mode 100644
index 000000000..4122e58f6
--- /dev/null
+++ b/LLama/Grammar/GrammarParser.cs
@@ -0,0 +1,388 @@
+using LLama.Exceptions;
+using LLama.Native;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Grammar
+{
+ ///
+ /// Source:
+ /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp
+ ///
+ /// The commit hash from URL is the actual commit hash that reflects current C# code.
+ ///
+ public class GrammarParser
+ {
+ // NOTE: assumes valid utf8 (but checks for overrun)
+ // copied from llama.cpp
+ private uint DecodeUTF8(ref ReadOnlySpan src)
+ {
+ int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+
+ byte firstByte = src[0];
+ byte highbits = (byte)(firstByte >> 4);
+ int len = lookup[highbits];
+ byte mask = (byte)((1 << (8 - len)) - 1);
+ uint value = (uint)(firstByte & mask);
+
+ int end = len;
+ int pos = 1;
+
+ for (; pos < end && pos < src.Length; pos++)
+ {
+ value = (uint)((value << 6) + (src[pos] & 0x3F));
+ }
+
+ src = src.Slice(pos);
+
+ return value;
+ }
+
+ private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len)
+ {
+ uint nextId = (uint)state.SymbolIds.Count;
+ string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
+
+ if (state.SymbolIds.TryGetValue(key, out uint existingId))
+ {
+ return existingId;
+ }
+ else
+ {
+ state.SymbolIds[key] = nextId;
+ return nextId;
+ }
+ }
+
+ private uint GenerateSymbolId(ParseState state, string baseName)
+ {
+ uint nextId = (uint)state.SymbolIds.Count;
+ string key = $"{baseName}_{nextId}";
+ state.SymbolIds[key] = nextId;
+ return nextId;
+ }
+
+ private void AddRule(ParseState state, uint ruleId, List rule)
+ {
+ while (state.Rules.Count <= ruleId)
+ {
+ state.Rules.Add(new List());
+ }
+
+ state.Rules[(int)ruleId] = rule;
+ }
+
+ private bool IsWordChar(byte c)
+ {
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+ }
+
+ private uint ParseHex(ref ReadOnlySpan src, int size)
+ {
+ int pos = 0;
+ int end = size;
+ uint value = 0;
+
+ for (; pos < end && pos < src.Length; pos++)
+ {
+ value <<= 4;
+ byte c = src[pos];
+ if ('a' <= c && c <= 'f')
+ {
+ value += (uint)(c - 'a' + 10);
+ }
+ else if ('A' <= c && c <= 'F')
+ {
+ value += (uint)(c - 'A' + 10);
+ }
+ else if ('0' <= c && c <= '9')
+ {
+ value += (uint)(c - '0');
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ if (pos != end)
+ {
+ throw new GrammarFormatException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}");
+ }
+ src = src.Slice(pos);
+ return value;
+ }
+
+ private ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk)
+ {
+ int pos = 0;
+ while (pos < src.Length &&
+ (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' ||
+ (newlineOk && (src[pos] == '\r' || src[pos] == '\n'))))
+ {
+ if (src[pos] == '#')
+ {
+ while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n')
+ {
+ pos++;
+ }
+ }
+ else
+ {
+ pos++;
+ }
+ }
+ return src.Slice(pos);
+ }
+
+ private ReadOnlySpan ParseName(ReadOnlySpan src)
+ {
+ int pos = 0;
+ while (pos < src.Length && IsWordChar(src[pos]))
+ {
+ pos++;
+ }
+ if (pos == 0)
+ {
+ throw new GrammarFormatException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}");
+ }
+ return src.Slice(pos);
+ }
+
+ private uint ParseChar(ref ReadOnlySpan src)
+ {
+ if (src[0] == '\\')
+ {
+ var chr = src[1];
+ src = src.Slice(2);
+ switch (chr)
+ {
+ case (byte)'x':
+ return ParseHex(ref src, 2);
+ case (byte)'u':
+ return ParseHex(ref src, 4);
+ case (byte)'U':
+ return ParseHex(ref src, 8);
+ case (byte)'t':
+ return '\t';
+ case (byte)'r':
+ return '\r';
+ case (byte)'n':
+ return '\n';
+ case (byte)'\\':
+ case (byte)'"':
+ case (byte)'[':
+ case (byte)']':
+ return chr;
+ default:
+ throw new GrammarFormatException("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray()));
+ }
+ }
+ else if (!src.IsEmpty)
+ {
+ return DecodeUTF8(ref src);
+ }
+
+ throw new GrammarFormatException("Unexpected end of input");
+ }
+
+ private ReadOnlySpan ParseSequence(
+ ParseState state,
+ ReadOnlySpan pos,
+ string ruleName,
+ List outElements,
+ bool isNested)
+ {
+ int lastSymStart = outElements.Count;
+
+ while (!pos.IsEmpty)
+ {
+ if (pos[0] == '"') // literal string
+ {
+ pos = pos.Slice(1);
+ lastSymStart = outElements.Count;
+
+ while (!pos.IsEmpty && pos[0] != '"')
+ {
+ var charPair = ParseChar(ref pos);
+ outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair });
+ }
+ pos = ParseSpace(pos.Slice(1), isNested);
+ }
+ else if (pos[0] == '[') // char range(s)
+ {
+ pos = pos.Slice(1);
+ var startType = LLamaGrammarElementType.CHAR;
+
+ if (pos[0] == '^')
+ {
+ pos = pos.Slice(1);
+ startType = LLamaGrammarElementType.CHAR_NOT;
+ }
+
+ lastSymStart = outElements.Count;
+
+ while (!pos.IsEmpty && pos[0] != ']')
+ {
+ var charPair = ParseChar(ref pos);
+ var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;
+
+ outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair });
+
+ if (pos[0] == '-' && pos[1] != ']')
+ {
+ pos = pos.Slice(1);
+ var endCharPair = ParseChar(ref pos);
+ outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair });
+ }
+ }
+ pos = ParseSpace(pos.Slice(1), isNested);
+ }
+ else if (IsWordChar(pos[0])) // rule reference
+ {
+ var nameEnd = ParseName(pos);
+ uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
+ pos = ParseSpace(nameEnd, isNested);
+ lastSymStart = outElements.Count;
+ outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId });
+ }
+ else if (pos[0] == '(') // grouping
+ {
+ // parse nested alternates into synthesized rule
+ pos = ParseSpace(pos.Slice(1), true);
+ uint subRuleId = GenerateSymbolId(state, ruleName);
+ pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
+ lastSymStart = outElements.Count;
+ // output reference to synthesized rule
+ outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
+ if (pos[0] != ')')
+ {
+ throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
+ }
+ pos = ParseSpace(pos.Slice(1), isNested);
+ }
+ else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator
+ {
+ if (lastSymStart == outElements.Count)
+ {
+ throw new GrammarFormatException($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}");
+ }
+
+ // apply transformation to previous symbol (lastSymStart to end) according to
+ // rewrite rules:
+ // S* --> S' ::= S S' |
+ // S+ --> S' ::= S S' | S
+ // S? --> S' ::= S |
+ uint subRuleId = GenerateSymbolId(state, ruleName);
+
+ List subRule = new List();
+
+ // add preceding symbol to generated rule
+ subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
+
+ if (pos[0] == '*' || pos[0] == '+')
+ {
+ // cause generated rule to recurse
+ subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
+ }
+
+ // mark start of alternate def
+ subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
+
+ if (pos[0] == '+')
+ {
+ // add preceding symbol as alternate only for '+' (otherwise empty)
+ subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart));
+ }
+
+ subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
+
+ AddRule(state, subRuleId, subRule);
+
+ // in original rule, replace previous symbol with reference to generated rule
+ outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
+ outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
+
+ pos = ParseSpace(pos.Slice(1), isNested);
+
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return pos;
+ }
+
+ private ReadOnlySpan ParseAlternates(
+ ParseState state,
+ ReadOnlySpan src,
+ string ruleName,
+ uint ruleId,
+ bool isNested)
+ {
+ var rule = new List();
+ ReadOnlySpan pos = ParseSequence(state, src, ruleName, rule, isNested);
+
+ while (!pos.IsEmpty && pos[0] == '|')
+ {
+ rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
+ pos = ParseSpace(pos.Slice(1), true);
+ pos = ParseSequence(state, pos, ruleName, rule, isNested);
+ }
+
+ rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 });
+ AddRule(state, ruleId, rule);
+
+ return pos;
+ }
+
+ private ReadOnlySpan ParseRule(ParseState state, ReadOnlySpan src)
+ {
+ ReadOnlySpan nameEnd = ParseName(src);
+ ReadOnlySpan pos = ParseSpace(nameEnd, false);
+ int nameLen = src.Length - nameEnd.Length;
+ uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0);
+ string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());
+
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
+ {
+ throw new GrammarFormatException($"Expecting ::= at {Encoding.UTF8.GetString(pos.ToArray())}");
+ }
+ pos = ParseSpace(pos.Slice(3), true);
+
+ pos = ParseAlternates(state, pos, name, ruleId, false);
+
+ if (!pos.IsEmpty && pos[0] == '\r')
+ {
+ pos = pos.Slice(pos[1] == '\n' ? 2 : 1);
+ }
+ else if (!pos.IsEmpty && pos[0] == '\n')
+ {
+ pos = pos.Slice(1);
+ }
+ else if (!pos.IsEmpty)
+ {
+ throw new GrammarFormatException($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}");
+ }
+ return ParseSpace(pos, true);
+ }
+
+ public ParseState Parse(string input)
+ {
+ byte[] byteArray = Encoding.UTF8.GetBytes(input);
+ ReadOnlySpan src = new ReadOnlySpan(byteArray);
+ ParseState state = new ParseState();
+ ReadOnlySpan pos = ParseSpace(src, true);
+
+ while (!pos.IsEmpty)
+ {
+ pos = ParseRule(state, pos);
+ }
+
+ return state;
+ }
+ }
+}
diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs
new file mode 100644
index 000000000..0c75a8a00
--- /dev/null
+++ b/LLama/Grammar/ParseState.cs
@@ -0,0 +1,181 @@
+using LLama.Exceptions;
+using LLama.Native;
+using System;
+using System.Collections.Generic;
+using System.IO;
+
+namespace LLama.Grammar
+{
+ ///
+ /// Source:
+ /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h
+ ///
+ /// The commit hash from URL is the actual commit hash that reflects current C# code.
+ ///
+ public class ParseState
+ {
+ public SortedDictionary SymbolIds { get; } = new SortedDictionary();
+ public List> Rules { get; } = new List>();
+
+ public IEnumerable> CRules()
+ {
+ foreach (var rule in Rules)
+ {
+ yield return rule;
+ }
+ }
+
+ public void PrintGrammar(StreamWriter file, ParseState state)
+ {
+ try
+ {
+ Dictionary symbolIdNames = new Dictionary();
+ foreach (var kv in state.SymbolIds)
+ {
+ symbolIdNames[kv.Value] = kv.Key;
+ }
+ for (int i = 0, end = state.Rules.Count; i < end; i++)
+ {
+ PrintRule(file, (uint)i, state.Rules[i], symbolIdNames);
+ }
+ }
+ catch(Exception err)
+ {
+ Console.Error.WriteLine($"\nError printing grammar: {err.Message}");
+ }
+ }
+
+ public void PrintRuleBinary(StreamWriter file, List rule)
+ {
+ foreach (var elem in rule)
+ {
+ switch (elem.Type)
+ {
+ case LLamaGrammarElementType.END: file.Write("END"); break;
+ case LLamaGrammarElementType.ALT: file.Write("ALT"); break;
+ case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break;
+ case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break;
+ case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break;
+ case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break;
+ case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break;
+ }
+ switch (elem.Type)
+ {
+ case LLamaGrammarElementType.END:
+ case LLamaGrammarElementType.ALT:
+ case LLamaGrammarElementType.RULE_REF:
+ file.Write($"({elem.Value}) ");
+ break;
+ case LLamaGrammarElementType.CHAR:
+ case LLamaGrammarElementType.CHAR_NOT:
+ case LLamaGrammarElementType.CHAR_RNG_UPPER:
+ case LLamaGrammarElementType.CHAR_ALT:
+ file.Write("(\"");
+ PrintGrammarChar(file, elem.Value);
+ file.Write("\") ");
+ break;
+ }
+ }
+ file.WriteLine();
+ }
+
+ private void PrintRule(
+ StreamWriter file,
+ uint ruleId,
+ List rule,
+ Dictionary symbolIdNames)
+ {
+ if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END)
+ {
+ throw new GrammarFormatException(
+ $"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}");
+ }
+
+ file.Write($"{symbolIdNames[ruleId]} ::= ");
+
+ for (int i = 0, end = rule.Count - 1; i < end; i++)
+ {
+ var elem = rule[i];
+ switch (elem.Type)
+ {
+ case LLamaGrammarElementType.END:
+ throw new GrammarFormatException(
+ $"Unexpected end of rule: {ruleId}, {i}");
+ case LLamaGrammarElementType.ALT:
+ file.Write("| ");
+ break;
+ case LLamaGrammarElementType.RULE_REF:
+ file.Write($"{symbolIdNames[elem.Value]} ");
+ break;
+ case LLamaGrammarElementType.CHAR:
+ file.Write("[");
+ PrintGrammarChar(file, elem.Value);
+ break;
+ case LLamaGrammarElementType.CHAR_NOT:
+ file.Write("[^");
+ PrintGrammarChar(file, elem.Value);
+ break;
+ case LLamaGrammarElementType.CHAR_RNG_UPPER:
+ if (i == 0 || !IsCharElement(rule[i - 1]))
+ {
+ throw new GrammarFormatException(
+ $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
+ }
+ file.Write("-");
+ PrintGrammarChar(file, elem.Value);
+ break;
+ case LLamaGrammarElementType.CHAR_ALT:
+ if (i == 0 || !IsCharElement(rule[i - 1]))
+ {
+ throw new GrammarFormatException(
+ $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
+ }
+ PrintGrammarChar(file, elem.Value);
+ break;
+
+ }
+
+ if (IsCharElement(elem))
+ {
+ switch (rule[i + 1].Type)
+ {
+ case LLamaGrammarElementType.CHAR_ALT:
+ case LLamaGrammarElementType.CHAR_RNG_UPPER:
+ break;
+ default:
+ file.Write("] ");
+ break;
+ }
+ }
+ }
+ file.WriteLine();
+ }
+
+ private void PrintGrammarChar(StreamWriter file, uint c)
+ {
+ if (c >= 0x20 && c <= 0x7F)
+ {
+ file.Write((char)c);
+ }
+ else
+ {
+ // cop out of encoding UTF-8
+ file.Write($"");
+ }
+ }
+
+ private bool IsCharElement(LLamaGrammarElement elem)
+ {
+ switch (elem.Type)
+ {
+ case LLamaGrammarElementType.CHAR:
+ case LLamaGrammarElementType.CHAR_NOT:
+ case LLamaGrammarElementType.CHAR_ALT:
+ case LLamaGrammarElementType.CHAR_RNG_UPPER:
+ return true;
+ default:
+ return false;
+ }
+ }
+ }
+}
diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs
index d097628f7..7c321c5d2 100644
--- a/LLama/Native/LLamaGrammarElement.cs
+++ b/LLama/Native/LLamaGrammarElement.cs
@@ -1,4 +1,5 @@
-using System.Runtime.InteropServices;
+using System.Diagnostics;
+using System.Runtime.InteropServices;
namespace LLama.Native
{
@@ -49,6 +50,7 @@ public enum LLamaGrammarElementType
/// An element of a grammar
///
[StructLayout(LayoutKind.Sequential)]
+ [DebuggerDisplay("{Type} {Value}")]
public struct LLamaGrammarElement
{
///