diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index dbd1b593a..18e0168b9 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -70,7 +70,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// /// Current "mu" value for mirostate sampling /// - protected float MirostateMu { get; set; } = float.NaN; + protected float? MirostateMu { get; set; } /// /// @@ -392,7 +392,7 @@ public class ExecutorBaseState public int LastTokensCapacity { get; set; } [JsonPropertyName("mirostate_mu")] - public float MirostateMu { get; set; } + public float? MirostateMu { get; set; } } } } diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 2bd31199f..4bffda397 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -231,7 +231,7 @@ public void LoadState(State state) /// /// /// - public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, + public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) { llama_token id; @@ -242,30 +242,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, } else { - if (float.IsNaN(mirostat_mu)) - mirostat_mu = 2 * mirostatTau; - - if (mirostat == MirostatType.Mirostat) - { - const int mirostat_m = 100; - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); - } - else if (mirostat == MirostatType.Mirostat2) - { - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); - } - else + var mu = mirostat_mu ?? (2 * mirostatTau); { - // Temperature sampling - SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); - SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); - SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); - SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token(_ctx, candidates); + if (mirostat == MirostatType.Mirostat) + { + const int mirostat_m = 100; + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); + } + else if (mirostat == MirostatType.Mirostat2) + { + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu); + } + else + { + // Temperature sampling + SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); + SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); + SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); + SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); + SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); + id = SamplingApi.llama_sample_token(_ctx, candidates); + } } + mirostat_mu = mu; } return id; } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 06b5159d2..088e10663 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -57,7 +57,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams lastTokens.AddRange(tokens); n_past += n_prompt_tokens; - var mu = float.NaN; + var mu = (float?)null; int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(int i = 0; i < max_tokens; i++) {