diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index 7bcbaf7be..a7ac6e8e8 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -7,6 +7,7 @@ using System.IO; using System.Runtime.CompilerServices; using System.Text; +using static LLama.InteractiveExecutor; using static LLama.LLamaTransforms; namespace LLamaSharp.SemanticKernel.ChatCompletion; @@ -22,6 +23,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletionService private readonly ITextStreamTransform outputTransform; private readonly Dictionary _attributes = new(); + private readonly bool _isStatefulExecutor; public IReadOnlyDictionary Attributes => this._attributes; @@ -42,6 +44,7 @@ public LLamaSharpChatCompletion(ILLamaExecutor model, ITextStreamTransform? outputTransform = null) { this._model = model; + this._isStatefulExecutor = this._model is StatefulExecutorBase; this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); this.historyTransform = historyTransform ?? new HistoryTransform(); this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", @@ -67,8 +70,8 @@ public async Task> GetChatMessageContentsAsync var settings = executionSettings != null ? ChatRequestSettings.FromRequestSettings(executionSettings) : defaultRequestSettings; - var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); + string prompt = this._getFormattedPrompt(chatHistory); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); var output = outputTransform.TransformAsync(result); @@ -88,8 +91,8 @@ public async IAsyncEnumerable GetStreamingChatMessa var settings = executionSettings != null ? ChatRequestSettings.FromRequestSettings(executionSettings) : defaultRequestSettings; - var prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); + string prompt = this._getFormattedPrompt(chatHistory); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); var output = outputTransform.TransformAsync(result); @@ -99,4 +102,33 @@ public async IAsyncEnumerable GetStreamingChatMessa yield return new StreamingChatMessageContent(AuthorRole.Assistant, token); } } + + /// + /// Return either the entire formatted chatHistory or just the most recent message based on + /// whether the model extends StatefulExecutorBase or not. + /// + /// + /// The formatted prompt + private string _getFormattedPrompt(ChatHistory chatHistory){ + string prompt; + if (this._isStatefulExecutor){ + InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData(); + if (state.IsPromptRun) + { + prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); + } + else + { + ChatHistory temp_history = new(); + temp_history.AddUserMessage(chatHistory.Last().Content); + prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory()); + } + } + else + { + prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); + } + + return prompt; + } }