Skip to content

Commit

Permalink
Totally rewritten the LLamaEmbedder based on https://github.com/ggerg…
Browse files Browse the repository at this point in the history
…anov/llama.cpp/tree/master/examples/embedding. New embedder properly handles pooling, either returning one embedding for the whole sequence or one per token.

 - Added `Encode` methods to `LLamaContext`
 - Moved some native methods from `NativeApi` to `SafeLLamaContextHandle` and wrapped them properly
 - Added `HasDecoder` property to `SafeLlamaModelHandle`. This function doesn't exist in the current version of llama.cpp, will need to be hooked up in the next binary update
 - Added some normalization methods as extensions on span/array. This required adding a dependency on `System.Numerics.Tensors`
  • Loading branch information
martindevans committed Aug 23, 2024
1 parent 3945705 commit 9c6e664
Show file tree
Hide file tree
Showing 12 changed files with 408 additions and 155 deletions.
9 changes: 6 additions & 3 deletions LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama;
using LLama.Common;
using LLama.Native;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;

Expand Down Expand Up @@ -35,7 +36,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
GpuLayerCount = config.GpuLayerCount ?? 20,
Embeddings = true,
MainGpu = config.MainGpu,
SplitMode = config.SplitMode
SplitMode = config.SplitMode,
PoolingType = LLamaPoolingType.Mean,
};
_weights = LLamaWeights.LoadFromFile(@params);
_embedder = new LLamaEmbedder(_weights, @params);
Expand All @@ -59,7 +61,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
GpuLayerCount = config.GpuLayerCount ?? 20,
Embeddings = true,
MainGpu = config.MainGpu,
SplitMode = config.SplitMode
SplitMode = config.SplitMode,
PoolingType = LLamaPoolingType.Mean,
};
_weights = weights;
_embedder = new LLamaEmbedder(_weights, @params);
Expand Down Expand Up @@ -92,7 +95,7 @@ public void Dispose()
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
return new Embedding(embeddings);
return new Embedding(embeddings.First());
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

namespace LLamaSharp.SemanticKernel.TextEmbedding;

public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
public sealed class LLamaSharpEmbeddingGeneration
: ITextEmbeddingGenerationService
{
private readonly LLamaEmbedder _embedder;

Expand All @@ -23,7 +24,7 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<st
var result = new List<ReadOnlyMemory<float>>();

foreach (var item in data)
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));
result.Add((await _embedder.GetEmbeddings(item, cancellationToken)).First());

return result;
}
Expand Down
39 changes: 36 additions & 3 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Xunit.Abstractions;

namespace LLama.Unittest;
Expand Down Expand Up @@ -26,17 +28,18 @@ private async Task CompareEmbeddings(string modelPath)
Threads = 4,
Embeddings = true,
GpuLayerCount = Constants.CIGpuLayerCount,
PoolingType = LLamaPoolingType.Mean,
};
using var weights = LLamaWeights.LoadFromFile(@params);
using var embedder = new LLamaEmbedder(weights, @params);

var cat = await embedder.GetEmbeddings("The cat is cute");
var cat = (await embedder.GetEmbeddings("The cat is cute")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, cat);

var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
var kitten = (await embedder.GetEmbeddings("The kitten is cute")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, kitten);

var spoon = await embedder.GetEmbeddings("The spoon is not real");
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, spoon);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
Expand Down Expand Up @@ -64,4 +67,34 @@ public async Task EmbedCompareGenerateModel()
{
await CompareEmbeddings(Constants.GenerativeModelPath);
}

private async Task NonPooledEmbeddings(string modelPath)
{
var @params = new ModelParams(modelPath)
{
ContextSize = 8,
Threads = 4,
Embeddings = true,
GpuLayerCount = Constants.CIGpuLayerCount,
PoolingType = LLamaPoolingType.None,
};
using var weights = LLamaWeights.LoadFromFile(@params);
using var embedder = new LLamaEmbedder(weights, @params);

var kitten = await embedder.GetEmbeddings("the kitten is kawaii");
foreach (var embd in kitten)
Assert.DoesNotContain(float.NaN, embd);
}

