Skip to content

Commit

Permalink
Merge pull request #715 from martindevans/llama-templater
Browse files Browse the repository at this point in the history
Llama Text Templater
  • Loading branch information
martindevans authored May 10, 2024
2 parents b25f93b + b326624 commit 44bd5b3
Show file tree
Hide file tree
Showing 5 changed files with 612 additions and 4 deletions.
252 changes: 252 additions & 0 deletions LLama.Unittest/TemplateTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
using System.Text;
using LLama.Common;
using LLama.Native;

namespace LLama.Unittest;

public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;

public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void BasicTemplate()
{
var templater = new LLamaTemplate(_model);

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void CustomTemplate()
{
var templater = new LLamaTemplate("gemma");

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(4, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(4, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<start_of_turn>model\n" +
"hello<end_of_turn>\n" +
"<start_of_turn>user\n" +
"world<end_of_turn>\n" +
"<start_of_turn>model\n" +
"111<end_of_turn>\n" +
"<start_of_turn>user\n" +
"aaa<end_of_turn>\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void BasicTemplateWithAddAssistant()
{
var templater = new LLamaTemplate(_model)
{
AddAssistant = true,
};

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n" +
"<|im_start|>assistant\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void GetOutOfRangeThrows()
{
var templater = new LLamaTemplate(_model);

Assert.Throws<ArgumentOutOfRangeException>(() => templater[0]);

templater.Add("assistant", "1");
templater.Add("user", "2");

Assert.Throws<ArgumentOutOfRangeException>(() => templater[-1]);
Assert.Throws<ArgumentOutOfRangeException>(() => templater[2]);
}

[Fact]
public void RemoveMid()
{
var templater = new LLamaTemplate(_model);

templater.Add("assistant", "1");
templater.Add("user", "2");
templater.Add("assistant", "3");
templater.Add("user", "4a");
templater.Add("user", "4b");
templater.Add("assistant", "5");

Assert.Equal("user", templater[3].Role);
Assert.Equal("4a", templater[3].Content);

Assert.Equal("assistant", templater[5].Role);
Assert.Equal("5", templater[5].Content);

Assert.Equal(6, templater.Count);
templater.RemoveAt(3);
Assert.Equal(5, templater.Count);

Assert.Equal("user", templater[3].Role);
Assert.Equal("4b", templater[3].Content);

Assert.Equal("assistant", templater[4].Role);
Assert.Equal("5", templater[4].Content);
}

[Fact]
public void RemoveLast()
{
var templater = new LLamaTemplate(_model);

templater.Add("assistant", "1");
templater.Add("user", "2");
templater.Add("assistant", "3");
templater.Add("user", "4a");
templater.Add("user", "4b");
templater.Add("assistant", "5");

Assert.Equal(6, templater.Count);
templater.RemoveAt(5);
Assert.Equal(5, templater.Count);

Assert.Equal("user", templater[4].Role);
Assert.Equal("4b", templater[4].Content);
}

[Fact]
public void RemoveOutOfRange()
{
var templater = new LLamaTemplate(_model);

Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(0));

templater.Add("assistant", "1");
templater.Add("user", "2");

Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
}
}
Loading

0 comments on commit 44bd5b3

Please sign in to comment.