Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow customized search path for native library loading. #333

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,24 @@ public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams?
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);

History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));

if (_executor is InteractiveExecutor executor)
// TODO: need to be refactored.
if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: prompt;
History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt));
var converted_prompt = HistoryTransform.HistoryToText(History);
// Avoid missing anti-prompt.
if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n"))
{
prompt = converted_prompt.Trim();
}
else
{
prompt = converted_prompt;
}
}
else
{
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
}

StringBuilder sb = new();
Expand Down
20 changes: 13 additions & 7 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;

Expand Down Expand Up @@ -128,7 +129,7 @@
{
return string.Empty;
}
return versionNode.GetString();

Check warning on line 132 in LLama/Native/NativeApi.Load.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference return.

Check warning on line 132 in LLama/Native/NativeApi.Load.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference return.
}
}
catch (Exception)
Expand Down Expand Up @@ -258,6 +259,7 @@
enableLogging = configuration.Logging;
// We move the flag to avoid loading library when the variable is called else where.
NativeLibraryConfig.LibraryHasLoaded = true;
Log(configuration.ToString(), LogLevel.Information);

if (!string.IsNullOrEmpty(configuration.Path))
{
Expand All @@ -273,26 +275,30 @@

var libraryTryLoadOrder = GetLibraryTryOrder(configuration);

string[] preferredPaths = configuration.SearchDirectories;
string[] possiblePathPrefix = new string[] {
System.AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};

var tryFindPath = (string filename) =>
{
int i = 0;
while (!File.Exists(filename))
foreach(var path in preferredPaths)
{
if (i < possiblePathPrefix.Length)
if (File.Exists(Path.Combine(path, filename)))
{
filename = Path.Combine(possiblePathPrefix[i], filename);
i++;
return Path.Combine(path, filename);
}
else
}

foreach(var path in possiblePathPrefix)
{
if (File.Exists(Path.Combine(path, filename)))
{
break;
return Path.Combine(path, filename);
}
}

return filename;
};

Expand Down
73 changes: 70 additions & 3 deletions LLama/Native/NativeLibraryConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace LLama.Native
{
Expand Down Expand Up @@ -27,6 +29,10 @@ public sealed class NativeLibraryConfig
private bool _allowFallback = true;
private bool _skipCheck = false;
private bool _logging = false;
/// <summary>
/// search directory -> priority level, 0 is the lowest.
/// </summary>
private List<string> _searchDirectories = new List<string>();

private static void ThrowIfLoaded()
{
Expand Down Expand Up @@ -120,13 +126,50 @@ public NativeLibraryConfig WithLogs(bool enable = true)
return this;
}

/// <summary>
/// Add self-defined search directories. Note that the file stucture of the added
/// directories must be the same as the default directory. Besides, the directory
/// won't be used recursively.
/// </summary>
/// <param name="directories"></param>
/// <returns></returns>
public NativeLibraryConfig WithSearchDirectories(IEnumerable<string> directories)
{
ThrowIfLoaded();

_searchDirectories.AddRange(directories);
return this;
}

/// <summary>
/// Add self-defined search directories. Note that the file stucture of the added
/// directories must be the same as the default directory. Besides, the directory
/// won't be used recursively.
/// </summary>
/// <param name="directory"></param>
/// <returns></returns>
public NativeLibraryConfig WithSearchDirectory(string directory)
{
ThrowIfLoaded();

_searchDirectories.Add(directory);
return this;
}

internal static Description CheckAndGatherDescription()
{
if (Instance._allowFallback && Instance._skipCheck)
{
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
}
return new Description(Instance._libraryPath, Instance._useCuda, Instance._avxLevel, Instance._allowFallback, Instance._skipCheck, Instance._logging);
return new Description(
Instance._libraryPath,
Instance._useCuda,
Instance._avxLevel,
Instance._allowFallback,
Instance._skipCheck,
Instance._logging,
Instance._searchDirectories.Concat(new string[] { "./" }).ToArray());
}

internal static string AvxLevelToString(AvxLevel level)
Expand Down Expand Up @@ -183,7 +226,31 @@ public enum AvxLevel
Avx512,
}

internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging);
internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, string[] SearchDirectories)
{
public override string ToString()
{
string avxLevelString = AvxLevel switch
{
AvxLevel.None => "NoAVX",
AvxLevel.Avx => "AVX",
AvxLevel.Avx2 => "AVX2",
AvxLevel.Avx512 => "AVX512",
_ => "Unknown"
};

string searchDirectoriesString = "{ " + string.Join(", ", SearchDirectories) + " }";

return $"NativeLibraryConfig Description:\n" +
$"- Path: {Path}\n" +
$"- PreferCuda: {UseCuda}\n" +
$"- PreferredAvxLevel: {avxLevelString}\n" +
$"- AllowFallback: {AllowFallback}\n" +
$"- SkipCheck: {SkipCheck}\n" +
$"- Logging: {Logging}\n" +
$"- SearchDirectories and Priorities: {searchDirectoriesString}";
}
}
}
#endif
}
}
Loading