[Fact]
public async Task EmbeddingModelNonPooledEmbeddings()
{
await NonPooledEmbeddings(Constants.EmbeddingModelPath);
}

[Fact]
public async Task GenerativeModelNonPooledEmbeddings()
{
await NonPooledEmbeddings(Constants.GenerativeModelPath);
}
}
126 changes: 126 additions & 0 deletions LLama/Extensions/SpanNormalizationExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
using System;
using System.Numerics.Tensors;

namespace LLama.Extensions;

/// <summary>
/// Extensions to span which apply <b>in-place</b> normalization
/// </summary>
public static class SpanNormalizationExtensions
{
/// <summary>
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
/// </summary>
/// <param name="vector"></param>
/// <returns>The same array</returns>
public static float[] MaxAbsoluteNormalization(this float[] vector)
{
vector.AsSpan().MaxAbsoluteNormalization();
return vector;
}

/// <summary>
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
/// </summary>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static Span<float> MaxAbsoluteNormalization(this Span<float> vector)
{
var factor = 32760 / TensorPrimitives.MaxMagnitude(vector);
TensorPrimitives.Multiply(vector, factor, vector);
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element in the array by the sum of absolute values in the array
/// </summary>
/// <remarks>Also known as "Manhattan normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same array</returns>
public static float[] TaxicabNormalization(this float[] vector)
{
vector.AsSpan().TaxicabNormalization();
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element in the span by the sum of absolute values in the span
/// </summary>
/// <remarks>Also known as "Manhattan normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static Span<float> TaxicabNormalization(this Span<float> vector)
{
var sumAbs = TensorPrimitives.SumOfMagnitudes(vector);
TensorPrimitives.Divide(vector, sumAbs, vector);
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element by the euclidean length of the vector
/// </summary>
/// <remarks>Also known as "L2 normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same array</returns>
public static float[] EuclideanNormalization(this float[] vector)
{
vector.AsSpan().EuclideanNormalization();
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element by the euclidean length of the vector
/// </summary>
/// <remarks>Also known as "L2 normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static Span<float> EuclideanNormalization(this Span<float> vector)
{
var norm = TensorPrimitives.Norm(vector);
TensorPrimitives.Divide(vector, norm, vector);
return vector;
}

/// <summary>
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
/// <list type="bullet">
/// <item>For p = 1, this is taxicab normalization</item>
/// <item>For p = 2, this is euclidean normalization</item>
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
/// </list>
/// </summary>
/// <param name="vector"></param>
/// <param name="p"></param>
/// <returns>The same array</returns>
public static float[] PNormalization(this float[] vector, int p)
{
vector.AsSpan().PNormalization(p);
return vector;
}

/// <summary>
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
/// <list type="bullet">
/// <item>For p = 1, this is taxicab normalization</item>
/// <item>For p = 2, this is euclidean normalization</item>
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
/// </list>
/// </summary>
/// <param name="vector"></param>
/// <param name="p"></param>
/// <returns>The same span</returns>
public static Span<float> PNormalization(this Span<float> vector, int p)
{
if (p == 2)
return vector.EuclideanNormalization();

var sum = 0.0;
for (var i = 0; i < vector.Length; i++)
sum += MathF.Pow(vector[i], p);
var divisor = (float)Math.Pow(sum, 1.0 / p);

TensorPrimitives.Divide(vector, divisor, vector);

return vector;
}
}
22 changes: 22 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,28 @@ public bool ShouldAddBosToken()
}

#region eval overloads
/// <summary>
/// </summary>
/// <param name="batch"></param>
public EncodeResult Encode(LLamaBatch batch)
{
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (EncodeResult)NativeHandle.Encode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
public Task<EncodeResult> EncodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => Encode(batch), cancellationToken);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
Expand Down
Loading

0 comments on commit 9c6e664

Please sign in to comment.