Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLamaEmbedder 2.0 #902

Merged
merged 2 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public class ExampleRunner
{ "Executor: Stateless mode chat", StatelessModeExecute.Run },
{ "Save and Load: chat session", SaveAndLoadSession.Run },
{ "Save and Load: state of model and executor", LoadAndSaveState.Run },
{ "LLama Model: Get embeddings", () => Task.Run(GetEmbeddings.Run) },
{ "LLama Model: Quantize", () => Task.Run(QuantizeModel.Run) },
{ "LLama Model: Get embeddings", GetEmbeddings.Run },
{ "LLama Model: Quantize", QuantizeModel.Run },
{ "Grammar: Constrain response to json format", GrammarJsonResponse.Run },
{ "Kernel Memory: Document Q&A", KernelMemory.Run },
{ "Kernel Memory: Save and Load", KernelMemorySaveAndLoad.Run },
Expand Down
25 changes: 18 additions & 7 deletions LLama.Examples/Examples/GetEmbeddings.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Examples.Examples
{
public class GetEmbeddings
{
public static void Run()
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

Console.ForegroundColor = ConsoleColor.DarkGray;
var @params = new ModelParams(modelPath) { Embeddings = true };
var @params = new ModelParams(modelPath)
{
// Embedding models can return one embedding per token, or all of them can be combined ("pooled") into
// one single embedding. Setting PoolingType to "Mean" will combine all of the embeddings using mean average.
PoolingType = LLamaPoolingType.Mean,
};
using var weights = LLamaWeights.LoadFromFile(@params);
var embedder = new LLamaEmbedder(weights, @params);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine(
"""
This example displays embeddings from a text prompt.
Embeddings are numerical codes that represent information like words, images, or concepts.
These codes capture important relationships between those objects,
Embeddings are vectors that represent information like words, images, or concepts.
These vector capture important relationships between those objects,
like how similar words are in meaning or how close images are visually.
This allows machine learning models to efficiently understand and process complex data.
Embeddings of a text in LLM is sometimes useful, for example, to train other MLP models.
"""); // NOTE: this description was AI generated
""");

while (true)
{
Expand All @@ -32,8 +38,13 @@ This allows machine learning models to efficiently understand and process comple
var text = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;

float[] embeddings = embedder.GetEmbeddings(text).Result;
Console.WriteLine($"Embeddings contain {embeddings.Length:N0} floating point values:");
// Get embeddings for the text
var embeddings = await embedder.GetEmbeddings(text);

// This should have returned one single embedding vector, because PoolingType was set to Mean above.
var embedding = embeddings.Single();

Console.WriteLine($"Embeddings contain {embedding.Length:N0} floating point values:");
Console.ForegroundColor = ConsoleColor.DarkGray;
Console.WriteLine(string.Join(", ", embeddings.Take(20)) + ", ...");
Console.WriteLine();
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/QuantizeModel.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
namespace LLama.Examples.Examples
namespace LLama.Examples.Examples
{
public class QuantizeModel
{
public static void Run()
public static async Task Run()
{
string inputPath = UserSettings.GetModelPath();

Expand Down
9 changes: 6 additions & 3 deletions LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama;
using LLama.Common;
using LLama.Native;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;

Expand Down Expand Up @@ -35,7 +36,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
GpuLayerCount = config.GpuLayerCount ?? 20,
Embeddings = true,
MainGpu = config.MainGpu,
SplitMode = config.SplitMode
SplitMode = config.SplitMode,
PoolingType = LLamaPoolingType.Mean,
};
_weights = LLamaWeights.LoadFromFile(@params);
_embedder = new LLamaEmbedder(_weights, @params);
Expand All @@ -59,7 +61,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
GpuLayerCount = config.GpuLayerCount ?? 20,
Embeddings = true,
MainGpu = config.MainGpu,
SplitMode = config.SplitMode
SplitMode = config.SplitMode,
PoolingType = LLamaPoolingType.Mean,
};
_weights = weights;
_embedder = new LLamaEmbedder(_weights, @params);
Expand Down Expand Up @@ -92,7 +95,7 @@ public void Dispose()
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
return new Embedding(embeddings);
return new Embedding(embeddings.First());
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

namespace LLamaSharp.SemanticKernel.TextEmbedding;

public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
public sealed class LLamaSharpEmbeddingGeneration
: ITextEmbeddingGenerationService
{
private readonly LLamaEmbedder _embedder;

Expand All @@ -23,7 +24,7 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<st
var result = new List<ReadOnlyMemory<float>>();

foreach (var item in data)
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));
result.Add((await _embedder.GetEmbeddings(item, cancellationToken)).First());

return result;
}
Expand Down
39 changes: 35 additions & 4 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Xunit.Abstractions;

namespace LLama.Unittest;
Expand All @@ -24,19 +26,19 @@ private async Task CompareEmbeddings(string modelPath)
{
ContextSize = 8,
Threads = 4,
Embeddings = true,
GpuLayerCount = Constants.CIGpuLayerCount,
PoolingType = LLamaPoolingType.Mean,
};
using var weights = LLamaWeights.LoadFromFile(@params);
using var embedder = new LLamaEmbedder(weights, @params);

var cat = await embedder.GetEmbeddings("The cat is cute");
var cat = (await embedder.GetEmbeddings("The cat is cute")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, cat);

var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
var kitten = (await embedder.GetEmbeddings("The kitten is cute")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, kitten);

var spoon = await embedder.GetEmbeddings("The spoon is not real");
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, spoon);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
Expand Down Expand Up @@ -64,4 +66,33 @@ public async Task EmbedCompareGenerateModel()
{
await CompareEmbeddings(Constants.GenerativeModelPath);
}

private async Task NonPooledEmbeddings(string modelPath)
{
var @params = new ModelParams(modelPath)
{
ContextSize = 8,
Threads = 4,
GpuLayerCount = Constants.CIGpuLayerCount,
PoolingType = LLamaPoolingType.None,
};
using var weights = LLamaWeights.LoadFromFile(@params);
using var embedder = new LLamaEmbedder(weights, @params);

var kitten = await embedder.GetEmbeddings("the kitten is kawaii");
foreach (var embd in kitten)
Assert.DoesNotContain(float.NaN, embd);
}

[Fact]
public async Task EmbeddingModelNonPooledEmbeddings()
{
await NonPooledEmbeddings(Constants.EmbeddingModelPath);
}

[Fact]
public async Task GenerativeModelNonPooledEmbeddings()
{
await NonPooledEmbeddings(Constants.GenerativeModelPath);
}
}
126 changes: 126 additions & 0 deletions LLama/Extensions/SpanNormalizationExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
using System;
using System.Numerics.Tensors;

namespace LLama.Extensions;

/// <summary>
/// Extensions to span which apply <b>in-place</b> normalization
/// </summary>
public static class SpanNormalizationExtensions
{
/// <summary>
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
/// </summary>
/// <param name="vector"></param>
/// <returns>The same array</returns>
public static float[] MaxAbsoluteNormalization(this float[] vector)
{
vector.AsSpan().MaxAbsoluteNormalization();
return vector;
}

/// <summary>
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
/// </summary>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static Span<float> MaxAbsoluteNormalization(this Span<float> vector)
{
var factor = 32760 / TensorPrimitives.MaxMagnitude(vector);
TensorPrimitives.Multiply(vector, factor, vector);
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element in the array by the sum of absolute values in the array
/// </summary>
/// <remarks>Also known as "Manhattan normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same array</returns>
public static float[] TaxicabNormalization(this float[] vector)
{
vector.AsSpan().TaxicabNormalization();
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element in the span by the sum of absolute values in the span
/// </summary>
/// <remarks>Also known as "Manhattan normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static Span<float> TaxicabNormalization(this Span<float> vector)
{
var sumAbs = TensorPrimitives.SumOfMagnitudes(vector);
TensorPrimitives.Divide(vector, sumAbs, vector);
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element by the euclidean length of the vector
/// </summary>
/// <remarks>Also known as "L2 normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same array</returns>
public static float[] EuclideanNormalization(this float[] vector)
{
vector.AsSpan().EuclideanNormalization();
return vector;
}

/// <summary>
/// <b>In-place</b> divide every element by the euclidean length of the vector
/// </summary>
/// <remarks>Also known as "L2 normalization".</remarks>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static Span<float> EuclideanNormalization(this Span<float> vector)
{
var norm = TensorPrimitives.Norm(vector);
TensorPrimitives.Divide(vector, norm, vector);
return vector;
}

/// <summary>
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
/// <list type="bullet">
/// <item>For p = 1, this is taxicab normalization</item>
/// <item>For p = 2, this is euclidean normalization</item>
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
/// </list>
/// </summary>
/// <param name="vector"></param>
/// <param name="p"></param>
/// <returns>The same array</returns>
public static float[] PNormalization(this float[] vector, int p)
{
vector.AsSpan().PNormalization(p);
return vector;
}

/// <summary>
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
/// <list type="bullet">
/// <item>For p = 1, this is taxicab normalization</item>
/// <item>For p = 2, this is euclidean normalization</item>
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
/// </list>
/// </summary>
/// <param name="vector"></param>
/// <param name="p"></param>
/// <returns>The same span</returns>
public static Span<float> PNormalization(this Span<float> vector, int p)
{
if (p == 2)
return vector.EuclideanNormalization();

var sum = 0.0;
for (var i = 0; i < vector.Length; i++)
sum += MathF.Pow(vector[i], p);
var divisor = (float)Math.Pow(sum, 1.0 / p);

TensorPrimitives.Divide(vector, divisor, vector);

return vector;
}
}
22 changes: 22 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,28 @@ public bool ShouldAddBosToken()
}

#region eval overloads
/// <summary>
/// </summary>
/// <param name="batch"></param>
public EncodeResult Encode(LLamaBatch batch)
{
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (EncodeResult)NativeHandle.Encode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
public Task<EncodeResult> EncodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => Encode(batch), cancellationToken);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
Expand Down
Loading
Loading