Skip to content

Commit

Permalink
Merge pull request #246 from martindevans/codeql_pointer_arithmetic
Browse files Browse the repository at this point in the history
CodeQL Pointer Arithmetic
  • Loading branch information
martindevans authored Nov 4, 2023
2 parents c933a71 + a03fdc4 commit 66986f1
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions LLama/Native/SafeLLamaGrammarHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using LLama.Exceptions;
using LLama.Grammars;

Expand Down Expand Up @@ -49,38 +50,39 @@ public static SafeLLamaGrammarHandle Create(IReadOnlyList<GrammarRule> rules, ul
// Borrow an array large enough to hold every single element
// and another array large enough to hold a pointer to each rule
var allElements = ArrayPool<LLamaGrammarElement>.Shared.Rent(totalElements);
var pointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
var rulePointers = ArrayPool<IntPtr>.Shared.Rent(rules.Count);
try
{
fixed (LLamaGrammarElement* allElementsPtr = allElements)
// We're taking pointers into `allElements` below, so this pin is required to fix
// that memory in place while those pointers are in use!
using var pin = allElements.AsMemory().Pin();

var elementIndex = 0;
var ruleIndex = 0;
foreach (var rule in rules)
{
var elementIndex = 0;
var pointerIndex = 0;
foreach (var rule in rules)
{
// Save a pointer to the start of this rule
pointers[pointerIndex++] = (IntPtr)(allElementsPtr + elementIndex);
// Save a pointer to the start of this rule
rulePointers[ruleIndex++] = (IntPtr)Unsafe.AsPointer(ref allElements[elementIndex]);

// Copy all of the rule elements into the flat array
foreach (var element in rule.Elements)
allElementsPtr[elementIndex++] = element;
}
// Copy all of the rule elements into the flat array
foreach (var element in rule.Elements)
allElements[elementIndex++] = element;
}

// Sanity check some things that should be true if the copy worked as planned
Debug.Assert((ulong)pointerIndex == nrules);
Debug.Assert(elementIndex == totalElements);
// Sanity check some things that should be true if the copy worked as planned
Debug.Assert((ulong)ruleIndex == nrules);
Debug.Assert(elementIndex == totalElements);

// Make the actual call through to llama.cpp
fixed (void* ptr = pointers)
{
return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
}
// Make the actual call through to llama.cpp
fixed (void* ptr = rulePointers)
{
return Create((LLamaGrammarElement**)ptr, nrules, start_rule_index);
}
}
finally
{
ArrayPool<LLamaGrammarElement>.Shared.Return(allElements);
ArrayPool<IntPtr>.Shared.Return(pointers);
ArrayPool<IntPtr>.Shared.Return(rulePointers);
}
}
}
Expand Down

0 comments on commit 66986f1

Please sign in to comment.