From f3511e390f5df427bb741773665ada2624acc937 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 8 Aug 2023 01:17:33 +0100 Subject: [PATCH 01/13] WIP demonstrating changes to support multi-context. You can see this in use in `TalkToYourself`, along with notes on what still needs improving. The biggest single change is renaming `LLamaModel` to `LLamaContext` --- .../NewVersion/ChatSessionStripRoleName.cs | 2 +- .../NewVersion/ChatSessionWithRoleName.cs | 2 +- .../NewVersion/InstructModeExecute.cs | 2 +- .../NewVersion/InteractiveModeExecute.cs | 2 +- .../NewVersion/LoadAndSaveSession.cs | 6 +- LLama.Examples/NewVersion/LoadAndSaveState.cs | 6 +- .../NewVersion/StatelessModeExecute.cs | 2 +- LLama.Examples/NewVersion/TalkToYourself.cs | 80 +++++++++++++++++++ LLama.Examples/NewVersion/TestRunner.cs | 5 ++ LLama.Unittest/BasicTest.cs | 2 +- LLama.Web/Models/ModelSession.cs | 3 +- .../Services/ConnectionSessionService.cs | 2 +- LLama.WebAPI/Services/StatefulChatService.cs | 8 +- LLama.WebAPI/Services/StatelessChatService.cs | 6 +- LLama/Abstractions/ILLamaExecutor.cs | 4 +- LLama/ChatSession.cs | 4 +- LLama/{LLamaModel.cs => LLamaContext.cs} | 79 +++++++++++++----- LLama/LLamaExecutorBase.cs | 26 +++--- LLama/LLamaInstructExecutor.cs | 28 +++---- LLama/LLamaInteractExecutor.cs | 32 ++++---- LLama/LLamaStatelessExecutor.cs | 44 +++++----- LLama/Native/SafeLLamaContextHandle.cs | 8 +- LLama/Native/SafeLlamaModelHandle.cs | 12 +++ LLama/ResettableLLamaModel.cs | 43 ---------- LLama/Utils.cs | 2 +- 25 files changed, 253 insertions(+), 157 deletions(-) create mode 100644 LLama.Examples/NewVersion/TalkToYourself.cs rename LLama/{LLamaModel.cs => LLamaContext.cs} (86%) delete mode 100644 LLama/ResettableLLamaModel.cs diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index ce677c40c..6402e360c 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -14,7 +14,7 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index cbf9333b9..d1cbf34b2 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -14,7 +14,7 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index 303c8644b..f81f2f587 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -15,7 +15,7 @@ public static void Run() string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/dan.txt").Trim(); - InstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024))); + InstructExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024))); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + diff --git a/LLama.Examples/NewVersion/InteractiveModeExecute.cs b/LLama.Examples/NewVersion/InteractiveModeExecute.cs index 23afcadf6..aaacabbed 100644 --- a/LLama.Examples/NewVersion/InteractiveModeExecute.cs +++ b/LLama.Examples/NewVersion/InteractiveModeExecute.cs @@ -15,7 +15,7 @@ public async static Task Run() string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 722ec3e0e..cbed9179d 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -15,7 +15,7 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. Console.ForegroundColor = ConsoleColor.Yellow; @@ -45,8 +45,8 @@ public static void Run() Console.WriteLine("Saved session!"); Console.ForegroundColor = ConsoleColor.White; - ex.Model.Dispose(); - ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + ex.Context.Dispose(); + ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); session = new ChatSession(ex); session.LoadSession(statePath); diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index dc3031414..15f2f815f 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -15,7 +15,7 @@ public static void Run() string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); @@ -37,7 +37,7 @@ public static void Run() { Console.Write("Your path to save model state: "); string modelStatePath = Console.ReadLine(); - ex.Model.SaveState(modelStatePath); + ex.Context.SaveState(modelStatePath); Console.Write("Your path to save executor state: "); string executorStatePath = Console.ReadLine(); @@ -47,7 +47,7 @@ public static void Run() Console.WriteLine("All states saved!"); Console.ForegroundColor = ConsoleColor.White; - var model = ex.Model; + var model = ex.Context; model.LoadState(modelStatePath); ex = new InteractiveExecutor(model); ex.LoadState(executorStatePath); diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index 3c4852319..8ff2c0a1a 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -14,7 +14,7 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); - StatelessExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + StatelessExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs new file mode 100644 index 000000000..784d952a3 --- /dev/null +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -0,0 +1,80 @@ +using System.Security.Cryptography; +using System.Text; +using LLama.Abstractions; +using LLama.Common; +using LLama.Extensions; +using LLama.Native; + +namespace LLama.Examples.NewVersion +{ + public class TalkToYourself + { + public static async Task Run() + { + Console.Write("Please input your model path: "); + string modelPath = "C:\\Users\\Martin\\Documents\\Python\\oobabooga_windows\\text-generation-webui\\models\\llama-2-7b-chat.ggmlv3.q6_K.bin"; + + // todo: model path is passed here, but isn't needed + var @params = new ModelParams(modelPath) + { + Seed = RandomNumberGenerator.GetInt32(int.MaxValue) + }; + + // todo: all this pin stuff is ugly and should be hidden in the higher level wrapper + using var pin = @params.ToLlamaContextParams(out var lparams); + + // todo: we need a higher level wrapper around the model weights (LLamaWeights??) + var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); + + // todo: need a method on the LLamaWeights which does this + var ctx1 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8); + var ctx2 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8); + + var alice = new InteractiveExecutor(ctx1); + var bob = new InteractiveExecutor(ctx2); + + // Initial alice prompt + var alicePrompt = "Transcript of a dialog, where the Alice interacts a person named Bob. Alice is friendly, kind, honest and good at writing.\nAlice: Hello"; + var aliceResponse = await Prompt(alice, ConsoleColor.Green, alicePrompt, false, false); + + // Initial bob prompt + var bobPrompt = $"Transcript of a dialog, where the Bob interacts a person named Alice. Bob is smart, intellectual and good at writing.\nAlice: Hello{aliceResponse}"; + var bobResponse = await Prompt(alice, ConsoleColor.Red, bobPrompt, true, true); + + // swap back and forth from Alice to Bob + while (true) + { + aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true); + bobResponse = await Prompt(alice, ConsoleColor.Red, aliceResponse, false, true); + Thread.Sleep(1000); + } + } + + private static async Task Prompt(ILLamaExecutor executor, ConsoleColor color, string prompt, bool showPrompt, bool showResponse) + { + var inferenceParams = new InferenceParams + { + Temperature = 0.9f, + AntiPrompts = new List { "Alice:", "Bob:", "User:" }, + MaxTokens = 128, + Mirostat = MirostatType.Mirostat2, + MirostatTau = 10, + }; + + Console.ForegroundColor = ConsoleColor.White; + if (showPrompt) + Console.Write(prompt); + + Console.ForegroundColor = color; + var builder = new StringBuilder(); + await foreach (var text in executor.InferAsync(prompt, inferenceParams)) + { + builder.Append(text); + if (showResponse) + Console.Write(text); + } + + return builder.ToString(); + } + } +} diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index 23c9ae6b7..c90bc78de 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -22,6 +22,7 @@ public static async Task Run() Console.WriteLine("6: Load and save state of model and executor."); Console.WriteLine("7: Get embeddings from LLama model."); Console.WriteLine("8: Quantize the model."); + Console.WriteLine("9: Automatic conversation."); while (true) { @@ -64,6 +65,10 @@ public static async Task Run() { QuantizeModel.Run(); } + else if (choice == 9) + { + await TalkToYourself.Run(); + } else { Console.WriteLine("Cannot parse your choice. Please select again."); diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 308b13ad4..5a07c86cc 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -8,7 +8,7 @@ public class BasicTest [Fact] public void LoadModel() { - var model = new LLamaModel(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256)); + var model = new LLamaContext(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256)); model.Dispose(); } } diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index d6d428133..c53676f24 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -60,7 +60,8 @@ public void Dispose() { _inferenceOptions = null; _outputTransform = null; - _executor.Model?.Dispose(); + + _executor?.Context.Dispose(); _executor = null; } } diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs index 6c266f14f..7dfcde397 100644 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ b/LLama.Web/Services/ConnectionSessionService.cs @@ -51,7 +51,7 @@ public Task> CreateAsync(LLamaExecutorType executor return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); // Create model - var llamaModel = new LLamaModel(modelOption); + var llamaModel = new LLamaContext(modelOption); // Create executor ILLamaExecutor executor = executorType switch diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ab89b5176..d6924e6c8 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -8,7 +8,7 @@ namespace LLama.WebAPI.Services; public class StatefulChatService : IDisposable { private readonly ChatSession _session; - private readonly LLamaModel _model; + private readonly LLamaContext _context; private bool _continue = false; private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" @@ -16,13 +16,13 @@ public class StatefulChatService : IDisposable public StatefulChatService(IConfiguration configuration) { - _model = new LLamaModel(new Common.ModelParams(configuration["ModelPath"], contextSize: 512)); - _session = new ChatSession(new InteractiveExecutor(_model)); + _context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"], contextSize: 512)); + _session = new ChatSession(new InteractiveExecutor(_context)); } public void Dispose() { - _model?.Dispose(); + _context?.Dispose(); } public string Send(SendMessageInput input) diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index c1356646e..27c508dd6 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -7,14 +7,14 @@ namespace LLama.WebAPI.Services { public class StatelessChatService { - private readonly LLamaModel _model; + private readonly LLamaContext _context; private readonly ChatSession _session; public StatelessChatService(IConfiguration configuration) { - _model = new LLamaModel(new ModelParams(configuration["ModelPath"], contextSize: 512)); + _context = new LLamaContext(new ModelParams(configuration["ModelPath"], contextSize: 512)); // TODO: replace with a stateless executor - _session = new ChatSession(new InteractiveExecutor(_model)) + _session = new ChatSession(new InteractiveExecutor(_context)) .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) .WithHistoryTransform(new HistoryTransform()); } diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 6a7508955..43a6092bd 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -12,9 +12,9 @@ namespace LLama.Abstractions public interface ILLamaExecutor { /// - /// The loaded model for this executor. + /// The loaded context for this executor. /// - public LLamaModel Model { get; } + public LLamaContext Context { get; } /// /// Infers a response from the model. diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 4a4544b0e..56ef47f50 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -91,7 +91,7 @@ public virtual void SaveSession(string path) { Directory.CreateDirectory(path); } - _executor.Model.SaveState(Path.Combine(path, _modelStateFilename)); + _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); if(Executor is StatelessExecutor) { @@ -116,7 +116,7 @@ public virtual void LoadSession(string path) { throw new FileNotFoundException($"Directory {path} does not exist."); } - _executor.Model.LoadState(Path.Combine(path, _modelStateFilename)); + _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); if (Executor is StatelessExecutor) { diff --git a/LLama/LLamaModel.cs b/LLama/LLamaContext.cs similarity index 86% rename from LLama/LLamaModel.cs rename to LLama/LLamaContext.cs index 4bcf0fec1..c6b8749a6 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaContext.cs @@ -15,28 +15,43 @@ namespace LLama { using llama_token = Int32; + /// - /// The abstraction of a LLama model, which holds the context in the native library. + /// A llama_context, which holds all the context required to interact with a model /// - public class LLamaModel: IDisposable + public class LLamaContext + : IDisposable { - // TODO: expose more properties. - ILLamaLogger? _logger; - Encoding _encoding; - SafeLLamaContextHandle _ctx; + private readonly ILLamaLogger? _logger; + private readonly Encoding _encoding; + private readonly SafeLLamaContextHandle _ctx; + + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount => _ctx.VocabCount; + + /// + /// Total number of tokens in the context + /// + public int ContextSize => _ctx.ContextSize; + /// - /// The context size. + /// Dimension of embedding vectors /// - public int ContextSize { get; } + public int EmbeddingCount => _ctx.EmbeddingCount; + /// /// The model params set for this model. /// public IModelParams Params { get; set; } + /// - /// The native handle, which is used to be passed to the native APIs. Please avoid using it - /// unless you know what is the usage of the Native API. + /// The native handle, which is used to be passed to the native APIs /// + /// Be careful how you use this! public SafeLLamaContextHandle NativeHandle => _ctx; + /// /// The encoding set for this model to deal with text input. /// @@ -45,17 +60,46 @@ public class LLamaModel: IDisposable /// /// /// - /// Model params. + /// Model params. /// Encoding to deal with text input. /// The logger. - public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null) + public LLamaContext(IModelParams @params, string encoding = "UTF-8", ILLamaLogger? logger = null) { + Params = @params; + _logger = logger; - this.Params = Params; _encoding = Encoding.GetEncoding(encoding); - _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); - _ctx = Utils.InitLLamaContextFromModelParams(this.Params); - ContextSize = NativeApi.llama_n_ctx(_ctx); + + _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); + _ctx = Utils.InitLLamaContextFromModelParams(Params); + } + + //todo make this internal + public LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + { + Params = @params; + + _logger = logger; + _encoding = encoding; + _ctx = nativeContext; + } + + /// + /// Create a copy of the current state of this context + /// + /// + public LLamaContext Clone() + { + using var pin = Params.ToLlamaContextParams(out var lparams); + + // Create a blank new context for the model + var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params, _encoding); + + // Copy across the state + using var state = GetState(); + ctx.LoadState(state); + + return ctx; } /// @@ -338,7 +382,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) { - _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); + _logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); } @@ -347,7 +391,6 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) return pastTokensCount; } - // TODO: add comment internal IEnumerable GenerateResult(IEnumerable ids) { foreach(var id in ids) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 2caaa8e50..5f2b129e5 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -18,10 +18,6 @@ namespace LLama /// public abstract class StatefulExecutorBase : ILLamaExecutor { - /// - /// The loaded model for this executor. - /// - protected readonly LLamaModel _model; /// /// The logger used by this executor. /// @@ -63,9 +59,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected FixedSizeQueue _last_n_tokens; /// - /// The mode used by the executor. + /// The context used by the executor. /// - public LLamaModel Model => _model; + public LLamaContext Context { get; } /// /// Current "mu" value for mirostat sampling @@ -75,16 +71,16 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// /// /// - /// + /// /// - protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null) + protected StatefulExecutorBase(LLamaContext context, ILLamaLogger? logger = null) { - _model = model; + Context = context; _logger = logger; _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; - _last_n_tokens = new FixedSizeQueue(_model.ContextSize).FillWith(0); + _last_n_tokens = new FixedSizeQueue(Context.ContextSize).FillWith(0); } /// @@ -104,9 +100,9 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename) if (File.Exists(filename)) { _logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info); - llama_token[] session_tokens = new llama_token[_model.ContextSize]; + llama_token[] session_tokens = new llama_token[Context.ContextSize]; ulong n_token_count_out = 0; - if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out)) + if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out)) { _logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error); throw new RuntimeError($"Failed to load session file {_pathSession}"); @@ -156,7 +152,7 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename) public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); - NativeApi.llama_save_session_file(_model.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); + NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); } /// @@ -173,7 +169,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) _pastTokensCount = Math.Max(1, tokensToKeep); // insert n_left/2 tokens at the start of embed from last_n_tokens - _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(_model.ContextSize - n_left / 2 - _embeds.Count)); + _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(Context.ContextSize - n_left / 2 - _embeds.Count)); // stop saving session if we run out of context _pathSession = string.Empty; @@ -296,7 +292,7 @@ public virtual IEnumerable Infer(string text, IInferenceParams? inferenc if (args.ReturnValue) { - foreach (var item in _model.GenerateResult(_embeds)) + foreach (var item in Context.GenerateResult(_embeds)) { yield return item; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 7a065ce52..6773cdde5 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -24,14 +24,14 @@ public class InstructExecutor : StatefulExecutorBase /// /// /// - /// + /// /// /// - public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", - string instructionSuffix = "\n\n### Response:\n\n") : base(model) + public InstructExecutor(LLamaContext context, string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n") : base(context) { - _inp_pfx = _model.Tokenize(instructionPrefix, true); - _inp_sfx = _model.Tokenize(instructionSuffix, false); + _inp_pfx = Context.Tokenize(instructionPrefix, true); + _inp_sfx = Context.Tokenize(instructionSuffix, false); _instructionPrefix = instructionPrefix; } @@ -117,7 +117,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) { // When running the first input (prompt) in inteactive mode, we should specially process it. text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); + _embed_inps = Context.Tokenize(text, true).ToList(); } else { @@ -128,7 +128,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) _consumedTokensCount = _embed_inps.Count; _embed_inps.AddRange(_inp_pfx); - var line_inp = _model.Tokenize(text, false); + var line_inp = Context.Tokenize(text, false); _embed_inps.AddRange(line_inp); _embed_inps.AddRange(_inp_sfx); @@ -146,7 +146,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState { string last_output = ""; foreach (var id in _last_n_tokens) - last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); + last_output += Context.NativeHandle.TokenToString(id, Context.Encoding); foreach (var antiprompt in args.Antiprompts) { @@ -183,13 +183,13 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + if (_pastTokensCount + _embeds.Count > Context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -202,7 +202,7 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) @@ -211,11 +211,11 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta SaveSessionFile(_pathSession); } - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; - var id = _model.Sample( + var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); @@ -235,7 +235,7 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) + if (_embeds.Count >= Context.Params.BatchSize) { break; } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 52d8d3bc4..4cf193db1 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -22,10 +22,10 @@ public class InteractiveExecutor : StatefulExecutorBase /// /// /// - /// - public InteractiveExecutor(LLamaModel model) : base(model) + /// + public InteractiveExecutor(LLamaContext context) : base(context) { - _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); + _llama_token_newline = Context.NativeHandle.Tokenize("\n", false, Context.Encoding); } /// @@ -45,7 +45,7 @@ public override ExecutorBaseState GetStateData() SessionFilePath = _pathSession, SessionTokens = _session_tokens, LastTokensCapacity = _last_n_tokens.Capacity, - MirostatMu = MirostatMu + MirostateMu = MirostateMu }; return state; } @@ -104,7 +104,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) { // When running the first input (prompt) in inteactive mode, we should specially process it. text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); + _embed_inps = Context.Tokenize(text, true).ToList(); } else { @@ -112,7 +112,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) { text += "\n"; } - var line_inp = _model.Tokenize(text, false); + var line_inp = Context.Tokenize(text, false); _embed_inps.AddRange(line_inp); args.RemainedTokens -= line_inp.Length; } @@ -133,7 +133,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState string last_output = ""; foreach (var id in _last_n_tokens) { - last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); + last_output += Context.NativeHandle.TokenToString(id, Context.Encoding); } foreach (var antiprompt in args.Antiprompts) @@ -172,13 +172,13 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + if (_pastTokensCount + _embeds.Count > Context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -191,7 +191,7 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) @@ -200,15 +200,15 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta SaveSessionFile(_pathSession); } - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = _model.Sample( + var mu = MirostateMu; + var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); - MirostatMu = mu; + MirostateMu = mu; _last_n_tokens.Enqueue(id); @@ -217,7 +217,7 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta id = _llama_token_newline.First(); if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false); + var first_antiprompt = Context.Tokenize(args.Antiprompts[0], false); _embed_inps.AddRange(first_antiprompt); } } @@ -234,7 +234,7 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) + if (_embeds.Count >= Context.Params.BatchSize) { break; } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index be32a5afc..f09ff7dd9 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -16,23 +16,23 @@ namespace LLama /// public class StatelessExecutor : ILLamaExecutor { - private LLamaModel _model; - private LLamaModel.State _originalState; + private LLamaContext _context; + private LLamaContext.State _originalState; /// - /// The mode used by the executor when running the inference. + /// The context used by the executor when running the inference. /// - public LLamaModel Model => _model; + public LLamaContext Context => _context; /// /// /// - /// The LLama model. - public StatelessExecutor(LLamaModel model) + /// The LLama model. + public StatelessExecutor(LLamaContext context) { - _model = model; + _context = context; - var tokens = model.Tokenize(" ", true).ToArray(); - _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads); - _originalState = model.GetState(); + var tokens = context.Tokenize(" ", true).ToArray(); + _context.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _context.Params.Threads); + _originalState = context.GetState(); } /// @@ -49,10 +49,10 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams { lastTokens[i] = 0; } - List tokens = _model.Tokenize(text, true).ToList(); + List tokens = _context.Tokenize(text, true).ToList(); int n_prompt_tokens = tokens.Count; - _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads); + _context.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _context.Params.Threads); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; @@ -63,20 +63,20 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams { if (cancellationToken.IsCancellationRequested) { - _model.LoadState(_originalState); + _context.LoadState(_originalState); break; } - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _context.ContextSize : inferenceParams.RepeatLastTokensCount; - var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = _context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); lastTokens.Add(id); - string response = _model.NativeHandle.TokenToString(id, _model.Encoding); + string response = _context.NativeHandle.TokenToString(id, _context.Encoding); yield return response; tokens.Clear(); @@ -87,7 +87,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams string last_output = ""; foreach (var token in lastTokens) { - last_output += _model.NativeHandle.TokenToString(token, _model.Encoding); + last_output += _context.NativeHandle.TokenToString(token, _context.Encoding); } bool should_break = false; @@ -106,20 +106,20 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams } // when run out of context - if (n_past + tokens.Count > _model.ContextSize) + if (n_past + tokens.Count > _context.ContextSize) { int n_left = n_past - inferenceParams.TokensKeep; n_past = Math.Max(1, inferenceParams.TokensKeep); // insert n_left/2 tokens at the start of embed from last_n_tokens - tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_model.ContextSize - n_left / 2 - tokens.Count)); + tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_context.ContextSize - n_left / 2 - tokens.Count)); } - n_past = _model.Eval(tokens.ToArray(), n_past); + n_past = _context.Eval(tokens.ToArray(), n_past); } - _model.LoadState(_originalState); + _context.LoadState(_originalState); } /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index fa54f73ec..a6760b7d5 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -28,8 +28,10 @@ public class SafeLLamaContextHandle public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; /// - /// This field guarantees that a reference to the model is held for as long as this handle is held + /// Get the model which this context is using /// + public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); + private SafeLlamaModelHandle? _model; #endregion @@ -55,7 +57,7 @@ protected override bool ReleaseHandle() { // Decrement refcount on model _model?.DangerousRelease(); - _model = null; + _model = null!; NativeApi.llama_free(handle); SetHandle(IntPtr.Zero); @@ -69,7 +71,7 @@ private SafeLlamaModelHandle ThrowIfDisposed() if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); - return _model; + return _model!; } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index dbb1b0707..01d2ca8c6 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -157,5 +157,17 @@ public int[] Tokenize(string text, bool add_bos, Encoding encoding) } } #endregion + + #region context + /// + /// Create a new context for this model + /// + /// + /// + public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) + { + return SafeLLamaContextHandle.Create(this, @params); + } + #endregion } } diff --git a/LLama/ResettableLLamaModel.cs b/LLama/ResettableLLamaModel.cs deleted file mode 100644 index d9b4e8226..000000000 --- a/LLama/ResettableLLamaModel.cs +++ /dev/null @@ -1,43 +0,0 @@ -using LLama.Abstractions; -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama -{ - /// - /// A LLamaModel what could be reset. Note that using this class will consume about 10% more memories. - /// - public class ResettableLLamaModel : LLamaModel - { - /// - /// The initial state of the model - /// - public State OriginalState { get; set; } - /// - /// - /// - /// - /// - public ResettableLLamaModel(IModelParams Params, string encoding = "UTF-8") : base(Params, encoding) - { - OriginalState = GetState(); - } - - /// - /// Reset the state to the initial state. - /// - public void Reset() - { - LoadState(OriginalState); - } - - /// - public override void Dispose() - { - OriginalState.Dispose(); - - base.Dispose(); - } - } -} diff --git a/LLama/Utils.cs b/LLama/Utils.cs index de363a3ed..42172737f 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -17,7 +17,7 @@ public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParam { var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); var ctx = SafeLLamaContextHandle.Create(model, lparams); - + if (!string.IsNullOrEmpty(@params.LoraAdapter)) model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); From f31bdf6b932cb88854cc21eab0ebafa71a79a5ba Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 8 Aug 2023 01:19:45 +0100 Subject: [PATCH 02/13] Using the right context for Bob --- LLama.Examples/NewVersion/TalkToYourself.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 784d952a3..4a61f88f0 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -39,13 +39,13 @@ public static async Task Run() // Initial bob prompt var bobPrompt = $"Transcript of a dialog, where the Bob interacts a person named Alice. Bob is smart, intellectual and good at writing.\nAlice: Hello{aliceResponse}"; - var bobResponse = await Prompt(alice, ConsoleColor.Red, bobPrompt, true, true); + var bobResponse = await Prompt(bob, ConsoleColor.Red, bobPrompt, true, true); // swap back and forth from Alice to Bob while (true) { aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true); - bobResponse = await Prompt(alice, ConsoleColor.Red, aliceResponse, false, true); + bobResponse = await Prompt(bob, ConsoleColor.Red, aliceResponse, false, true); Thread.Sleep(1000); } } From fda7e1c0389dbc982fee24a51ee7a0f51f24bd18 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 8 Aug 2023 14:51:07 +0100 Subject: [PATCH 03/13] Fixed mirostat/mirostate --- LLama/LLamaInteractExecutor.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4cf193db1..533a18639 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -45,7 +45,7 @@ public override ExecutorBaseState GetStateData() SessionFilePath = _pathSession, SessionTokens = _session_tokens, LastTokensCapacity = _last_n_tokens.Capacity, - MirostateMu = MirostateMu + MirostatMu = MirostatMu }; return state; } @@ -203,12 +203,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostateMu; + var mu = MirostatMu; var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); - MirostateMu = mu; + MirostatMu = mu; _last_n_tokens.Enqueue(id); From e2fe08a9a2098ee7cd0de120d28477cb3d30744d Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 8 Aug 2023 15:12:42 +0100 Subject: [PATCH 04/13] Added a higher level `LLamaWeights` wrapper around `SafeLlamaModelHandle` --- LLama.Examples/NewVersion/TalkToYourself.cs | 28 ++++------ LLama/LLamaContext.cs | 11 ++++ LLama/LLamaWeights.cs | 57 +++++++++++++++++++++ 3 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 LLama/LLamaWeights.cs diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 4a61f88f0..35a652417 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -2,8 +2,6 @@ using System.Text; using LLama.Abstractions; using LLama.Common; -using LLama.Extensions; -using LLama.Native; namespace LLama.Examples.NewVersion { @@ -12,26 +10,20 @@ public class TalkToYourself public static async Task Run() { Console.Write("Please input your model path: "); - string modelPath = "C:\\Users\\Martin\\Documents\\Python\\oobabooga_windows\\text-generation-webui\\models\\llama-2-7b-chat.ggmlv3.q6_K.bin"; + var modelPath = Console.ReadLine(); - // todo: model path is passed here, but isn't needed + // Load weights into memory var @params = new ModelParams(modelPath) { Seed = RandomNumberGenerator.GetInt32(int.MaxValue) }; + using var weights = LLamaWeights.LoadFromFile(@params); - // todo: all this pin stuff is ugly and should be hidden in the higher level wrapper - using var pin = @params.ToLlamaContextParams(out var lparams); - - // todo: we need a higher level wrapper around the model weights (LLamaWeights??) - var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); - - // todo: need a method on the LLamaWeights which does this - var ctx1 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8); - var ctx2 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8); - - var alice = new InteractiveExecutor(ctx1); - var bob = new InteractiveExecutor(ctx2); + // Create 2 contexts sharing the same weights + using var aliceCtx = weights.CreateContext(@params, Encoding.UTF8); + var alice = new InteractiveExecutor(aliceCtx); + using var bobCtx = weights.CreateContext(@params, Encoding.UTF8); + var bob = new InteractiveExecutor(bobCtx); // Initial alice prompt var alicePrompt = "Transcript of a dialog, where the Alice interacts a person named Bob. Alice is friendly, kind, honest and good at writing.\nAlice: Hello"; @@ -46,7 +38,9 @@ public static async Task Run() { aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true); bobResponse = await Prompt(bob, ConsoleColor.Red, aliceResponse, false, true); - Thread.Sleep(1000); + + if (Console.KeyAvailable) + break; } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index c6b8749a6..8eb2b9aa3 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -84,6 +84,17 @@ public LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, _ctx = nativeContext; } + public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + { + Params = @params; + + _logger = logger; + _encoding = encoding; + + using var pin = @params.ToLlamaContextParams(out var lparams); + _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); + } + /// /// Create a copy of the current state of this context /// diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs new file mode 100644 index 000000000..be21c6f51 --- /dev/null +++ b/LLama/LLamaWeights.cs @@ -0,0 +1,57 @@ +using System; +using System.Text; +using LLama.Common; +using LLama.Extensions; +using LLama.Native; + +namespace LLama +{ + /// + /// A set of model weights, loaded into memory. + /// + public class LLamaWeights + : IDisposable + { + private readonly SafeLlamaModelHandle _weights; + + /// + /// The native handle, which is used in the native APIs + /// + /// Be careful how you use this! + public SafeLlamaModelHandle NativeHandle => _weights; + + private LLamaWeights(SafeLlamaModelHandle weights) + { + _weights = weights; + } + + /// + /// Load weights into memory + /// + /// + /// + public static LLamaWeights LoadFromFile(ModelParams @params) + { + using var pin = @params.ToLlamaContextParams(out var lparams); + var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + return new LLamaWeights(weights); + } + + /// + public void Dispose() + { + _weights.Dispose(); + } + + /// + /// Create a llama_context using this model + /// + /// + /// + /// + public LLamaContext CreateContext(ModelParams @params, Encoding utf8) + { + return new LLamaContext(this, @params, Encoding.UTF8); + } + } +} From 20bdc2ec6f9d32c3cb907d48461d94cd3a24cf4a Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 8 Aug 2023 15:53:14 +0100 Subject: [PATCH 05/13] - Apply LoRA in `LLamaWeights.LoadFromFile` - Sanity checking that weights are not disposed when creating a context from them - Further simplified `Utils.InitLLamaContextFromModelParams` --- LLama/LLamaContext.cs | 6 ++++-- LLama/LLamaWeights.cs | 14 +++++++++----- LLama/Utils.cs | 12 +++--------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 8eb2b9aa3..f0817f8f9 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -74,8 +74,7 @@ public LLamaContext(IModelParams @params, string encoding = "UTF-8", ILLamaLogge _ctx = Utils.InitLLamaContextFromModelParams(Params); } - //todo make this internal - public LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) { Params = @params; @@ -86,6 +85,9 @@ public LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) { + if (model.NativeHandle.IsClosed) + throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); + Params = @params; _logger = logger; diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index be21c6f51..8226753f4 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,6 +1,6 @@ using System; using System.Text; -using LLama.Common; +using LLama.Abstractions; using LLama.Extensions; using LLama.Native; @@ -30,10 +30,14 @@ private LLamaWeights(SafeLlamaModelHandle weights) /// /// /// - public static LLamaWeights LoadFromFile(ModelParams @params) + public static LLamaWeights LoadFromFile(IModelParams @params) { using var pin = @params.ToLlamaContextParams(out var lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + + if (!string.IsNullOrEmpty(@params.LoraAdapter)) + weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + return new LLamaWeights(weights); } @@ -47,11 +51,11 @@ public void Dispose() /// Create a llama_context using this model /// /// - /// + /// /// - public LLamaContext CreateContext(ModelParams @params, Encoding utf8) + public LLamaContext CreateContext(IModelParams @params, Encoding encoding) { - return new LLamaContext(this, @params, Encoding.UTF8); + return new LLamaContext(this, @params, encoding); } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 42172737f..9f4fb3fa0 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -13,16 +13,10 @@ public static class Utils { public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) { - using (@params.ToLlamaContextParams(out var lparams)) - { - var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - var ctx = SafeLLamaContextHandle.Create(model, lparams); - - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + using var weights = LLamaWeights.LoadFromFile(@params); - return ctx; - } + using (@params.ToLlamaContextParams(out var lparams)) + return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); } [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] From 4d741d24f24aacbbfbe265c00fd600506a7e6cdb Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 8 Aug 2023 15:57:08 +0100 Subject: [PATCH 06/13] Marked old `LLamaContext` constructor obsolete --- LLama/LLamaContext.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index f0817f8f9..4516517d7 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -63,6 +63,7 @@ public class LLamaContext /// Model params. /// Encoding to deal with text input. /// The logger. + [Obsolete("Use the LLamaWeights.CreateContext instead")] public LLamaContext(IModelParams @params, string encoding = "UTF-8", ILLamaLogger? logger = null) { Params = @params; @@ -83,6 +84,14 @@ internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params _ctx = nativeContext; } + /// + /// Create a new LLamaContext for the given LLamaWeights + /// + /// + /// + /// + /// + /// public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) { if (model.NativeHandle.IsClosed) From d0a7a8fcd636493087d33e2bd20f9df8b3e5e7a7 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 9 Aug 2023 00:35:32 +0100 Subject: [PATCH 07/13] - Cleaned up disposal in LLamaContext - sealed some classes not intended to be extended --- LLama/LLamaContext.cs | 7 ++++--- LLama/LLamaWeights.cs | 2 +- LLama/Native/SafeLLamaContextHandle.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 4516517d7..3e526749a 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -422,6 +422,8 @@ internal IEnumerable GenerateResult(IEnumerable ids) /// public virtual void Dispose() { + GC.SuppressFinalize(this); + _ctx.Dispose(); } @@ -429,12 +431,11 @@ public virtual void Dispose() /// The state of this model, which can be reloaded later /// public class State - : SafeHandleZeroOrMinusOneIsInvalid + : SafeLLamaHandleBase { internal State(IntPtr memory) - : base(true) + : base(memory) { - SetHandle(memory); } /// diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 8226753f4..cb237a701 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -9,7 +9,7 @@ namespace LLama /// /// A set of model weights, loaded into memory. /// - public class LLamaWeights + public sealed class LLamaWeights : IDisposable { private readonly SafeLlamaModelHandle _weights; diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index a6760b7d5..49a31e65e 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -8,7 +8,7 @@ namespace LLama.Native /// /// A safe wrapper around a llama_context /// - public class SafeLLamaContextHandle + public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 01d2ca8c6..182d503f8 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -7,7 +7,7 @@ namespace LLama.Native /// /// A reference to a set of llama model weights /// - public class SafeLlamaModelHandle + public sealed class SafeLlamaModelHandle : SafeLLamaHandleBase { /// From 479ff57853f7b893ea018f36283ef083057d6c79 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 9 Aug 2023 00:41:52 +0100 Subject: [PATCH 08/13] Renamed `EmbeddingCount` to `EmbeddingSize` --- LLama/LLamaContext.cs | 2 +- LLama/Native/SafeLLamaContextHandle.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 3e526749a..c8927b9d2 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -39,7 +39,7 @@ public class LLamaContext /// /// Dimension of embedding vectors /// - public int EmbeddingCount => _ctx.EmbeddingCount; + public int EmbeddingCount => _ctx.EmbeddingSize; /// /// The model params set for this model. diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 49a31e65e..04663d77b 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -25,7 +25,7 @@ public sealed class SafeLLamaContextHandle /// /// Dimension of embedding vectors /// - public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; + public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; /// /// Get the model which this context is using diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 182d503f8..4dadee500 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -23,14 +23,14 @@ public sealed class SafeLlamaModelHandle /// /// Dimension of embedding vectors /// - public int EmbeddingCount { get; } + public int EmbeddingSize { get; } internal SafeLlamaModelHandle(IntPtr handle) : base(handle) { VocabCount = NativeApi.llama_n_vocab_from_model(this); ContextSize = NativeApi.llama_n_ctx_from_model(this); - EmbeddingCount = NativeApi.llama_n_embd_from_model(this); + EmbeddingSize = NativeApi.llama_n_embd_from_model(this); } /// From f5a260926f2f1f872e9f3b9bed36641865a15842 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 9 Aug 2023 00:56:47 +0100 Subject: [PATCH 09/13] Renamed `EmbeddingCount` to `EmbeddingSize` in higher level class --- LLama/LLamaContext.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index c8927b9d2..1ef2a8db5 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -39,7 +39,7 @@ public class LLamaContext /// /// Dimension of embedding vectors /// - public int EmbeddingCount => _ctx.EmbeddingSize; + public int EmbeddingSize => _ctx.EmbeddingSize; /// /// The model params set for this model. From 1b35be2e0cc2168b27949ddbf3a7aef9f5aa2f61 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 9 Aug 2023 01:07:42 +0100 Subject: [PATCH 10/13] Added some additional basic tests --- LLama.Unittest/BasicTest.cs | 1 - LLama.Unittest/LLamaContextTests.cs | 36 +++++++++++++++++++++++ LLama.Unittest/LLamaEmbedderTests.cs | 44 ++++++++++++++++++++++++++++ LLama/LLamaEmbedder.cs | 12 +++++--- 4 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 LLama.Unittest/LLamaContextTests.cs create mode 100644 LLama.Unittest/LLamaEmbedderTests.cs diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 5a07c86cc..2a88f25dd 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -1,4 +1,3 @@ -using LLama; using LLama.Common; namespace LLama.Unittest diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs new file mode 100644 index 000000000..a34f58cb3 --- /dev/null +++ b/LLama.Unittest/LLamaContextTests.cs @@ -0,0 +1,36 @@ +using System.Text; +using LLama.Common; + +namespace LLama.Unittest +{ + public class LLamaContextTests + : IDisposable + { + private readonly LLamaWeights _weights; + private readonly LLamaContext _context; + + public LLamaContextTests() + { + var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") + { + ContextSize = 768, + }; + _weights = LLamaWeights.LoadFromFile(@params); + _context = _weights.CreateContext(@params, Encoding.UTF8); + } + + public void Dispose() + { + _weights.Dispose(); + _context.Dispose(); + } + + [Fact] + public void CheckProperties() + { + Assert.Equal(768, _context.ContextSize); + Assert.Equal(4096, _context.EmbeddingSize); + Assert.Equal(32000, _context.VocabCount); + } + } +} diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs new file mode 100644 index 000000000..1c4b9fd7f --- /dev/null +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -0,0 +1,44 @@ +using LLama.Common; + +namespace LLama.Unittest +{ + public class LLamaEmbedderTests + : IDisposable + { + private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")); + + public void Dispose() + { + _embedder.Dispose(); + } + + private static float Dot(float[] a, float[] b) + { + Assert.Equal(a.Length, b.Length); + return a.Zip(b, (x, y) => x + y).Sum(); + } + + [Fact] + public void EmbedHello() + { + var hello = _embedder.GetEmbeddings("Hello"); + + Assert.NotNull(hello); + Assert.NotEmpty(hello); + Assert.Equal(_embedder.EmbeddingSize, hello.Length); + } + + [Fact] + public void EmbedCompare() + { + var cat = _embedder.GetEmbeddings("cat"); + var kitten = _embedder.GetEmbeddings("kitten"); + var spoon = _embedder.GetEmbeddings("spoon"); + + var close = Dot(cat, kitten); + var far = Dot(cat, spoon); + + Assert.True(close < far); + } + } +} diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index a74f11ee2..6b82c4d8c 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,9 +1,7 @@ using LLama.Native; using System; -using System.Collections.Generic; using System.Text; using LLama.Exceptions; -using System.Linq; using LLama.Abstractions; namespace LLama @@ -11,9 +9,15 @@ namespace LLama /// /// The embedder for LLama, which supports getting embeddings from text. /// - public class LLamaEmbedder : IDisposable + public class LLamaEmbedder + : IDisposable { - SafeLLamaContextHandle _ctx; + private readonly SafeLLamaContextHandle _ctx; + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingSize => _ctx.EmbeddingSize; /// /// Warning: must ensure the original model has params.embedding = true; From 6473f8d5e518193d4ee4e415bf9c4cb5ad529a85 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 10 Aug 2023 17:51:51 +0100 Subject: [PATCH 11/13] Temporarily added a `Console.WriteLine` into the test, to print the embedding vector for "cat" in CI --- LLama.Unittest/LLamaEmbedderTests.cs | 2 ++ LLama/LLamaEmbedder.cs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 1c4b9fd7f..773153109 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -35,6 +35,8 @@ public void EmbedCompare() var kitten = _embedder.GetEmbeddings("kitten"); var spoon = _embedder.GetEmbeddings("spoon"); + Console.WriteLine(string.Join(",", cat)); + var close = Dot(cat, kitten); var far = Dot(cat, spoon); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 6b82c4d8c..5acf756bf 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -53,7 +53,7 @@ public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = { threads = Math.Max(Environment.ProcessorCount / 2, 1); } - int n_past = 0; + if (addBos) { text = text.Insert(0, " "); @@ -65,7 +65,7 @@ public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = if (embed_inp_array.Length > 0) { - if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0) + if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, 0, threads) != 0) { throw new RuntimeError("Failed to eval."); } From aeb7943710a36026745113bc5423c3f26b86f635 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 10 Aug 2023 17:59:11 +0100 Subject: [PATCH 12/13] Removed Console.WriteLine --- LLama.Unittest/LLamaEmbedderTests.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 773153109..1c4b9fd7f 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -35,8 +35,6 @@ public void EmbedCompare() var kitten = _embedder.GetEmbeddings("kitten"); var spoon = _embedder.GetEmbeddings("spoon"); - Console.WriteLine(string.Join(",", cat)); - var close = Dot(cat, kitten); var far = Dot(cat, spoon); From 76d991f376a1bf42bc5bf40888763c05a6d76918 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 10 Aug 2023 18:09:01 +0100 Subject: [PATCH 13/13] Removed embedding test, moved to another PR --- LLama.Unittest/LLamaEmbedderTests.cs | 44 ---------------------------- 1 file changed, 44 deletions(-) delete mode 100644 LLama.Unittest/LLamaEmbedderTests.cs diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs deleted file mode 100644 index 1c4b9fd7f..000000000 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ /dev/null @@ -1,44 +0,0 @@ -using LLama.Common; - -namespace LLama.Unittest -{ - public class LLamaEmbedderTests - : IDisposable - { - private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")); - - public void Dispose() - { - _embedder.Dispose(); - } - - private static float Dot(float[] a, float[] b) - { - Assert.Equal(a.Length, b.Length); - return a.Zip(b, (x, y) => x + y).Sum(); - } - - [Fact] - public void EmbedHello() - { - var hello = _embedder.GetEmbeddings("Hello"); - - Assert.NotNull(hello); - Assert.NotEmpty(hello); - Assert.Equal(_embedder.EmbeddingSize, hello.Length); - } - - [Fact] - public void EmbedCompare() - { - var cat = _embedder.GetEmbeddings("cat"); - var kitten = _embedder.GetEmbeddings("kitten"); - var spoon = _embedder.GetEmbeddings("spoon"); - - var close = Dot(cat, kitten); - var far = Dot(cat, spoon); - - Assert.True(close < far); - } - } -}