diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml
index 4039b0bb3..26de28b54 100644
--- a/.github/workflows/compile.yml
+++ b/.github/workflows/compile.yml
@@ -14,6 +14,10 @@ on:
#schedule:
# - cron: "22 22 * * 2"
+env:
+ # Compiler defines common to all platforms
+ COMMON_DEFINE: -DLLAMA_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON
+
jobs:
compile-linux:
name: Compile (Linux)
@@ -22,13 +26,13 @@ jobs:
matrix:
include:
- build: 'noavx'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON'
+ defines: '-DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF'
- build: 'avx2'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON'
+ defines: ''
- build: 'avx'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DBUILD_SHARED_LIBS=ON'
+ defines: '-DLLAMA_AVX2=OFF'
- build: 'avx512'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON'
+ defines: '-DLLAMA_AVX512=ON'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@@ -39,7 +43,7 @@ jobs:
run: |
mkdir build
cd build
- cmake .. ${{ matrix.defines }}
+ cmake .. ${{ env.COMMON_DEFINE }} ${{ matrix.defines }}
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
- uses: actions/upload-artifact@v3
with:
@@ -53,13 +57,13 @@ jobs:
matrix:
include:
- build: 'noavx'
- defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON'
+ defines: '-DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF'
- build: 'avx2'
- defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON'
+ defines: ''
- build: 'avx'
- defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX2=OFF -DBUILD_SHARED_LIBS=ON'
+ defines: '-DLLAMA_AVX2=OFF'
- build: 'avx512'
- defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=OFF -DLLAMA_AVX512=ON -DBUILD_SHARED_LIBS=ON'
+ defines: '-DLLAMA_AVX512=ON'
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
@@ -71,7 +75,7 @@ jobs:
run: |
mkdir build
cd build
- cmake .. ${{ matrix.defines }}
+ cmake .. ${{ env.COMMON_DEFINE }} ${{ matrix.defines }}
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
- name: Upload artifacts
@@ -117,7 +121,7 @@ jobs:
run: |
mkdir build
cd build
- cmake .. -DLLAMA_CUBLAS=ON -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF
+ cmake .. ${{ env.COMMON_DEFINE }} -DLLAMA_CUBLAS=ON
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
ls -R
@@ -142,7 +146,7 @@ jobs:
matrix:
include:
- build: 'metal'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON -DLLAMA_NATIVE=OFF -DCMAKE_OSX_ARCHITECTURES=arm64'
+ defines: '-DCMAKE_OSX_ARCHITECTURES=arm64'
runs-on: macos-latest
steps:
- uses: actions/checkout@v3
@@ -157,7 +161,7 @@ jobs:
run: |
mkdir build
cd build
- cmake .. ${{ matrix.defines }}
+ cmake .. ${{ env.COMMON_DEFINE }} ${{ matrix.defines }}
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
- name: Upload artifacts
uses: actions/upload-artifact@v3
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 3fbfc9f50..1c08e6e57 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -12,23 +12,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- build: [linux-debug, linux-release, windows-debug, windows-release]
+ build: [linux-release, windows-release]
include:
- - build: linux-debug
- os: ubuntu-latest
- config: debug
- build: linux-release
os: ubuntu-latest
- config: release
- # - build: macos-debug
- # os: macos-latest
- # config: debug
+ config: release
# - build: macos-release
# os: macos-latest
# config: release
- - build: windows-debug
- os: windows-2019
- config: debug
- build: windows-release
os: windows-2019
config: release
diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj
index c17618297..02066fa0e 100644
--- a/LLama.Examples/LLama.Examples.csproj
+++ b/LLama.Examples/LLama.Examples.csproj
@@ -30,6 +30,7 @@
+
diff --git a/LLama.Examples/NewVersion/GetEmbeddings.cs b/LLama.Examples/NewVersion/GetEmbeddings.cs
index fe9e3ea80..1e5b19be3 100644
--- a/LLama.Examples/NewVersion/GetEmbeddings.cs
+++ b/LLama.Examples/NewVersion/GetEmbeddings.cs
@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class GetEmbeddings
{
- public static void Run()
+ public static Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -23,6 +23,7 @@ public static void Run()
Console.WriteLine(string.Join(", ", embedder.GetEmbeddings(text)));
Console.WriteLine();
}
+ return Task.CompletedTask;
}
}
}
diff --git a/LLama.Examples/NewVersion/QuantizeModel.cs b/LLama.Examples/NewVersion/QuantizeModel.cs
index 71966af8f..456d89290 100644
--- a/LLama.Examples/NewVersion/QuantizeModel.cs
+++ b/LLama.Examples/NewVersion/QuantizeModel.cs
@@ -2,7 +2,7 @@
{
public class QuantizeModel
{
- public static void Run()
+ public static Task Run()
{
Console.Write("Please input your original model path: ");
var inputPath = Console.ReadLine();
@@ -21,6 +21,8 @@ public static void Run()
{
Console.WriteLine("Quantization failed!");
}
+
+ return Task.CompletedTask;
}
}
}
diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs
index a21a2eed4..c89cba305 100644
--- a/LLama.Examples/NewVersion/TestRunner.cs
+++ b/LLama.Examples/NewVersion/TestRunner.cs
@@ -1,109 +1,54 @@
-namespace LLama.Examples.NewVersion
+using System.Linq.Expressions;
+using Spectre.Console;
+
+namespace LLama.Examples.NewVersion
{
public class NewVersionTestRunner
{
+ static Dictionary> Examples = new Dictionary>
+ {
+ {"Run a chat session without stripping the role names.", () => ChatSessionWithRoleName.Run()},
+ {"Run a chat session with the role names stripped.",()=> ChatSessionStripRoleName.Run()},
+ {"Interactive mode chat by using executor.",()=> InteractiveModeExecute.Run()},
+ {"Instruct mode chat by using executor.",()=> InstructModeExecute.Run()},
+ {"Stateless mode chat by using executor.",()=> StatelessModeExecute.Run()},
+ {"Load and save chat session.",()=> SaveAndLoadSession.Run()},
+ {"Load and save state of model and executor.",()=> LoadAndSaveState.Run()},
+ {"Get embeddings from LLama model.",()=> GetEmbeddings.Run()},
+ {"Quantize the model.",()=> QuantizeModel.Run()},
+ {"Automatic conversation.",()=> TalkToYourself.Run()},
+ {"Constrain response to json format using grammar.",()=> GrammarJsonResponse.Run()},
+ {"Semantic Kernel Prompt.",()=> SemanticKernelPrompt.Run()},
+ {"Semantic Kernel Chat.",()=> SemanticKernelChat.Run()},
+ {"Semantic Kernel Memory.",()=> SemanticKernelMemory.Run()},
+ {"Coding Assistant.",()=> CodingAssistant.Run()},
+ {"Batch Decoding.",()=> BatchedDecoding.Run()},
+ {"SK Kernel Memory.",()=> KernelMemory.Run()},
+ {"Exit", ()=> Task.CompletedTask}
+ };
public static async Task Run()
{
- Console.WriteLine("================LLamaSharp Examples (New Version)==================\n");
-
- Console.WriteLine("Please input a number to choose an example to run:");
- Console.WriteLine("0: Run a chat session without stripping the role names.");
- Console.WriteLine("1: Run a chat session with the role names stripped.");
- Console.WriteLine("2: Interactive mode chat by using executor.");
- Console.WriteLine("3: Instruct mode chat by using executor.");
- Console.WriteLine("4: Stateless mode chat by using executor.");
- Console.WriteLine("5: Load and save chat session.");
- Console.WriteLine("6: Load and save state of model and executor.");
- Console.WriteLine("7: Get embeddings from LLama model.");
- Console.WriteLine("8: Quantize the model.");
- Console.WriteLine("9: Automatic conversation.");
- Console.WriteLine("10: Constrain response to json format using grammar.");
- Console.WriteLine("11: Semantic Kernel Prompt.");
- Console.WriteLine("12: Semantic Kernel Chat.");
- Console.WriteLine("13: Semantic Kernel Memory.");
- Console.WriteLine("14: Coding Assistant.");
- Console.WriteLine("15: Batch Decoding.");
- Console.WriteLine("16: SK Kernel Memory.");
+ AnsiConsole.Write(new Rule("LLamaSharp Examples"));
while (true)
{
- Console.Write("\nYour choice: ");
- int choice = int.Parse(Console.ReadLine());
+ var choice = AnsiConsole.Prompt(
+ new SelectionPrompt()
+ .Title("Please choose[green] an example[/] to run: ")
+ .AddChoices(Examples.Keys));
- if (choice == 0)
- {
- await ChatSessionWithRoleName.Run();
- }
- else if (choice == 1)
- {
- await ChatSessionStripRoleName.Run();
- }
- else if (choice == 2)
- {
- await InteractiveModeExecute.Run();
- }
- else if (choice == 3)
- {
- await InstructModeExecute.Run();
- }
- else if (choice == 4)
- {
- await StatelessModeExecute.Run();
- }
- else if (choice == 5)
- {
- await SaveAndLoadSession.Run();
- }
- else if (choice == 6)
- {
- await LoadAndSaveState.Run();
- }
- else if (choice == 7)
- {
- GetEmbeddings.Run();
- }
- else if (choice == 8)
- {
- QuantizeModel.Run();
- }
- else if (choice == 9)
- {
- await TalkToYourself.Run();
- }
- else if (choice == 10)
- {
- await GrammarJsonResponse.Run();
- }
- else if (choice == 11)
- {
- await SemanticKernelPrompt.Run();
- }
- else if (choice == 12)
- {
- await SemanticKernelChat.Run();
- }
- else if (choice == 13)
- {
- await SemanticKernelMemory.Run();
- }
- else if (choice == 14)
- {
- await CodingAssistant.Run();
- }
- else if (choice == 15)
- {
- await BatchedDecoding.Run();
- }
- else if (choice == 16)
- {
- await KernelMemory.Run();
- }
- else
+
+ if (Examples.TryGetValue(choice, out var example))
{
- Console.WriteLine("Cannot parse your choice. Please select again.");
- continue;
+ if (choice == "Exit")
+ {
+ break;
+ }
+ AnsiConsole.Write(new Rule(choice));
+ await example();
}
- break;
+
+ AnsiConsole.Clear();
}
}
}
diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs
index 7d280b494..0b92ca6ed 100644
--- a/LLama.KernelMemory/BuilderExtensions.cs
+++ b/LLama.KernelMemory/BuilderExtensions.cs
@@ -4,6 +4,9 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
+using LLama;
+using LLama.Common;
+using Microsoft.KernelMemory.AI;
namespace LLamaSharp.KernelMemory
{
@@ -24,6 +27,18 @@ public static KernelMemoryBuilder WithLLamaSharpTextEmbeddingGeneration(this Ker
return builder;
}
+ ///
+ /// Adds LLamaSharpTextEmbeddingGeneration to the KernelMemoryBuilder.
+ ///
+ /// The KernelMemoryBuilder instance.
+ /// The LLamaSharpTextEmbeddingGeneration instance.
+ /// The KernelMemoryBuilder instance with LLamaSharpTextEmbeddingGeneration added.
+ public static KernelMemoryBuilder WithLLamaSharpTextEmbeddingGeneration(this KernelMemoryBuilder builder, LLamaSharpTextEmbeddingGeneration textEmbeddingGeneration)
+ {
+ builder.WithCustomEmbeddingGeneration(textEmbeddingGeneration);
+ return builder;
+ }
+
///
/// Adds LLamaSharpTextGeneration to the KernelMemoryBuilder.
///
@@ -36,6 +51,18 @@ public static KernelMemoryBuilder WithLLamaSharpTextGeneration(this KernelMemory
return builder;
}
+ ///
+ /// Adds LLamaSharpTextGeneration to the KernelMemoryBuilder.
+ ///
+ /// The KernelMemoryBuilder instance.
+ /// The LlamaSharpTextGeneration instance.
+ /// The KernelMemoryBuilder instance with LLamaSharpTextGeneration added.
+ public static KernelMemoryBuilder WithLLamaSharpTextGeneration(this KernelMemoryBuilder builder, LlamaSharpTextGeneration textGeneration)
+ {
+ builder.WithCustomTextGeneration(textGeneration);
+ return builder;
+ }
+
///
/// Adds LLamaSharpTextEmbeddingGeneration and LLamaSharpTextGeneration to the KernelMemoryBuilder.
///
@@ -44,8 +71,18 @@ public static KernelMemoryBuilder WithLLamaSharpTextGeneration(this KernelMemory
/// The KernelMemoryBuilder instance with LLamaSharpTextEmbeddingGeneration and LLamaSharpTextGeneration added.
public static KernelMemoryBuilder WithLLamaSharpDefaults(this KernelMemoryBuilder builder, LLamaSharpConfig config)
{
- builder.WithLLamaSharpTextEmbeddingGeneration(config);
- builder.WithLLamaSharpTextGeneration(config);
+ var parameters = new ModelParams(config.ModelPath)
+ {
+ ContextSize = config?.ContextSize ?? 2048,
+ Seed = config?.Seed ?? 0,
+ GpuLayerCount = config?.GpuLayerCount ?? 20
+ };
+ var weights = LLamaWeights.LoadFromFile(parameters);
+ var context = weights.CreateContext(parameters);
+ var executor = new StatelessExecutor(weights, parameters);
+ var embedder = new LLamaEmbedder(weights, parameters);
+ builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGeneration(embedder));
+ builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor));
return builder;
}
}
diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs
index a1681e153..cebbbe64b 100644
--- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs
+++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs
@@ -1,4 +1,5 @@
using LLama;
+using LLama.Abstractions;
using LLama.Common;
using Microsoft.SemanticKernel.AI.Embeddings;
using System;
@@ -14,9 +15,11 @@ namespace LLamaSharp.KernelMemory
///
public class LLamaSharpTextEmbeddingGeneration : ITextEmbeddingGeneration, IDisposable
{
- private readonly LLamaSharpConfig _config;
+ private readonly LLamaSharpConfig? _config;
+ private readonly LLamaWeights? _weights;
private readonly LLamaEmbedder _embedder;
- private readonly LLamaWeights _weights;
+ private bool _ownsEmbedder = false;
+ private bool _ownsWeights = false;
///
/// Initializes a new instance of the class.
@@ -28,13 +31,46 @@ public LLamaSharpTextEmbeddingGeneration(LLamaSharpConfig config)
var @params = new ModelParams(_config.ModelPath);
_weights = LLamaWeights.LoadFromFile(@params);
_embedder = new LLamaEmbedder(_weights, @params);
+ _ownsWeights = true;
+ _ownsEmbedder = true;
+ }
+
+ ///
+ /// Initializes a new instance of the class from reused weights.
+ ///
+ /// The configuration for LLamaSharp.
+ /// A LLamaWeights object.
+ public LLamaSharpTextEmbeddingGeneration(LLamaSharpConfig config, LLamaWeights weights)
+ {
+ this._config = config;
+ var @params = new ModelParams(_config.ModelPath);
+ _weights = weights;
+ _embedder = new LLamaEmbedder(_weights, @params);
+ _ownsEmbedder = true;
+ }
+
+ ///
+ /// Initializes a new instance of the class from reused embedder.
+ ///
+ /// A LLamaEmbedder object.
+ public LLamaSharpTextEmbeddingGeneration(LLamaEmbedder embedder)
+ {
+ this._config = null;
+ this._weights = null;
+ _embedder = embedder;
}
///
public void Dispose()
{
- _embedder.Dispose();
- _weights.Dispose();
+ if (_ownsWeights)
+ {
+ _weights?.Dispose();
+ }
+ if(_ownsEmbedder)
+ {
+ _embedder.Dispose();
+ }
}
///
diff --git a/LLama.KernelMemory/LlamaSharpConfig.cs b/LLama.KernelMemory/LlamaSharpConfig.cs
index 2220bf719..7d3aefbef 100644
--- a/LLama.KernelMemory/LlamaSharpConfig.cs
+++ b/LLama.KernelMemory/LlamaSharpConfig.cs
@@ -7,7 +7,7 @@
namespace LLamaSharp.KernelMemory
{
///
- /// Represents the configuration for LLamaSharp.
+ /// Represents the configuration for LLamaSharp. Available properties are `ModelPath`, `ContextSize`, `Seed`, `GpuLayerCount`.
///
public class LLamaSharpConfig
{
diff --git a/LLama.KernelMemory/LlamaSharpTextGeneration.cs b/LLama.KernelMemory/LlamaSharpTextGeneration.cs
index abc534b3c..c3734ea4d 100644
--- a/LLama.KernelMemory/LlamaSharpTextGeneration.cs
+++ b/LLama.KernelMemory/LlamaSharpTextGeneration.cs
@@ -1,4 +1,5 @@
using LLama;
+using LLama.Abstractions;
using LLama.Common;
using Microsoft.KernelMemory.AI;
using System;
@@ -14,10 +15,12 @@ namespace LLamaSharp.KernelMemory
///
public class LlamaSharpTextGeneration : ITextGeneration, IDisposable
{
- private readonly LLamaSharpConfig _config;
+ private readonly LLamaSharpConfig? _config;
private readonly LLamaWeights _weights;
private readonly StatelessExecutor _executor;
private readonly LLamaContext _context;
+ private bool _ownsContext = false;
+ private bool _ownsWeights = false;
///
/// Initializes a new instance of the class.
@@ -35,13 +38,35 @@ public LlamaSharpTextGeneration(LLamaSharpConfig config)
_weights = LLamaWeights.LoadFromFile(parameters);
_context = _weights.CreateContext(parameters);
_executor = new StatelessExecutor(_weights, parameters);
+ _ownsWeights = _ownsContext = true;
+ }
+
+ ///
+ /// Initializes a new instance of the class from reused weights, context and executor.
+ /// If executor is not specified, then a StatelessExecutor will be created with `context.Params`. So far only `StatelessExecutor` is expected.
+ ///
+ /// A LLamaWeights object.
+ /// A LLamaContext object.
+ /// An executor. Currently only StatelessExecutor is expected.
+ public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null)
+ {
+ _config = null;
+ _weights = weights;
+ _context = context;
+ _executor = executor ?? new StatelessExecutor(_weights, _context.Params);
}
///
public void Dispose()
{
- _context.Dispose();
- _weights.Dispose();
+ if (_ownsWeights)
+ {
+ _weights?.Dispose();
+ }
+ if (_ownsContext)
+ {
+ _context.Dispose();
+ }
}
///
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 6a63ccc31..8cbf2f091 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -17,9 +17,9 @@ public class ModelOptions
public int MaxInstances { get; set; }
///
- /// Model context size (n_ctx)
+ /// Model context size (n_ctx). Null to use value from model.
///
- public uint ContextSize { get; set; } = 512;
+ public uint? ContextSize { get; set; }
///
/// the GPU that is used for scratch and small tensors
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index 8ff6d7ccf..a2ac24f1a 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -8,9 +8,9 @@ namespace LLama.Abstractions;
public interface IContextParams
{
///
- /// Model context size (n_ctx)
+ /// Model context size (n_ctx). Null to use value from model file.
///
- uint ContextSize { get; set; }
+ uint? ContextSize { get; set; }
///
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 457e7e487..7ee995906 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -1,11 +1,14 @@
using LLama.Abstractions;
using LLama.Common;
+using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
+using static LLama.InteractiveExecutor;
namespace LLama
{
@@ -95,11 +98,11 @@ public virtual void SaveSession(string path)
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
- if(Executor is StatelessExecutor)
+ if (Executor is StatelessExecutor)
{
}
- else if(Executor is StatefulExecutorBase statefulExecutor)
+ else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
@@ -135,46 +138,90 @@ public virtual void LoadSession(string path)
}
///
- /// Get the response from the LLama model. Note that prompt could not only be the preset words,
- /// but also the question you want to ask.
+ /// Generates a response for a given user prompt and manages history state for the user.
+ /// This will always pass the whole history to the model. Don't pass a whole history
+ /// to this method as the user prompt will be appended to the history of the current session.
+ /// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
///
///
///
///
- ///
+ /// Returns generated text of the assistant message.
public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- foreach(var inputTransform in InputTransformPipeline)
+ foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);
-
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
+
+ History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
+
+ if (_executor is InteractiveExecutor executor)
+ {
+ InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
+ prompt = state.IsPromptRun
+ ? HistoryTransform.HistoryToText(History)
+ : prompt;
+ }
+
StringBuilder sb = new();
+
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
+
+ string assistantMessage = sb.ToString();
+
+ // Remove end tokens from the assistant message
+ // if defined in inferenceParams.AntiPrompts.
+ // We only want the response that was generated and not tokens
+ // that are delimiting the beginning or end of the response.
+ if (inferenceParams?.AntiPrompts != null)
+ {
+ foreach (var stopToken in inferenceParams.AntiPrompts)
+ {
+ assistantMessage = assistantMessage.Replace(stopToken, "");
+ }
+ }
+
+ History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
}
///
- /// Get the response from the LLama model with chat histories.
+ /// Generates a response for a given chat history. This method does not manage history state for the user.
+ /// If you want to e.g. truncate the history of a session to fit into the model's context window,
+ /// use this method and pass the truncated history to it. If you don't need this control, use the other
+ /// overload of this method that accepts a user prompt instead.
///
///
///
///
- ///
+ /// Returns generated text of the assistant message.
public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- var prompt = HistoryTransform.HistoryToText(history);
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
- StringBuilder sb = new();
+ if (history.Messages.Count == 0)
+ {
+ throw new ArgumentException("History must contain at least one message.");
+ }
+
+ string prompt;
+ if (_executor is InteractiveExecutor executor)
+ {
+ InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
+
+ prompt = state.IsPromptRun
+ ? HistoryTransform.HistoryToText(History)
+ : history.Messages.Last().Content;
+ }
+ else
+ {
+ prompt = history.Messages.Last().Content;
+ }
+
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
- sb.Append(result);
}
- History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs
index d4577a475..37fb1cf51 100644
--- a/LLama/Common/FixedSizeQueue.cs
+++ b/LLama/Common/FixedSizeQueue.cs
@@ -43,7 +43,7 @@ public FixedSizeQueue(int size)
///
public FixedSizeQueue(int size, IEnumerable data)
{
-#if !NETSTANDARD2_0
+#if NET6_0_OR_GREATER
// Try to check the size without enumerating the entire IEnumerable. This may not be able to get the count,
// in which case we'll have to check later
if (data.TryGetNonEnumeratedCount(out var dataCount) && dataCount > size)
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index ee5bd3e4c..9561e482e 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -12,105 +12,68 @@ namespace LLama.Common
public record ModelParams
: ILLamaParams
{
- ///
- /// Model context size (n_ctx)
- ///
- public uint ContextSize { get; set; } = 512;
- ///
- /// the GPU that is used for scratch and small tensors
- ///
+ ///
+ public uint? ContextSize { get; set; }
+
+ ///
public int MainGpu { get; set; } = 0;
- ///
- /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
- ///
+ ///
public int GpuLayerCount { get; set; } = 20;
- ///
- /// Seed for the random number generator (seed)
- ///
+
+ ///
public uint Seed { get; set; } = 0xFFFFFFFF;
- ///
- /// Use f16 instead of f32 for memory kv (memory_f16)
- ///
+
+ ///
public bool UseFp16Memory { get; set; } = true;
- ///
- /// Use mmap for faster loads (use_mmap)
- ///
+
+ ///
public bool UseMemorymap { get; set; } = true;
- ///
- /// Use mlock to keep model in memory (use_mlock)
- ///
+
+ ///
public bool UseMemoryLock { get; set; }
- ///
- /// Compute perplexity over the prompt (perplexity)
- ///
+
+ ///
public bool Perplexity { get; set; }
- ///
- /// Model path (model)
- ///
+
+ ///
public string ModelPath { get; set; }
- ///
- /// List of LoRAs to apply
- ///
+ ///
public AdapterCollection LoraAdapters { get; set; } = new();
- ///
- /// base model path for the lora adapter (lora_base)
- ///
+ ///
public string LoraBase { get; set; } = string.Empty;
- ///
- /// Number of threads (null = autodetect) (n_threads)
- ///
+ ///
public uint? Threads { get; set; }
- ///
- /// Number of threads to use for batch processing (null = autodetect) (n_threads)
- ///
+ ///
public uint? BatchThreads { get; set; }
- ///
- /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
- ///
+ ///
public uint BatchSize { get; set; } = 512;
- ///
- /// Whether to use embedding mode. (embedding) Note that if this is set to true,
- /// The LLamaModel won't produce text response anymore.
- ///
+ ///
public bool EmbeddingMode { get; set; }
- ///
- /// how split tensors should be distributed across GPUs.
- ///
- /// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.
+ ///
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public TensorSplitsCollection TensorSplits { get; set; } = new();
- ///
- /// RoPE base frequency
- ///
- public float? RopeFrequencyBase { get; set; }
+ ///
+ public float? RopeFrequencyBase { get; set; }
- ///
- /// RoPE frequency scaling factor
- ///
- public float? RopeFrequencyScale { get; set; }
+ ///
+ public float? RopeFrequencyScale { get; set; }
- ///
- /// Use experimental mul_mat_q kernels
- ///
- public bool MulMatQ { get; set; }
+ ///
+ public bool MulMatQ { get; set; }
- ///
- /// Load vocab only (no weights)
- ///
+ ///
public bool VocabOnly { get; set; }
- ///
- /// The encoding to use to convert text for the model
- ///
+ ///
[JsonConverter(typeof(EncodingConverter))]
public Encoding Encoding { get; set; } = Encoding.UTF8;
diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs
index a39ed7e8b..1af0e9e1f 100644
--- a/LLama/Extensions/DictionaryExtensions.cs
+++ b/LLama/Extensions/DictionaryExtensions.cs
@@ -9,6 +9,8 @@ public static TValue GetValueOrDefault(this IReadOnlyDictionary(IReadOnlyDictionary dictionary, TKey key, TValue defaultValue)
diff --git a/LLama/Extensions/EncodingExtensions.cs b/LLama/Extensions/EncodingExtensions.cs
index e88d83a70..5005b16c1 100644
--- a/LLama/Extensions/EncodingExtensions.cs
+++ b/LLama/Extensions/EncodingExtensions.cs
@@ -15,6 +15,8 @@ public static int GetCharCount(this Encoding encoding, ReadOnlySpan bytes)
{
return GetCharCountImpl(encoding, bytes);
}
+#elif !NET6_0_OR_GREATER && !NETSTANDARD2_1_OR_GREATER
+#error Target framework not supported!
#endif
internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan bytes, Span output)
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index fcc9d372a..ed59c9df0 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -21,7 +21,7 @@ public static class IContextParamsExtensions
public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
{
result = NativeApi.llama_context_default_params();
- result.n_ctx = @params.ContextSize;
+ result.n_ctx = @params.ContextSize ?? 0;
result.n_batch = @params.BatchSize;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
diff --git a/LLama/Extensions/IEnumerableExtensions.cs b/LLama/Extensions/IEnumerableExtensions.cs
index 9e01feb85..17428d297 100644
--- a/LLama/Extensions/IEnumerableExtensions.cs
+++ b/LLama/Extensions/IEnumerableExtensions.cs
@@ -10,6 +10,8 @@ public static IEnumerable TakeLast(this IEnumerable source, int count)
{
return TakeLastImpl(source, count);
}
+#elif !NET6_0_OR_GREATER && !NETSTANDARD2_1_OR_GREATER
+#error Target framework not supported!
#endif
internal static IEnumerable TakeLastImpl(IEnumerable source, int count)
diff --git a/LLama/Extensions/KeyValuePairExtensions.cs b/LLama/Extensions/KeyValuePairExtensions.cs
index 6e12654de..233195ed0 100644
--- a/LLama/Extensions/KeyValuePairExtensions.cs
+++ b/LLama/Extensions/KeyValuePairExtensions.cs
@@ -19,5 +19,7 @@ public static void Deconstruct(this System.Collections.Generic.Key
first = pair.Key;
second = pair.Value;
}
+#elif !NET6_0_OR_GREATER && !NETSTANDARD2_1_OR_GREATER
+#error Target framework not supported!
#endif
}
\ No newline at end of file
diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs
index 11a1d4f00..eb30a07a0 100644
--- a/LLama/Extensions/ListExtensions.cs
+++ b/LLama/Extensions/ListExtensions.cs
@@ -5,7 +5,7 @@ namespace LLama.Extensions
{
internal static class ListExtensions
{
-#if NETSTANDARD2_0
+#if !NET6_0_OR_GREATER
public static void EnsureCapacity(this List list, int capacity)
{
if (list.Capacity < capacity)
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 50f30c0ae..9a0b2a8e5 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -22,7 +22,7 @@ public struct LLamaContextParams
public uint seed;
///
- /// text context
+ /// text context, 0 = from model
///
public uint n_ctx;
@@ -42,17 +42,41 @@ public struct LLamaContextParams
public uint n_threads_batch;
///
- /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
- /// RoPE base frequency
+ /// RoPE scaling type, from `enum llama_rope_scaling_type`
///
- public float rope_freq_base;
+ public sbyte rope_scaling_type;
+
///
- /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
- /// RoPE frequency scaling factor
+ /// RoPE base frequency, 0 = from model
///
- public float rope_freq_scale;
-
+ public float rope_freq_base;
+ ///
+ /// RoPE frequency scaling factor, 0 = from model
+ ///
+ public float rope_freq_scale;
+ ///
+ /// YaRN extrapolation mix factor, NaN = from model
+ ///
+ public float yarn_ext_factor;
+ ///
+ /// YaRN magnitude scaling factor
+ ///
+ public float yarn_attn_factor;
+ ///
+ /// YaRN low correction dim
+ ///
+ public float yarn_beta_fast;
+ ///
+ /// YaRN high correction dim
+ ///
+ public float yarn_beta_slow;
+
+ ///
+ /// YaRN original context size
+ ///
+ public uint yarn_orig_ctx;
+
///
/// if true, use experimental mul_mat_q kernels
///
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index e7ee32ba1..365acacfc 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -64,6 +64,17 @@ public static extern void llama_sample_repetition_penalties(SafeLLamaContextHand
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
+ ///
+ /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
+ ///
+ ///
+ /// Pointer to LLamaTokenDataArray
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_sample_min_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
+
+
///
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
///
diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal
index f4b460564..7c35f23a7 100644
--- a/LLama/runtimes/ggml-metal.metal
+++ b/LLama/runtimes/ggml-metal.metal
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
}
- const float max = simd_max(lmax);
+
+ float max = simd_max(lmax);
+ if (tiisg == 0) {
+ buf[sgitg] = max;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max = buf[0];
// parallel sum
float lsum = 0.0f;
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
- // whish to compute it twice.
+ // wish to compute it twice.
pdst[i00] = exp_psrc0;
}
- const float sum = simd_sum(lsum);
+ float sum = simd_sum(lsum);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] += buf[tpitg + i];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[0];
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum;
}
}
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
- float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- const float max = simd_max(lmax);
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+ float max = simd_max(lmax);
+ if (tiisg == 0) {
+ buf[sgitg] = max;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max = buf[0];
// parallel sum
float4 lsum4 = 0.0f;
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
- const float sum = simd_sum(lsum);
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+ float sum = simd_sum(lsum);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] += buf[tpitg + i];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[0];
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum;
}
}
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
- }
+ }
}
kernel void kernel_diag_mask_inf_8(
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32(
}
}
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ thread float * cos_theta, thread float * sin_theta
+) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+ }
+ *cos_theta = cos(theta) * mscale;
+ *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
@@ -1011,8 +1123,13 @@ typedef void (rope_t)(
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1041,8 +1158,13 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1052,19 +1174,22 @@ kernel void kernel_rope(
const bool is_neox = mode & 2;
+ float corr_dims[2];
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+
device const int32_t * pos = src1;
const int64_t p = pos[i2];
- const float theta_0 = freq_scale * (float)p;
+ const float theta_0 = (float)p;
const float inv_ndims = -1.f/n_dims;
if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
- const float cos_theta = cos(theta);
- const float sin_theta = sin(theta);
+ float cos_theta, sin_theta;
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -1079,9 +1204,12 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
- const float cos_theta = cos(theta);
- const float sin_theta = sin(theta);
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
+ const float cur_rot = inv_ndims*ic - ib;
+
+ const float theta = theta_0 * pow(freq_base, cur_rot);
+ float cos_theta, sin_theta;
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const int64_t i0 = ib*n_dims + ic/2;
diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll
index 70bb5b07a..ab4f4be28 100644
Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ
diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so
index 45bac80be..146b30abd 100644
Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ
diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll
index 7f64e0e38..a51954b89 100644
Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ
diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so
index 4a1e4380b..615d9c704 100644
Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ
diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll
index 00b93ba0f..d2cc2a7be 100644
Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ
diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib
old mode 100755
new mode 100644
index 3f36bb359..54d7a9324
Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ
diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so
index 5240d696d..e5a01286a 100644
Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ
diff --git a/README.md b/README.md
index 96c9883a1..74d5aee67 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
**The C#/.NET binding of [llama.cpp](https://github.com/ggerganov/llama.cpp). It provides higher-level APIs to inference the LLaMA Models and deploy it on local device with C#/.NET. It works on
-both Windows, Linux and MAC without requirment for compiling llama.cpp yourself. Even without GPU or not enought GPU memory, you can still apply LLaMA models well with this repo. 🤗**
+both Windows, Linux and MAC without requirment for compiling llama.cpp yourself. Even without GPU or not enough GPU memory, you can still apply LLaMA models well with this repo. 🤗**
**Furthermore, it provides integrations with other projects such as [semantic-kernel](https://github.com/microsoft/semantic-kernel), [kernel-memory](https://github.com/microsoft/kernel-memory) and [BotSharp](https://github.com/SciSharp/BotSharp) to provide higher-level applications.**
@@ -80,7 +80,7 @@ The llama.cpp commit id will help if you want to compile a DLL yourself.
| v0.4.2-preview (cpu,cuda11) |v0.4.2-preview | [Llama2 7b GGML](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGML)| 3323112 |
| v0.5.1 | v0.5.1 | [Llama2 7b GGUF](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGUF)| 6b73ef1 |
| v0.6.0 | v0.6.0 | | [cb33f43](https://github.com/ggerganov/llama.cpp/commit/cb33f43a2a9f5a5a5f8d290dd97c625d9ba97a2f) |
-| v0.7.0 | v0.7.0 | [Thespis-13B](https://huggingface.co/TheBloke/Thespis-13B-v0.5-GGUF/tree/main?not-for-all-audiences=true), [LLaMA2-7B](https://huggingface.co/TheBloke/Thespis-13B-v0.5-GGUF/tree/main?not-for-all-audiences=true) | [207b519](https://github.com/ggerganov/llama.cpp/commit/207b51900e15cc7f89763a3bb1c565fe11cbb45d) |
+| v0.7.0 | v0.7.0 | [Thespis-13B](https://huggingface.co/TheBloke/Thespis-13B-v0.5-GGUF/tree/main?not-for-all-audiences=true), [LLaMA2-7B](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGUF) | [207b519](https://github.com/ggerganov/llama.cpp/commit/207b51900e15cc7f89763a3bb1c565fe11cbb45d) |
Many hands make light work. If you have found any other model resource that could work for a version, we'll appreciate it for opening an PR about it! 😊