Skip to content

Commit

Permalink
Merge pull request #88 from martindevans/fix_serialization_nan
Browse files Browse the repository at this point in the history
Fix serialization error due to NaN
  • Loading branch information
martindevans authored Aug 8, 2023
2 parents f612275 + b5de3ee commit 270c6d5
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 34 deletions.
8 changes: 4 additions & 4 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
public LLamaModel Model => _model;

/// <summary>
/// Current "mu" value for mirostate sampling
/// Current "mu" value for mirostat sampling
/// </summary>
protected float MirostateMu { get; set; } = float.NaN;
protected float? MirostatMu { get; set; }

/// <summary>
///
Expand Down Expand Up @@ -391,8 +391,8 @@ public class ExecutorBaseState
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }

[JsonPropertyName("mirostate_mu")]
public float MirostateMu { get; set; }
[JsonPropertyName("mirostat_mu")]
public float? MirostatMu { get; set; }
}
}
}
6 changes: 3 additions & 3 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public override ExecutorBaseState GetStateData()
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
MirostatMu = MirostatMu
};
return state;
}
Expand Down Expand Up @@ -214,12 +214,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostateMu;
var mu = MirostatMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
MirostatMu = mu;

_last_n_tokens.Enqueue(id);

Expand Down
6 changes: 3 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public override ExecutorBaseState GetStateData()
SessionFilePath = _pathSession,

Check warning on line 45 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-debug)

Possible null reference assignment.
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
MirostatMu = MirostatMu
};
return state;
}
Expand Down Expand Up @@ -203,12 +203,12 @@ protected override void InferInternal(IInferenceParams inferenceParams, InferSta
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostateMu;
var mu = MirostatMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
MirostatMu = mu;

_last_n_tokens.Enqueue(id);

Expand Down
47 changes: 24 additions & 23 deletions LLama/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id;
Expand All @@ -239,30 +239,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu,
}
else
{
if (float.IsNaN(mirostat_mu))
mirostat_mu = 2 * mirostatTau;

if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}
else
var mu = mirostat_mu ?? (2 * mirostatTau);
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
}
}
mirostat_mu = mu;
}
return id;
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;

var mu = float.NaN;
var mu = (float?)null;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
Expand Down

0 comments on commit 270c6d5

Please sign in to comment.