diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index e0feae696..019172fd5 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -19,8 +19,8 @@ public class ExampleRunner { "Executor: Stateless mode chat", StatelessModeExecute.Run }, { "Save and Load: chat session", SaveAndLoadSession.Run }, { "Save and Load: state of model and executor", LoadAndSaveState.Run }, - { "LLama Model: Get embeddings", () => Task.Run(GetEmbeddings.Run) }, - { "LLama Model: Quantize", () => Task.Run(QuantizeModel.Run) }, + { "LLama Model: Get embeddings", GetEmbeddings.Run }, + { "LLama Model: Quantize", QuantizeModel.Run }, { "Grammar: Constrain response to json format", GrammarJsonResponse.Run }, { "Kernel Memory: Document Q&A", KernelMemory.Run }, { "Kernel Memory: Save and Load", KernelMemorySaveAndLoad.Run }, diff --git a/LLama.Examples/Examples/GetEmbeddings.cs b/LLama.Examples/Examples/GetEmbeddings.cs index ad844004e..a249a5bc4 100644 --- a/LLama.Examples/Examples/GetEmbeddings.cs +++ b/LLama.Examples/Examples/GetEmbeddings.cs @@ -1,15 +1,21 @@ using LLama.Common; +using LLama.Native; namespace LLama.Examples.Examples { public class GetEmbeddings { - public static void Run() + public static async Task Run() { string modelPath = UserSettings.GetModelPath(); Console.ForegroundColor = ConsoleColor.DarkGray; - var @params = new ModelParams(modelPath) { Embeddings = true }; + var @params = new ModelParams(modelPath) + { + // Embedding models can return one embedding per token, or all of them can be combined ("pooled") into + // one single embedding. Setting PoolingType to "Mean" will combine all of the embeddings using mean average. + PoolingType = LLamaPoolingType.Mean, + }; using var weights = LLamaWeights.LoadFromFile(@params); var embedder = new LLamaEmbedder(weights, @params); @@ -17,12 +23,12 @@ public static void Run() Console.WriteLine( """ This example displays embeddings from a text prompt. - Embeddings are numerical codes that represent information like words, images, or concepts. - These codes capture important relationships between those objects, + Embeddings are vectors that represent information like words, images, or concepts. + These vector capture important relationships between those objects, like how similar words are in meaning or how close images are visually. This allows machine learning models to efficiently understand and process complex data. Embeddings of a text in LLM is sometimes useful, for example, to train other MLP models. - """); // NOTE: this description was AI generated + """); while (true) { @@ -32,8 +38,13 @@ This allows machine learning models to efficiently understand and process comple var text = Console.ReadLine(); Console.ForegroundColor = ConsoleColor.White; - float[] embeddings = embedder.GetEmbeddings(text).Result; - Console.WriteLine($"Embeddings contain {embeddings.Length:N0} floating point values:"); + // Get embeddings for the text + var embeddings = await embedder.GetEmbeddings(text); + + // This should have returned one single embedding vector, because PoolingType was set to Mean above. + var embedding = embeddings.Single(); + + Console.WriteLine($"Embeddings contain {embedding.Length:N0} floating point values:"); Console.ForegroundColor = ConsoleColor.DarkGray; Console.WriteLine(string.Join(", ", embeddings.Take(20)) + ", ..."); Console.WriteLine(); diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs index 233b59678..a1f7ca1bd 100644 --- a/LLama.Examples/Examples/QuantizeModel.cs +++ b/LLama.Examples/Examples/QuantizeModel.cs @@ -1,8 +1,8 @@ -namespace LLama.Examples.Examples +namespace LLama.Examples.Examples { public class QuantizeModel { - public static void Run() + public static async Task Run() { string inputPath = UserSettings.GetModelPath(); diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index 2f6e332d5..79eaa1514 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -1,5 +1,6 @@ using LLama; using LLama.Common; +using LLama.Native; using Microsoft.KernelMemory; using Microsoft.KernelMemory.AI; @@ -35,7 +36,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config) GpuLayerCount = config.GpuLayerCount ?? 20, Embeddings = true, MainGpu = config.MainGpu, - SplitMode = config.SplitMode + SplitMode = config.SplitMode, + PoolingType = LLamaPoolingType.Mean, }; _weights = LLamaWeights.LoadFromFile(@params); _embedder = new LLamaEmbedder(_weights, @params); @@ -59,7 +61,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we GpuLayerCount = config.GpuLayerCount ?? 20, Embeddings = true, MainGpu = config.MainGpu, - SplitMode = config.SplitMode + SplitMode = config.SplitMode, + PoolingType = LLamaPoolingType.Mean, }; _weights = weights; _embedder = new LLamaEmbedder(_weights, @params); @@ -92,7 +95,7 @@ public void Dispose() public async Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { var embeddings = await _embedder.GetEmbeddings(text, cancellationToken); - return new Embedding(embeddings); + return new Embedding(embeddings.First()); } /// diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs index 9514e1711..d50945117 100644 --- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs +++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs @@ -4,7 +4,8 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding; -public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService +public sealed class LLamaSharpEmbeddingGeneration + : ITextEmbeddingGenerationService { private readonly LLamaEmbedder _embedder; @@ -23,7 +24,7 @@ public async Task>> GenerateEmbeddingsAsync(IList>(); foreach (var item in data) - result.Add(await _embedder.GetEmbeddings(item, cancellationToken)); + result.Add((await _embedder.GetEmbeddings(item, cancellationToken)).First()); return result; } diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index e9d9359f2..f48d1ef45 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -1,4 +1,6 @@ using LLama.Common; +using LLama.Extensions; +using LLama.Native; using Xunit.Abstractions; namespace LLama.Unittest; @@ -24,19 +26,19 @@ private async Task CompareEmbeddings(string modelPath) { ContextSize = 8, Threads = 4, - Embeddings = true, GpuLayerCount = Constants.CIGpuLayerCount, + PoolingType = LLamaPoolingType.Mean, }; using var weights = LLamaWeights.LoadFromFile(@params); using var embedder = new LLamaEmbedder(weights, @params); - var cat = await embedder.GetEmbeddings("The cat is cute"); + var cat = (await embedder.GetEmbeddings("The cat is cute")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, cat); - var kitten = await embedder.GetEmbeddings("The kitten is kawaii"); + var kitten = (await embedder.GetEmbeddings("The kitten is cute")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, kitten); - var spoon = await embedder.GetEmbeddings("The spoon is not real"); + var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, spoon); _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); @@ -64,4 +66,33 @@ public async Task EmbedCompareGenerateModel() { await CompareEmbeddings(Constants.GenerativeModelPath); } + + private async Task NonPooledEmbeddings(string modelPath) + { + var @params = new ModelParams(modelPath) + { + ContextSize = 8, + Threads = 4, + GpuLayerCount = Constants.CIGpuLayerCount, + PoolingType = LLamaPoolingType.None, + }; + using var weights = LLamaWeights.LoadFromFile(@params); + using var embedder = new LLamaEmbedder(weights, @params); + + var kitten = await embedder.GetEmbeddings("the kitten is kawaii"); + foreach (var embd in kitten) + Assert.DoesNotContain(float.NaN, embd); + } + + [Fact] + public async Task EmbeddingModelNonPooledEmbeddings() + { + await NonPooledEmbeddings(Constants.EmbeddingModelPath); + } + + [Fact] + public async Task GenerativeModelNonPooledEmbeddings() + { + await NonPooledEmbeddings(Constants.GenerativeModelPath); + } } \ No newline at end of file diff --git a/LLama/Extensions/SpanNormalizationExtensions.cs b/LLama/Extensions/SpanNormalizationExtensions.cs new file mode 100644 index 000000000..8ed827b64 --- /dev/null +++ b/LLama/Extensions/SpanNormalizationExtensions.cs @@ -0,0 +1,126 @@ +using System; +using System.Numerics.Tensors; + +namespace LLama.Extensions; + +/// +/// Extensions to span which apply in-place normalization +/// +public static class SpanNormalizationExtensions +{ + /// + /// In-place multiple every element by 32760 and divide every element in the span by the max absolute value in the span + /// + /// + /// The same array + public static float[] MaxAbsoluteNormalization(this float[] vector) + { + vector.AsSpan().MaxAbsoluteNormalization(); + return vector; + } + + /// + /// In-place multiple every element by 32760 and divide every element in the span by the max absolute value in the span + /// + /// + /// The same span + public static Span MaxAbsoluteNormalization(this Span vector) + { + var factor = 32760 / TensorPrimitives.MaxMagnitude(vector); + TensorPrimitives.Multiply(vector, factor, vector); + return vector; + } + + /// + /// In-place divide every element in the array by the sum of absolute values in the array + /// + /// Also known as "Manhattan normalization". + /// + /// The same array + public static float[] TaxicabNormalization(this float[] vector) + { + vector.AsSpan().TaxicabNormalization(); + return vector; + } + + /// + /// In-place divide every element in the span by the sum of absolute values in the span + /// + /// Also known as "Manhattan normalization". + /// + /// The same span + public static Span TaxicabNormalization(this Span vector) + { + var sumAbs = TensorPrimitives.SumOfMagnitudes(vector); + TensorPrimitives.Divide(vector, sumAbs, vector); + return vector; + } + + /// + /// In-place divide every element by the euclidean length of the vector + /// + /// Also known as "L2 normalization". + /// + /// The same array + public static float[] EuclideanNormalization(this float[] vector) + { + vector.AsSpan().EuclideanNormalization(); + return vector; + } + + /// + /// In-place divide every element by the euclidean length of the vector + /// + /// Also known as "L2 normalization". + /// + /// The same span + public static Span EuclideanNormalization(this Span vector) + { + var norm = TensorPrimitives.Norm(vector); + TensorPrimitives.Divide(vector, norm, vector); + return vector; + } + + /// + /// In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm + /// + /// For p = 1, this is taxicab normalization + /// For p = 2, this is euclidean normalization + /// As p => infinity, this approaches infinity norm or maximum norm + /// + /// + /// + /// + /// The same array + public static float[] PNormalization(this float[] vector, int p) + { + vector.AsSpan().PNormalization(p); + return vector; + } + + /// + /// In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm + /// + /// For p = 1, this is taxicab normalization + /// For p = 2, this is euclidean normalization + /// As p => infinity, this approaches infinity norm or maximum norm + /// + /// + /// + /// + /// The same span + public static Span PNormalization(this Span vector, int p) + { + if (p == 2) + return vector.EuclideanNormalization(); + + var sum = 0.0; + for (var i = 0; i < vector.Length; i++) + sum += MathF.Pow(vector[i], p); + var divisor = (float)Math.Pow(sum, 1.0 / p); + + TensorPrimitives.Divide(vector, divisor, vector); + + return vector; + } +} \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index ca38d49e4..9dd2a6394 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -379,6 +379,28 @@ public bool ShouldAddBosToken() } #region eval overloads + /// + /// + /// + public EncodeResult Encode(LLamaBatch batch) + { + if (batch.TokenCount == 0) + return 0; + if (batch.TokenCount > BatchSize) + throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); + + return (EncodeResult)NativeHandle.Encode(batch); + } + + /// + /// + /// + /// + public Task EncodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) + { + return Task.Run(() => Encode(batch), cancellationToken); + } + /// /// /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index d050707e8..ed6240359 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,138 +1,118 @@ -using LLama.Native; using System; -using LLama.Exceptions; -using LLama.Abstractions; -using Microsoft.Extensions.Logging; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Exceptions; +using LLama.Native; +using Microsoft.Extensions.Logging; -namespace LLama +namespace LLama; + +/// +/// Generate high dimensional embedding vectors from text +/// +public sealed class LLamaEmbedder + : IDisposable { /// - /// The embedder for LLama, which supports getting embeddings from text. + /// Dimension of embedding vectors /// - public sealed class LLamaEmbedder - : IDisposable - { - /// - /// Dimension of embedding vectors - /// - public int EmbeddingSize => Context.EmbeddingSize; - - /// - /// LLama Context - /// - public LLamaContext Context { get; } - - /// - /// Create a new embedder, using the given LLamaWeights - /// - /// - /// - /// - public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) - { - if (!@params.Embeddings) - throw new ArgumentException("Embeddings must be true", nameof(@params)); - - Context = weights.CreateContext(@params, logger); - } - - /// - /// Get the embeddings of the text. - /// - /// - /// - /// - /// - public Task GetEmbeddings(string text, CancellationToken cancellationToken = default) - { - return GetEmbeddings(text, true, cancellationToken); - } + public int EmbeddingSize => Context.EmbeddingSize; - /// - /// Get the embeddings of the text. - /// - /// - /// Add bos to the text. - /// - /// - /// - public async Task GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default) - { - var tokens = Context.Tokenize(text, addBos); - if (tokens.Length > Context.ContextSize) - throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text)); - - // Evaluate prompt in batch-size chunks - var n_past = 0; - var batch = new LLamaBatch(); - var batchSize = (int)Context.BatchSize; - for (var i = 0; i < tokens.Length; i += batchSize) - { - var n_eval = tokens.Length - i; - if (n_eval > batchSize) - n_eval = batchSize; + /// + /// LLama Context + /// + public LLamaContext Context { get; } - batch.Clear(); - batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true); - n_past += n_eval; + /// + /// Create a new embedder, using the given LLamaWeights + /// + /// + /// + /// + public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) + { + if (@params.UBatchSize != @params.BatchSize) + throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params)); + if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true }) + throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); - var returnCode = await Context.DecodeAsync(batch, cancellationToken); - if (returnCode != 0) - throw new LLamaDecodeError(returnCode); - } + Context = weights.CreateContext(@params, logger); + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + } - var embeddings = GetEmbeddingsArray(); + /// + public void Dispose() + { + Context.Dispose(); + } - // Remove everything we just evaluated from the context cache - Context.NativeHandle.KvCacheClear(); + /// + /// Get high dimensional embedding vectors for the given text. Depending on the pooling type used when constructing + /// this this may return an embedding vector per token, or one single embedding vector for the entire string. + /// + /// Embedding vectors are not normalized, consider using one of the extensions in . + /// + /// + /// + /// + /// + public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) + { + // Add all of the tokens to the batch + var tokens = Context.Tokenize(input); + var batch = new LLamaBatch(); + for (var i = 0; i < tokens.Length; i++) + batch.Add(tokens[i], i, LLamaSeqId.Zero, true); - // Normalize the embeddings vector - // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 - Normalize(embeddings); + // clear previous kv_cache values + Context.NativeHandle.KvCacheClear(); - return embeddings; - } + // Check if we should cancel the work, just before doing anything expensive (encode/decode) + cancellationToken.ThrowIfCancellationRequested(); - private float[] GetEmbeddingsArray() + // Run model + switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) { - unsafe + case (true, false): { - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - - if (embeddings == null) - embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); - - if (embeddings == null) - return [ ]; + var result = await Context.EncodeAsync(batch, cancellationToken); + if (result != EncodeResult.Ok) + throw new RuntimeError($"Failed to encode: {result}"); + break; + } - return new Span(embeddings, Context.EmbeddingSize).ToArray(); + case (false, true): + { + var result = await Context.DecodeAsync(batch, cancellationToken); + if (result != DecodeResult.Ok) + throw new RuntimeError($"Failed to decode: {result}"); + break; } + + default: + throw new NotSupportedException("Unsupported model type"); } - private static void Normalize(Span embeddings) + // Extract results + var poolingType = Context.NativeHandle.PoolingType; + var resultsCount = poolingType == LLamaPoolingType.None ? tokens.Length : 1; + var results = new List(resultsCount); + + if (poolingType == LLamaPoolingType.None) { - // Calculate length - var lengthSqr = 0.0; - foreach (var value in embeddings) - lengthSqr += value * value; - var length = (float)Math.Sqrt(lengthSqr); - - // Do not divide by length if it is zero - if (length <= float.Epsilon) - return; - - // Normalize - for (var i = 0; i < embeddings.Length; i++) - embeddings[i] /= length; + var positions = batch.GetLogitPositions(); + foreach (var (_, pos) in positions) + results.Add(Context.NativeHandle.GetEmbeddingsIth(pos).ToArray()); } - - /// - public void Dispose() + else { - Context.Dispose(); + results.Add(Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero).ToArray()); } + Context.NativeHandle.KvCacheClear(); + + return results; } -} +} \ No newline at end of file diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index addda27f2..7b6e11c50 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -50,6 +50,7 @@ + @@ -83,11 +84,6 @@ OverwriteReadOnlyFiles="true" Include="*.dll;*.so;*.dylib;*.metal;" /> - diff --git a/LLama/Native/EncodeResult.cs b/LLama/Native/EncodeResult.cs new file mode 100644 index 000000000..31bafc098 --- /dev/null +++ b/LLama/Native/EncodeResult.cs @@ -0,0 +1,17 @@ +namespace LLama.Native; + +/// +/// Return codes from llama_encode +/// +public enum EncodeResult +{ + /// + /// An unspecified error + /// + Error = -1, + + /// + /// Ok. + /// + Ok = 0 +} \ No newline at end of file diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index 91a2fafbc..c66bd9277 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -281,11 +281,8 @@ public void Clear() /// Get the positions where logits can be sampled from /// /// - internal Span<(LLamaSeqId, int)> GetLogitPositions(Span<(LLamaSeqId, int)> dest) + internal IReadOnlyList<(LLamaSeqId, int)> GetLogitPositions() { - for (var i = 0; i < _logitPositions.Count; i++) - dest[i] = _logitPositions[i]; - - return dest.Slice(0, _logitPositions.Count); + return _logitPositions; } } \ No newline at end of file diff --git a/LLama/Native/LLamaPoolingType.cs b/LLama/Native/LLamaPoolingType.cs index 31c615d7e..ab0b75457 100644 --- a/LLama/Native/LLamaPoolingType.cs +++ b/LLama/Native/LLamaPoolingType.cs @@ -1,3 +1,5 @@ +using LLama.Abstractions; + namespace LLama.Native; /// @@ -6,9 +8,24 @@ namespace LLama.Native; /// llama_pooling_type public enum LLamaPoolingType { + /// + /// No specific pooling type. Use the model default if this is specific in + /// Unspecified = -1, + + /// + /// Do not pool embeddings (per-token embeddings) + /// None = 0, + + /// + /// Take the mean of every token embedding + /// Mean = 1, + + /// + /// Return the embedding for the special "CLS" token + /// CLS = 2, Last = 3, diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 8d967e670..46ec79813 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -131,30 +131,6 @@ public static void llama_empty_call() [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx); - /// - /// Get the pooling type for this context - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx); - - /// - /// Get the embeddings for the a specific sequence. - /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - /// - /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id); - - /// - /// Get the embeddings for the ith sequence. - /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - /// - /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); - /// /// Get all output token embeddings. /// When pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, the embeddings for which diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index b5932aa04..dee74f590 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; using LLama.Exceptions; @@ -62,6 +63,11 @@ public uint BatchThreads set => llama_set_n_threads(this, GenerationThreads, value); } + /// + /// Get the pooling type for this context + /// + public LLamaPoolingType PoolingType => llama_pooling_type(this); + /// /// Get the model which this context is using /// @@ -169,7 +175,7 @@ static SafeLLamaContextHandle() private static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); /// - /// Processes a batch of tokens with the ecoder part of the encoder-decoder model. Stores the encoder output + /// Processes a batch of tokens with the encoder part of the encoder-decoder model. Stores the encoder output /// internally for later use by the decoder cross-attention layers. /// /// @@ -365,6 +371,30 @@ static SafeLLamaContextHandle() [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_lora_adapter_clear(SafeLLamaContextHandle context); + + /// + /// Get the pooling type for this context + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx); + + /// + /// Get the embeddings for the a specific sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id); + + /// + /// Get the embeddings for the ith sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); #endregion #region LoRA @@ -410,6 +440,7 @@ public void ClearLoraAdapters() } #endregion + #region GetLogits /// /// Token logits obtained from the last call to llama_decode /// The logits for the last token are stored in the last row @@ -444,6 +475,43 @@ public Span GetLogitsIth(int i) return new Span(logits, model.VocabCount); } } + #endregion + + #region GetEmbeddings() + /// + /// Get the embeddings for the ith sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + public Span GetEmbeddingsIth(LLamaPos pos) + { + var model = ThrowIfDisposed(); + + unsafe + { + var embd = llama_get_embeddings_ith(this, pos.Value); + Debug.Assert(embd != null); + return new Span(embd, model.EmbeddingSize); + } + } + + /// + /// Get the embeddings for the a specific sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// A pointer to the first float in an embedding, length = ctx.EmbeddingSize + public Span GetEmbeddingsSeq(LLamaSeqId seq) + { + var model = ThrowIfDisposed(); + + unsafe + { + var embd = llama_get_embeddings_seq(this, seq); + Debug.Assert(embd != null); + return new Span(embd, model.EmbeddingSize); + } + } + #endregion #region tokens /// @@ -495,6 +563,22 @@ public void Synchronize() llama_synchronize(this); } + /// + /// Processes a batch of tokens with the encoder part of the encoder-decoder model. Stores the encoder output + /// internally for later use by the decoder cross-attention layers. + /// + /// + /// 0 = success
< 0 = error
+ public DecodeResult Encode(LLamaBatch batch) + { + if (batch.TokenCount == 0) + return DecodeResult.Ok; + + lock (GlobalInferenceLock) + using (batch.ToNativeBatch(out var nb)) + return (DecodeResult)llama_encode(this, nb); + } + /// /// ///