diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index dbd1b593a..2caaa8e50 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -68,9 +68,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor public LLamaModel Model => _model; /// - /// Current "mu" value for mirostate sampling + /// Current "mu" value for mirostat sampling /// - protected float MirostateMu { get; set; } = float.NaN; + protected float? MirostatMu { get; set; } /// /// @@ -391,8 +391,8 @@ public class ExecutorBaseState [JsonPropertyName("last_tokens_maximum_count")] public int LastTokensCapacity { get; set; } - [JsonPropertyName("mirostate_mu")] - public float MirostateMu { get; set; } + [JsonPropertyName("mirostat_mu")] + public float? MirostatMu { get; set; } } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index e055c1475..5c976b534 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -53,7 +53,7 @@ public override ExecutorBaseState GetStateData() SessionFilePath = _pathSession, SessionTokens = _session_tokens, LastTokensCapacity = _last_n_tokens.Capacity, - MirostateMu = MirostateMu + MirostatMu = MirostatMu }; return state; } @@ -216,12 +216,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostateMu; + var mu = MirostatMu; var id = _model.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); - MirostateMu = mu; + MirostatMu = mu; _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index f5c1583ec..6a75e126d 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -45,7 +45,7 @@ public override ExecutorBaseState GetStateData() SessionFilePath = _pathSession, SessionTokens = _session_tokens, LastTokensCapacity = _last_n_tokens.Capacity, - MirostateMu = MirostateMu + MirostatMu = MirostatMu }; return state; } @@ -203,12 +203,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostateMu; + var mu = MirostatMu; var id = _model.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP ); - MirostateMu = mu; + MirostatMu = mu; _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 2bd31199f..4bffda397 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -231,7 +231,7 @@ public void LoadState(State state) /// /// /// - public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, + public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) { llama_token id; @@ -242,30 +242,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, } else { - if (float.IsNaN(mirostat_mu)) - mirostat_mu = 2 * mirostatTau; - - if (mirostat == MirostatType.Mirostat) - { - const int mirostat_m = 100; - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); - } - else if (mirostat == MirostatType.Mirostat2) - { - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); - } - else + var mu = mirostat_mu ?? (2 * mirostatTau); { - // Temperature sampling - SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); - SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); - SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); - SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token(_ctx, candidates); + if (mirostat == MirostatType.Mirostat) + { + const int mirostat_m = 100; + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); + } + else if (mirostat == MirostatType.Mirostat2) + { + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu); + } + else + { + // Temperature sampling + SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); + SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); + SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); + SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token(_ctx, candidates); + } } + mirostat_mu = mu; } return id; } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 06b5159d2..088e10663 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -57,7 +57,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams lastTokens.AddRange(tokens); n_past += n_prompt_tokens; - var mu = float.NaN; + var mu = (float?)null; int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(int i = 0; i < max_tokens; i++) {