Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement context shifting in executor base #714

Merged
merged 3 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Exceptions;
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -521,6 +521,17 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> la
return candidates_p;
}

/// <summary>
/// Gets whether or not the Bos token should be added.
/// From common.cpp https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/common/common.cpp#L2417
/// </summary>
/// <returns></returns>
public bool ShouldAddBosToken()
{
var addBos = NativeApi.llama_add_bos_token(NativeHandle.ModelHandle);
return addBos != -1 ? Convert.ToBoolean(addBos) : NativeHandle.LLamaVocabType == LLamaVocabType.SentencePiece;
}

#region eval overloads
/// <summary>
/// </summary>
Expand Down
13 changes: 7 additions & 6 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
Expand Down Expand Up @@ -195,13 +195,14 @@
// if we run out of context:
// - take the tokensToKeep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
int n_left = _pastTokensCount - tokensToKeep;
var n_left = _pastTokensCount - tokensToKeep;
var n_discard = n_left / 2;

_pastTokensCount = Math.Max(1, tokensToKeep);

// insert n_left/2 tokens at the start of embed from last_n_tokens
_embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip((int)Context.ContextSize - n_left / 2 - _embeds.Count));
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep, tokensToKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensToKeep + n_discard, _pastTokensCount, -n_discard);

_pastTokensCount -= n_discard;

// stop saving session if we run out of context
_pathSession = string.Empty;
}
Expand Down Expand Up @@ -419,13 +420,13 @@
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public LLamaToken[] Embeds { get; set; }

Check warning on line 423 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'Embeds' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("embd_inps")]
public LLamaToken[] EmbedInps { get; set; }

Check warning on line 426 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'EmbedInps' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("session_tokens")]
public LLamaToken[] SessionTokens { get; set; }

Check warning on line 429 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'SessionTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("last_n_tokens")]
public LLamaToken[] LastTokens { get; set; }
Expand Down
7 changes: 5 additions & 2 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
using System;
Expand Down Expand Up @@ -106,7 +106,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -147,11 +147,11 @@
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand Down Expand Up @@ -186,7 +186,10 @@
_is_prompt_run = false;
if (_pastTokensCount + _embeds.Count > Context.ContextSize)
{
HandleRunOutOfContext(inferenceParams.TokensKeep);
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
// Instruct always uses input token size.
var tokensToKeep = _embed_inps.Count;
HandleRunOutOfContext(tokensToKeep);
}

TryReuseMatchingPrefix();
Expand Down
16 changes: 14 additions & 2 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Common;
using LLama.Common;
using LLama.Native;
using LLama.Abstractions;
using System;
Expand Down Expand Up @@ -98,7 +98,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -159,7 +159,7 @@
{
foreach (var image in Images)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, image));

Check warning on line 162 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Dereference of a possibly null reference.

Check warning on line 162 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Dereference of a possibly null reference.
}

int imageIndex = text.IndexOf("<image>");
Expand Down Expand Up @@ -196,11 +196,11 @@
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 199 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 199 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 203 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 203 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand Down Expand Up @@ -231,7 +231,19 @@
_is_prompt_run = false;
if (_pastTokensCount + _embeds.Count > Context.ContextSize)
{
HandleRunOutOfContext(inferenceParams.TokensKeep);
// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
var tokensToKeep = inferenceParams.TokensKeep;
if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count)
{
tokensToKeep = _embed_inps.Count;
}
else
{
tokensToKeep += Convert.ToInt32(Context.ShouldAddBosToken()); // always keep the BOS token
}

HandleRunOutOfContext(tokensToKeep);
}

TryReuseMatchingPrefix();
Expand All @@ -247,7 +259,7 @@

// Images
foreach( var image in _imageEmbedHandles )
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);

Check warning on line 262 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Dereference of a possibly null reference.

// Post-image Tokens
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
Expand Down Expand Up @@ -280,7 +292,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 295 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

LLamaToken id;
Expand Down
22 changes: 18 additions & 4 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Abstractions;
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -144,11 +144,25 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep - 1;
var canAddBos = Context.ShouldAddBosToken();
var tokensKeep = inferenceParams.TokensKeep;

// number of tokens to keep when resetting context
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
if (tokensKeep < 0 || tokensKeep > tokens.Count)
{
tokensKeep = tokens.Count;
}
else
{
tokensKeep += Convert.ToInt32(canAddBos);
}

var n_left = n_past - tokensKeep;
var n_discard = n_left / 2;

NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, tokensKeep , tokensKeep + n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, tokensKeep + n_discard, n_past, -n_discard);

n_past -= n_discard;
}
Expand Down
4 changes: 3 additions & 1 deletion LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
Expand All @@ -19,6 +19,8 @@ public sealed class SafeLLamaContextHandle
/// </summary>
public int VocabCount => ThrowIfDisposed().VocabCount;

public LLamaVocabType LLamaVocabType => ThrowIfDisposed().VocabType;

/// <summary>
/// Total number of tokens in the context
/// </summary>
Expand Down
Loading