Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization error due to NaN #88

Merged
merged 2 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
public LLamaModel Model => _model;

/// <summary>
/// Current "mu" value for mirostate sampling
/// Current "mu" value for mirostat sampling
/// </summary>
protected float MirostateMu { get; set; } = float.NaN;
protected float? MirostatMu { get; set; }

/// <summary>
///
Expand Down Expand Up @@ -391,8 +391,8 @@ public class ExecutorBaseState
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }

[JsonPropertyName("mirostate_mu")]
public float MirostateMu { get; set; }
[JsonPropertyName("mirostat_mu")]
public float? MirostatMu { get; set; }
}
}
}
6 changes: 3 additions & 3 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
InputSuffixTokens = _inp_sfx,
MatchingSessionTokensCount = _n_matching_session_tokens,
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,

Check warning on line 53 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference assignment.
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
MirostatMu = MirostatMu
};
return state;
}
Expand Down Expand Up @@ -216,12 +216,12 @@
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostateMu;
var mu = MirostatMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
MirostatMu = mu;

_last_n_tokens.Enqueue(id);

Expand Down
6 changes: 3 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
LLamaNewlineTokens = _llama_token_newline,
MatchingSessionTokensCount = _n_matching_session_tokens,
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,

Check warning on line 45 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-debug)

Possible null reference assignment.
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
MirostatMu = MirostatMu
};
return state;
}
Expand Down Expand Up @@ -203,12 +203,12 @@
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostateMu;
var mu = MirostatMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
MirostatMu = mu;

_last_n_tokens.Enqueue(id);

Expand Down
47 changes: 24 additions & 23 deletions LLama/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id;
Expand All @@ -242,30 +242,31 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu,
}
else
{
if (float.IsNaN(mirostat_mu))
mirostat_mu = 2 * mirostatTau;

if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}
else
var mu = mirostat_mu ?? (2 * mirostatTau);
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
}
}
mirostat_mu = mu;
}
return id;
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;

var mu = float.NaN;
var mu = (float?)null;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
Expand Down
Loading