Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a converter similar to the Open AI one #315

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Microsoft.SemanticKernel.AI;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

Expand All @@ -8,47 +10,104 @@ public class ChatRequestSettings : AIRequestSettings
/// Temperature controls the randomness of the completion.
/// The higher the temperature, the more random the completion.
/// </summary>
[JsonPropertyName("temperature")]
public double Temperature { get; set; } = 0;

/// <summary>
/// TopP controls the diversity of the completion.
/// The higher the TopP, the more diverse the completion.
/// </summary>
[JsonPropertyName("top_p")]
public double TopP { get; set; } = 0;

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on whether they appear in the text so far, increasing the
/// model's likelihood to talk about new topics.
/// </summary>
[JsonPropertyName("presence_penalty")]
public double PresencePenalty { get; set; } = 0;

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on their existing frequency in the text so far, decreasing
/// the model's likelihood to repeat the same line verbatim.
/// </summary>
[JsonPropertyName("frequency_penalty")]
public double FrequencyPenalty { get; set; } = 0;

/// <summary>
/// Sequences where the completion will stop generating further tokens.
/// </summary>
[JsonPropertyName("stop_sequences")]
public IList<string> StopSequences { get; set; } = Array.Empty<string>();

/// <summary>
/// How many completions to generate for each prompt. Default is 1.
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.
/// Use carefully and ensure that you have reasonable settings for max_tokens and stop.
/// </summary>
[JsonPropertyName("results_per_prompt")]
public int ResultsPerPrompt { get; set; } = 1;

/// <summary>
/// The maximum number of tokens to generate in the completion.
/// </summary>
[JsonPropertyName("max_tokens")]
public int? MaxTokens { get; set; }

/// <summary>
/// Modify the likelihood of specified tokens appearing in the completion.
/// </summary>
[JsonPropertyName("token_selection_biases")]
public IDictionary<int, int> TokenSelectionBiases { get; set; } = new Dictionary<int, int>();

/// <summary>
/// Create a new settings object with the values from another settings object.
/// </summary>
/// <param name="requestSettings">Template configuration</param>
/// <param name="defaultMaxTokens">Default max tokens</param>
/// <returns>An instance of OpenAIRequestSettings</returns>
public static ChatRequestSettings FromRequestSettings(AIRequestSettings? requestSettings, int? defaultMaxTokens = null)
{
if (requestSettings is null)
{
return new ChatRequestSettings()
{
MaxTokens = defaultMaxTokens
};
}

if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings)
{
return requestSettingsChatRequestSettings;
}

var json = JsonSerializer.Serialize(requestSettings);
var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options);

if (chatRequestSettings is not null)
{
return chatRequestSettings;
}

throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings));
}

private static readonly JsonSerializerOptions s_options = CreateOptions();

private static JsonSerializerOptions CreateOptions()
{
JsonSerializerOptions options = new()
{
WriteIndented = true,
MaxDepth = 20,
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new ChatRequestSettingsConverter() }
};

return options;
}
}
105 changes: 105 additions & 0 deletions LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// </summary>
public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
{
/// <inheritdoc/>
public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new ChatRequestSettings();

while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
string? propertyName = reader.GetString();

if (propertyName is not null)
{
// normalise property name to uppercase
propertyName = propertyName.ToUpperInvariant();
}

reader.Read();

switch (propertyName)
{
case "TEMPERATURE":
requestSettings.Temperature = reader.GetDouble();
break;
case "TOPP":
case "TOP_P":
requestSettings.TopP = reader.GetDouble();
break;
case "FREQUENCYPENALTY":
case "FREQUENCY_PENALTY":
requestSettings.FrequencyPenalty = reader.GetDouble();
break;
case "PRESENCEPENALTY":
case "PRESENCE_PENALTY":
requestSettings.PresencePenalty = reader.GetDouble();
break;
case "MAXTOKENS":
case "MAX_TOKENS":
requestSettings.MaxTokens = reader.GetInt32();
break;
case "STOPSEQUENCES":
case "STOP_SEQUENCES":
requestSettings.StopSequences = JsonSerializer.Deserialize<IList<string>>(ref reader, options) ?? Array.Empty<string>();
break;
case "RESULTSPERPROMPT":
case "RESULTS_PER_PROMPT":
requestSettings.ResultsPerPrompt = reader.GetInt32();
break;
case "TOKENSELECTIONBIASES":
case "TOKEN_SELECTION_BIASES":
requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize<IDictionary<int, int>>(ref reader, options) ?? new Dictionary<int, int>();
break;
case "SERVICEID":
case "SERVICE_ID":
requestSettings.ServiceId = reader.GetString();
break;
default:
reader.Skip();
break;
}
}
}

return requestSettings;
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options)
{
writer.WriteStartObject();

writer.WriteNumber("temperature", value.Temperature);
writer.WriteNumber("top_p", value.TopP);
writer.WriteNumber("frequency_penalty", value.FrequencyPenalty);
writer.WriteNumber("presence_penalty", value.PresencePenalty);
if (value.MaxTokens is null)
{
writer.WriteNull("max_tokens");
}
else
{
writer.WriteNumber("max_tokens", (decimal)value.MaxTokens);
}
writer.WritePropertyName("stop_sequences");
JsonSerializer.Serialize(writer, value.StopSequences, options);
writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt);
writer.WritePropertyName("token_selection_biases");
JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options);
writer.WriteString("service_id", value.ServiceId);

