diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..b00368fb0 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,55 @@ +name: CI +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + build: + name: Test + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + build: [linux-debug, linux-release, macos-debug, macos-release, windows-debug, windows-release] + include: + - build: linux-debug + os: ubuntu-latest + config: debug + - build: linux-release + os: ubuntu-latest + config: release + - build: macos-debug + os: macos-latest + config: debug + - build: macos-release + os: macos-latest + config: release + - build: windows-debug + os: windows-2019 + config: debug + - build: windows-release + os: windows-2019 + config: release + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-dotnet@v1 + with: + dotnet-version: | + 6.0.x + 7.0.x + - name: Cache Gradle packages + uses: actions/cache@v3 + with: + key: "unit_test_models" + path: LLama.Unittest/Models + # workaround for actions/setup-dotnet#155 + - name: Clear package cache + run: dotnet clean LLamaSharp.sln && dotnet nuget locals all --clear + - name: Restore packages + run: dotnet restore LLamaSharp.sln + - name: Build + run: dotnet build LLamaSharp.sln -c ${{ matrix.config }} --no-restore + - name: Test + run: dotnet test LLamaSharp.sln -c ${{ matrix.config }} diff --git a/.gitignore b/.gitignore index d1d0ba40b..2f38fac02 100644 --- a/.gitignore +++ b/.gitignore @@ -341,4 +341,5 @@ test/TensorFlowNET.Examples/mnist *.xsd # docs -site/ \ No newline at end of file +site/ +/LLama.Unittest/Models/*.bin diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index ce677c40c..6a8e81c93 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -14,7 +14,8 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)); + InteractiveExecutor ex = new(new LLamaModelContext(model)); ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index cbf9333b9..3634d6e55 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -14,7 +14,8 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)); + InteractiveExecutor ex = new(new LLamaModelContext(model)); ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index 303c8644b..97ce002d3 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -15,7 +15,8 @@ public static void Run() string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/dan.txt").Trim(); - InstructExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 1024)); + InstructExecutor ex = new(new LLamaModelContext(model)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + diff --git a/LLama.Examples/NewVersion/InteractiveModeExecute.cs b/LLama.Examples/NewVersion/InteractiveModeExecute.cs index 23afcadf6..f2cef5c3c 100644 --- a/LLama.Examples/NewVersion/InteractiveModeExecute.cs +++ b/LLama.Examples/NewVersion/InteractiveModeExecute.cs @@ -15,7 +15,8 @@ public async static Task Run() string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 256)); + InteractiveExecutor ex = new(new LLamaModelContext(model)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 722ec3e0e..a2f00c758 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -15,7 +15,8 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)); + InteractiveExecutor ex = new(new LLamaModelContext(model)); ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. Console.ForegroundColor = ConsoleColor.Yellow; @@ -45,8 +46,10 @@ public static void Run() Console.WriteLine("Saved session!"); Console.ForegroundColor = ConsoleColor.White; - ex.Model.Dispose(); - ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); + ex.Context.Dispose(); + + //LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)); + ex = new(new LLamaModelContext(model)); session = new ChatSession(ex); session.LoadSession(statePath); diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index dc3031414..df2998a27 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -15,7 +15,8 @@ public static void Run() string modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 256)); + InteractiveExecutor ex = new(new LLamaModelContext(model)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); @@ -37,7 +38,7 @@ public static void Run() { Console.Write("Your path to save model state: "); string modelStatePath = Console.ReadLine(); - ex.Model.SaveState(modelStatePath); + ex.Context.SaveState(modelStatePath); Console.Write("Your path to save executor state: "); string executorStatePath = Console.ReadLine(); @@ -47,9 +48,9 @@ public static void Run() Console.WriteLine("All states saved!"); Console.ForegroundColor = ConsoleColor.White; - var model = ex.Model; - model.LoadState(modelStatePath); - ex = new InteractiveExecutor(model); + var context = ex.Context; + context.LoadState(modelStatePath); + ex = new InteractiveExecutor(context); ex.LoadState(executorStatePath); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Loaded state!"); diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index 3c4852319..6e7268cd0 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -14,7 +14,8 @@ public static void Run() Console.Write("Please input your model path: "); string modelPath = Console.ReadLine(); - StatelessExecutor ex = new(new LLamaModel(new ModelParams(modelPath, contextSize: 256))); + LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 256)); + StatelessExecutor ex = new(new LLamaModelContext(model)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 291784329..fd602cf03 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -1,11 +1,14 @@ +using LLama.Common; + namespace LLama.Unittest { public class BasicTest { [Fact] - public void SimpleQA() + public void LoadModel() { - + var model = new LLamaModel(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin", contextSize: 256)); + model.Dispose(); } } } \ No newline at end of file diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 93922e813..65dbd1616 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -23,8 +23,23 @@ + + + + + + + + + + + + PreserveNewest + + + diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index 1ac0d829f..2348dd133 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -3,18 +3,9 @@ public class LLamaOptions { public List Models { get; set; } - public List Prompts { get; set; } = new List(); - public List Parameters { get; set; } = new List(); public void Initialize() { - foreach (var prompt in Prompts) - { - if (File.Exists(prompt.Path)) - { - prompt.Prompt = File.ReadAllText(prompt.Path).Trim(); - } - } } } } diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 4f8f97e61..744b54a76 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -1,15 +1,32 @@ -using LLama.Common; +using LLama.Abstractions; namespace LLama.Web.Common { - public class ModelOptions : ModelParams + public class ModelOptions : IModelParams { - public ModelOptions() : base("", 512, 20, 1337, true, true, false, false, "", "", -1, 512, false, false) - { - } - - public string Name { get; set; } public int MaxInstances { get; set; } + public string Name { get; set; } = "unknown"; + public int ContextSize { get; set; } = 512; + public int MainGpu { get; set; } = 0; + public bool LowVram { get; set; } = false; + public int GpuLayerCount { get; set; } = 20; + public int Seed { get; set; } = 1686349486; + public bool UseFp16Memory { get; set; } = true; + public bool UseMemorymap { get; set; } = true; + public bool UseMemoryLock { get; set; } = false; + public bool Perplexity { get; set; } = false; + public string ModelPath { get; set; } + public string LoraAdapter { get; set; } = string.Empty; + public string LoraBase { get; set; } = string.Empty; + public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); + public int BatchSize { get; set; } = 512; + public bool ConvertEosToNewLine { get; set; } = false; + public bool EmbeddingMode { get; set; } = false; + public float[] TensorSplits { get; set; } = new float[] { 0 }; + public int GroupedQueryAttention { get; set; } = 1; + public float RmsNormEpsilon { get; set; } = 5e-6f; + public float RopeFrequencyBase { get; set; } = 10000.0f; + public float RopeFrequencyScale { get; set; } = 1.0f; } } diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs deleted file mode 100644 index 3cdd37019..000000000 --- a/LLama.Web/Common/ParameterOptions.cs +++ /dev/null @@ -1,9 +0,0 @@ -using LLama.Common; - -namespace LLama.Web.Common -{ - public class ParameterOptions : InferenceParams - { - public string Name { get; set; } - } -} diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 080866c6b..efaf28c88 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -2,7 +2,6 @@ using LLama.Web.Models; using LLama.Web.Services; using Microsoft.AspNetCore.SignalR; -using System.Diagnostics; namespace LLama.Web.Hubs { @@ -38,15 +37,13 @@ public override async Task OnDisconnectedAsync(Exception? exception) [HubMethodName("LoadModel")] - public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, string promptName, string parameterName) + public async Task OnLoadModel(CreateSessionModel sessionModel) { - _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName); - - // Remove existing connections session - await _modelSessionService.RemoveAsync(Context.ConnectionId); + _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); + // Create model session - var modelSessionResult = await _modelSessionService.CreateAsync(executorType, Context.ConnectionId, modelName, promptName, parameterName); + var modelSessionResult = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionModel); if (modelSessionResult.HasError) { await Clients.Caller.OnError(modelSessionResult.Error); @@ -63,35 +60,11 @@ public async Task OnSendPrompt(string prompt) { _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId); - // Get connections session - var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId); - if (modelSession is null) - { - await Clients.Caller.OnError("No model has been loaded"); - return; - } - - - // Create unique response id - var responseId = Guid.NewGuid().ToString(); - - // Send begin of response - await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); - - // Send content of response - var stopwatch = Stopwatch.GetTimestamp(); - await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) + // Send Infer response + await foreach (var responseFragment in _modelSessionService.InferAsync(Context.ConnectionId, prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) { - await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); + await Clients.Caller.OnResponse(responseFragment); } - - // Send end of response - var elapsedTime = Stopwatch.GetElapsedTime(stopwatch); - var signature = modelSession.IsInferCanceled() - ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" - : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; - await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); - _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); } } diff --git a/LLama.Web/Models/CreateSessionModel.cs b/LLama.Web/Models/CreateSessionModel.cs new file mode 100644 index 000000000..aec93daf3 --- /dev/null +++ b/LLama.Web/Models/CreateSessionModel.cs @@ -0,0 +1,39 @@ +using LLama.Abstractions; +using LLama.Common; +using LLama.Web.Common; + +namespace LLama.Web.Models +{ + public class CreateSessionModel : IInferenceParams + { + public string Model { get; set; } + public string Prompt { get; set; } + public LLamaExecutorType ExecutorType { get; set; } = LLamaExecutorType.Interactive; + public string AntiPrompt { get; set; } = string.Empty; + public string OutputFilter { get; set; } = string.Empty; + + public int TokensKeep { get; set; } = 0; + public int MaxTokens { get; set; } = -1; + public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + public string InputSuffix { get; set; } = string.Empty; + public string InputPrefix { get; set; } = string.Empty; + public int TopK { get; set; } = 40; + public float TopP { get; set; } = 0.95f; + public float TfsZ { get; set; } = 1.0f; + public float TypicalP { get; set; } = 1.0f; + public float Temperature { get; set; } = 0.8f; + public float RepeatPenalty { get; set; } = 1.1f; + public int RepeatLastTokensCount { get; set; } = 64; + public float FrequencyPenalty { get; set; } = .0f; + public float PresencePenalty { get; set; } = .0f; + public MirostatType Mirostat { get; set; } = MirostatType.Disable; + public float MirostatTau { get; set; } = 5.0f; + public float MirostatEta { get; set; } = 0.1f; + public bool PenalizeNL { get; set; } = true; + + + // TODO: Ensure overpost protected + public Dictionary LogitBias { get; set; } + public string PathSession { get; set; } = string.Empty; + } +} diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index d6d428133..5ea5c08b1 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -3,31 +3,31 @@ namespace LLama.Web.Models { - public class ModelSession : IDisposable + public class ModelSession { private bool _isFirstInteraction = true; - private ModelOptions _modelOptions; + private IModelParams _modelParams; private PromptOptions _promptOptions; - private ParameterOptions _inferenceOptions; + private IInferenceParams _inferenceParams; private ITextStreamTransform _outputTransform; private ILLamaExecutor _executor; private CancellationTokenSource _cancellationTokenSource; - public ModelSession(ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) + public ModelSession(ILLamaExecutor executor, IModelParams modelOptions, PromptOptions promptOptions, IInferenceParams inferenceParams) { _executor = executor; - _modelOptions = modelOptions; + _modelParams = modelOptions; _promptOptions = promptOptions; - _inferenceOptions = parameterOptions; + _inferenceParams = inferenceParams; - _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty()).Distinct() ?? _inferenceOptions.AntiPrompts; + _inferenceParams.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceParams.AntiPrompts ?? Enumerable.Empty()).Distinct() ?? _inferenceParams.AntiPrompts; if (_promptOptions.OutputFilter?.Count > 0) _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); } public string ModelName { - get { return _modelOptions.Name; } + get { return _modelParams.Name; } } public IAsyncEnumerable InferAsync(string message, CancellationTokenSource cancellationTokenSource) @@ -36,13 +36,13 @@ public IAsyncEnumerable InferAsync(string message, CancellationTokenSour if (_isFirstInteraction) { _isFirstInteraction = false; - message = _promptOptions.Prompt + message; + message = string.Join(" ", _promptOptions.Prompt , message); } if (_outputTransform is not null) - return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); + return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceParams, _cancellationTokenSource.Token)); - return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); + return _executor.InferAsync(message, _inferenceParams, _cancellationTokenSource.Token); } @@ -53,15 +53,7 @@ public void CancelInfer() public bool IsInferCanceled() { - return _cancellationTokenSource.IsCancellationRequested; - } - - public void Dispose() - { - _inferenceOptions = null; - _outputTransform = null; - _executor.Model?.Dispose(); - _executor = null; + return _cancellationTokenSource?.IsCancellationRequested ?? false; } } } diff --git a/LLama.Web/Models/ResponseFragment.cs b/LLama.Web/Models/ResponseFragment.cs index 02f27f13e..10ab51f3d 100644 --- a/LLama.Web/Models/ResponseFragment.cs +++ b/LLama.Web/Models/ResponseFragment.cs @@ -2,17 +2,11 @@ { public class ResponseFragment { - public ResponseFragment(string id, string content = null, bool isFirst = false, bool isLast = false) - { - Id = id; - IsLast = isLast; - IsFirst = isFirst; - Content = content; - } - public string Id { get; set; } public string Content { get; set; } public bool IsLast { get; set; } public bool IsFirst { get; set; } + public bool IsCancelled { get; set; } + public int Elapsed { get; set; } } } diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml b/LLama.Web/Pages/Executor/Instruct.cshtml deleted file mode 100644 index 9f8cb2d89..000000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InstructModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Instruct

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.cs b/LLama.Web/Pages/Executor/Instruct.cshtml.cs deleted file mode 100644 index 18a58253b..000000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InstructModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InstructModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.css b/LLama.Web/Pages/Executor/Instruct.cshtml.css deleted file mode 100644 index ed9a1d59f..000000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml b/LLama.Web/Pages/Executor/Interactive.cshtml deleted file mode 100644 index 916b59ca8..000000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InteractiveModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Interactive

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates");} - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.cs b/LLama.Web/Pages/Executor/Interactive.cshtml.cs deleted file mode 100644 index 7179a4405..000000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InteractiveModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InteractiveModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.css b/LLama.Web/Pages/Executor/Interactive.cshtml.css deleted file mode 100644 index ed9a1d59f..000000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml b/LLama.Web/Pages/Executor/Stateless.cshtml deleted file mode 100644 index b5d8eea37..000000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml +++ /dev/null @@ -1,97 +0,0 @@ -@page -@model StatelessModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Stateless

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.cs b/LLama.Web/Pages/Executor/Stateless.cshtml.cs deleted file mode 100644 index f88c4b832..000000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class StatelessModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public StatelessModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.css b/LLama.Web/Pages/Executor/Stateless.cshtml.css deleted file mode 100644 index ed9a1d59f..000000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml index b5f0c15fc..3e97195dd 100644 --- a/LLama.Web/Pages/Index.cshtml +++ b/LLama.Web/Pages/Index.cshtml @@ -1,10 +1,111 @@ @page @model IndexModel @{ - ViewData["Title"] = "Home page"; + ViewData["Title"] = "Inference Demo"; } -
-

