diff --git a/LLama.Examples/Assets/json.gbnf b/LLama.Examples/Assets/json.gbnf new file mode 100644 index 000000000..a01c4efd7 --- /dev/null +++ b/LLama.Examples/Assets/json.gbnf @@ -0,0 +1,27 @@ +# https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/grammars/json.gbnf + +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? \ No newline at end of file diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index ef7ac437e..6a1685ed0 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -49,6 +49,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/LLama.Examples/NewVersion/GrammarJsonResponse.cs b/LLama.Examples/NewVersion/GrammarJsonResponse.cs new file mode 100644 index 000000000..926aa82b1 --- /dev/null +++ b/LLama.Examples/NewVersion/GrammarJsonResponse.cs @@ -0,0 +1,55 @@ +using LLama.Common; +using LLama.Grammar; +using LLama.Native; + +namespace LLama.Examples.NewVersion +{ + public class GrammarJsonResponse + { + public static void Run() + { + var grammarBytes = File.ReadAllText("Assets/json.gbnf").Trim(); + var parsedGrammar = new GrammarParser(); + + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + var ex = new StatelessExecutor(model, parameters); + ParseState state = parsedGrammar.Parse(grammarBytes); + using var grammar = SafeLLamaGrammarHandle.Create(state.Rules, 0); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions and always respond in a JSON format. For example, you can input \"Tell me the attributes of a good dish\""); + Console.ForegroundColor = ConsoleColor.White; + + var inferenceParams = new InferenceParams() + { + Temperature = 0.6f, + AntiPrompts = new List { "Question:", "#", "Question: ", ".\n" }, + MaxTokens = 50, + Grammar = grammar + }; + + while (true) + { + Console.Write("\nQuestion: "); + Console.ForegroundColor = ConsoleColor.Green; + var prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Answer: "); + prompt = $"Question: {prompt?.Trim()} Answer: "; + foreach (var text in ex.Infer(prompt, inferenceParams)) + { + Console.Write(text); + } + } + } + } +} diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index 6cc3f3dac..f5a10ef4f 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -17,6 +17,7 @@ public static async Task Run() Console.WriteLine("7: Get embeddings from LLama model."); Console.WriteLine("8: Quantize the model."); Console.WriteLine("9: Automatic conversation."); + Console.WriteLine("10: Constrain response to json format using grammar."); while (true) { @@ -63,6 +64,10 @@ public static async Task Run() { await TalkToYourself.Run(); } + else if (choice == 10) + { + GrammarJsonResponse.Run(); + } else { Console.WriteLine("Cannot parse your choice. Please select again."); diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs new file mode 100644 index 000000000..496f502ce --- /dev/null +++ b/LLama.Unittest/GrammarParserTest.cs @@ -0,0 +1,274 @@ +using LLama.Common; +using LLama.Native; +using System.Diagnostics; +using LLama.Grammar; +using Newtonsoft.Json.Linq; + +namespace LLama.Unittest +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/tests/test-grammar-parser.cpp + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + public sealed class GrammarParserTest + { + [Fact] + public void ParseComplexGrammar() + { + GrammarParser parsedGrammar = new GrammarParser(); + string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]+"; + + ParseState state = parsedGrammar.Parse(grammarBytes); + + List> expected = new List> + { + new KeyValuePair("expr", 2), + new KeyValuePair("expr_5", 5), + new KeyValuePair("expr_6", 6), + new KeyValuePair("root", 0), + new KeyValuePair("root_1", 1), + new KeyValuePair("root_4", 4), + new KeyValuePair("term", 3), + new KeyValuePair("term_7", 7), + }; + + uint index = 0; + foreach (var it in state.SymbolIds) + { + string key = it.Key; + uint value = it.Value; + var expectedPair = expected[(int)index]; + + // pretty print error message before asserting + if (expectedPair.Key != key || expectedPair.Value != value) + { + Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); + Console.Error.WriteLine($"actualPair: {key}, {value}"); + Console.Error.WriteLine("expectedPair != actualPair"); + } + Assert.Equal(expectedPair.Key, key); + Assert.Equal(expectedPair.Value, value); + + index++; + } + Assert.NotEmpty(state.SymbolIds); + + + var expectedRules = new List + { + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }; + + index = 0; + foreach (var rule in state.Rules) + { + // compare rule to expected rule + for (uint i = 0; i < rule.Count; i++) + { + var element = rule[(int)i]; + var expectedElement = expectedRules[(int)index]; + + // Pretty print error message before asserting + if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) + { + Console.Error.WriteLine($"index: {index}"); + Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); + Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); + Console.Error.WriteLine("expected_element != actual_element"); + } + Assert.Equal(expectedElement.Type, element.Type); + Assert.Equal(expectedElement.Value, element.Value); + index++; + } + } + Assert.NotEmpty(state.Rules); + } + + [Fact] + public void ParseExtraComplexGrammar() + { + GrammarParser parsedGrammar = new GrammarParser(); + string grammarBytes = @" + root ::= (expr ""="" ws term ""\n"")+ + expr ::= term ([-+*/] term)* + term ::= ident | num | ""("" ws expr "")"" ws + ident ::= [a-z] [a-z0-9_]* ws + num ::= [0-9]+ ws + ws ::= [ \t\n]* + "; + + ParseState state = parsedGrammar.Parse(grammarBytes); + + List> expected = new List> + { + new KeyValuePair("expr", 2), + new KeyValuePair("expr_6", 6), + new KeyValuePair("expr_7", 7), + new KeyValuePair("ident", 8), + new KeyValuePair("ident_10", 10), + new KeyValuePair("num", 9), + new KeyValuePair("num_11", 11), + new KeyValuePair("root", 0), + new KeyValuePair("root_1", 1), + new KeyValuePair("root_5", 5), + new KeyValuePair("term", 4), + new KeyValuePair("ws", 3), + new KeyValuePair("ws_12", 12), + }; + + uint index = 0; + foreach (var it in state.SymbolIds) + { + string key = it.Key; + uint value = it.Value; + var expectedPair = expected[(int)index]; + + // pretty print error message before asserting + if (expectedPair.Key != key || expectedPair.Value != value) + { + Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); + Console.Error.WriteLine($"actualPair: {key}, {value}"); + Console.Error.WriteLine("expectedPair != actualPair"); + } + Assert.Equal(expectedPair.Key, key); + Assert.Equal(expectedPair.Value, value); + + index++; + } + Assert.NotEmpty(state.SymbolIds); + + + var expectedRules = new List + { + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 8), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 9), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 40), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 41), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 97), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 122), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 95), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 10), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 11), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 9), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 10), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 12), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0) + }; + + index = 0; + foreach (var rule in state.Rules) + { + // compare rule to expected rule + for (uint i = 0; i < rule.Count; i++) + { + var element = rule[(int)i]; + var expectedElement = expectedRules[(int)index]; + + // Pretty print error message before asserting + if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) + { + Console.Error.WriteLine($"index: {index}"); + Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); + Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); + Console.Error.WriteLine("expected_element != actual_element"); + } + Assert.Equal(expectedElement.Type, element.Type); + Assert.Equal(expectedElement.Value, element.Value); + index++; + } + } + Assert.NotEmpty(state.Rules); + } + } +} diff --git a/LLama/Exceptions/GrammarFormatException.cs b/LLama/Exceptions/GrammarFormatException.cs new file mode 100644 index 000000000..5b1299dd4 --- /dev/null +++ b/LLama/Exceptions/GrammarFormatException.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Exceptions +{ + public class GrammarFormatException + : Exception + { + public GrammarFormatException() + { + + } + + public GrammarFormatException(string message) + : base(message) + { + + } + } +} diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs new file mode 100644 index 000000000..4122e58f6 --- /dev/null +++ b/LLama/Grammar/GrammarParser.cs @@ -0,0 +1,388 @@ +using LLama.Exceptions; +using LLama.Native; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Grammar +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + public class GrammarParser + { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + private uint DecodeUTF8(ref ReadOnlySpan src) + { + int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + + byte firstByte = src[0]; + byte highbits = (byte)(firstByte >> 4); + int len = lookup[highbits]; + byte mask = (byte)((1 << (8 - len)) - 1); + uint value = (uint)(firstByte & mask); + + int end = len; + int pos = 1; + + for (; pos < end && pos < src.Length; pos++) + { + value = (uint)((value << 6) + (src[pos] & 0x3F)); + } + + src = src.Slice(pos); + + return value; + } + + private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) + { + uint nextId = (uint)state.SymbolIds.Count; + string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); + + if (state.SymbolIds.TryGetValue(key, out uint existingId)) + { + return existingId; + } + else + { + state.SymbolIds[key] = nextId; + return nextId; + } + } + + private uint GenerateSymbolId(ParseState state, string baseName) + { + uint nextId = (uint)state.SymbolIds.Count; + string key = $"{baseName}_{nextId}"; + state.SymbolIds[key] = nextId; + return nextId; + } + + private void AddRule(ParseState state, uint ruleId, List rule) + { + while (state.Rules.Count <= ruleId) + { + state.Rules.Add(new List()); + } + + state.Rules[(int)ruleId] = rule; + } + + private bool IsWordChar(byte c) + { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + private uint ParseHex(ref ReadOnlySpan src, int size) + { + int pos = 0; + int end = size; + uint value = 0; + + for (; pos < end && pos < src.Length; pos++) + { + value <<= 4; + byte c = src[pos]; + if ('a' <= c && c <= 'f') + { + value += (uint)(c - 'a' + 10); + } + else if ('A' <= c && c <= 'F') + { + value += (uint)(c - 'A' + 10); + } + else if ('0' <= c && c <= '9') + { + value += (uint)(c - '0'); + } + else + { + break; + } + } + + if (pos != end) + { + throw new GrammarFormatException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}"); + } + src = src.Slice(pos); + return value; + } + + private ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) + { + int pos = 0; + while (pos < src.Length && + (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' || + (newlineOk && (src[pos] == '\r' || src[pos] == '\n')))) + { + if (src[pos] == '#') + { + while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n') + { + pos++; + } + } + else + { + pos++; + } + } + return src.Slice(pos); + } + + private ReadOnlySpan ParseName(ReadOnlySpan src) + { + int pos = 0; + while (pos < src.Length && IsWordChar(src[pos])) + { + pos++; + } + if (pos == 0) + { + throw new GrammarFormatException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}"); + } + return src.Slice(pos); + } + + private uint ParseChar(ref ReadOnlySpan src) + { + if (src[0] == '\\') + { + var chr = src[1]; + src = src.Slice(2); + switch (chr) + { + case (byte)'x': + return ParseHex(ref src, 2); + case (byte)'u': + return ParseHex(ref src, 4); + case (byte)'U': + return ParseHex(ref src, 8); + case (byte)'t': + return '\t'; + case (byte)'r': + return '\r'; + case (byte)'n': + return '\n'; + case (byte)'\\': + case (byte)'"': + case (byte)'[': + case (byte)']': + return chr; + default: + throw new GrammarFormatException("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray())); + } + } + else if (!src.IsEmpty) + { + return DecodeUTF8(ref src); + } + + throw new GrammarFormatException("Unexpected end of input"); + } + + private ReadOnlySpan ParseSequence( + ParseState state, + ReadOnlySpan pos, + string ruleName, + List outElements, + bool isNested) + { + int lastSymStart = outElements.Count; + + while (!pos.IsEmpty) + { + if (pos[0] == '"') // literal string + { + pos = pos.Slice(1); + lastSymStart = outElements.Count; + + while (!pos.IsEmpty && pos[0] != '"') + { + var charPair = ParseChar(ref pos); + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (pos[0] == '[') // char range(s) + { + pos = pos.Slice(1); + var startType = LLamaGrammarElementType.CHAR; + + if (pos[0] == '^') + { + pos = pos.Slice(1); + startType = LLamaGrammarElementType.CHAR_NOT; + } + + lastSymStart = outElements.Count; + + while (!pos.IsEmpty && pos[0] != ']') + { + var charPair = ParseChar(ref pos); + var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; + + outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair }); + + if (pos[0] == '-' && pos[1] != ']') + { + pos = pos.Slice(1); + var endCharPair = ParseChar(ref pos); + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair }); + } + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (IsWordChar(pos[0])) // rule reference + { + var nameEnd = ParseName(pos); + uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); + pos = ParseSpace(nameEnd, isNested); + lastSymStart = outElements.Count; + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId }); + } + else if (pos[0] == '(') // grouping + { + // parse nested alternates into synthesized rule + pos = ParseSpace(pos.Slice(1), true); + uint subRuleId = GenerateSymbolId(state, ruleName); + pos = ParseAlternates(state, pos, ruleName, subRuleId, true); + lastSymStart = outElements.Count; + // output reference to synthesized rule + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + if (pos[0] != ')') + { + throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}"); + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator + { + if (lastSymStart == outElements.Count) + { + throw new GrammarFormatException($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}"); + } + + // apply transformation to previous symbol (lastSymStart to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint subRuleId = GenerateSymbolId(state, ruleName); + + List subRule = new List(); + + // add preceding symbol to generated rule + subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); + + if (pos[0] == '*' || pos[0] == '+') + { + // cause generated rule to recurse + subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + } + + // mark start of alternate def + subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); + + if (pos[0] == '+') + { + // add preceding symbol as alternate only for '+' (otherwise empty) + subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); + } + + subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); + + AddRule(state, subRuleId, subRule); + + // in original rule, replace previous symbol with reference to generated rule + outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + + pos = ParseSpace(pos.Slice(1), isNested); + + } + else + { + break; + } + } + + return pos; + } + + private ReadOnlySpan ParseAlternates( + ParseState state, + ReadOnlySpan src, + string ruleName, + uint ruleId, + bool isNested) + { + var rule = new List(); + ReadOnlySpan pos = ParseSequence(state, src, ruleName, rule, isNested); + + while (!pos.IsEmpty && pos[0] == '|') + { + rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); + pos = ParseSpace(pos.Slice(1), true); + pos = ParseSequence(state, pos, ruleName, rule, isNested); + } + + rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); + AddRule(state, ruleId, rule); + + return pos; + } + + private ReadOnlySpan ParseRule(ParseState state, ReadOnlySpan src) + { + ReadOnlySpan nameEnd = ParseName(src); + ReadOnlySpan pos = ParseSpace(nameEnd, false); + int nameLen = src.Length - nameEnd.Length; + uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); + string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) + { + throw new GrammarFormatException($"Expecting ::= at {Encoding.UTF8.GetString(pos.ToArray())}"); + } + pos = ParseSpace(pos.Slice(3), true); + + pos = ParseAlternates(state, pos, name, ruleId, false); + + if (!pos.IsEmpty && pos[0] == '\r') + { + pos = pos.Slice(pos[1] == '\n' ? 2 : 1); + } + else if (!pos.IsEmpty && pos[0] == '\n') + { + pos = pos.Slice(1); + } + else if (!pos.IsEmpty) + { + throw new GrammarFormatException($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}"); + } + return ParseSpace(pos, true); + } + + public ParseState Parse(string input) + { + byte[] byteArray = Encoding.UTF8.GetBytes(input); + ReadOnlySpan src = new ReadOnlySpan(byteArray); + ParseState state = new ParseState(); + ReadOnlySpan pos = ParseSpace(src, true); + + while (!pos.IsEmpty) + { + pos = ParseRule(state, pos); + } + + return state; + } + } +} diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs new file mode 100644 index 000000000..0c75a8a00 --- /dev/null +++ b/LLama/Grammar/ParseState.cs @@ -0,0 +1,181 @@ +using LLama.Exceptions; +using LLama.Native; +using System; +using System.Collections.Generic; +using System.IO; + +namespace LLama.Grammar +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + public class ParseState + { + public SortedDictionary SymbolIds { get; } = new SortedDictionary(); + public List> Rules { get; } = new List>(); + + public IEnumerable> CRules() + { + foreach (var rule in Rules) + { + yield return rule; + } + } + + public void PrintGrammar(StreamWriter file, ParseState state) + { + try + { + Dictionary symbolIdNames = new Dictionary(); + foreach (var kv in state.SymbolIds) + { + symbolIdNames[kv.Value] = kv.Key; + } + for (int i = 0, end = state.Rules.Count; i < end; i++) + { + PrintRule(file, (uint)i, state.Rules[i], symbolIdNames); + } + } + catch(Exception err) + { + Console.Error.WriteLine($"\nError printing grammar: {err.Message}"); + } + } + + public void PrintRuleBinary(StreamWriter file, List rule) + { + foreach (var elem in rule) + { + switch (elem.Type) + { + case LLamaGrammarElementType.END: file.Write("END"); break; + case LLamaGrammarElementType.ALT: file.Write("ALT"); break; + case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break; + case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break; + case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break; + case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break; + case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break; + } + switch (elem.Type) + { + case LLamaGrammarElementType.END: + case LLamaGrammarElementType.ALT: + case LLamaGrammarElementType.RULE_REF: + file.Write($"({elem.Value}) "); + break; + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + case LLamaGrammarElementType.CHAR_ALT: + file.Write("(\""); + PrintGrammarChar(file, elem.Value); + file.Write("\") "); + break; + } + } + file.WriteLine(); + } + + private void PrintRule( + StreamWriter file, + uint ruleId, + List rule, + Dictionary symbolIdNames) + { + if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END) + { + throw new GrammarFormatException( + $"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}"); + } + + file.Write($"{symbolIdNames[ruleId]} ::= "); + + for (int i = 0, end = rule.Count - 1; i < end; i++) + { + var elem = rule[i]; + switch (elem.Type) + { + case LLamaGrammarElementType.END: + throw new GrammarFormatException( + $"Unexpected end of rule: {ruleId}, {i}"); + case LLamaGrammarElementType.ALT: + file.Write("| "); + break; + case LLamaGrammarElementType.RULE_REF: + file.Write($"{symbolIdNames[elem.Value]} "); + break; + case LLamaGrammarElementType.CHAR: + file.Write("["); + PrintGrammarChar(file, elem.Value); + break; + case LLamaGrammarElementType.CHAR_NOT: + file.Write("[^"); + PrintGrammarChar(file, elem.Value); + break; + case LLamaGrammarElementType.CHAR_RNG_UPPER: + if (i == 0 || !IsCharElement(rule[i - 1])) + { + throw new GrammarFormatException( + $"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}"); + } + file.Write("-"); + PrintGrammarChar(file, elem.Value); + break; + case LLamaGrammarElementType.CHAR_ALT: + if (i == 0 || !IsCharElement(rule[i - 1])) + { + throw new GrammarFormatException( + $"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}"); + } + PrintGrammarChar(file, elem.Value); + break; + + } + + if (IsCharElement(elem)) + { + switch (rule[i + 1].Type) + { + case LLamaGrammarElementType.CHAR_ALT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + break; + default: + file.Write("] "); + break; + } + } + } + file.WriteLine(); + } + + private void PrintGrammarChar(StreamWriter file, uint c) + { + if (c >= 0x20 && c <= 0x7F) + { + file.Write((char)c); + } + else + { + // cop out of encoding UTF-8 + file.Write($""); + } + } + + private bool IsCharElement(LLamaGrammarElement elem) + { + switch (elem.Type) + { + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: + case LLamaGrammarElementType.CHAR_ALT: + case LLamaGrammarElementType.CHAR_RNG_UPPER: + return true; + default: + return false; + } + } + } +} diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index d097628f7..7c321c5d2 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -1,4 +1,5 @@ -using System.Runtime.InteropServices; +using System.Diagnostics; +using System.Runtime.InteropServices; namespace LLama.Native { @@ -49,6 +50,7 @@ public enum LLamaGrammarElementType /// An element of a grammar /// [StructLayout(LayoutKind.Sequential)] + [DebuggerDisplay("{Type} {Value}")] public struct LLamaGrammarElement { ///