Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi Context #90

Merged
merged 13 commits into from
Aug 17, 2023
Merged

Conversation

martindevans
Copy link
Member

@martindevans martindevans commented Aug 8, 2023

Added higher level multi-context support - multiple contexts sharing the same model weights. Added a new demo "TalkToYourself" where you can see this in use.

The most major single change is renaming LLamaModel to LLamaContext, to more closely follow the naming of the original project.

@saddam213
Copy link
Collaborator

This approch works really well :)

Like your code comments suggest we will need a wrapper class for the ModelHandle, I personally think LLamaWeight is also good name,

We could even wrap LLamaWeight and a collection of LLamaContext at a slightly higher level to ensure all contexts are closed and disposed when a LLamaWeight is unloaded.

This would make more sense in another PR as an appropiate name for that class would probably be LLamaModel and to avoid nightmare merging that could be further down the track, I think it would be nice if we managed it a bit for the end users as most will be building chat like apps, and tracking contexts will be a common feature

Amazing work man!!

@martindevans
Copy link
Member Author

Thanks for the feedback! I'll continue on developing this a bit more, getting rid of some of those todo comments.

@martindevans
Copy link
Member Author

I think this is ready for review now

@martindevans
Copy link
Member Author

Memory Management Thoughts

At the moment the SafeLlamaModelHandle is reference counted. The LLamaWeights is one reference and each LLamaContext is another reference. This means that the weights will only be unloaded from memory when the LLamaWeights and all of the LLamaContext objects are disposed.

// Load weights
using var weights = LLamaWeights.LoadFromFile(@params);

// Unload weights
weights.Dispose();

// This will fail, because the weights are unloaded
var ctx = weights.CreateContext(@params, Encoding.UTF8);

However this is a bit surprising:

// Load weights
using var weights = LLamaWeights.LoadFromFile(@params);

// Create a context, increasing reference count by one
var ctx = weights.CreateContext(@params, Encoding.UTF8);

// Unload weights
weights.Dispose();

// You can use ctx here!

// Now the weights will be unloaded
ctx.Dispose();

In some ways this is actually quite convenient, contexts automatically keep the weights loaded. However, in other ways this means that calling Dispose() on the weights does not free up all the memory, which is odd!

Should this be changed: i.e. should weights.Dispose() immediately unload them from memory and invalidate all the contexts?

@martindevans
Copy link
Member Author

I added a couple of extra tests. Our test coverage is still terrible, but it's a start!

@SignalRT
Copy link
Collaborator

SignalRT commented Aug 9, 2023

I test it on a MacOS running in GPU (metal) and this test fails:

[Fact]
public void EmbedCompare()
{
var cat = _embedder.GetEmbeddings("cat");
var kitten = _embedder.GetEmbeddings("kitten");
var spoon = _embedder.GetEmbeddings("spoon");

        var close = Dot(cat, kitten);
        var far = Dot(cat, spoon);

        Assert.True(close < far);
    }

I will try to review if it works on CPU and where is the problem.

@martindevans
Copy link
Member Author

Thanks for testing, I have no way to test these things on MacOS.

If you're debugging it try printing out the generated vectors, most likely guess is it's returning all zeros for some reason.

@SignalRT
Copy link
Collaborator

SignalRT commented Aug 9, 2023

I debugged this and close (13.xxx) is not < than far (8.yyy). I will try to debug this properly in the next days.

@martindevans
Copy link
Member Author

Huh, in that case it must just be returning the wrong vectors when the GPU is in use, which is very concerning!

In fact it might be good to add a test for that, generate a few "known good" vectors and then check that those exact vectors are generated in a Unit test

@SignalRT
Copy link
Collaborator

SignalRT commented Aug 9, 2023

Same results with CPU. So it's not a GPU / CPU problem.

I will review the returned values in the _embedder.GetEmbeddings invocation to compare the differences between platforms.

@martindevans
Copy link
Member Author

@SignalRT do you see the same problems with the code in master btw? If so then at least the MacOS test problem isn't an issue with this PR.

@SignalRT
Copy link
Collaborator

@martindevans I don´t know if this is a macOS specific issue. This is the output of the first values of "cat" on the three OS:

image

In each OS Debug and Release execution of the test (that means different executions) produces the same values.

@martindevans
Copy link
Member Author

Well that's clearly broken! I'm fairly confident it's not a problem with this PR at least - I've barely touched the Embedder, it's even using the old model loading method!

@SignalRT
Copy link
Collaborator

I would try to review this issue.

@martindevans
Copy link
Member Author

martindevans commented Aug 10, 2023

I just added an output in the CI tests, to grab the values for the "cat" embedding. They're the same as you reported.

CI Output:

  • Windows: -0.12730388,-0.67805725,-0.08524404,-0.9569152,-0.6386326...
  • Linux: -0.09917596,-0.71790683,-0.008531962,-0.9898389,-0.66339684...

@martindevans
Copy link
Member Author

I've added this test into another PR, on top of master (#97). That way we can see if this is an issue with this PR or not. If the issue can be reproduced in the other PR I'll remove the test from this one.

@martindevans
Copy link
Member Author

Ok I got exactly the same results in the new PR, so I'm going to remove the test from this PR to unblock it.

@SignalRT
Copy link
Collaborator

@martindevans It’s happening in the master. I used the same approach in my fork.

…in use in `TalkToYourself`, along with notes on what still needs improving.

The biggest single change is renaming `LLamaModel` to `LLamaContext`
 - Sanity checking that weights are not disposed when creating a context from them
 - Further simplified `Utils.InitLLamaContextFromModelParams`
 - sealed some classes not intended to be extended
@martindevans
Copy link
Member Author

@AsakusaRinne this is ready for your review. It's a pretty big change so I've held off merging it myself

@AsakusaRinne
Copy link
Collaborator

Thank you for all your contributions! Does the difference of embeddings between windows, MAC and linux matter? I'm not sure if it affects the performance of the model.

LLama/LLamaContext.cs Show resolved Hide resolved
@AsakusaRinne
Copy link
Collaborator

LGTM, it's really a good job! @Oceania2018 In this PR the LLamaModel is renamed to LLamaContext, along with some API changes. Will it impact much on BotSharp?

@martindevans
Copy link
Member Author

Thank you for all your contributions! Does the difference of embeddings between windows, MAC and linux matter? I'm not sure if it affects the performance of the model.

We've tracked that down to a issue with llama.cpp itself, so it's not a problem with this PR. This PR #97 has a test which fails due to that bug and saddam213 reported it upstream here ggerganov/llama.cpp#2582.

@AsakusaRinne AsakusaRinne merged commit 6233185 into SciSharp:master Aug 17, 2023
4 checks passed
@martindevans martindevans deleted the proposal_multi_context branch August 17, 2023 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants