Skip to content

Commit

Permalink
Merge pull request #664 from SignalRT/LLavaResetOnImageChange
Browse files Browse the repository at this point in the history
Llava Initial approach to clear images
  • Loading branch information
SignalRT authored Apr 16, 2024
2 parents 274ab6e + 0cf6073 commit 399e81d
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 42 deletions.
31 changes: 14 additions & 17 deletions LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Examples.Examples
{
Expand All @@ -19,12 +18,8 @@ public static async Task Run()

var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";

var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
GpuLayerCount = 10
};
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);

Expand All @@ -47,16 +42,16 @@ public static async Task Run()
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;

if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();

List<byte[]> imageBytes;
try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
}
catch (IOException exception)
{
Expand All @@ -69,15 +64,17 @@ public static async Task Run()
break;
}

// Each prompt with images we clear cache
// When the prompt contains images we clear KV_CACHE to restart conversation
// See:
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
}


Expand All @@ -102,7 +99,7 @@ public static async Task Run()
//
foreach (var image in imagePaths)
{
ex.Images.Add(File.ReadAllBytes(image));
ex.Images.Add(await File.ReadAllBytesAsync(image));
}
}

Expand All @@ -118,7 +115,7 @@ public static async Task Run()

// let the user finish with exit
//
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

}
Expand Down
2 changes: 1 addition & 1 deletion LLama/Abstractions/ILLamaExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public interface ILLamaExecutor
public LLavaWeights? ClipModel { get; }

/// <summary>
/// List of images: Image filen path, uri or image byte array. See ImageData.
/// List of images: List of images in byte array format.
/// </summary>
public List<byte[]> Images { get; }

Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public bool IsMultiModal
public LLavaWeights? ClipModel { get; }

/// <inheritdoc />
public List<byte[]> Images { get; set; }
public List<byte[]> Images { get; }

/// <summary>
/// Current "mu" value for mirostat sampling
Expand Down
31 changes: 25 additions & 6 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
using LLama.Exceptions;
using LLama.Extensions;
using Microsoft.Extensions.Logging;
using System.Net.Http;


namespace LLama
{
Expand Down Expand Up @@ -136,20 +136,29 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
text += "\n";
}

var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
}
}

return Task.CompletedTask;
}

/// <inheritdoc />
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
{
int usedTokens = 0;

// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt && ClipModel != null)
if (_imageInPrompt && IsMultiModal )
{
foreach (var image in Images)
{
Expand All @@ -170,7 +179,16 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
}
else
{
_embed_inps = Context.Tokenize(text, true).ToList();
if (addBos)
{
_embed_inps = Context.Tokenize(text, true).ToList();
}
else
{
var line_inp = Context.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
}
return Task.CompletedTask;
}
Expand Down Expand Up @@ -239,6 +257,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta

_EmbedImagePosition = -1;
_imageEmbedHandles.Clear();
Images.Clear();
}
else
{
Expand Down
32 changes: 17 additions & 15 deletions docs/Examples/LLavaInteractiveModeExecute.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

```cs
using System.Text.RegularExpressions;
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
using LLama.Native;

namespace LLama.Examples.Examples
{
Expand All @@ -21,11 +21,8 @@ namespace LLama.Examples.Examples

var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";

var parameters = new ModelParams(modelPath)
{
ContextSize = 4096,
Seed = 1337,
};
var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);

Expand All @@ -48,16 +45,16 @@ namespace LLama.Examples.Examples
var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imageCount = imageMatches.Count();
var hasImages = imageCount > 0;
byte[][] imageBytes = null;

if (hasImages)
{
var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value);
var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();

List<byte[]> imageBytes;
try
{
imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray();
imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
}
catch (IOException exception)
{
Expand All @@ -70,15 +67,17 @@ namespace LLama.Examples.Examples
break;
}

// Each prompt with images we clear cache
// When the prompt contains images we clear KV_CACHE to restart conversation
// See:
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

int index = 0;
foreach (var path in imagePathsWithCurlyBraces)
{
// First image replace to tag <image, the rest of the images delete the tag
if (index++ == 0)
prompt = prompt.Replace(path, "<image>");
else
prompt = prompt.Replace(path, "");
prompt = prompt.Replace(path, index++ == 0 ? "<image>" : "");
}


Expand All @@ -101,7 +100,10 @@ namespace LLama.Examples.Examples

// Initilize Images in executor
//
ex.ImagePaths = imagePaths.ToList();
foreach (var image in imagePaths)
{
ex.Images.Add(await File.ReadAllBytesAsync(image));
}
}

Console.ForegroundColor = Color.White;
Expand All @@ -116,7 +118,7 @@ namespace LLama.Examples.Examples

// let the user finish with exit
//
if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

}
Expand Down
4 changes: 2 additions & 2 deletions docs/Tutorials/Executors.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ public interface ILLamaExecutor
public LLavaWeights? ClipModel { get; }

/// <summary>
/// List of images: Image filename and path (jpeg images).
/// List of images: List of images in byte array format.
/// </summary>
public List<string> ImagePaths { get; set; }
public List<byte[]> Images { get; }


/// <summary>
Expand Down

0 comments on commit 399e81d

Please sign in to comment.