Skip to content

Commit

Permalink
Using a nullable float instead of NaN, this should fix the serializat…
Browse files Browse the repository at this point in the history
…ion issue reported in #85
  • Loading branch information
martindevans committed Aug 7, 2023
1 parent 250c024 commit be52737
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
4 changes: 2 additions & 2 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// <summary>
/// Current "mu" value for mirostate sampling
/// </summary>
protected float MirostateMu { get; set; } = float.NaN;
protected float? MirostateMu { get; set; }

/// <summary>
///
Expand Down Expand Up @@ -392,7 +392,7 @@ public class ExecutorBaseState
public int LastTokensCapacity { get; set; }

[JsonPropertyName("mirostate_mu")]
public float MirostateMu { get; set; }
public float? MirostateMu { get; set; }
}
}
}
47 changes: 24 additions & 23 deletions LLama/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id;
Expand All @@ -242,30 +242,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu,
}
else
{
if (float.IsNaN(mirostat_mu))
mirostat_mu = 2 * mirostatTau;

if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}
else
var mu = mirostat_mu ?? (2 * mirostatTau);
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
}
}
mirostat_mu = mu;
}
return id;
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;

var mu = float.NaN;
var mu = (float?)null;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
Expand Down

0 comments on commit be52737

Please sign in to comment.