Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DefaultInferenceParams to Kernel Memory #307

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion LLama.Examples/Examples/KernelMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ public static async Task Run()
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var memory = new KernelMemoryBuilder()
.WithLLamaSharpDefaults(new LLamaSharpConfig(modelPath))
.WithLLamaSharpDefaults(new LLamaSharpConfig(modelPath)
{
DefaultInferenceParams = new Common.InferenceParams
{
AntiPrompts = new List<string> { "\n\n" }
}
})
.With(new TextPartitioningOptions
{
MaxTokensPerParagraph = 300,
Expand Down
3 changes: 2 additions & 1 deletion LLama.Examples/Examples/Runner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public static async Task Run()
AnsiConsole.Write(new Rule(choice));
await example();
}

Console.WriteLine("Press any key to continue...");
Console.ReadKey();
AnsiConsole.Clear();
}
}
Expand Down
2 changes: 1 addition & 1 deletion LLama.KernelMemory/BuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public static KernelMemoryBuilder WithLLamaSharpDefaults(this KernelMemoryBuilde
var executor = new StatelessExecutor(weights, parameters);
var embedder = new LLamaEmbedder(weights, parameters);
builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGeneration(embedder));
builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor));
builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor, config?.DefaultInferenceParams));
return builder;
}
}
Expand Down
9 changes: 8 additions & 1 deletion LLama.KernelMemory/LlamaSharpConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
Expand Down Expand Up @@ -39,5 +40,11 @@ public LLamaSharpConfig(string modelPath)
/// Gets or sets the number of GPU layers.
/// </summary>
public int? GpuLayerCount { get; set; }


/// <summary>
/// Set the default inference parameters.
/// </summary>
public InferenceParams? DefaultInferenceParams { get; set; }
}
}
43 changes: 29 additions & 14 deletions LLama.KernelMemory/LlamaSharpTextGeneration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ namespace LLamaSharp.KernelMemory
/// </summary>
public class LlamaSharpTextGeneration : ITextGeneration, IDisposable
{
private readonly LLamaSharpConfig? _config;
private readonly LLamaWeights _weights;
private readonly StatelessExecutor _executor;
private readonly LLamaContext _context;
private readonly InferenceParams? _defaultInferenceParams;
private bool _ownsContext = false;
private bool _ownsWeights = false;

Expand All @@ -28,7 +28,6 @@ public class LlamaSharpTextGeneration : ITextGeneration, IDisposable
/// <param name="config">The configuration for LLamaSharp.</param>
public LlamaSharpTextGeneration(LLamaSharpConfig config)
{
this._config = config;
var parameters = new ModelParams(config.ModelPath)
{
ContextSize = config?.ContextSize ?? 2048,
Expand All @@ -38,6 +37,7 @@ public LlamaSharpTextGeneration(LLamaSharpConfig config)
_weights = LLamaWeights.LoadFromFile(parameters);
_context = _weights.CreateContext(parameters);
_executor = new StatelessExecutor(_weights, parameters);
_defaultInferenceParams = config?.DefaultInferenceParams;
_ownsWeights = _ownsContext = true;
}

Expand All @@ -48,12 +48,12 @@ public LlamaSharpTextGeneration(LLamaSharpConfig config)
/// <param name="weights">A LLamaWeights object.</param>
/// <param name="context">A LLamaContext object.</param>
/// <param name="executor">An executor. Currently only StatelessExecutor is expected.</param>
public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null)
public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null, InferenceParams? inferenceParams = null)
{
_config = null;
_weights = weights;
_context = context;
_executor = executor ?? new StatelessExecutor(_weights, _context.Params);
_defaultInferenceParams = inferenceParams;
}

/// <inheritdoc/>
Expand All @@ -72,20 +72,35 @@ public void Dispose()
/// <inheritdoc/>
public IAsyncEnumerable<string> GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default)
{
return _executor.InferAsync(prompt, OptionsToParams(options), cancellationToken: cancellationToken);
return _executor.InferAsync(prompt, OptionsToParams(options, this._defaultInferenceParams), cancellationToken: cancellationToken);
}

private static InferenceParams OptionsToParams(TextGenerationOptions options)
private static InferenceParams OptionsToParams(TextGenerationOptions options, InferenceParams? defaultParams)
{
return new InferenceParams()
if (defaultParams != null)
{
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
Temperature = (float)options.Temperature,
MaxTokens = options.MaxTokens ?? 1024,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.TopP,
};
return defaultParams with
{
AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(),
Temperature = options.Temperature == defaultParams.Temperature ? defaultParams.Temperature : (float)options.Temperature,
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
FrequencyPenalty = options.FrequencyPenalty == defaultParams.FrequencyPenalty ? defaultParams.FrequencyPenalty : (float)options.FrequencyPenalty,
PresencePenalty = options.PresencePenalty == defaultParams.PresencePenalty ? defaultParams.PresencePenalty : (float)options.PresencePenalty,
TopP = options.TopP == defaultParams.TopP ? defaultParams.TopP : (float)options.TopP
};
}
else
{
return new InferenceParams()
{
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
Temperature = (float)options.Temperature,
MaxTokens = options.MaxTokens ?? 1024,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.TopP,
};
}
}
}
}
Loading