diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index a8b29509f..631b87136 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -26,7 +26,7 @@ public unsafe partial class NativeApi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty);
+ public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty);
///
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
@@ -38,7 +38,7 @@ public unsafe partial class NativeApi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
+ public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
///
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
@@ -118,7 +118,7 @@ public unsafe partial class NativeApi
/// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu);
+ public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu);
///
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
@@ -130,7 +130,7 @@ public unsafe partial class NativeApi
/// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu);
+ public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu);
///
/// Selects the token with the highest probability.
diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs
index fde2311b1..d67ac9a8c 100644
--- a/LLama/Native/SamplingApi.cs
+++ b/LLama/Native/SamplingApi.cs
@@ -25,10 +25,12 @@ public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDa
///
///
///
- public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty)
+ public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty);
+ using var last_tokens_handle = last_tokens.Pin();
+
+ NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty);
}
///
@@ -40,10 +42,12 @@ public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, L
///
///
///
- public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
+ public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence);
+ using var last_tokens_handle = last_tokens.Pin();
+
+ NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence);
}
///
@@ -128,10 +132,7 @@ public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTok
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- fixed(float* pmu = &mu)
- {
- return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu);
- }
+ return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
}
///
@@ -146,10 +147,7 @@ public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- fixed (float* pmu = &mu)
- {
- return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu);
- }
+ return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
}
///