diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index dbd1b593a..2caaa8e50 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -68,9 +68,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
public LLamaModel Model => _model;
///
- /// Current "mu" value for mirostate sampling
+ /// Current "mu" value for mirostat sampling
///
- protected float MirostateMu { get; set; } = float.NaN;
+ protected float? MirostatMu { get; set; }
///
///
@@ -391,8 +391,8 @@ public class ExecutorBaseState
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
- [JsonPropertyName("mirostate_mu")]
- public float MirostateMu { get; set; }
+ [JsonPropertyName("mirostat_mu")]
+ public float? MirostatMu { get; set; }
}
}
}
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index e055c1475..5c976b534 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -53,7 +53,7 @@ public override ExecutorBaseState GetStateData()
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
- MirostateMu = MirostateMu
+ MirostatMu = MirostatMu
};
return state;
}
@@ -216,12 +216,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var mu = MirostateMu;
+ var mu = MirostatMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
- MirostateMu = mu;
+ MirostatMu = mu;
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index f5c1583ec..6a75e126d 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -45,7 +45,7 @@ public override ExecutorBaseState GetStateData()
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
- MirostateMu = MirostateMu
+ MirostatMu = MirostatMu
};
return state;
}
@@ -203,12 +203,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var mu = MirostateMu;
+ var mu = MirostatMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
- MirostateMu = mu;
+ MirostatMu = mu;
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index 2bd31199f..4bffda397 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -231,7 +231,7 @@ public void LoadState(State state)
///
///
///
- public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
+ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id;
@@ -242,30 +242,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu,
}
else
{
- if (float.IsNaN(mirostat_mu))
- mirostat_mu = 2 * mirostatTau;
-
- if (mirostat == MirostatType.Mirostat)
- {
- const int mirostat_m = 100;
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
- }
- else if (mirostat == MirostatType.Mirostat2)
- {
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
- }
- else
+ var mu = mirostat_mu ?? (2 * mirostatTau);
{
- // Temperature sampling
- SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
- SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
- SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
- SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token(_ctx, candidates);
+ if (mirostat == MirostatType.Mirostat)
+ {
+ const int mirostat_m = 100;
+ SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
+ id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
+ }
+ else if (mirostat == MirostatType.Mirostat2)
+ {
+ SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
+ id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
+ }
+ else
+ {
+ // Temperature sampling
+ SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
+ SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
+ SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
+ SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
+ SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
+ id = SamplingApi.llama_sample_token(_ctx, candidates);
+ }
}
+ mirostat_mu = mu;
}
return id;
}
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 06b5159d2..088e10663 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -57,7 +57,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
- var mu = float.NaN;
+ var mu = (float?)null;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{