writer.WriteEndObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public ChatHistory CreateNewChat(string? instructions = "")
public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var settings = requestSettings != null
? (ChatRequestSettings)requestSettings
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;

// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
Expand All @@ -76,7 +76,7 @@ public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsA
#pragma warning restore CS1998
{
var settings = requestSettings != null
? (ChatRequestSettings)requestSettings
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;

// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public LLamaSharpTextCompletion(ILLamaExecutor executor)

public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default)
{
var settings = (ChatRequestSettings?)requestSettings;
var settings = ChatRequestSettings.FromRequestSettings(requestSettings);
var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
return await Task.FromResult(new List<ITextResult> { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false);
}
Expand All @@ -30,7 +30,7 @@ public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, A
public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default)
#pragma warning restore CS1998
{
var settings = (ChatRequestSettings?)requestSettings;
var settings = ChatRequestSettings.FromRequestSettings(requestSettings);
var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
yield return new LLamaTextResult(result);
}
Expand Down
1 change: 1 addition & 0 deletions LLama.Unittest/LLama.Unittest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
</Target>

<ItemGroup>
<ProjectReference Include="..\LLama.SemanticKernel\LLamaSharp.SemanticKernel.csproj" />
<ProjectReference Include="..\LLama\LLamaSharp.csproj" />
</ItemGroup>

Expand Down
107 changes: 107 additions & 0 deletions LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using LLamaSharp.SemanticKernel.ChatCompletion;
using System.Text.Json;

namespace LLama.Unittest.SemanticKernel
{
public class ChatRequestSettingsConverterTests
{
[Fact]
public void ChatRequestSettingsConverter_DeserializeWithDefaults()
{
// Arrange
var options = new JsonSerializerOptions();
options.Converters.Add(new ChatRequestSettingsConverter());
var json = "{}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
Assert.Equal(0, requestSettings.FrequencyPenalty);
Assert.Null(requestSettings.MaxTokens);
Assert.Equal(0, requestSettings.PresencePenalty);
Assert.Equal(1, requestSettings.ResultsPerPrompt);
Assert.NotNull(requestSettings.StopSequences);
Assert.Empty(requestSettings.StopSequences);
Assert.Equal(0, requestSettings.Temperature);
Assert.NotNull(requestSettings.TokenSelectionBiases);
Assert.Empty(requestSettings.TokenSelectionBiases);
Assert.Equal(0, requestSettings.TopP);
}

[Fact]
public void ChatRequestSettingsConverter_DeserializeWithSnakeCase()
{
// Arrange
var options = new JsonSerializerOptions();
options.AllowTrailingCommas = true;
options.Converters.Add(new ChatRequestSettingsConverter());
var json = @"{
""frequency_penalty"": 0.5,
""max_tokens"": 250,
""presence_penalty"": 0.5,
""results_per_prompt"": -1,
""stop_sequences"": [ ""foo"", ""bar"" ],
""temperature"": 0.5,
""token_selection_biases"": { ""1"": 2, ""3"": 4 },
""top_p"": 0.5,
}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
Assert.Equal(0.5, requestSettings.FrequencyPenalty);
Assert.Equal(250, requestSettings.MaxTokens);
Assert.Equal(0.5, requestSettings.PresencePenalty);
Assert.Equal(-1, requestSettings.ResultsPerPrompt);
Assert.NotNull(requestSettings.StopSequences);
Assert.Contains("foo", requestSettings.StopSequences);
Assert.Contains("bar", requestSettings.StopSequences);
Assert.Equal(0.5, requestSettings.Temperature);
Assert.NotNull(requestSettings.TokenSelectionBiases);
Assert.Equal(2, requestSettings.TokenSelectionBiases[1]);
Assert.Equal(4, requestSettings.TokenSelectionBiases[3]);
Assert.Equal(0.5, requestSettings.TopP);
}

[Fact]
public void ChatRequestSettingsConverter_DeserializeWithPascalCase()
{
// Arrange
var options = new JsonSerializerOptions();
options.AllowTrailingCommas = true;
options.Converters.Add(new ChatRequestSettingsConverter());
var json = @"{
""FrequencyPenalty"": 0.5,
""MaxTokens"": 250,
""PresencePenalty"": 0.5,
""ResultsPerPrompt"": -1,
""StopSequences"": [ ""foo"", ""bar"" ],
""Temperature"": 0.5,
""TokenSelectionBiases"": { ""1"": 2, ""3"": 4 },
""TopP"": 0.5,
}";

// Act
var requestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, options);

// Assert
Assert.NotNull(requestSettings);
Assert.Equal(0.5, requestSettings.FrequencyPenalty);
Assert.Equal(250, requestSettings.MaxTokens);
Assert.Equal(0.5, requestSettings.PresencePenalty);
Assert.Equal(-1, requestSettings.ResultsPerPrompt);
Assert.NotNull(requestSettings.StopSequences);
Assert.Contains("foo", requestSettings.StopSequences);
Assert.Contains("bar", requestSettings.StopSequences);
Assert.Equal(0.5, requestSettings.Temperature);
Assert.NotNull(requestSettings.TokenSelectionBiases);
Assert.Equal(2, requestSettings.TokenSelectionBiases[1]);
Assert.Equal(4, requestSettings.TokenSelectionBiases[3]);
Assert.Equal(0.5, requestSettings.TopP);
}
}
}
Loading
Loading