diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index ed50046d..9601579f 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -74,6 +74,7 @@ public void AddModelSettings(SerializedObject llmScriptSO) if (llmScriptSO.FindProperty("advancedOptions").boolValue) { attributeClasses.Add(typeof(ModelAdvancedAttribute)); + if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute)); } ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false); Space(); diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs index 87a40938..d3510ac8 100644 --- a/Editor/PropertyEditor.cs +++ b/Editor/PropertyEditor.cs @@ -16,17 +16,11 @@ public void AddScript(SerializedObject llmScriptSO) EditorGUILayout.PropertyField(scriptProp); } - public void AddOptionsToggle(SerializedObject llmScriptSO, string propertyName, string name) + public bool ToggleButton(string text, bool activated) { - SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty(propertyName); - string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " " + name; GUIStyle style = new GUIStyle("Button"); - if (advancedOptionsProp.boolValue) - style.normal = new GUIStyleState() { background = Texture2D.grayTexture }; - if (GUILayout.Button(toggleText, style, GUILayout.Width(buttonWidth))) - { - advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue; - } + if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture }; + return GUILayout.Button(text, style, GUILayout.Width(buttonWidth)); } public void AddSetupSettings(SerializedObject llmScriptSO) @@ -54,8 +48,12 @@ public void AddChatSettings(SerializedObject llmScriptSO) public void AddOptionsToggles(SerializedObject llmScriptSO) { LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode)); + EditorGUILayout.BeginHorizontal(); - AddOptionsToggle(llmScriptSO, "advancedOptions", "Advanced Options"); + SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions"); + string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options"; + if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue; + if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib); EditorGUILayout.EndHorizontal(); Space(); } diff --git a/README.md b/README.md index d8e79bba..bdf6350b 100644 --- a/README.md +++ b/README.md @@ -345,6 +345,7 @@ If you have loaded a model locally you need to set its URL through the expanded - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below - `Log Level` select how verbose the log messages are +- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants) #### 💻 Setup Settings @@ -381,6 +382,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Model` the path of the model being used (relative to the Assets/StreamingAssets folder) - `Chat Template` the chat template being used for the LLM - `Lora` the path of the LoRA being used (relative to the Assets/StreamingAssets folder) + - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled) @@ -395,6 +397,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below - `Log Level` select how verbose the log messages are +- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants) #### 💻 Setup Settings
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 54395e8c..9e50c513 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -22,7 +22,7 @@ public LLMException(string message, int errorCode) : base(message) } } - public class DestroyException : Exception { } + public class DestroyException : Exception {} /// \endcond [DefaultExecutionOrder(-1)] @@ -74,6 +74,8 @@ public class LLM : MonoBehaviour /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). /// Models with .gguf format are allowed. [ModelAdvanced] public string lora = ""; + /// enable use of flash attention + [ModelExtras] public bool flashAttention = false; /// \cond HIDE @@ -297,6 +299,7 @@ protected virtual string GetLlamaccpArguments() if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; arguments += loraArgument; arguments += $" -ngl {numGPULayers}"; + if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn"; return arguments; } @@ -383,7 +386,7 @@ private void StartService() { llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); - while (!llmlib.LLM_Started(LLMObject)) { } + while (!llmlib.LLM_Started(LLMObject)) {} loraWeights = new List(); for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); started = true; @@ -607,7 +610,7 @@ public async Task Slot(string json) public async Task Completion(string json, Callback streamCallback = null) { AssertStarted(); - if (streamCallback == null) streamCallback = (string s) => { }; + if (streamCallback == null) streamCallback = (string s) => {}; StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback); await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper())); if (!started) return null; diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs index a3cf0070..542520f9 100644 --- a/Runtime/LLMBuilder.cs +++ b/Runtime/LLMBuilder.cs @@ -104,15 +104,19 @@ static void AddActionAddMeta(string target) public static void HideLibraryPlatforms(string platform) { - List platforms = new List(){ "windows", "macos", "linux", "android", "ios" }; + List platforms = new List(){ "windows", "macos", "linux", "android", "ios", "setup" }; platforms.Remove(platform); foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath)) { + string sourceName = Path.GetFileName(source); foreach (string platformPrefix in platforms) { - if (Path.GetFileName(source).StartsWith(platformPrefix)) + bool move = sourceName.StartsWith(platformPrefix); + move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib); + move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib); + if (move) { - string target = Path.Combine(BuildTempDir, Path.GetFileName(source)); + string target = Path.Combine(BuildTempDir, sourceName); MoveAction(source, target); MoveAction(source + ".meta", target + ".meta"); } diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index edb60601..e3501769 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -326,8 +326,16 @@ public static List PossibleArchitectures(bool gpu = false) { if (gpu) { - architectures.Add("cuda-cu12.2.0"); - architectures.Add("cuda-cu11.7.1"); + if (LLMUnitySetup.FullLlamaLib) + { + architectures.Add("cuda-cu12.2.0-full"); + architectures.Add("cuda-cu11.7.1-full"); + } + else + { + architectures.Add("cuda-cu12.2.0"); + architectures.Add("cuda-cu11.7.1"); + } architectures.Add("hip"); architectures.Add("vulkan"); } diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 285eee78..428c2da5 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -47,6 +47,7 @@ public class ModelAttribute : PropertyAttribute {} public class ModelDownloadAttribute : ModelAttribute {} public class ModelDownloadAdvancedAttribute : ModelAdvancedAttribute {} public class ModelAdvancedAttribute : PropertyAttribute {} + public class ModelExtrasAttribute : PropertyAttribute {} public class ChatAttribute : PropertyAttribute {} public class ChatAdvancedAttribute : PropertyAttribute {} public class LLMUnityAttribute : PropertyAttribute {} @@ -87,8 +88,12 @@ public class LLMUnitySetup public static string Version = "v2.1.2"; /// LlamaLib version public static string LlamaLibVersion = "v1.1.8"; + /// LlamaLib release url + public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}"; /// LlamaLib url - public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; + public static string LlamaLibURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp.zip"; + /// LlamaLib extension url + public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp-full.zip"; /// LlamaLib path public static string libraryPath = GetAssetPath(Path.GetFileName(LlamaLibURL).Replace(".zip", "")); /// LLMnity store path @@ -109,25 +114,15 @@ public class LLMUnitySetup ("Qwen 2 0.5B (tiny, useful for mobile)", "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true", null), }; - /// Add callback function to call for error logs - public static void AddErrorCallBack(Callback callback) - { - errorCallbacks.Add(callback); - } - - /// Remove callback function added for error logs - public static void RemoveErrorCallBack(Callback callback) - { - errorCallbacks.Remove(callback); - } - - /// Remove all callback function added for error logs - public static void ClearErrorCallBacks() - { - errorCallbacks.Clear(); - } - /// \cond HIDE + [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; + static string DebugModeKey = "DebugMode"; + public static bool FullLlamaLib = false; + static string FullLlamaLibKey = "FullLlamaLib"; + static List> errorCallbacks = new List>(); + static readonly object lockObject = new object(); + static Dictionary androidExtractTasks = new Dictionary(); + public enum DebugModeType { All, @@ -135,10 +130,6 @@ public enum DebugModeType Error, None } - [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; - static List> errorCallbacks = new List>(); - static readonly object lockObject = new object(); - static Dictionary androidExtractTasks = new Dictionary(); public static void Log(string message) { @@ -159,10 +150,10 @@ public static void LogError(string message) foreach (Callback errorCallback in errorCallbacks) errorCallback(message); } - static string DebugModeKey = "DebugMode"; - static void LoadDebugMode() + static void LoadPlayerPrefs() { DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All); + FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1; } public static void SetDebugMode(DebugModeType newDebugMode) @@ -173,6 +164,18 @@ public static void SetDebugMode(DebugModeType newDebugMode) PlayerPrefs.Save(); } +#if UNITY_EDITOR + public static void SetFullLlamaLib(bool value) + { + if (FullLlamaLib == value) return; + FullLlamaLib = value; + PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0); + PlayerPrefs.Save(); + _ = DownloadLibrary(); + } + +#endif + public static string GetAssetPath(string relPath = "") { // Path to store llm server binaries and models @@ -184,15 +187,15 @@ public static string GetAssetPath(string relPath = "") [InitializeOnLoadMethod] static async Task InitializeOnLoad() { + LoadPlayerPrefs(); await DownloadLibrary(); - LoadDebugMode(); } #else [RuntimeInitializeOnLoadMethod(RuntimeInitializeLoadType.BeforeSceneLoad)] void InitializeOnLoad() { - LoadDebugMode(); + LoadPlayerPrefs(); } #endif @@ -307,40 +310,89 @@ public static bool IsSubPath(string childPath, string parentPath) [HideInInspector] public static float libraryProgress = 1; - private static async Task DownloadLibrary() + static void CreateEmptyFile(string path) + { + File.Create(path).Dispose(); + } + + static void ExtractInsideDirectory(string zipPath, string extractPath, bool overwrite = true) { - if (libraryProgress < 1) return; - libraryProgress = 0; - string libZip = Path.Combine(Application.temporaryCachePath, Path.GetFileName(LlamaLibURL)); - if (!Directory.Exists(libraryPath)) + using (ZipArchive archive = ZipFile.OpenRead(zipPath)) { - await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress); + foreach (ZipArchiveEntry entry in archive.Entries) + { + if (string.IsNullOrEmpty(entry.Name)) continue; + string destinationPath = Path.Combine(extractPath, entry.FullName); + Directory.CreateDirectory(Path.GetDirectoryName(destinationPath)); + entry.ExtractToFile(destinationPath, overwrite); + } + } + } + + static async Task DownloadAndExtractInsideDirectory(string url, string path, string setupDir) + { + string urlName = Path.GetFileName(url); + string setupFile = Path.Combine(setupDir, urlName + ".complete"); + if (File.Exists(setupFile)) return; + + string zipPath = Path.Combine(Application.temporaryCachePath, urlName); + await DownloadFile(url, zipPath, true, null, SetLibraryProgress); + + AssetDatabase.StartAssetEditing(); + ExtractInsideDirectory(zipPath, path); + CreateEmptyFile(setupFile); + AssetDatabase.StopAssetEditing(); + + File.Delete(zipPath); + } + + static async Task DownloadLibrary() + { + void DeleteFileAndMeta(string path) + { + if (File.Exists(path + ".meta")) File.Delete(path + ".meta"); + if (File.Exists(path)) File.Delete(path); + } + + try + { + string setupDir = Path.Combine(libraryPath, "setup"); + Directory.CreateDirectory(setupDir); + + string lockFile = Path.Combine(setupDir, "LLMUnitySetup.lock"); + if (File.Exists(lockFile)) return; + CreateEmptyFile(lockFile); + + libraryProgress = 0; + await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir); + AssetDatabase.StartAssetEditing(); - ZipFile.ExtractToDirectory(libZip, libraryPath); string androidDir = Path.Combine(libraryPath, "android"); if (Directory.Exists(androidDir)) { - string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android"); - Directory.CreateDirectory(androidPluginDir); - Directory.Move(androidDir, Path.Combine(androidPluginDir, Path.GetFileName(libraryPath))); - } - foreach (string librarySubPath in Directory.GetDirectories(libraryPath)) - { - if (Path.GetFileName(librarySubPath).StartsWith("android")) - { - string pluginPath = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(librarySubPath)); - Directory.Move(librarySubPath, pluginPath); - } + string androidPluginsDir = Path.Combine(Application.dataPath, "Plugins", "Android"); + Directory.CreateDirectory(androidPluginsDir); + string pluginDir = Path.Combine(androidPluginsDir, Path.GetFileName(libraryPath)); + if (Directory.Exists(pluginDir)) Directory.Delete(pluginDir, true); + Directory.Move(androidDir, pluginDir); + if (File.Exists(androidDir + ".meta")) File.Delete(androidDir + ".meta"); } AssetDatabase.StopAssetEditing(); - File.Delete(libZip); + + if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir); + + libraryProgress = 1; + DeleteFileAndMeta(lockFile); + } + catch (Exception e) + { + LogError(e.Message); } - libraryProgress = 1; } private static void SetLibraryProgress(float progress) { - libraryProgress = progress; + libraryProgress = Math.Min(0.99f, progress); } public static string AddAsset(string assetPath) @@ -364,6 +416,25 @@ public static string AddAsset(string assetPath) #endif /// \endcond + + /// Add callback function to call for error logs + public static void AddErrorCallBack(Callback callback) + { + errorCallbacks.Add(callback); + } + + /// Remove callback function added for error logs + public static void RemoveErrorCallBack(Callback callback) + { + errorCallbacks.Remove(callback); + } + + /// Remove all callback function added for error logs + public static void ClearErrorCallBacks() + { + errorCallbacks.Clear(); + } + public static int GetMaxFreqKHz(int cpuId) { string[] paths = new string[]