Skip to content

Commit

Permalink
Address code review comments (create custom exception, move printing …
Browse files Browse the repository at this point in the history
…to the ParseState class, rethrow error).
  • Loading branch information
Mihaiii committed Aug 30, 2023
1 parent 71f02e0 commit 60790c5
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 166 deletions.
21 changes: 21 additions & 0 deletions LLama/Exceptions/GrammarFormatException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Exceptions
{
public class GrammarFormatException
: Exception
{
public GrammarFormatException()
{

}

public GrammarFormatException(string message)
: base(message)
{

}
}
}
177 changes: 12 additions & 165 deletions LLama/Grammar/GrammarParser.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LLama.Native;
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace LLama.Grammar
Expand Down Expand Up @@ -108,7 +108,7 @@ private uint ParseHex(ref ReadOnlySpan<byte> src, int size)

if (pos != end)
{
throw new InvalidOperationException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}");
throw new GrammarFormatException($"Expecting {size} hex chars at {Encoding.UTF8.GetString(src.ToArray())}");
}
src = src.Slice(pos);
return value;
Expand Down Expand Up @@ -145,7 +145,7 @@ private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
}
if (pos == 0)
{
throw new InvalidOperationException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}");
throw new GrammarFormatException($"Expecting name at {Encoding.UTF8.GetString(src.ToArray())}");
}
return src.Slice(pos);
}
Expand Down Expand Up @@ -176,15 +176,15 @@ private uint ParseChar(ref ReadOnlySpan<byte> src)
case (byte)']':
return chr;
default:
throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray()));
throw new GrammarFormatException("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray()));
}
}
else if (!src.IsEmpty)
{
return DecodeUTF8(ref src);
}

throw new Exception("Unexpected end of input");
throw new GrammarFormatException("Unexpected end of input");
}

private ReadOnlySpan<byte> ParseSequence(
Expand Down Expand Up @@ -258,15 +258,15 @@ private ReadOnlySpan<byte> ParseSequence(
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId });
if (pos[0] != ')')
{
throw new Exception($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
throw new GrammarFormatException($"Expecting ')' at {Encoding.UTF8.GetString(pos.ToArray())}");
}
pos = ParseSpace(pos.Slice(1), isNested);
}
else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator
{
if (lastSymStart == outElements.Count)
{
throw new Exception($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}");
throw new GrammarFormatException($"Expecting preceding item to */+/? at {Encoding.UTF8.GetString(pos.ToArray())}");
}

// apply transformation to previous symbol (lastSymStart to end) according to
Expand Down Expand Up @@ -349,7 +349,7 @@ private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)

if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
{
throw new Exception($"Expecting ::= at {Encoding.UTF8.GetString(pos.ToArray())}");
throw new GrammarFormatException($"Expecting ::= at {Encoding.UTF8.GetString(pos.ToArray())}");
}
pos = ParseSpace(pos.Slice(3), true);

Expand All @@ -365,7 +365,7 @@ private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)
}
else if (!pos.IsEmpty)
{
throw new Exception($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}");
throw new GrammarFormatException($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}");
}
return ParseSpace(pos, true);
}
Expand All @@ -386,163 +386,10 @@ public ParseState Parse(string input)

return state;
}
catch (Exception err)
catch(Exception err)
{
Console.Error.WriteLine($"{nameof(Parse)}: error parsing grammar: {err.Message}");
return new ParseState();
}
}

private void PrintGrammarChar(StreamWriter file, uint c)
{
if (c >= 0x20 && c <= 0x7F)
{
file.Write((char)c);
}
else
{
// cop out of encoding UTF-8
file.Write($"<U+{c:X4}>");
}
}

private bool IsCharElement(LLamaGrammarElement elem)
{
switch (elem.Type)
{
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
return true;
default:
return false;
}
}

public void PrintRuleBinary(StreamWriter file, List<LLamaGrammarElement> rule)
{
foreach (var elem in rule)
{
switch (elem.Type)
{
case LLamaGrammarElementType.END: file.Write("END"); break;
case LLamaGrammarElementType.ALT: file.Write("ALT"); break;
case LLamaGrammarElementType.RULE_REF: file.Write("RULE_REF"); break;
case LLamaGrammarElementType.CHAR: file.Write("CHAR"); break;
case LLamaGrammarElementType.CHAR_NOT: file.Write("CHAR_NOT"); break;
case LLamaGrammarElementType.CHAR_RNG_UPPER: file.Write("CHAR_RNG_UPPER"); break;
case LLamaGrammarElementType.CHAR_ALT: file.Write("CHAR_ALT"); break;
}
switch (elem.Type)
{
case LLamaGrammarElementType.END:
case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
file.Write($"({elem.Value}) ");
break;
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
case LLamaGrammarElementType.CHAR_ALT:
file.Write("(\"");
PrintGrammarChar(file, elem.Value);
file.Write("\") ");
break;
}
}
file.WriteLine();
}

private void PrintRule(
StreamWriter file,
uint ruleId,
List<LLamaGrammarElement> rule,
Dictionary<uint, string> symbolIdNames)
{
if (rule.Count == 0 || rule[rule.Count - 1].Type != LLamaGrammarElementType.END)
{
throw new InvalidOperationException(
$"Malformed rule, does not end with LLamaGrammarElementType.END: {ruleId}");
}

file.Write($"{symbolIdNames[ruleId]} ::= ");

for (int i = 0, end = rule.Count - 1; i < end; i++)
{
var elem = rule[i];
switch (elem.Type)
{
case LLamaGrammarElementType.END:
throw new InvalidOperationException(
$"Unexpected end of rule: {ruleId}, {i}");
case LLamaGrammarElementType.ALT:
file.Write("| ");
break;
case LLamaGrammarElementType.RULE_REF:
file.Write($"{symbolIdNames[elem.Value]} ");
break;
case LLamaGrammarElementType.CHAR:
file.Write("[");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_NOT:
file.Write("[^");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_RNG_UPPER:
if (i == 0 || !IsCharElement(rule[i - 1]))
{
throw new InvalidOperationException(
$"LLamaGrammarElementType.CHAR_RNG_UPPER without preceding char: {ruleId},{i}");
}
file.Write("-");
PrintGrammarChar(file, elem.Value);
break;
case LLamaGrammarElementType.CHAR_ALT:
if (i == 0 || !IsCharElement(rule[i - 1]))
{
throw new InvalidOperationException(
$"LLamaGrammarElementType.CHAR_ALT without preceding char: {ruleId},{i}");
}
PrintGrammarChar(file, elem.Value);
break;

}

if (IsCharElement(elem))
{
switch (rule[i + 1].Type)
{
case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER:
break;
default:
file.Write("] ");
break;
}
}
}
file.WriteLine();
}

public void PrintGrammar(StreamWriter file, ParseState state)
{
try
{
Dictionary<uint, string> symbolIdNames = new Dictionary<uint, string>();
foreach (var kv in state.SymbolIds)
{
symbolIdNames[kv.Value] = kv.Key;
}
for (int i = 0, end = state.Rules.Count; i < end; i++)
{
PrintRule(file, (uint)i, state.Rules[i], symbolIdNames);
}
}
catch (Exception err)
{
Console.Error.WriteLine($"\nError printing grammar: {err.Message}");
throw;
}
}
}
Expand Down
Loading

0 comments on commit 60790c5

Please sign in to comment.