From d7a8917617d013c9e748eb70b76fe32764a5b93f Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 4 May 2024 15:53:54 +0100 Subject: [PATCH] - Added `LLamaTemplate` which efficiently formats a series of messages according to the model template. - Fixed `llama_chat_apply_template` method (wrong entrypoint, couldn't handle null model) --- LLama.Unittest/TemplateTests.cs | 245 +++++++++++++++++++++++++ LLama/LLamaTemplate.cs | 301 +++++++++++++++++++++++++++++++ LLama/Native/LLamaChatMessage.cs | 12 +- LLama/Native/NativeApi.cs | 11 +- LLama/Usings.cs | 1 + 5 files changed, 566 insertions(+), 4 deletions(-) create mode 100644 LLama.Unittest/TemplateTests.cs create mode 100644 LLama/LLamaTemplate.cs create mode 100644 LLama/Usings.cs diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs new file mode 100644 index 000000000..05d131a59 --- /dev/null +++ b/LLama.Unittest/TemplateTests.cs @@ -0,0 +1,245 @@ +using System.Text; +using LLama.Common; +using LLama.Native; + +namespace LLama.Unittest; + +public sealed class TemplateTests + : IDisposable +{ + private readonly LLamaWeights _model; + + public TemplateTests() + { + var @params = new ModelParams(Constants.GenerativeModelPath) + { + ContextSize = 1, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(@params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact] + public void BasicTemplate() + { + var templater = new LLamaTemplate(_model); + + Assert.Equal(0, templater.Count); + templater.Add("assistant", "hello"); + Assert.Equal(1, templater.Count); + templater.Add("user", "world"); + Assert.Equal(2, templater.Count); + templater.Add("assistant", "111"); + Assert.Equal(3, templater.Count); + templater.Add("user", "aaa"); + Assert.Equal(4, templater.Count); + templater.Add("assistant", "222"); + Assert.Equal(5, templater.Count); + templater.Add("user", "bbb"); + Assert.Equal(6, templater.Count); + templater.Add("assistant", "333"); + Assert.Equal(7, templater.Count); + templater.Add("user", "ccc"); + Assert.Equal(8, templater.Count); + + // Call once with empty array to discover length + var length = templater.Apply(Array.Empty()); + var dest = new byte[length]; + + Assert.Equal(8, templater.Count); + + // Call again to get contents + length = templater.Apply(dest); + + Assert.Equal(8, templater.Count); + + var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + + "<|im_start|>user\nworld<|im_end|>\n" + + "<|im_start|>assistant\n" + + "111<|im_end|>" + + "\n<|im_start|>user\n" + + "aaa<|im_end|>\n" + + "<|im_start|>assistant\n" + + "222<|im_end|>\n" + + "<|im_start|>user\n" + + "bbb<|im_end|>\n" + + "<|im_start|>assistant\n" + + "333<|im_end|>\n" + + "<|im_start|>user\n" + + "ccc<|im_end|>\n"; + + Assert.Equal(expected, templateResult); + } + + [Fact] + public void CustomTemplate() + { + var templater = new LLamaTemplate("gemma"); + + Assert.Equal(0, templater.Count); + templater.Add("assistant", "hello"); + Assert.Equal(1, templater.Count); + templater.Add("user", "world"); + Assert.Equal(2, templater.Count); + templater.Add("assistant", "111"); + Assert.Equal(3, templater.Count); + templater.Add("user", "aaa"); + Assert.Equal(4, templater.Count); + + // Call once with empty array to discover length + var length = templater.Apply(Array.Empty()); + var dest = new byte[length]; + + Assert.Equal(4, templater.Count); + + // Call again to get contents + length = templater.Apply(dest); + + Assert.Equal(4, templater.Count); + + var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + const string expected = "model\n" + + "hello\n" + + "user\n" + + "world\n" + + "model\n" + + "111\n" + + "user\n" + + "aaa\n"; + + Assert.Equal(expected, templateResult); + } + + [Fact] + public void BasicTemplateWithAddAssistant() + { + var templater = new LLamaTemplate(_model) + { + AddAssistant = true, + }; + + Assert.Equal(0, templater.Count); + templater.Add("assistant", "hello"); + Assert.Equal(1, templater.Count); + templater.Add("user", "world"); + Assert.Equal(2, templater.Count); + templater.Add("assistant", "111"); + Assert.Equal(3, templater.Count); + templater.Add("user", "aaa"); + Assert.Equal(4, templater.Count); + templater.Add("assistant", "222"); + Assert.Equal(5, templater.Count); + templater.Add("user", "bbb"); + Assert.Equal(6, templater.Count); + templater.Add("assistant", "333"); + Assert.Equal(7, templater.Count); + templater.Add("user", "ccc"); + Assert.Equal(8, templater.Count); + + // Call once with empty array to discover length + var length = templater.Apply(Array.Empty()); + var dest = new byte[length]; + + Assert.Equal(8, templater.Count); + + // Call again to get contents + length = templater.Apply(dest); + + Assert.Equal(8, templater.Count); + + var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + + "<|im_start|>user\nworld<|im_end|>\n" + + "<|im_start|>assistant\n" + + "111<|im_end|>" + + "\n<|im_start|>user\n" + + "aaa<|im_end|>\n" + + "<|im_start|>assistant\n" + + "222<|im_end|>\n" + + "<|im_start|>user\n" + + "bbb<|im_end|>\n" + + "<|im_start|>assistant\n" + + "333<|im_end|>\n" + + "<|im_start|>user\n" + + "ccc<|im_end|>\n" + + "<|im_start|>assistant\n"; + + Assert.Equal(expected, templateResult); + } + + [Fact] + public void GetOutOfRangeThrows() + { + var templater = new LLamaTemplate(_model); + + Assert.Throws(() => templater[0]); + + templater.Add("assistant", "1"); + templater.Add("user", "2"); + + Assert.Throws(() => templater[-1]); + Assert.Throws(() => templater[2]); + } + + [Fact] + public void RemoveMid() + { + var templater = new LLamaTemplate(_model); + + templater.Add("assistant", "1"); + templater.Add("user", "2"); + templater.Add("assistant", "3"); + templater.Add("user", "4a"); + templater.Add("user", "4b"); + templater.Add("assistant", "5"); + + Assert.Equal(("user", "4a"), templater[3]); + Assert.Equal(("assistant", "5"), templater[5]); + + Assert.Equal(6, templater.Count); + templater.RemoveAt(3); + Assert.Equal(5, templater.Count); + + Assert.Equal(("user", "4b"), templater[3]); + Assert.Equal(("assistant", "5"), templater[4]); + } + + [Fact] + public void RemoveLast() + { + var templater = new LLamaTemplate(_model); + + templater.Add("assistant", "1"); + templater.Add("user", "2"); + templater.Add("assistant", "3"); + templater.Add("user", "4a"); + templater.Add("user", "4b"); + templater.Add("assistant", "5"); + + Assert.Equal(6, templater.Count); + templater.RemoveAt(5); + Assert.Equal(5, templater.Count); + + Assert.Equal(("user", "4b"), templater[4]); + } + + [Fact] + public void RemoveOutOfRange() + { + var templater = new LLamaTemplate(_model); + + Assert.Throws(() => templater.RemoveAt(0)); + + templater.Add("assistant", "1"); + templater.Add("user", "2"); + + Assert.Throws(() => templater.RemoveAt(-1)); + Assert.Throws(() => templater.RemoveAt(2)); + } +} \ No newline at end of file diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs new file mode 100644 index 000000000..f3032adc2 --- /dev/null +++ b/LLama/LLamaTemplate.cs @@ -0,0 +1,301 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using LLama.Native; + +namespace LLama; + +/// +/// Converts a sequence of messages into text according to a model template +/// +public sealed class LLamaTemplate +{ + #region private state + /// + /// The model this template is for. May be null if a custom template was supplied to the constructor. + /// + private readonly SafeLlamaModelHandle? _model; + + /// + /// Custom template. May be null if a model was supplied to the constructor. + /// + private readonly byte[]? _customTemplate; + + /// + /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times. + /// + private readonly Dictionary> _roleCache = new(); + + /// + /// Array of messages. The property indicates how many messages there are + /// + private Message[] _messages = new Message[4]; + + /// + /// Backing field for + /// + private bool _addAssistant; + + /// + /// Temporary array of messages in the format llama.cpp needs, used when applying the template + /// + private LLamaChatMessage[] _nativeChatMessages = new LLamaChatMessage[4]; + + /// + /// Indicates how many bytes are in array + /// + private int _resultLength; + + /// + /// Result bytes of last call to + /// + private byte[] _result = Array.Empty(); + + /// + /// Indicates if this template has been modified and needs regenerating + /// + private bool _dirty = true; + #endregion + + #region properties + /// + /// Number of messages added to this template + /// + public int Count { get; private set; } + + /// + /// Get the message at the given index + /// + /// + /// + /// Thrown if index is less than zero or greater than or equal to + public (string role, string content) this[int index] + { + get + { + if (index < 0) + throw new ArgumentOutOfRangeException(nameof(index), "Index must be >= 0"); + if (index >= Count) + throw new ArgumentOutOfRangeException(nameof(index), "Index must be < Count"); + + return (_messages[index].Role, _messages[index].Content); + } + } + + /// + /// Whether to end the prompt with the token(s) that indicate the start of an assistant message. + /// + public bool AddAssistant + { + get => _addAssistant; + set + { + if (value != _addAssistant) + { + _dirty = true; + _addAssistant = value; + } + } + } + #endregion + + #region construction + /// + /// Construct a new template, using the default model template + /// + /// + public LLamaTemplate(SafeLlamaModelHandle model) + { + _model = model; + } + + /// + /// Construct a new template, using the default model template + /// + /// + public LLamaTemplate(LLamaWeights weights) + : this(weights.NativeHandle) + { + } + + /// + /// Construct a new template, using a custom template. + /// + /// Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// + public LLamaTemplate(string customTemplate) + { + _customTemplate = Encoding.UTF8.GetBytes(customTemplate + "\0"); + } + #endregion + + /// + /// Add a new message to the end of this template + /// + /// + /// + public void Add(string role, string content) + { + // Expand messages array if necessary + if (Count == _messages.Length) + Array.Resize(ref _messages, _messages.Length * 2); + + // Add message + _messages[Count] = new Message(role, content, _roleCache); + Count++; + + // Mark as dirty to ensure template is recalculated + _dirty = true; + } + + /// + /// Remove a message at the given index + /// + /// + public void RemoveAt(int index) + { + if (index < 0) + throw new ArgumentOutOfRangeException(nameof(index), "Index must be greater than or equal to zero"); + if (index >= Count) + throw new ArgumentOutOfRangeException(nameof(index), "Index must be less than Count"); + + _dirty = true; + Count--; + + // Copy all items after index down by one + if (index < Count) + Array.Copy(_messages, index + 1, _messages, index, Count - index); + + _messages[Count] = default; + } + + /// + /// Apply the template to the messages and write it into the output buffer + /// + /// Destination to write template bytes into + /// The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer + public int Apply(Memory dest) + { + // Recalculate template if necessary + if (_dirty) + { + _dirty = false; + + using var group = new GroupDisposable(); + unsafe + { + // Convert all the messages + var totalInputBytes = 0; + if (_nativeChatMessages.Length < _messages.Length) + Array.Resize(ref _nativeChatMessages, _messages.Length); + for (var i = 0; i < Count; i++) + { + ref var m = ref _messages[i]; + totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; + + // Pin byte arrays in place + var r = m.RoleBytes.Pin(); + group.Add(r); + var c = m.ContentBytes.Pin(); + group.Add(c); + + _nativeChatMessages[i] = new LLamaChatMessage + { + role = (byte*)r.Pointer, + content = (byte*)c.Pointer + }; + } + + // Get an array that's twice as large as the amount of input, hopefully that's large enough! + var output = ArrayPool.Shared.Rent(Math.Max(32, totalInputBytes * 2)); + try + { + + // Run templater and discover true length + var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); + + // If length was too big for output buffer run it again + if (outputLength > output.Length) + { + // Array was too small, rent another one that's exactly the size needed + ArrayPool.Shared.Return(output, true); + output = ArrayPool.Shared.Rent(outputLength); + + // Run again, but this time with an output that is definitely large enough + ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); + } + + // Grow result buffer if necessary + if (_result.Length < outputLength) + Array.Resize(ref _result, Math.Max(_result.Length * 2, outputLength)); + + // Copy to result buffer + output.AsSpan(0, outputLength).CopyTo(_result); + _resultLength = outputLength; + } + finally + { + ArrayPool.Shared.Return(output, true); + } + } + } + + // Now that the template has been applied and is in the result buffer, copy it to the dest + _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span); + return _resultLength; + + unsafe int ApplyInternal(Span messages, byte[] output) + { + fixed (byte* customTemplatePtr = _customTemplate) + fixed (byte* outputPtr = output) + fixed (LLamaChatMessage* messagesPtr = messages) + { + return NativeApi.llama_chat_apply_template(_model, customTemplatePtr, messagesPtr, (nuint)messages.Length, AddAssistant, outputPtr, output.Length); + } + } + } + + /// + /// A message that has been added to the template, contains role and content converted into UTF8 bytes. + /// + private readonly record struct Message + { + public string Role { get; } + public string Content { get; } + + public ReadOnlyMemory RoleBytes { get; } + public ReadOnlyMemory ContentBytes { get; } + + public Message(string role, string content, Dictionary> roleCache) + { + Role = role; + Content = content; + + // Get bytes for role from cache + if (!roleCache.TryGetValue(role, out var roleBytes)) + { + // Convert role. Add one to length so there is a null byte at the end. + var rArr = new byte[Encoding.UTF8.GetByteCount(role) + 1]; + var encodedRoleLength = Encoding.UTF8.GetBytes(role.AsSpan(), rArr); + Debug.Assert(rArr.Length == encodedRoleLength + 1); + + // Add to cache for future use. + // To ensure the cache cannot grow infinitely add a hard limit to size. + if (roleCache.Count < 128) + { + roleCache.Add(role, rArr); + roleBytes = rArr; + } + } + RoleBytes = roleBytes; + + // Convert content. Add one to length so there is a null byte at the end. + var contentArray = new byte[Encoding.UTF8.GetByteCount(content) + 1]; + var encodedContentLength = Encoding.UTF8.GetBytes(content.AsSpan(), contentArray); + Debug.Assert(contentArray.Length == encodedContentLength + 1); + ContentBytes = contentArray; + } + } +} \ No newline at end of file diff --git a/LLama/Native/LLamaChatMessage.cs b/LLama/Native/LLamaChatMessage.cs index 3e70f3e78..e731901fa 100644 --- a/LLama/Native/LLamaChatMessage.cs +++ b/LLama/Native/LLamaChatMessage.cs @@ -1,11 +1,21 @@ -namespace LLama.Native; +using System.Runtime.InteropServices; + +namespace LLama.Native; /// /// /// /// llama_chat_message +[StructLayout(LayoutKind.Sequential)] public unsafe struct LLamaChatMessage { + /// + /// Pointer to the null terminated bytes that make up the role string + /// public byte* role; + + /// + /// Pointer to the null terminated bytes that make up the content string + /// public byte* content; } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 11a8690fd..841642765 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Runtime.InteropServices; #pragma warning disable IDE1006 // Naming Styles @@ -187,8 +187,13 @@ public static void llama_empty_call() /// A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// The size of the allocated buffer /// The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] - public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, char* buf, int length); + public static unsafe int llama_chat_apply_template(SafeLlamaModelHandle? model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length) + { + return internal_llama_chat_apply_template(model?.DangerousGetHandle() ?? IntPtr.Zero, tmpl, chat, n_msg, add_ass, buf, length); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")] + static extern int internal_llama_chat_apply_template(IntPtr model, byte* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, byte* buf, int length); + } /// /// Returns -1 if unknown, 1 for true or 0 for false. diff --git a/LLama/Usings.cs b/LLama/Usings.cs new file mode 100644 index 000000000..1510815ab --- /dev/null +++ b/LLama/Usings.cs @@ -0,0 +1 @@ +global using LLama.Extensions; \ No newline at end of file