Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two small improvements to the native sampling API #124

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions LLama/Native/NativeApi.Sampling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public unsafe partial class NativeApi
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty);
public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty);

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
Expand All @@ -38,7 +38,7 @@ public unsafe partial class NativeApi
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);

/// <summary>
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
Expand Down Expand Up @@ -118,7 +118,7 @@ public unsafe partial class NativeApi
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu);
public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu);

/// <summary>
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
Expand All @@ -130,7 +130,7 @@ public unsafe partial class NativeApi
/// <param name="mu">Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu);
public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu);

/// <summary>
/// Selects the token with the highest probability.
Expand Down
22 changes: 10 additions & 12 deletions LLama/Native/SamplingApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDa
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty)
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty);
}

/// <summary>
Expand All @@ -40,10 +42,12 @@ public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, L
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence);
}

/// <summary>
Expand Down Expand Up @@ -128,10 +132,7 @@ public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTok
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
fixed(float* pmu = &mu)
{
return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu);
}
return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
}

/// <summary>
Expand All @@ -146,10 +147,7 @@ public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
fixed (float* pmu = &mu)
{
return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu);
}
return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
}

/// <summary>
Expand Down
Loading