Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Semantic Kernel LLamaSharpChatCompletion when running with StatefulExecutorBase models #671

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.IO;
using System.Runtime.CompilerServices;
using System.Text;
using static LLama.InteractiveExecutor;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
Expand All @@ -22,6 +23,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService
private readonly ITextStreamTransform outputTransform;

private readonly Dictionary<string, object?> _attributes = new();
private readonly bool _isStatefulExecutor;

public IReadOnlyDictionary<string, object?> Attributes => this._attributes;

Expand All @@ -42,6 +44,7 @@ public LLamaSharpChatCompletion(ILLamaExecutor model,
ITextStreamTransform? outputTransform = null)
{
this._model = model;
this._isStatefulExecutor = this._model is StatefulExecutorBase;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
Expand All @@ -67,8 +70,8 @@ public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

string prompt = this._getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = outputTransform.TransformAsync(result);
Expand All @@ -88,8 +91,8 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
var settings = executionSettings != null
? ChatRequestSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());

string prompt = this._getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = outputTransform.TransformAsync(result);
Expand All @@ -99,4 +102,33 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
yield return new StreamingChatMessageContent(AuthorRole.Assistant, token);
}
}

/// <summary>
/// Return either the entire formatted chatHistory or just the most recent message based on
/// whether the model extends StatefulExecutorBase or not.
/// </summary>
/// <param name="chatHistory"></param>
/// <returns>The formatted prompt</returns>
private string _getFormattedPrompt(ChatHistory chatHistory){
string prompt;
if (this._isStatefulExecutor){
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData();
if (state.IsPromptRun)
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}
else
{
ChatHistory temp_history = new();
temp_history.AddUserMessage(chatHistory.Last().Content);
prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory());
}
}
else
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}

return prompt;
}
}
Loading