diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index dbd1b593a..18e0168b9 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -70,7 +70,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
///
/// Current "mu" value for mirostate sampling
///
- protected float MirostateMu { get; set; } = float.NaN;
+ protected float? MirostateMu { get; set; }
///
///
@@ -392,7 +392,7 @@ public class ExecutorBaseState
public int LastTokensCapacity { get; set; }
[JsonPropertyName("mirostate_mu")]
- public float MirostateMu { get; set; }
+ public float? MirostateMu { get; set; }
}
}
}
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index 2bd31199f..4bffda397 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -231,7 +231,7 @@ public void LoadState(State state)
///
///
///
- public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
+ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id;
@@ -242,30 +242,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu,
}
else
{
- if (float.IsNaN(mirostat_mu))
- mirostat_mu = 2 * mirostatTau;
-
- if (mirostat == MirostatType.Mirostat)
- {
- const int mirostat_m = 100;
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
- }
- else if (mirostat == MirostatType.Mirostat2)
- {
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
- }
- else
+ var mu = mirostat_mu ?? (2 * mirostatTau);
{
- // Temperature sampling
- SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
- SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
- SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
- SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token(_ctx, candidates);
+ if (mirostat == MirostatType.Mirostat)
+ {
+ const int mirostat_m = 100;
+ SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
+ id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
+ }
+ else if (mirostat == MirostatType.Mirostat2)
+ {
+ SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
+ id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
+ }
+ else
+ {
+ // Temperature sampling
+ SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
+ SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
+ SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
+ SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
+ SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
+ id = SamplingApi.llama_sample_token(_ctx, candidates);
+ }
}
+ mirostat_mu = mu;
}
return id;
}
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 06b5159d2..088e10663 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -57,7 +57,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
- var mu = float.NaN;
+ var mu = (float?)null;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{