Welcome

-

Learn about building Web apps with ASP.NET Core.

+@Html.AntiForgeryToken() +
+ +
+
+
+ @ViewData["Title"] +
+
+ Socket: Disconnected +
+
+ +
+
+
+
+ Model + @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control" ,required="required", autocomplete="off"}) +
+
+ Executor + @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control" ,required="required", autocomplete="off"}) +
+ + +
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+ +@{ + await Html.RenderPartialAsync("_ChatTemplates"); +} + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Index.cshtml.cs b/LLama.Web/Pages/Index.cshtml.cs index 477c9bfbe..8cd694bfa 100644 --- a/LLama.Web/Pages/Index.cshtml.cs +++ b/LLama.Web/Pages/Index.cshtml.cs @@ -1,20 +1,43 @@ -using Microsoft.AspNetCore.Mvc; +using LLama.Web.Common; +using LLama.Web.Models; +using LLama.Web.Services; +using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.Extensions.Options; namespace LLama.Web.Pages { public class IndexModel : PageModel { private readonly ILogger _logger; + private readonly ConnectionSessionService _modelSessionService; - public IndexModel(ILogger logger) + public IndexModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) { _logger = logger; + Options = options.Value; + _modelSessionService = modelSessionService; } + public LLamaOptions Options { get; set; } + + [BindProperty] + public CreateSessionModel SessionOptions { get; set; } + public void OnGet() { + SessionOptions = new CreateSessionModel + { + Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", + AntiPrompt = "User:", + OutputFilter = "User:, Response:" + }; + } + public async Task OnPostCancel(CancelModel model) + { + await _modelSessionService.CancelAsync(model.ConnectionId); + return new JsonResult(default); } } } \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml index 156440124..cd768f1f5 100644 --- a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -12,7 +12,7 @@
- {{text}} + {{text}}
{{date}}
@@ -26,9 +26,7 @@
- - - +
@@ -41,20 +39,6 @@
- \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 23132bfa4..b5d1ef02e 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -21,16 +21,7 @@ diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml b/LLama.Web/Pages/Shared/_Parameters.cshtml new file mode 100644 index 000000000..165b65a87 --- /dev/null +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml @@ -0,0 +1,137 @@ +@page +@model LLama.Abstractions.IInferenceParams +@{ +} + +
+
+ MaxTokens +
+ @Html.TextBoxFor(m => m.MaxTokens, new { @type="range", @class = "slider", min="-1", max="2048", step="1" }) + +
+
+ +
+ TokensKeep +
+ @Html.TextBoxFor(m => m.TokensKeep, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ TopK +
+ @Html.TextBoxFor(m => m.TopK, new { @type="range", @class = "slider", min="-1", max="100", step="1" }) + +
+
+ +
+ TopP +
+ @Html.TextBoxFor(m => m.TopP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ + + +
+
+ TypicalP +
+ @Html.TextBoxFor(m => m.TypicalP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ Temperature +
+ @Html.TextBoxFor(m => m.Temperature, new { @type="range", @class = "slider", min="0.0", max="1.5", step="0.01" }) + +
+
+
+ +
+
+ RepeatPenalty +
+ @Html.TextBoxFor(m => m.RepeatPenalty, new { @type="range", @class = "slider", min="0.0", max="2.0", step="0.01" }) + +
+
+ +
+ RepeatLastTokensCount +
+ @Html.TextBoxFor(m => m.RepeatLastTokensCount, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ FrequencyPenalty +
+ @Html.TextBoxFor(m => m.FrequencyPenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ PresencePenalty +
+ @Html.TextBoxFor(m => m.PresencePenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ +
+
+ TfsZ +
+ @Html.TextBoxFor(m => m.TfsZ, new { @type="range", @class = "slider",min="0.0", max="1.0", step="0.01" }) + +
+
+
+ - +
+ + +
+
+
+ + +
+ Mirostat + @Html.DropDownListFor(m => m.Mirostat, Html.GetEnumSelectList(), new { @class = "form-control form-select" }) +
+ +
+
+ MirostatTau +
+ @Html.TextBoxFor(m => m.MirostatTau, new { @type="range", @class = "slider", min="0.0", max="10.0", step="0.01" }) + +
+
+ +
+ MirostatEta +
+ @Html.TextBoxFor(m => m.MirostatEta, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
\ No newline at end of file diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 6db653a14..0d2aa35de 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -20,6 +20,7 @@ public static void Main(string[] args) .BindConfiguration(nameof(LLamaOptions)); // Services DI + builder.Services.AddSingleton(); builder.Services.AddSingleton(); var app = builder.Build(); diff --git a/LLama.Web/README.md b/LLama.Web/README.md index 9b6786e6b..bee508652 100644 --- a/LLama.Web/README.md +++ b/LLama.Web/README.md @@ -1,41 +1,45 @@ ## LLama.Web - Basic ASP.NET Core examples of LLamaSharp in action -LLama.Web has no heavy dependencies and no extra frameworks ove bootstrap and jquery to keep the examples clean and easy to copy over to your own project +LLama.Web has no heavy dependencies and no extra frameworks except bootstrap and jquery to keep the examples clean and easy to copy over to your own project + ## Websockets Using signalr websockets simplifys the streaming of responses and model per connection management - ## Setup -You can setup Models, Prompts and Inference parameters in the appsettings.json +You can setup Models, Ports etc in the appsettings.json **Models** You can add multiple models to the options for quick selection in the UI, options are based on ModelParams so its fully configurable -**Parameters** -You can add multiple sets of inference parameters to the options for quick selection in the UI, options are based on InferenceParams so its fully configurable - -**Prompts** -You can add multiple sets of prompts to the options for quick selection in the UI Example: ```json - { - "Name": "Alpaca", - "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", - "Prompt": "Alternativly to can set a prompt text directly and omit the Path" - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Response:", - "User:" - ] + "LLamaOptions": { + "Models": [ + { + "Name": "WizardLM-7B", + "MaxInstances": 2, + "ModelPath": "\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", + "ContextSize": 2048 + }, + { + "Name": "WizardLM-13B", + "MaxInstances": 2, + "ModelPath": "\\Models\\wizardLM-13B.ggmlv3.q4_0.bin", + "ContextSize": 1024, + "GpuLayerCount": 16, + "Threads": 15 + } + ] } ``` -## Interactive UI -The interactive UI is a simple example of using LLammaSharp -![demo-ui](https://i.imgur.com/nQsnWP1.png) +## Inference Demo +The Inference Demo UI is a simple example of using LLamaSharp +![demo-ui](https://i.imgur.com/FG0YEzw.png) + +Inference Parameters +![demo-ui2](https://i.imgur.com/fZEQTQ5.png) diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs index 6c266f14f..c94968b6b 100644 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ b/LLama.Web/Services/ConnectionSessionService.cs @@ -3,7 +3,7 @@ using LLama.Web.Models; using Microsoft.Extensions.Options; using System.Collections.Concurrent; -using System.Drawing; +using System.Diagnostics; namespace LLama.Web.Services { @@ -15,80 +15,162 @@ public class ConnectionSessionService : IModelSessionService { private readonly LLamaOptions _options; private readonly ILogger _logger; + private readonly IModelCacheService _modelCacheService; private readonly ConcurrentDictionary _modelSessions; - public ConnectionSessionService(ILogger logger, IOptions options) + + public ConnectionSessionService(ILogger logger, IOptions options, IModelCacheService modelCacheService) { _logger = logger; _options = options.Value; + _modelCacheService = modelCacheService; _modelSessions = new ConcurrentDictionary(); } - public Task GetAsync(string connectionId) + public Task GetAsync(string sessionId) { - _modelSessions.TryGetValue(connectionId, out var modelSession); + _modelSessions.TryGetValue(sessionId, out var modelSession); return Task.FromResult(modelSession); } - public Task> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName) - { - var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName); - if (modelOption is null) - return Task.FromResult(ServiceResult.FromError($"Model option '{modelName}' not found")); - var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName); - if (promptOption is null) - return Task.FromResult(ServiceResult.FromError($"Prompt option '{promptName}' not found")); + public async Task> CreateAsync(string sessionId, CreateSessionModel sessionModel) + { + // Remove existing connections session + await RemoveAsync(sessionId); - var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName); - if (parameterOption is null) - return Task.FromResult(ServiceResult.FromError($"Parameter option '{parameterName}' not found")); + var modelOption = _options.Models.FirstOrDefault(x => x.Name == sessionModel.Model); + if (modelOption is null) + return ServiceResult.FromError($"Model option '{sessionModel.Model}' not found"); //Max instance var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name); if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) - return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); + return ServiceResult.FromError("Maximum model instances reached"); - // Create model - var llamaModel = new LLamaModel(modelOption); + // Create Model/Context + var llamaModelContext = await CreateModelContext(sessionId, modelOption); // Create executor - ILLamaExecutor executor = executorType switch + ILLamaExecutor executor = sessionModel.ExecutorType switch { - LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), - LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), - LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), + LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModelContext), + LLamaExecutorType.Instruct => new InstructExecutor(llamaModelContext), + LLamaExecutorType.Stateless => new StatelessExecutor(llamaModelContext), _ => default }; + // Create Prompt + var promptOption = new PromptOptions + { + Name = "Custom", + Prompt = sessionModel.Prompt, + AntiPrompt = CreateListFromCSV(sessionModel.AntiPrompt), + OutputFilter = CreateListFromCSV(sessionModel.OutputFilter), + }; + // Create session - var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption); - if (!_modelSessions.TryAdd(connectionId, modelSession)) - return Task.FromResult(ServiceResult.FromError("Failed to create model session")); + var modelSession = new ModelSession(executor, modelOption, promptOption, sessionModel); + if (!_modelSessions.TryAdd(sessionId, modelSession)) + return ServiceResult.FromError("Failed to create model session"); - return Task.FromResult(ServiceResult.FromValue(modelSession)); + return ServiceResult.FromValue(default); } - public Task RemoveAsync(string connectionId) + + public async IAsyncEnumerable InferAsync(string sessionId, string prompt, CancellationTokenSource cancellationTokenSource) { - if (_modelSessions.TryRemove(connectionId, out var modelSession)) + var modelSession = await GetAsync(sessionId); + if (modelSession is null) + yield break; + + // Create unique response id + var responseId = Guid.NewGuid().ToString(); + + // Send begin of response + var stopwatch = Stopwatch.GetTimestamp(); + yield return new ResponseFragment + { + Id = responseId, + IsFirst = true + }; + + // Send content of response + await foreach (var fragment in modelSession.InferAsync(prompt, cancellationTokenSource)) + { + yield return new ResponseFragment + { + Id = responseId, + Content = fragment + }; + } + + // Send end of response + var elapsedTime = Stopwatch.GetElapsedTime(stopwatch); + var signature = modelSession.IsInferCanceled() + ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" + : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; + yield return new ResponseFragment + { + Id = responseId, + IsLast = true, + Content = signature, + IsCancelled = modelSession.IsInferCanceled(), + Elapsed = (int)elapsedTime.TotalMilliseconds + }; + } + + + public async Task RemoveAsync(string sessionId) + { + if (_modelSessions.TryRemove(sessionId, out var modelSession)) { modelSession.CancelInfer(); - modelSession.Dispose(); - return Task.FromResult(true); + var llamaModel = await _modelCacheService.Get(modelSession.ModelName); + if (llamaModel is null) + return false; + + return await llamaModel.RemoveContext(sessionId); } - return Task.FromResult(false); + return false; } - public Task CancelAsync(string connectionId) + public Task CancelAsync(string sessionId) { - if (_modelSessions.TryGetValue(connectionId, out var modelSession)) + if (_modelSessions.TryGetValue(sessionId, out var modelSession)) { modelSession.CancelInfer(); return Task.FromResult(true); } return Task.FromResult(false); } + + private async Task CreateModelContext(string sessionId, ModelOptions modelOption) + { + // Create model + var llamaModel = await _modelCacheService.Get(modelOption.Name) + ?? await _modelCacheService.Create(modelOption); + if (llamaModel is null) + throw new Exception($"Failed to create model, modelName: {modelOption.Name}"); + + //Create context + var llamaModelContext = await llamaModel.GetContext(sessionId) + ?? await llamaModel.CreateContext(sessionId); + if (llamaModelContext is null) + throw new Exception($"Failed to create model, connectionId: {sessionId}"); + + return llamaModelContext; + } + + private List CreateListFromCSV(string csv) + { + if(string.IsNullOrEmpty(csv)) + return null; + + return csv.Split(",") + .Select(x => x.Trim()) + .ToList(); + } } } diff --git a/LLama.Web/Services/IModelCacheService.cs b/LLama.Web/Services/IModelCacheService.cs new file mode 100644 index 000000000..ed627cbae --- /dev/null +++ b/LLama.Web/Services/IModelCacheService.cs @@ -0,0 +1,11 @@ +using LLama.Web.Common; + +namespace LLama.Web.Services +{ + public interface IModelCacheService + { + Task Create(ModelOptions modelOptions); + Task Get(string modelName); + Task Remove(string modelName); + } +} \ No newline at end of file diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index 4ee0d483f..fb177ccf1 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -7,10 +7,9 @@ namespace LLama.Web.Services public interface IModelSessionService { Task GetAsync(string sessionId); - Task> CreateAsync(LLamaExecutorType executorType, string sessionId, string modelName, string promptName, string parameterName); + Task> CreateAsync(string sessionId, CreateSessionModel sessionModel); Task RemoveAsync(string sessionId); Task CancelAsync(string sessionId); } - } diff --git a/LLama.Web/Services/ModelCacheService.cs b/LLama.Web/Services/ModelCacheService.cs new file mode 100644 index 000000000..f2d1f7ed5 --- /dev/null +++ b/LLama.Web/Services/ModelCacheService.cs @@ -0,0 +1,40 @@ +using System.Collections.Concurrent; +using LLama.Web.Services; + +namespace LLama.Web.Common +{ + public class ModelCacheService : IModelCacheService + { + private readonly ConcurrentDictionary _modelInstances = new ConcurrentDictionary(); + + public Task Create(ModelOptions modelOptions) + { + if (_modelInstances.TryGetValue(modelOptions.Name, out LLamaModel model)) + return Task.FromResult(model); + + model = new LLamaModel(modelOptions); + if (!_modelInstances.TryAdd(modelOptions.Name, model)) + throw new Exception($"Failed to cache model {modelOptions.Name}."); + + return Task.FromResult(model); + } + + + public Task Get(string modelName) + { + _modelInstances.TryGetValue(modelName, out LLamaModel model); + return Task.FromResult(model); + } + + + public Task Remove(string modelName) + { + if (_modelInstances.TryRemove(modelName, out LLamaModel model)) + { + model?.Dispose(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + } +} diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 9f340a9c7..652fd36ba 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -6,49 +6,15 @@ } }, "AllowedHosts": "*", + "Urls": "https://localhost:5000", "LLamaOptions": { "Models": [ { "Name": "WizardLM-7B", "MaxInstances": 2, - "ModelPath": "D:\\Repositories\\AI\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", + "ModelPath": "\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", "ContextSize": 2048 } - ], - "Parameters": [ - { - "Name": "Default", - "Temperature": 0.6 - } - ], - "Prompts": [ - { - "Name": "None", - "Prompt": "" - }, - { - "Name": "Alpaca", - "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Response:", - "User:" - ] - }, - { - "Name": "ChatWithBob", - "Path": "D:\\Repositories\\AI\\Prompts\\chat-with-bob.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Bob:", - "User:" - ] - } ] - } } diff --git a/LLama.Web/wwwroot/css/site.css b/LLama.Web/wwwroot/css/site.css index d10ef9757..a31573cf2 100644 --- a/LLama.Web/wwwroot/css/site.css +++ b/LLama.Web/wwwroot/css/site.css @@ -1,4 +1,4 @@ -html, body { +html, body { font-size: 14px; height: 100%; display: flex; @@ -31,4 +31,21 @@ footer { box-shadow: 0 0 0 0.1rem white, 0 0 0 0.25rem #258cfb; } +#scroll-container { + flex: 1; + overflow-y: scroll; +} + +#output-container .content { + white-space: break-spaces; +} + + +.slider-container > .slider { + width: 100%; +} +.slider-container > label { + width: 50px; + text-align: center; +} \ No newline at end of file diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 472b59718..99c7a3418 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -1,9 +1,9 @@ -const createConnectionSessionChat = (LLamaExecutorType) => { +const createConnectionSessionChat = () => { const outputErrorTemplate = $("#outputErrorTemplate").html(); const outputInfoTemplate = $("#outputInfoTemplate").html(); const outputUserTemplate = $("#outputUserTemplate").html(); const outputBotTemplate = $("#outputBotTemplate").html(); - const sessionDetailsTemplate = $("#sessionDetailsTemplate").html(); + const signatureTemplate = $("#signatureTemplate").html(); let connectionId; const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build(); @@ -20,7 +20,6 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } else if (status == Enums.SessionConnectionStatus.Loaded) { enableControls(); - $("#session-details").html(Mustache.render(sessionDetailsTemplate, { model: getSelectedModel(), prompt: getSelectedPrompt(), parameter: getSelectedParameter() })); onInfo(`New model session successfully started`) } } @@ -53,7 +52,7 @@ const createConnectionSessionChat = (LLamaExecutorType) => { if (response.isLast) { enableControls(); - responseContainer.find(".signature").append(response.content); + responseContainer.find(".signature").append(Mustache.render(signatureTemplate, response)); scrollToBottom(); } else { @@ -71,10 +70,10 @@ const createConnectionSessionChat = (LLamaExecutorType) => { const sendPrompt = async () => { const text = chatInput.val(); if (text) { + chatInput.val(null); disableControls(); outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); await connection.invoke('SendPrompt', text); - chatInput.val(null); scrollToBottom(true); } } @@ -84,16 +83,32 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } const loadModel = async () => { - const modelName = getSelectedModel(); - const promptName = getSelectedPrompt(); - const parameterName = getSelectedParameter(); - if (!modelName || !promptName || !parameterName) { - onError("Please select a valid Model, Parameter and Prompt"); - return; - } - disableControls(); - await connection.invoke('LoadModel', LLamaExecutorType, modelName, promptName, parameterName); + await connection.invoke('LoadModel', serializeFormToJson('SessionParameters')); + } + + + const serializeFormToJson = (form) => { + const formDataJson = {}; + const formData = new FormData(document.getElementById(form)); + formData.forEach((value, key) => { + + if (key.includes(".")) + key = key.split(".")[1]; + + // Convert number strings to numbers + if (!isNaN(value) && value.trim() !== "") { + formDataJson[key] = parseFloat(value); + } + // Convert boolean strings to booleans + else if (value === "true" || value === "false") { + formDataJson[key] = (value === "true"); + } + else { + formDataJson[key] = value; + } + }); + return formDataJson; } @@ -118,21 +133,6 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } - const getSelectedModel = () => { - return $("option:selected", "#Model").val(); - } - - - const getSelectedParameter = () => { - return $("option:selected", "#Parameter").val(); - } - - - const getSelectedPrompt = () => { - return $("option:selected", "#Prompt").val(); - } - - const getDateTime = () => { const dateTime = new Date(); return dateTime.toLocaleString(); @@ -165,7 +165,10 @@ const createConnectionSessionChat = (LLamaExecutorType) => { sendPrompt(); } }); - + $(".slider").on("input", function (e) { + const slider = $(this); + slider.next().text(slider.val()); + }).trigger("input"); // Map signalr functions diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ab89b5176..d2a54f702 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -1,7 +1,5 @@  using LLama.WebAPI.Models; -using Microsoft; -using System.Runtime.CompilerServices; namespace LLama.WebAPI.Services; @@ -9,6 +7,7 @@ public class StatefulChatService : IDisposable { private readonly ChatSession _session; private readonly LLamaModel _model; + private readonly LLamaModelContext _context; private bool _continue = false; private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" @@ -17,12 +16,13 @@ public class StatefulChatService : IDisposable public StatefulChatService(IConfiguration configuration) { _model = new LLamaModel(new Common.ModelParams(configuration["ModelPath"], contextSize: 512)); - _session = new ChatSession(new InteractiveExecutor(_model)); + _context = new LLamaModelContext(_model); + _session = new ChatSession(new InteractiveExecutor(_context)); } public void Dispose() { - _model?.Dispose(); + _context?.Dispose(); } public string Send(SendMessageInput input) diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index c1356646e..c186954f3 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -1,5 +1,4 @@ using LLama.Common; -using Microsoft.AspNetCore.Http; using System.Text; using static LLama.LLamaTransforms; @@ -8,13 +7,15 @@ namespace LLama.WebAPI.Services public class StatelessChatService { private readonly LLamaModel _model; + private readonly LLamaModelContext _context; private readonly ChatSession _session; public StatelessChatService(IConfiguration configuration) { _model = new LLamaModel(new ModelParams(configuration["ModelPath"], contextSize: 512)); + _context = new LLamaModelContext(_model); // TODO: replace with a stateless executor - _session = new ChatSession(new InteractiveExecutor(_model)) + _session = new ChatSession(new InteractiveExecutor(_context)) .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) .WithHistoryTransform(new HistoryTransform()); } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs new file mode 100644 index 000000000..de51f8986 --- /dev/null +++ b/LLama/Abstractions/IInferenceParams.cs @@ -0,0 +1,29 @@ +using System.Collections.Generic; +using LLama.Common; + +namespace LLama.Abstractions +{ + public interface IInferenceParams + { + IEnumerable AntiPrompts { get; set; } + float FrequencyPenalty { get; set; } + string InputPrefix { get; set; } + string InputSuffix { get; set; } + Dictionary? LogitBias { get; set; } + int MaxTokens { get; set; } + MirostatType Mirostat { get; set; } + float MirostatEta { get; set; } + float MirostatTau { get; set; } + string PathSession { get; set; } + bool PenalizeNL { get; set; } + float PresencePenalty { get; set; } + int RepeatLastTokensCount { get; set; } + float RepeatPenalty { get; set; } + float Temperature { get; set; } + float TfsZ { get; set; } + int TokensKeep { get; set; } + int TopK { get; set; } + float TopP { get; set; } + float TypicalP { get; set; } + } +} \ No newline at end of file diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index d35e075e9..4bdcc6bd9 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -1,5 +1,4 @@ -using LLama.Common; -using System; +using System; using System.Collections.Generic; using System.Text; using System.Threading; @@ -14,7 +13,7 @@ public interface ILLamaExecutor /// /// The loaded model for this executor. /// - public LLamaModel Model { get; } + public LLamaModelContext Context { get; } /// /// Infers a response from the model. @@ -23,7 +22,7 @@ public interface ILLamaExecutor /// Any additional parameters /// A cancellation token. /// - IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken token = default); + IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); /// /// Asynchronously infers a response from the model. @@ -32,6 +31,6 @@ public interface ILLamaExecutor /// Any additional parameters /// A cancellation token. /// - IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, CancellationToken token = default); + IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); } } diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs new file mode 100644 index 000000000..695de2284 --- /dev/null +++ b/LLama/Abstractions/IModelParams.cs @@ -0,0 +1,28 @@ +namespace LLama.Abstractions +{ + public interface IModelParams + { + int BatchSize { get; set; } + int ContextSize { get; set; } + bool ConvertEosToNewLine { get; set; } + bool EmbeddingMode { get; set; } + int GpuLayerCount { get; set; } + int GroupedQueryAttention { get; set; } + string LoraAdapter { get; set; } + string LoraBase { get; set; } + bool LowVram { get; set; } + int MainGpu { get; set; } + string ModelPath { get; set; } + string Name { get; set; } + bool Perplexity { get; set; } + float RmsNormEpsilon { get; set; } + float RopeFrequencyBase { get; set; } + float RopeFrequencyScale { get; set; } + int Seed { get; set; } + float[] TensorSplits { get; set; } + int Threads { get; set; } + bool UseFp16Memory { get; set; } + bool UseMemoryLock { get; set; } + bool UseMemorymap { get; set; } + } +} \ No newline at end of file diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index b87e8984f..5b1e7eeac 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -91,7 +91,7 @@ public virtual void SaveSession(string path) { Directory.CreateDirectory(path); } - _executor.Model.SaveState(Path.Combine(path, _modelStateFilename)); + _executor.Context.SaveState(Path.Combine(path, _modelStateFilename)); if(Executor is StatelessExecutor) { @@ -116,7 +116,7 @@ public virtual void LoadSession(string path) { throw new FileNotFoundException($"Directory {path} does not exist."); } - _executor.Model.LoadState(Path.Combine(path, _modelStateFilename)); + _executor.Context.LoadState(Path.Combine(path, _modelStateFilename)); if (Executor is StatelessExecutor) { diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index e059ef705..4b082feb7 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -30,9 +30,11 @@ public FixedSizeQueue(int size) /// public FixedSizeQueue(int size, IEnumerable data) { +#if NETCOREAPP3_0_OR_GREATER // Try an early check on the amount of data supplied (if possible) if (data.TryGetNonEnumeratedCount(out var count) && count > size) throw new ArgumentException($"The max size set for the quene is {size}, but got {count} initial values."); +#endif // Size of "data" is unknown, copy it all into a list _maxSize = size; @@ -40,7 +42,7 @@ public FixedSizeQueue(int size, IEnumerable data) // Now check if that list is a valid size if (_storage.Count > _maxSize) - throw new ArgumentException($"The max size set for the quene is {size}, but got {count} initial values."); + throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); } /// diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 77af7eafe..b21bd32ee 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using LLama.Abstractions; namespace LLama.Common { @@ -7,7 +8,7 @@ namespace LLama.Common /// /// The paramters used for inference. /// - public class InferenceParams + public class InferenceParams : IInferenceParams { /// /// number of tokens to keep from initial prompt diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 2a591bcd7..d917deaa8 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,13 +1,14 @@ using System; using System.Collections.Generic; using System.Text; +using LLama.Abstractions; namespace LLama.Common { /// /// The parameters for initializing a LLama model. /// - public class ModelParams + public class ModelParams : IModelParams { /// /// Model context size (n_ctx) @@ -50,9 +51,9 @@ public class ModelParams /// public string ModelPath { get; set; } /// - /// model alias + /// model name /// - public string ModelAlias { get; set; } = "unknown"; + public string Name { get; set; } = "unknown"; /// /// lora adapter path (lora_adapter) /// @@ -86,6 +87,26 @@ public class ModelParams /// public float[] TensorSplits { get; set; } = new float[] { 0 }; + /// + /// Grouped-Query Attention + /// + public int GroupedQueryAttention { get; set; } = 1; + + /// + /// RMS Norm Epsilon + /// + public float RmsNormEpsilon { get; set; } = 5e-6f; + + /// + /// RoPE base frequency + /// + public float RopeFrequencyBase { get; set; } = 10000.0f; + + /// + /// RoPE frequency scaling factor + /// + public float RopeFrequencyScale { get; set; } = 1.0f; + /// /// /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 4bbb61d2e..325a7520f 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -5,6 +5,7 @@ using LLama.Exceptions; using System.Linq; using LLama.Common; +using LLama.Abstractions; namespace LLama { @@ -28,7 +29,7 @@ internal LLamaEmbedder(SafeLLamaContextHandle ctx) /// /// /// - public LLamaEmbedder(ModelParams @params) + public LLamaEmbedder(IModelParams @params) { @params.EmbeddingMode = true; _ctx = Utils.InitLLamaContextFromModelParams(@params); diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index afbc0f258..5b4e84062 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -19,9 +19,9 @@ namespace LLama public abstract class StatefulExecutorBase : ILLamaExecutor { /// - /// The loaded model for this executor. + /// The loaded context for this executor. /// - protected readonly LLamaModel _model; + protected readonly LLamaModelContext _context; /// /// The logger used by this executor. /// @@ -63,9 +63,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected FixedSizeQueue _last_n_tokens; /// - /// The mode used by the executor. + /// The context used by the executor. /// - public LLamaModel Model => _model; + public LLamaModelContext Context => _context; /// /// Current "mu" value for mirostate sampling @@ -75,16 +75,16 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// /// /// - /// + /// /// - protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null) + protected StatefulExecutorBase(LLamaModelContext context, ILLamaLogger? logger = null) { - _model = model; + _context = context; _logger = logger; _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; - _last_n_tokens = new FixedSizeQueue(_model.ContextSize).FillWith(0); + _last_n_tokens = new FixedSizeQueue(_context.ContextSize).FillWith(0); } /// @@ -104,9 +104,9 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename) if (File.Exists(filename)) { _logger?.Log("LLamaExecutor", $"Attempting to load saved session from {filename}", ILLamaLogger.LogLevel.Info); - llama_token[] session_tokens = new llama_token[_model.ContextSize]; + llama_token[] session_tokens = new llama_token[_context.ContextSize]; ulong n_token_count_out = 0; - if (!NativeApi.llama_load_session_file(_model.NativeHandle, _pathSession, session_tokens, (ulong)_model.ContextSize, &n_token_count_out)) + if (!NativeApi.llama_load_session_file(_context.NativeHandle, _pathSession, session_tokens, (ulong)_context.ContextSize, &n_token_count_out)) { _logger?.Log("LLamaExecutor", $"Failed to load session file {filename}", ILLamaLogger.LogLevel.Error); throw new RuntimeError($"Failed to load session file {_pathSession}"); @@ -156,7 +156,7 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename) public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); - NativeApi.llama_save_session_file(_model.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); + NativeApi.llama_save_session_file(_context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); } /// @@ -173,7 +173,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) _pastTokensCount = Math.Max(1, tokensToKeep); // insert n_left/2 tokens at the start of embed from last_n_tokens - _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(_model.ContextSize - n_left / 2 - _embeds.Count)); + _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(_context.ContextSize - n_left / 2 - _embeds.Count)); // stop saving session if we run out of context _pathSession = string.Empty; @@ -231,13 +231,13 @@ protected virtual void TryReuseMathingPrefix() /// /// /// - protected abstract bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs); + protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs); /// /// The core inference logic. /// /// /// - protected abstract void InferInternal(InferenceParams inferenceParams, InferStateArgs args); + protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); /// /// Save the current state to a file. /// @@ -267,7 +267,7 @@ protected virtual void TryReuseMathingPrefix() /// /// /// - public virtual IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public virtual IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); if (inferenceParams is null) @@ -296,7 +296,7 @@ public virtual IEnumerable Infer(string text, InferenceParams? inference if (args.ReturnValue) { - foreach (var item in _model.GenerateResult(_embeds)) + foreach (var item in _context.GenerateResult(_embeds)) { yield return item; } @@ -324,7 +324,7 @@ public virtual IEnumerable Infer(string text, InferenceParams? inference /// /// /// - public virtual async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var result in Infer(text, inferenceParams, cancellationToken)) { diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 89fbac59f..d9ba2c46f 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Abstractions; +using LLama.Common; using LLama.Native; using System; using System.Collections.Generic; @@ -26,11 +27,11 @@ public class InstructExecutor : StatefulExecutorBase /// /// /// - public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", + public InstructExecutor(LLamaModelContext model, string instructionPrefix = "\n\n### Instruction:\n\n", string instructionSuffix = "\n\n### Response:\n\n") : base(model) { - _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray(); - _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray(); + _inp_pfx = _context.Tokenize(instructionPrefix, true).ToArray(); + _inp_sfx = _context.Tokenize(instructionSuffix, false).ToArray(); _instructionPrefix = instructionPrefix; } @@ -116,7 +117,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) { // When running the first input (prompt) in inteactive mode, we should specially process it. text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); + _embed_inps = _context.Tokenize(text, true).ToList(); } else { @@ -127,7 +128,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) _consumedTokensCount = _embed_inps.Count; _embed_inps.AddRange(_inp_pfx); - var line_inp = _model.Tokenize(text, false); + var line_inp = _context.Tokenize(text, false); _embed_inps.AddRange(line_inp); _embed_inps.AddRange(_inp_sfx); @@ -136,7 +137,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) } } /// - protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) + protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) { extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) @@ -146,7 +147,7 @@ protected override bool PostProcess(InferenceParams inferenceParams, InferStateA string last_output = ""; foreach (var id in _last_n_tokens) { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); + last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_context.NativeHandle, id), _context.Encoding); } foreach (var antiprompt in args.Antiprompts) @@ -179,18 +180,18 @@ protected override bool PostProcess(InferenceParams inferenceParams, InferStateA return false; } /// - protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args) + protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + if (_pastTokensCount + _embeds.Count > _context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = _context.Eval(_embeds.ToArray(), _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -203,7 +204,7 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) @@ -212,11 +213,11 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat SaveSessionFile(_pathSession); } - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = _context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostateMu; - var id = _model.Sample( + var id = _context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); @@ -236,7 +237,7 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) + if (_embeds.Count >= _context.Params.BatchSize) { break; } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index bc3a242e8..e15f1f452 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Abstractions; +using LLama.Common; using LLama.Native; using System; using System.Collections.Generic; @@ -22,9 +23,9 @@ public class InteractiveExecutor : StatefulExecutorBase /// /// /// - public InteractiveExecutor(LLamaModel model) : base(model) + public InteractiveExecutor(LLamaModelContext model) : base(model) { - _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray(); + _llama_token_newline = Utils.Tokenize(_context.NativeHandle, "\n", false, _context.Encoding).ToArray(); } /// @@ -103,7 +104,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) { // When running the first input (prompt) in inteactive mode, we should specially process it. text = " " + text; - _embed_inps = _model.Tokenize(text, true).ToList(); + _embed_inps = _context.Tokenize(text, true).ToList(); } else { @@ -111,7 +112,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) { text += "\n"; } - var line_inp = _model.Tokenize(text, false); + var line_inp = _context.Tokenize(text, false); _embed_inps.AddRange(line_inp); args.RemainedTokens -= line_inp.Count(); } @@ -122,7 +123,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) /// /// /// - protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) + protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) { extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) @@ -132,7 +133,7 @@ protected override bool PostProcess(InferenceParams inferenceParams, InferStateA string last_output = ""; foreach (var id in _last_n_tokens) { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); + last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_context.NativeHandle, id), _context.Encoding); } foreach (var antiprompt in args.Antiprompts) @@ -166,18 +167,18 @@ protected override bool PostProcess(InferenceParams inferenceParams, InferStateA } /// - protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args) + protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > _model.ContextSize) + if (_pastTokensCount + _embeds.Count > _context.ContextSize) { HandleRunOutOfContext(inferenceParams.TokensKeep); } TryReuseMathingPrefix(); - _pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = _context.Eval(_embeds.ToArray(), _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { @@ -190,7 +191,7 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount; + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _context.ContextSize : inferenceParams.RepeatLastTokensCount; // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) @@ -199,11 +200,11 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat SaveSessionFile(_pathSession); } - var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = _context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostateMu; - var id = _model.Sample( + var id = _context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); @@ -216,7 +217,7 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat id = _llama_token_newline.First(); if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { - var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false); + var first_antiprompt = _context.Tokenize(args.Antiprompts[0], false); _embed_inps.AddRange(first_antiprompt); } } @@ -233,7 +234,7 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat _embeds.Add(_embed_inps[_consumedTokensCount]); _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); _consumedTokensCount++; - if (_embeds.Count >= _model.Params.BatchSize) + if (_embeds.Count >= _context.Params.BatchSize) { break; } diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 78118116f..b5b4d46ef 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -1,386 +1,104 @@ -using LLama.Exceptions; +using LLama.Abstractions; +using LLama.Common; using LLama.Native; using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.IO; -using System.IO.MemoryMappedFiles; -using LLama.Common; -using System.Runtime.InteropServices; -using LLama.Extensions; -using Microsoft.Win32.SafeHandles; +using System.Collections.Concurrent; +using System.Threading.Tasks; namespace LLama { - using llama_token = Int32; /// /// The abstraction of a LLama model, which holds the context in the native library. /// - public class LLamaModel: IDisposable + public class LLamaModel : IDisposable { // TODO: expose more properties. ILLamaLogger? _logger; - Encoding _encoding; - SafeLLamaContextHandle _ctx; - /// - /// The context size. - /// - public int ContextSize { get; } + SafeLlamaModelHandle _model; + ConcurrentDictionary _contexts; /// /// The model params set for this model. /// - public ModelParams Params { get; set; } + public IModelParams Params { get; set; } /// /// The native handle, which is used to be passed to the native APIs. Please avoid using it /// unless you know what is the usage of the Native API. /// - public SafeLLamaContextHandle NativeHandle => _ctx; - /// - /// The encoding set for this model to deal with text input. - /// - public Encoding Encoding => _encoding; + public SafeLlamaModelHandle NativeHandle => _model; + /// /// /// - /// Model params. - /// Encoding to deal with text input. + /// Model params. /// The logger. - public LLamaModel(ModelParams Params, string encoding = "UTF-8", ILLamaLogger? logger = null) + public LLamaModel(IModelParams modelParams, ILLamaLogger? logger = null) { _logger = logger; - this.Params = Params; - _encoding = Encoding.GetEncoding(encoding); - _logger?.Log(nameof(LLamaModel), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); - _ctx = Utils.InitLLamaContextFromModelParams(this.Params); - ContextSize = NativeApi.llama_n_ctx(_ctx); - } - - /// - /// Tokenize a string. - /// - /// - /// Whether to add a bos to the text. - /// - public IEnumerable Tokenize(string text, bool addBos = true) - { - // TODO: reconsider whether to convert to array here. - return Utils.Tokenize(_ctx, text, addBos, _encoding); + _contexts = new ConcurrentDictionary(); + _logger?.Log(nameof(LLamaModelContext), $"Initializing LLama model with params: {modelParams}", ILLamaLogger.LogLevel.Info); + + Params = modelParams; + var contextParams = Utils.CreateContextParams(modelParams); + _model = SafeLlamaModelHandle.LoadFromFile(modelParams.ModelPath, contextParams); + if (!string.IsNullOrEmpty(modelParams.LoraAdapter)) + _model.ApplyLoraFromFile(modelParams.LoraAdapter, modelParams.LoraBase, modelParams.Threads); } /// - /// Detokenize the tokens to text. + /// Creates a new context session on this model /// - /// - /// - public string DeTokenize(IEnumerable tokens) + /// The unique context identifier + /// The contexts text encoding + /// LLamaModelContext for this LLamaModel + /// Context exists + public Task CreateContext(string contextId, string encoding = "UTF-8") { - StringBuilder sb = new(); - foreach(var token in tokens) - { - sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding)); - } - return sb.ToString(); - } - - /// - /// Save the state to specified path. - /// - /// - public void SaveState(string filename) - { - // Delete that file before overwriting it - if (File.Exists(filename)) - File.Delete(filename); - - // Estimate size of state to write to disk, this is always equal to or greater than the actual size - var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx); + if (_contexts.TryGetValue(contextId, out var context)) + throw new Exception($"Context with id {contextId} already exists."); - // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array - long writtenBytes; - using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize)) - using (var view = file.CreateViewAccessor(0, estimatedStateSize)) - { - unsafe - { - byte* ptr = null; - view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); - writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr); - view.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + context = new LLamaModelContext(this, encoding, _logger); + if (_contexts.TryAdd(contextId, context)) + return Task.FromResult(context); - // Truncate the file to the actual size of data that was written - using (var fileStream = new FileStream(filename, FileMode.Open)) - fileStream.SetLength(writtenBytes); + return Task.FromResult(null); } /// - /// Get the state data as a byte array. + /// Get a contexts belonging to this model /// - /// - [Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")] - public byte[] GetStateData() + /// The unique context identifier + /// LLamaModelContext for this LLamaModel with the specified contextId + public Task GetContext(string contextId) { - var stateSize = NativeApi.llama_get_state_size(_ctx); - byte[] stateMemory = new byte[stateSize]; - NativeApi.llama_copy_state_data(_ctx, stateMemory); - return stateMemory; - } + if (_contexts.TryGetValue(contextId, out var context)) + return Task.FromResult(context); - /// - /// Get the state data as an opaque handle - /// - /// - public State GetState() - { - var stateSize = NativeApi.llama_get_state_size(_ctx); - - unsafe - { - var bigMemory = Marshal.AllocHGlobal((nint)stateSize); - var smallMemory = IntPtr.Zero; - try - { - // Copy the state data into "big memory", discover the actual size required - var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); - - // Allocate a smaller buffer - smallMemory = Marshal.AllocHGlobal((nint)actualSize); - - // Copy into the smaller buffer and free the large one to save excess memory usage - Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); - Marshal.FreeHGlobal(bigMemory); - bigMemory = IntPtr.Zero; - - return new State(smallMemory); - } - catch - { - if (bigMemory != IntPtr.Zero) - Marshal.FreeHGlobal(bigMemory); - if (smallMemory != IntPtr.Zero) - Marshal.FreeHGlobal(smallMemory); - throw; - } - } + return Task.FromResult(null); } /// - /// Load the state from specified path. + /// Remove a context from this model /// - /// - /// - public void LoadState(string filename) + /// The unique context identifier + /// true if removed, otherwise false + public Task RemoveContext(string contextId) { - // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from - using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) - using (var view = file.CreateViewAccessor()) - { - unsafe - { - byte* ptr = null; - view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); - NativeApi.llama_set_state_data(_ctx, ptr); - view.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } - } + if (!_contexts.TryRemove(contextId, out var context)) + return Task.FromResult(false); - /// - /// Load the state from memory. - /// - /// - /// - public void LoadState(byte[] stateData) - { - int stateSize = (int)NativeApi.llama_get_state_size(_ctx); - if (stateData.Length > stateSize) - { - throw new RuntimeError("Failed to validate state size."); - } - NativeApi.llama_set_state_data(_ctx, stateData); - } - - /// - /// Load the state from memory. - /// - /// - /// - public void LoadState(State state) - { - unsafe - { - NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); - } - } - - /// - /// Perform the sampling. Please don't use it unless you fully know what it does. - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStatType mirostat = MiroStatType.Disable, - float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) - { - llama_token id; - if (temperature <= 0) - { - // Greedy sampling - id = SamplingApi.llama_sample_token_greedy(_ctx, candidates); - } - else - { - if (float.IsNaN(mirostat_mu)) - mirostat_mu = 2 * mirostatTau; - - if (mirostat == MiroStatType.MiroStat) - { - const int mirostat_m = 100; - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); - } - else if (mirostat == MirostatType.Mirostat2) - { - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); - } - else - { - // Temperature sampling - SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); - SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); - SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); - SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token(_ctx, candidates); - } - } - return id; - } - - /// - /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. - /// - /// - /// - /// - /// - /// - /// - /// - /// - public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, - int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, - bool penalizeNL = true) - { - var n_vocab = NativeApi.llama_n_vocab(_ctx); - var logits = Utils.GetLogits(_ctx, n_vocab); - - // Apply params.logit_bias map - if(logitBias is not null) - { - foreach (var (key, value) in logitBias) - { - logits[key] += value; - } - } - - var candidates = new LLamaTokenData[n_vocab]; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) - candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); - LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); - - // Apply penalties - float nl_logit = logits[NativeApi.llama_token_nl()]; - int lastTokensCount = lastTokens.Count(); - var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize); - SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, - lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), - (ulong)last_n_repeat, repeatPenalty); - SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, - lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), - (ulong)last_n_repeat, alphaFrequency, alphaPresence); - if (!penalizeNL) - { - logits[NativeApi.llama_token_nl()] = nl_logit; - } - - return candidates_p; - } - - /// - /// - /// - /// - /// - /// The updated `pastTokensCount`. - /// - public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) - { - int total = tokens.Length; - for(int i = 0; i < total; i += Params.BatchSize) - { - int n_eval = total - i; - if(n_eval > Params.BatchSize) - { - n_eval = Params.BatchSize; - } - - if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) - { - _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); - throw new RuntimeError("Failed to eval."); - } - - pastTokensCount += n_eval; - } - return pastTokensCount; - } - - // TODO: add comment - internal IEnumerable GenerateResult(IEnumerable ids) - { - foreach(var id in ids) - { - yield return Utils.TokenToString(id, _ctx, _encoding); - } + context?.Dispose(); + return Task.FromResult(true); } /// public virtual void Dispose() { - _ctx.Dispose(); - } - - /// - /// The state of this model, which can be reloaded later - /// - public class State - : SafeHandleZeroOrMinusOneIsInvalid - { - internal State(IntPtr memory) - : base(true) - { - SetHandle(memory); - } - - /// - protected override bool ReleaseHandle() + foreach (var context in _contexts.Values) { - Marshal.FreeHGlobal(handle); - return true; + context?.Dispose(); } + _model.Dispose(); } } } diff --git a/LLama/LLamaModelContext.cs b/LLama/LLamaModelContext.cs new file mode 100644 index 000000000..15d8298b2 --- /dev/null +++ b/LLama/LLamaModelContext.cs @@ -0,0 +1,388 @@ +using LLama.Abstractions; +using LLama.Common; +using LLama.Exceptions; +using LLama.Extensions; +using LLama.Native; +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.MemoryMappedFiles; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama +{ + using llama_token = Int32; + /// + /// The abstraction of a context over a LLama model + /// + public class LLamaModelContext : IDisposable + { + // TODO: expose more properties. + ILLamaLogger? _logger; + Encoding _encoding; + SafeLLamaContextHandle _ctx; + /// + /// The context size. + /// + public int ContextSize { get; } + /// + /// The model params set for this model. + /// + public IModelParams Params { get; set; } + /// + /// The native handle, which is used to be passed to the native APIs. Please avoid using it + /// unless you know what is the usage of the Native API. + /// + public SafeLLamaContextHandle NativeHandle => _ctx; + /// + /// The encoding set for this model to deal with text input. + /// + public Encoding Encoding => _encoding; + + /// + /// + /// + /// Model instance. + /// Encoding to deal with text input. + /// The logger. + public LLamaModelContext(LLamaModel model, string encoding = "UTF-8", ILLamaLogger? logger = null) + { + _logger = logger; + Params = model.Params; + _encoding = Encoding.GetEncoding(encoding); + _logger?.Log(nameof(LLamaModelContext), $"Initializing LLama model with params: {Params}", ILLamaLogger.LogLevel.Info); + + var contextParams = Utils.CreateContextParams(Params); + _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, contextParams); + ContextSize = NativeApi.llama_n_ctx(_ctx); + } + /// + /// Tokenize a string. + /// + /// + /// Whether to add a bos to the text. + /// + public IEnumerable Tokenize(string text, bool addBos = true) + { + // TODO: reconsider whether to convert to array here. + return Utils.Tokenize(_ctx, text, addBos, _encoding); + } + + /// + /// Detokenize the tokens to text. + /// + /// + /// + public string DeTokenize(IEnumerable tokens) + { + StringBuilder sb = new(); + foreach (var token in tokens) + { + sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding)); + } + return sb.ToString(); + } + + /// + /// Save the state to specified path. + /// + /// + public void SaveState(string filename) + { + // Delete that file before overwriting it + if (File.Exists(filename)) + File.Delete(filename); + + // Estimate size of state to write to disk, this is always equal to or greater than the actual size + var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx); + + // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array + long writtenBytes; + using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize)) + using (var view = file.CreateViewAccessor(0, estimatedStateSize)) + { + unsafe + { + byte* ptr = null; + view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); + writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr); + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } + } + + // Truncate the file to the actual size of data that was written + using (var fileStream = new FileStream(filename, FileMode.Open)) + fileStream.SetLength(writtenBytes); + } + + /// + /// Get the state data as a byte array. + /// + /// + [Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")] + public byte[] GetStateData() + { + var stateSize = NativeApi.llama_get_state_size(_ctx); + byte[] stateMemory = new byte[stateSize]; + NativeApi.llama_copy_state_data(_ctx, stateMemory); + return stateMemory; + } + + /// + /// Get the state data as an opaque handle + /// + /// + public State GetState() + { + var stateSize = NativeApi.llama_get_state_size(_ctx); + + unsafe + { + var bigMemory = Marshal.AllocHGlobal((nint)stateSize); + var smallMemory = IntPtr.Zero; + try + { + // Copy the state data into "big memory", discover the actual size required + var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); + + // Allocate a smaller buffer + smallMemory = Marshal.AllocHGlobal((nint)actualSize); + + // Copy into the smaller buffer and free the large one to save excess memory usage + Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); + Marshal.FreeHGlobal(bigMemory); + bigMemory = IntPtr.Zero; + + return new State(smallMemory); + } + catch + { + if (bigMemory != IntPtr.Zero) + Marshal.FreeHGlobal(bigMemory); + if (smallMemory != IntPtr.Zero) + Marshal.FreeHGlobal(smallMemory); + throw; + } + } + } + + /// + /// Load the state from specified path. + /// + /// + /// + public void LoadState(string filename) + { + // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from + using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) + using (var view = file.CreateViewAccessor()) + { + unsafe + { + byte* ptr = null; + view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); + NativeApi.llama_set_state_data(_ctx, ptr); + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } + } + } + + /// + /// Load the state from memory. + /// + /// + /// + public void LoadState(byte[] stateData) + { + int stateSize = (int)NativeApi.llama_get_state_size(_ctx); + if (stateData.Length > stateSize) + { + throw new RuntimeError("Failed to validate state size."); + } + NativeApi.llama_set_state_data(_ctx, stateData); + } + + /// + /// Load the state from memory. + /// + /// + /// + public void LoadState(State state) + { + unsafe + { + NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); + } + } + + /// + /// Perform the sampling. Please don't use it unless you fully know what it does. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, + float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) + { + llama_token id; + if (temperature <= 0) + { + // Greedy sampling + id = SamplingApi.llama_sample_token_greedy(_ctx, candidates); + } + else + { + if (float.IsNaN(mirostat_mu)) + mirostat_mu = 2 * mirostatTau; + + if (mirostat == MirostatType.Mirostat) + { + const int mirostat_m = 100; + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); + } + else if (mirostat == MirostatType.Mirostat2) + { + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); + } + else + { + // Temperature sampling + SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); + SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); + SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); + SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token(_ctx, candidates); + } + } + return id; + } + + /// + /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, + int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, + bool penalizeNL = true) + { + var n_vocab = NativeApi.llama_n_vocab(_ctx); + var logits = Utils.GetLogits(_ctx, n_vocab); + + // Apply params.logit_bias map + if (logitBias is not null) + { + foreach (var (key, value) in logitBias) + { + logits[key] += value; + } + } + + var candidates = new LLamaTokenData[n_vocab]; + for (llama_token token_id = 0; token_id < n_vocab; token_id++) + candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); + LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); + + // Apply penalties + float nl_logit = logits[NativeApi.llama_token_nl()]; + int lastTokensCount = lastTokens.Count(); + var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize); + SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, + lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), + (ulong)last_n_repeat, repeatPenalty); + SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, + lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), + (ulong)last_n_repeat, alphaFrequency, alphaPresence); + if (!penalizeNL) + { + logits[NativeApi.llama_token_nl()] = nl_logit; + } + + return candidates_p; + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) + { + int total = tokens.Length; + for (int i = 0; i < total; i += Params.BatchSize) + { + int n_eval = total - i; + if (n_eval > Params.BatchSize) + { + n_eval = Params.BatchSize; + } + + if (Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) + { + _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); + throw new RuntimeError("Failed to eval."); + } + + pastTokensCount += n_eval; + } + return pastTokensCount; + } + + // TODO: add comment + internal IEnumerable GenerateResult(IEnumerable ids) + { + foreach (var id in ids) + { + yield return Utils.TokenToString(id, _ctx, _encoding); + } + } + + /// + public virtual void Dispose() + { + _ctx.Dispose(); + } + + /// + /// The state of this model, which can be reloaded later + /// + public class State + : SafeHandleZeroOrMinusOneIsInvalid + { + internal State(IntPtr memory) + : base(true) + { + SetHandle(memory); + } + + /// + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + } +} \ No newline at end of file diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 88fa16954..8148df3d4 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -16,17 +16,17 @@ namespace LLama /// public class StatelessExecutor : ILLamaExecutor { - private LLamaModel _model; - private LLamaModel.State _originalState; + private LLamaModelContext _model; + private LLamaModelContext.State _originalState; /// /// The mode used by the executor when running the inference. /// - public LLamaModel Model => _model; + public LLamaModelContext Context => _model; /// /// /// /// The LLama model. - public StatelessExecutor(LLamaModel model) + public StatelessExecutor(LLamaModelContext model) { _model = model; @@ -36,7 +36,7 @@ public StatelessExecutor(LLamaModel model) } /// - public IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); int n_past = 1; @@ -123,7 +123,7 @@ public IEnumerable Infer(string text, InferenceParams? inferenceParams = } /// - public async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var result in Infer(text, inferenceParams, cancellationToken)) { diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 42f2be3fc..17b79ae5f 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -47,7 +47,7 @@ public struct LLamaContextParams /// /// how to split layers across multiple GPUs /// - public float[] tensor_split; + public IntPtr tensor_split; /// /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 diff --git a/LLama/ResettableLLamaModel.cs b/LLama/ResettableLLamaModel.cs index f2862dc77..8788aa942 100644 --- a/LLama/ResettableLLamaModel.cs +++ b/LLama/ResettableLLamaModel.cs @@ -8,7 +8,7 @@ namespace LLama /// /// A LLamaModel what could be reset. Note that using this class will consume about 10% more memories. /// - public class ResettableLLamaModel : LLamaModel + public class ResettableLLamaModel : LLamaModelContext { /// /// The initial state of the model @@ -17,9 +17,9 @@ public class ResettableLLamaModel : LLamaModel /// /// /// - /// + /// /// - public ResettableLLamaModel(ModelParams Params, string encoding = "UTF-8") : base(Params, encoding) + public ResettableLLamaModel(LLamaModel model, string encoding = "UTF-8") : base(model, encoding) { OriginalState = GetState(); } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index c08912cf6..3a2ece76f 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; using LLama.Native; using System; @@ -13,41 +14,49 @@ namespace LLama using llama_token = Int32; internal static class Utils { - public static SafeLLamaContextHandle InitLLamaContextFromModelParams(ModelParams @params) + public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) + { + var lparams = CreateContextParams(@params); + var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + var ctx = SafeLLamaContextHandle.Create(model, lparams); + + if (!string.IsNullOrEmpty(@params.LoraAdapter)) + model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + + return ctx; + } + + public static LLamaContextParams CreateContextParams(IModelParams modelParams) { var lparams = NativeApi.llama_context_default_params(); - lparams.n_ctx = @params.ContextSize; - lparams.n_batch = @params.BatchSize; - lparams.main_gpu = @params.MainGpu; - lparams.n_gpu_layers = @params.GpuLayerCount; - lparams.seed = @params.Seed; - lparams.f16_kv = @params.UseFp16Memory; - lparams.use_mmap = @params.UseMemoryLock; - lparams.use_mlock = @params.UseMemoryLock; - lparams.logits_all = @params.Perplexity; - lparams.embedding = @params.EmbeddingMode; - lparams.low_vram = @params.LowVram; - - if (@params.TensorSplits.Length != 1) - { - throw new ArgumentException("Currently multi-gpu support is not supported by " + - "both llama.cpp and LLamaSharp."); - } - lparams.tensor_split = @params.TensorSplits; + lparams.n_ctx = modelParams.ContextSize; + lparams.n_batch = modelParams.BatchSize; + lparams.main_gpu = modelParams.MainGpu; + lparams.n_gpu_layers = modelParams.GpuLayerCount; + lparams.seed = modelParams.Seed; + lparams.f16_kv = modelParams.UseFp16Memory; + lparams.use_mmap = modelParams.UseMemoryLock; + lparams.use_mlock = modelParams.UseMemoryLock; + lparams.logits_all = modelParams.Perplexity; + lparams.embedding = modelParams.EmbeddingMode; + lparams.low_vram = modelParams.LowVram; - if (!File.Exists(@params.ModelPath)) + if (modelParams.TensorSplits.Length != 1) { - throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); + throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); } - var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - var ctx = SafeLLamaContextHandle.Create(model, lparams); + // Allocate memory for the 'tensor_split' array in C++, + lparams.tensor_split = Marshal.AllocHGlobal(modelParams.TensorSplits.Length * sizeof(float)); + Marshal.Copy(modelParams.TensorSplits, 0, lparams.tensor_split, modelParams.TensorSplits.Length); - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + if (!File.Exists(modelParams.ModelPath)) + { + throw new FileNotFoundException($"The model file does not exist: {modelParams.ModelPath}"); + } - return ctx; + return lparams; } public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll index 97c1a9b13..d392d4c29 100644 Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll index 6b9e5f9d0..e0d8ca490 100644 Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so index 5acc506a2..aabcd434b 100644 Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ diff --git a/README.md b/README.md index 7c126e942..e6673a33a 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,10 @@ string modelPath = "" // change it to your own model path var prompt = "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\r\n\r\nUser: Hello, Bob.\r\nBob: Hello. How may I help you today?\r\nUser: Please tell me the largest city in Europe.\r\nBob: Sure. The largest city in Europe is Moscow, the capital of Russia.\r\nUser:"; // use the "chat-with-bob" prompt here. // Initialize a chat session -var ex = new InteractiveExecutor(new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); -ChatSession session = new ChatSession(ex); +LLamaModel model = new LLamaModel(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5)); +LLamaModelContext context = model.CreateContext("context1"); +InteractiveExecutor executor = new InteractiveExecutor(context); +ChatSession session = new ChatSession(executor); // show the prompt Console.WriteLine(); @@ -124,10 +126,25 @@ We provide the integration of ASP.NET core [here](./LLama.WebAPI). Since current Since we are in short of hands, if you're familiar with ASP.NET core, we'll appreciate it if you would like to help upgrading the Web API integration. -## Demo +## Console Demo ![demo-console](Assets/console_demo.gif) + + + + +## Web Demo + +Model Parameters +![demo-ui](https://i.imgur.com/FG0YEzw.png) + + +Inference Parameters +![demo-ui2](https://i.imgur.com/fZEQTQ5.png) + + + ## Roadmap ---