From 024787225be3754bb13f469f944251e8a98db3c6 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 17 Mar 2024 19:54:20 +0000 Subject: [PATCH] `SetDllImportResolver` based loading (#603) - Modified library loading to be based on `SetDllImportResolver`. This replaces the built in loading system and ensures there can't be two libraries loaded at once. - llava and llama are loaded separately, as needed. - All the previous loading logic is still used, within the `SetDllImportResolver` - Split out CUDA, AVX and MacOS paths to separate helper methods. - `Description` now specifies if it is for `llama` or `llava` --- LLama.Examples/Program.cs | 7 +- LLama/Native/NativeApi.Load.cs | 392 +++++++++++++++++----------- LLama/Native/NativeLibraryConfig.cs | 59 ++++- 3 files changed, 298 insertions(+), 160 deletions(-) diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 54d40fb54..f8c7ba608 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -16,11 +16,14 @@ __ __ ____ __ """); +// Configure native library to use NativeLibraryConfig .Instance .WithCuda() - .WithLogs(LLamaLogLevel.Warning); + .WithLogs(LLamaLogLevel.Info); +// Calling this method forces loading to occur now. NativeApi.llama_empty_call(); -await ExampleRunner.Run(); \ No newline at end of file +await ExampleRunner.Run(); + diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 1b1868161..28f0564d9 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -1,11 +1,9 @@ using LLama.Exceptions; -using Microsoft.Extensions.Logging; using System; -using System.Collections.Generic; -using System.Diagnostics; using System.IO; using System.Runtime.InteropServices; using System.Text.Json; +using System.Collections.Generic; namespace LLama.Native { @@ -13,9 +11,14 @@ public static partial class NativeApi { static NativeApi() { - // Try to load a preferred library, based on CPU feature detection - TryLoadLibrary(); - + // Overwrite the Dll import resolver for this assembly. The resolver gets + // called by the runtime every time that a call into a DLL is required. The + // resolver returns the loaded DLL handle. This allows us to take control of + // which llama.dll is used. + SetDllImportResolver(); + + // Immediately make a call which requires loading the llama DLL. This method call + // can't fail unless the DLL hasn't been loaded. try { llama_empty_call(); @@ -30,39 +33,97 @@ static NativeApi() "4. Try to compile llama.cpp yourself to generate a libllama library, then use `LLama.Native.NativeLibraryConfig.WithLibrary` " + "to specify it at the very beginning of your code. For more informations about compilation, please refer to LLamaSharp repo on github.\n"); } + + // Init llama.cpp backend llama_backend_init(); } - private static void Log(string message, LogLevel level) +#if NET5_0_OR_GREATER + private static IntPtr _loadedLlamaHandle; + private static IntPtr _loadedLlavaSharedHandle; +#endif + + private static void SetDllImportResolver() + { + // NativeLibrary is not available on older runtimes. We'll have to depend on + // the normal runtime dll resolution there. +#if NET5_0_OR_GREATER + NativeLibrary.SetDllImportResolver(typeof(NativeApi).Assembly, (name, _, _) => + { + if (name == "llama") + { + // If we've already loaded llama return the handle that was loaded last time. + if (_loadedLlamaHandle != IntPtr.Zero) + return _loadedLlamaHandle; + + // Try to load a preferred library, based on CPU feature detection + _loadedLlamaHandle = TryLoadLibraries(LibraryName.Llama); + return _loadedLlamaHandle; + } + + if (name == "llava_shared") + { + // If we've already loaded llava return the handle that was loaded last time. + if (_loadedLlavaSharedHandle != IntPtr.Zero) + return _loadedLlavaSharedHandle; + + // Try to load a preferred library, based on CPU feature detection + _loadedLlavaSharedHandle = TryLoadLibraries(LibraryName.LlavaShared); + return _loadedLlavaSharedHandle; + } + + // Return null pointer to indicate that nothing was loaded. + return IntPtr.Zero; + }); +#endif + } + + private static void Log(string message, LLamaLogLevel level) { if (!enableLogging) return; - if ((int)level < (int)logLevel) + if ((int)level > (int)logLevel) return; - ConsoleColor color; - string levelPrefix; - if (level == LogLevel.Information) - { - color = ConsoleColor.Green; - levelPrefix = "[Info]"; - } - else if (level == LogLevel.Error) + var fg = Console.ForegroundColor; + var bg = Console.BackgroundColor; + try { - color = ConsoleColor.Red; - levelPrefix = "[Error]"; + ConsoleColor color; + string levelPrefix; + if (level == LLamaLogLevel.Debug) + { + color = ConsoleColor.Cyan; + levelPrefix = "[Debug]"; + } + else if (level == LLamaLogLevel.Info) + { + color = ConsoleColor.Green; + levelPrefix = "[Info]"; + } + else if (level == LLamaLogLevel.Error) + { + color = ConsoleColor.Red; + levelPrefix = "[Error]"; + } + else + { + color = ConsoleColor.Yellow; + levelPrefix = "[UNK]"; + } + + Console.ForegroundColor = color; + Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}"); } - else + finally { - color = ConsoleColor.Yellow; - levelPrefix = "[Error]"; + Console.ForegroundColor = fg; + Console.BackgroundColor = bg; } - Console.ForegroundColor = color; - Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}"); - Console.ResetColor(); } + #region CUDA version private static int GetCudaMajorVersion() { string? cudaPath; @@ -131,65 +192,33 @@ private static string GetCudaVersionFromPath(string cudaPath) return string.Empty; } } + #endregion #if NET6_0_OR_GREATER - private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix, string libraryNamePrefix) - { - var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel); - if (!string.IsNullOrEmpty(avxStr)) - { - avxStr += "/"; - } - return $"{prefix}{avxStr}{libraryNamePrefix}{libraryName}{suffix}"; - } - - private static List GetLibraryTryOrder(NativeLibraryConfig.Description configuration) + private static IEnumerable GetLibraryTryOrder(NativeLibraryConfig.Description configuration) { - OSPlatform platform; - string prefix, suffix, libraryNamePrefix; - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - platform = OSPlatform.Windows; - prefix = "runtimes/win-x64/native/"; - suffix = ".dll"; - libraryNamePrefix = ""; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - platform = OSPlatform.Linux; - prefix = "runtimes/linux-x64/native/"; - suffix = ".so"; - libraryNamePrefix = "lib"; - } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - platform = OSPlatform.OSX; - suffix = ".dylib"; + var loadingName = configuration.Library.GetLibraryName(); + Log($"Loading library: '{loadingName}'", LLamaLogLevel.Debug); - prefix = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported - ? "runtimes/osx-arm64/native/" - : "runtimes/osx-x64/native/"; - libraryNamePrefix = "lib"; - } - else - { - throw new RuntimeError("Your system plarform is not supported, please open an issue in LLamaSharp."); - } - Log($"Detected OS Platform: {platform}", LogLevel.Information); + // Get platform specific parts of the path (e.g. .so/.dll/.dylib, libName prefix or not) + GetPlatformPathParts(out var platform, out var os, out var ext, out var libPrefix); + Log($"Detected OS Platform: '{platform}'", LLamaLogLevel.Info); + Log($"Detected OS string: '{os}'", LLamaLogLevel.Debug); + Log($"Detected extension string: '{ext}'", LLamaLogLevel.Debug); + Log($"Detected prefix string: '{libPrefix}'", LLamaLogLevel.Debug); - List result = new(); - if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos + if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) { - int cudaVersion = GetCudaMajorVersion(); + var cudaVersion = GetCudaMajorVersion(); + Log($"Detected cuda major version {cudaVersion}.", LLamaLogLevel.Info); - // TODO: load cuda library with avx if (cudaVersion == -1 && !configuration.AllowFallback) { // if check skipped, we just try to load cuda libraries one by one. if (configuration.SkipCheck) { - result.Add($"{prefix}cuda12/{libraryNamePrefix}{libraryName}{suffix}"); - result.Add($"{prefix}cuda11/{libraryNamePrefix}{libraryName}{suffix}"); + yield return GetCudaLibraryPath(loadingName, "cuda12"); + yield return GetCudaLibraryPath(loadingName, "cuda11"); } else { @@ -198,121 +227,167 @@ private static List GetLibraryTryOrder(NativeLibraryConfig.Description c } else if (cudaVersion == 11) { - Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information); - result.Add($"{prefix}cuda11/{libraryNamePrefix}{libraryName}{suffix}"); + yield return GetCudaLibraryPath(loadingName, "cuda11"); } else if (cudaVersion == 12) { - Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information); - result.Add($"{prefix}cuda12/{libraryNamePrefix}{libraryName}{suffix}"); + yield return GetCudaLibraryPath(loadingName, "cuda12"); } else if (cudaVersion > 0) { throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it."); } + // otherwise no cuda detected but allow fallback } - // use cpu (or mac possibly with metal) - if (!configuration.AllowFallback && platform != OSPlatform.OSX) + // Add the CPU/Metal libraries + if (platform == OSPlatform.OSX) { - result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix, libraryNamePrefix)); + // On Mac it's very simple, there's no AVX to consider. + yield return GetMacLibraryPath(loadingName); } - else if (platform != OSPlatform.OSX) // in macos there's absolutely no avx + else { - if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512) - result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix, libraryNamePrefix)); + if (configuration.AllowFallback) + { + // Try all of the AVX levels we can support. + if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512) + yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx512); - if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2) - result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix, libraryNamePrefix)); + if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2) + yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx2); - if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx) - result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix, libraryNamePrefix)); + if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx) + yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx); - result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix, libraryNamePrefix)); + yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.None); + } + else + { + // Fallback is not allowed - use the exact specified AVX level + yield return GetAvxLibraryPath(loadingName, configuration.AvxLevel); + } } + } - if (platform == OSPlatform.OSX) + private static string GetMacLibraryPath(string libraryName) + { + GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix); + + return $"runtimes/{os}/native/{libPrefix}{libraryName}{fileExtension}"; + } + + /// + /// Given a CUDA version and some path parts, create a complete path to the library file + /// + /// Library being loaded (e.g. "llama") + /// CUDA version (e.g. "cuda11") + /// + private static string GetCudaLibraryPath(string libraryName, string cuda) + { + GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix); + + return $"runtimes/{os}/native/{cuda}/{libPrefix}{libraryName}{fileExtension}"; + } + + /// + /// Given an AVX level and some path parts, create a complete path to the library file + /// + /// Library being loaded (e.g. "llama") + /// + /// + private static string GetAvxLibraryPath(string libraryName, NativeLibraryConfig.AvxLevel avx) + { + GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix); + + var avxStr = NativeLibraryConfig.AvxLevelToString(avx); + if (!string.IsNullOrEmpty(avxStr)) + avxStr += "/"; + + return $"runtimes/{os}/native/{avxStr}{libPrefix}{libraryName}{fileExtension}"; + } + + private static void GetPlatformPathParts(out OSPlatform platform, out string os, out string fileExtension, out string libPrefix) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - result.Add($"{prefix}{libraryNamePrefix}{libraryName}{suffix}"); - result.Add($"{prefix}{libraryNamePrefix}{llavaLibraryName}{suffix}"); + platform = OSPlatform.Windows; + os = "win-x64"; + fileExtension = ".dll"; + libPrefix = ""; + return; + } + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + platform = OSPlatform.Linux; + os = "linux-x64"; + fileExtension = ".so"; + libPrefix = "lib"; + return; } + + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + platform = OSPlatform.OSX; + fileExtension = ".dylib"; - return result; + os = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported + ? "osx-arm64" + : "osx-x64"; + libPrefix = "lib"; + } + else + { + throw new RuntimeError("Your operating system is not supported, please open an issue in LLamaSharp."); + } } #endif /// - /// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible + /// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible /// /// The library handle to unload later, or IntPtr.Zero if no library was loaded - private static IntPtr TryLoadLibrary() + private static IntPtr TryLoadLibraries(LibraryName lib) { #if NET6_0_OR_GREATER - var configuration = NativeLibraryConfig.CheckAndGatherDescription(); + var configuration = NativeLibraryConfig.CheckAndGatherDescription(lib); enableLogging = configuration.Logging; logLevel = configuration.LogLevel; - // We move the flag to avoid loading library when the variable is called else where. + + // Set the flag to ensure the NativeLibraryConfig can no longer be modified NativeLibraryConfig.LibraryHasLoaded = true; - Log(configuration.ToString(), LogLevel.Information); + // Show the configuration we're working with + Log(configuration.ToString(), LLamaLogLevel.Info); + + // If a specific path is requested, load that or immediately fail if (!string.IsNullOrEmpty(configuration.Path)) { - // When loading the user specified library, there's no fallback. - var success = NativeLibrary.TryLoad(configuration.Path, out var result); - if (!success) - { + if (!NativeLibrary.TryLoad(configuration.Path, out var handle)) throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified."); - } - Log($"Successfully loaded the library [{configuration.Path}] specified by user", LogLevel.Information); - return result; + + Log($"Successfully loaded the library [{configuration.Path}] specified by user", LLamaLogLevel.Info); + return handle; } + // Get a list of locations to try loading (in order of preference) var libraryTryLoadOrder = GetLibraryTryOrder(configuration); - var preferredPaths = configuration.SearchDirectories; - var possiblePathPrefix = new[] { - AppDomain.CurrentDomain.BaseDirectory, - Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" - }; - - string TryFindPath(string filename) - { - foreach (var path in preferredPaths) - { - if (File.Exists(Path.Combine(path, filename))) - { - return Path.Combine(path, filename); - } - } - - foreach (var path in possiblePathPrefix) - { - if (File.Exists(Path.Combine(path, filename))) - { - return Path.Combine(path, filename); - } - } - - return filename; - } - foreach (var libraryPath in libraryTryLoadOrder) { var fullPath = TryFindPath(libraryPath); - var result = TryLoad(fullPath, true); - if (result is not null && result != IntPtr.Zero) + Log($"Trying '{fullPath}'", LLamaLogLevel.Debug); + + var result = TryLoad(fullPath); + if (result != IntPtr.Zero) { - Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information); - - // One we have clear the detection and that llama loads successfully we load LLaVa if exist on the - // same path. - TryLoad( libraryPath.Replace("llama", "llava_shared"), true); - - return (IntPtr)result; + Log($"Loaded '{fullPath}'", LLamaLogLevel.Info); + return result; } - Log($"Tried to load {fullPath} but failed.", LogLevel.Information); + Log($"Failed Loading '{fullPath}'", LLamaLogLevel.Info); } if (!configuration.AllowFallback) @@ -325,20 +400,45 @@ string TryFindPath(string filename) #endif Log($"No library was loaded before calling native apis. " + - $"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LogLevel.Warning); + $"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LLamaLogLevel.Warning); return IntPtr.Zero; #if NET6_0_OR_GREATER - // Try to load a DLL from the path if supported. Returns null if nothing is loaded. - static IntPtr? TryLoad(string path, bool supported = true) + // Try to load a DLL from the path. + // Returns null if nothing is loaded. + static IntPtr TryLoad(string path) { - if (!supported) - return null; - if (NativeLibrary.TryLoad(path, out var handle)) return handle; - return null; + return IntPtr.Zero; + } + + // Try to find the given file in any of the possible search paths + string TryFindPath(string filename) + { + // Try the configured search directories in the configuration + foreach (var path in configuration.SearchDirectories) + { + var candidate = Path.Combine(path, filename); + if (File.Exists(candidate)) + return candidate; + } + + // Try a few other possible paths + var possiblePathPrefix = new[] { + AppDomain.CurrentDomain.BaseDirectory, + Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" + }; + + foreach (var path in possiblePathPrefix) + { + var candidate = Path.Combine(path, filename); + if (File.Exists(candidate)) + return candidate; + } + + return filename; } #endif } diff --git a/LLama/Native/NativeLibraryConfig.cs b/LLama/Native/NativeLibraryConfig.cs index 19be7c848..c08749ba9 100644 --- a/LLama/Native/NativeLibraryConfig.cs +++ b/LLama/Native/NativeLibraryConfig.cs @@ -11,19 +11,19 @@ namespace LLama.Native /// public sealed class NativeLibraryConfig { - private static readonly Lazy _instance = new(() => new NativeLibraryConfig()); - /// /// Get the config instance /// - public static NativeLibraryConfig Instance => _instance.Value; + public static NativeLibraryConfig Instance { get; } = new(); /// - /// Whether there's already a config for native library. + /// Check if the native library has already been loaded. Configuration cannot be modified if this is true. /// public static bool LibraryHasLoaded { get; internal set; } = false; - private string _libraryPath = string.Empty; + private string? _libraryPath; + private string? _libraryPathLLava; + private bool _useCuda = true; private AvxLevel _avxLevel; private bool _allowFallback = true; @@ -42,17 +42,20 @@ private static void ThrowIfLoaded() throw new InvalidOperationException("NativeLibraryConfig must be configured before using **any** other LLamaSharp methods!"); } + #region configurators /// /// Load a specified native library as backend for LLamaSharp. /// When this method is called, all the other configurations will be ignored. /// - /// + /// The full path to the llama library to load. + /// The full path to the llava library to load. /// Thrown if `LibraryHasLoaded` is true. - public NativeLibraryConfig WithLibrary(string libraryPath) + public NativeLibraryConfig WithLibrary(string? llamaPath, string? llavaPath) { ThrowIfLoaded(); - _libraryPath = libraryPath; + _libraryPath = llamaPath; + _libraryPathLLava = llavaPath; return this; } @@ -172,14 +175,23 @@ public NativeLibraryConfig WithSearchDirectory(string directory) _searchDirectories.Add(directory); return this; } + #endregion - internal static Description CheckAndGatherDescription() + internal static Description CheckAndGatherDescription(LibraryName library) { if (Instance._allowFallback && Instance._skipCheck) throw new ArgumentException("Cannot skip the check when fallback is allowed."); + var path = library switch + { + LibraryName.Llama => Instance._libraryPath, + LibraryName.LlavaShared => Instance._libraryPathLLava, + _ => throw new ArgumentException($"Unknown library name '{library}'", nameof(library)), + }; + return new Description( - Instance._libraryPath, + path, + library, Instance._useCuda, Instance._avxLevel, Instance._allowFallback, @@ -267,7 +279,7 @@ public enum AvxLevel Avx512, } - internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, LLamaLogLevel LogLevel, string[] SearchDirectories) + internal record Description(string? Path, LibraryName Library, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, LLamaLogLevel LogLevel, string[] SearchDirectories) { public override string ToString() { @@ -283,7 +295,8 @@ public override string ToString() string searchDirectoriesString = "{ " + string.Join(", ", SearchDirectories) + " }"; return $"NativeLibraryConfig Description:\n" + - $"- Path: {Path}\n" + + $"- LibraryName: {Library}\n" + + $"- Path: '{Path}'\n" + $"- PreferCuda: {UseCuda}\n" + $"- PreferredAvxLevel: {avxLevelString}\n" + $"- AllowFallback: {AllowFallback}\n" + @@ -295,4 +308,26 @@ public override string ToString() } } #endif + + internal enum LibraryName + { + Llama, + LlavaShared + } + + internal static class LibraryNameExtensions + { + public static string GetLibraryName(this LibraryName name) + { + switch (name) + { + case LibraryName.Llama: + return NativeApi.libraryName; + case LibraryName.LlavaShared: + return NativeApi.llavaLibraryName; + default: + throw new ArgumentOutOfRangeException(nameof(name), name, null); + } + } + } }