diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 7d0c0f91d..4f8e7d94a 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,4 +1,4 @@ -using LLama.Exceptions; +using LLama.Exceptions; using LLama.Native; using System; using System.Collections.Generic; @@ -521,6 +521,17 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la return candidates_p; } + /// + /// Gets whether or not the Bos token should be added. + /// From common.cpp https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/common/common.cpp#L2417 + /// + /// + public bool ShouldAddBosToken() + { + var addBos = NativeApi.llama_add_bos_token(NativeHandle.ModelHandle); + return addBos != -1 ? Convert.ToBoolean(addBos) : NativeHandle.LLamaVocabType == LLamaVocabType.SentencePiece; + } + #region eval overloads /// /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 70081c9f5..9700eb0e8 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -1,4 +1,4 @@ -using LLama.Abstractions; +using LLama.Abstractions; using LLama.Common; using LLama.Exceptions; using LLama.Native; @@ -195,13 +195,14 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) // if we run out of context: // - take the tokensToKeep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches - int n_left = _pastTokensCount - tokensToKeep; + var n_left = _pastTokensCount - tokensToKeep; + var n_discard = n_left / 2; - _pastTokensCount = Math.Max(1, tokensToKeep); - - // insert n_left/2 tokens at the start of embed from last_n_tokens - _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip((int)Context.ContextSize - n_left / 2 - _embeds.Count)); + NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard); + NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard); + _pastTokensCount -= n_discard; + // stop saving session if we run out of context _pathSession = string.Empty; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 65d2d6c77..5b2530969 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,4 +1,4 @@ -using LLama.Abstractions; +using LLama.Abstractions; using LLama.Common; using LLama.Native; using System; @@ -186,7 +186,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta _is_prompt_run = false; if (_pastTokensCount + _embeds.Count > Context.ContextSize) { - HandleRunOutOfContext(inferenceParams.TokensKeep); + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + // Instruct always uses input token size. + var tokensToKeep = _embed_inps.Count; + HandleRunOutOfContext(tokensToKeep); } TryReuseMatchingPrefix(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index fec4f9c4e..226b18ef9 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,4 +1,4 @@ -using LLama.Common; +using LLama.Common; using LLama.Native; using LLama.Abstractions; using System; @@ -231,7 +231,19 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta _is_prompt_run = false; if (_pastTokensCount + _embeds.Count > Context.ContextSize) { - HandleRunOutOfContext(inferenceParams.TokensKeep); + // number of tokens to keep when resetting context + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + var tokensToKeep = inferenceParams.TokensKeep; + if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) + { + tokensToKeep = _embed_inps.Count; + } + else + { + tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token + } + + HandleRunOutOfContext(tokensToKeep); } TryReuseMatchingPrefix(); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 39d74f905..ab5f41469 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -1,4 +1,4 @@ -using LLama.Abstractions; +using LLama.Abstractions; using LLama.Common; using System; using System.Collections.Generic; @@ -144,11 +144,25 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 if (n_past + tokens.Count >= Context.ContextSize) { - var n_left = n_past - inferenceParams.TokensKeep - 1; + var canAddBos = Context.ShouldAddBosToken(); + var tokensKeep = inferenceParams.TokensKeep; + + // number of tokens to keep when resetting context + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + if (tokensKeep < 0 || tokensKeep > tokens.Count) + { + tokensKeep = tokens.Count; + } + else + { + tokensKeep += Convert.ToInt32(canAddBos); + } + + var n_left = n_past - tokensKeep; var n_discard = n_left / 2; - NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); - NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); + NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard); + NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard); n_past -= n_discard; } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 51568769d..13a7aa1b2 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; @@ -19,6 +19,8 @@ public sealed class SafeLLamaContextHandle /// public int VocabCount => ThrowIfDisposed().VocabCount; + public LLamaVocabType LLamaVocabType => ThrowIfDisposed().VocabType; + /// /// Total number of tokens in the context ///