Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchedExecutor Save/Load #681

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class ExampleRunner
{ "Semantic Kernel: Prompt", SemanticKernelPrompt.Run },
{ "Semantic Kernel: Chat", SemanticKernelChat.Run },
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
Expand Down
108 changes: 108 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorSaveAndLoad
{
private const int n_len = 18;

public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Create a conversation
var conversation = executor.Create();
conversation.Prompt(prompt);

// Run inference loop
var decoder = new StreamingTokenDecoder(executor.Context);
var sampler = new DefaultSamplingPipeline();
var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);

// Can't save a conversation while RequiresInference is true
if (conversation.RequiresInference)
await executor.Infer();

// Save this conversation to a file and dispose it
conversation.Save("demo_conversation.state");
conversation.Dispose();
AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes");

// Now create a new conversation by loading that state
conversation = executor.Load("demo_conversation.state");
AnsiConsole.WriteLine("Loaded state");

// Prompt it again with the last token, so we can continue generating
conversation.Rewind(1);
conversation.Prompt(lastToken);

// Continue generating text
lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);

// Can't save a conversation while RequiresInference is true
if (conversation.RequiresInference)
await executor.Infer();

// Save the conversation again, this time into system memory
using (var state = conversation.Save())
{
conversation.Dispose();
AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes");

// Now create a new conversation by loading that state
conversation = executor.Load("demo_conversation.state");
AnsiConsole.WriteLine("Loaded state");
}

// Prompt it again with the last token, so we can continue generating
conversation.Rewind(1);
conversation.Prompt(lastToken);

// Continue generating text
await GenerateTokens(executor, conversation, sampler, decoder, n_len);

// Display final ouput
AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]");
}

private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15)
{
var token = (LLamaToken)0;

for (var i = 0; i < count; i++)
{
// Run inference
await executor.Infer();

// Use sampling pipeline to pick a token
token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty);

// Add it to the decoder, so it can be converted into text later
decoder.Add(token);

// Prompt the conversation with the token
conversation.Prompt(token);
}

return token;
}
}
33 changes: 33 additions & 0 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,39 @@ public Conversation Create()
return new Conversation(this, GetNextSequenceId());
}

/// <summary>
/// Load a conversation that was previously saved to a file. Once loaded the conversation will
/// need to be prompted.
/// </summary>
/// <param name="filepath"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public Conversation Load(string filepath)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var conversation = Create();
conversation.Load(filepath);
return conversation;
}

/// <summary>
/// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
/// </summary>
/// <param name="state"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public Conversation Load(Conversation.State state)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var conversation = Create();
conversation.Load(state);
return conversation;
}

/// <summary>
/// Run inference for all conversations in the batch which have pending tokens.
///
Expand Down
173 changes: 169 additions & 4 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text.Json;
using LLama.Native;

namespace LLama.Batched;
Expand All @@ -14,7 +15,7 @@ public sealed class Conversation
{
private ulong _requiredEpoch;
private LLamaPos _end;
private int _batchIndex;
private int _batchSampleIndex;
private bool _disposed;
private bool _forked;

Expand Down Expand Up @@ -107,7 +108,7 @@ public Conversation Fork()
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
_batchIndex = _batchIndex,
_batchSampleIndex = _batchSampleIndex,
_forked = true,

_end = _end,
Expand Down Expand Up @@ -140,7 +141,7 @@ public Span<float> Sample()
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();

var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);

// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
Expand Down Expand Up @@ -220,7 +221,7 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens)

// Add the prompt to the batch
for (var i = 0; i < tokens.Length; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
_batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
Expand Down Expand Up @@ -350,4 +351,168 @@ public void Divide(LLamaPos start, LLamaPos end, int divisor)
/// <returns>The new end token position</returns>
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
#endregion

#region save/load
private void AssertCanLoad()
{
AssertNotDisposed();
if (_end.Value > 0)
throw new InvalidOperationException("Cannot load into a non-empty conversation");
}

private void AssertCanSave()
{
AssertNotDisposed();
if (RequiresInference)
throw new CannotSaveWhileRequiresInferenceException();
}


/// <summary>
/// Save the complete state of this conversation to a file. if the file already exists it will be overwritten.
/// </summary>
/// <param name="filepath"></param>
/// <exception cref="CannotSaveWhileRequiresInferenceException"></exception>
public void Save(string filepath)
{
AssertCanSave();

// Prepare extra state to put into file header
var state = GetState();
var bytes = JsonSerializer.SerializeToUtf8Bytes(state);

// Save extra state along with the KV cache
Executor.Context.SaveState(filepath, ConversationId, bytes);
}

/// <summary>
/// Save the complete state of this conversation in system memory.
/// </summary>
/// <returns></returns>
public State Save()
{
AssertCanSave();

return new PrivateState(
Executor.Context.GetState(ConversationId),
GetState()
);
}


/// <summary>
/// Load state from a file
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
/// </summary>
/// <param name="filepath"></param>
/// <exception cref="InvalidOperationException"></exception>
internal void Load(string filepath)
{
AssertCanLoad();

// Load the state from file into the KV cache
Executor.Context.LoadState(filepath, ConversationId, out var header);

// deserialize the extra state in the file header
var state = JsonSerializer.Deserialize<SerializableConversationState>(header);
if (state == null)
{
Dispose();
throw new InvalidOperationException("Failed to deserialize - deserialized header state was null");
}

Load(state);
}

/// <summary>
/// Load state from a previously saved state.
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
/// </summary>
/// <param name="state"></param>
internal void Load(State state)
{
AssertCanLoad();

// There is only one class that extends State and it is PrivateState, so this cast is safe.
var priv = (PrivateState)state;

// Load the state from file into the KV cache
Executor.Context.LoadState(priv.SequenceState, ConversationId);

Load(priv.ConversationState);
}


private void Load(SerializableConversationState state)
{
if (state.Version != 1)
throw new InvalidOperationException("Failed to deserialize - mismatched version number");

// Load extra conversation state
_end = state.TokenCount;
}

private SerializableConversationState GetState()
{
return new SerializableConversationState(
Version: 1,
TokenCount: TokenCount
);
}


private record SerializableConversationState(int Version, int TokenCount);

private sealed class PrivateState
: State
{
public readonly LLamaContext.SequenceState SequenceState;
public readonly SerializableConversationState ConversationState;

public override ulong Size => SequenceState.Size;

public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState)
{
SequenceState = sequenceState;
ConversationState = conversationState;
}

/// <inheritdoc />
public override void Dispose()
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(State));
IsDisposed = true;

SequenceState.Dispose();
}
}

/// <summary>
/// In memory saved state of a <see cref="Conversation"/>
/// </summary>
public abstract class State
: IDisposable
{
/// <summary>
/// Indicates if this state has been disposed
/// </summary>
public bool IsDisposed { get; protected set; }

/// <summary>
/// Get the size in bytes of this state object
/// </summary>
public abstract ulong Size { get; }

/// <inheritdoc />
public abstract void Dispose();

/// <summary>
/// Internal constructor prevent anyone outside of LLamaSharp extending this class
/// </summary>
internal State()
{
}
}
#endregion
}
Loading
Loading