diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 9601579f..dcf5cfed 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -206,7 +206,7 @@ async Task createButtons() } else if (modelIndex > 1) { - if (modelLicenses[modelIndex] != null) Debug.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license."); + if (modelLicenses[modelIndex] != null) LLMUnitySetup.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license."); string filename = await LLMManager.DownloadModel(modelURLs[modelIndex], true, modelOptions[modelIndex]); SetModelIfNone(filename, false); UpdateModels(true); @@ -300,7 +300,7 @@ void OnEnable() } else { - isSelected = llmScript.lora.Split(" ").Contains(entry.filename); + isSelected = llmScript.loraManager.Contains(entry.filename); bool newSelected = EditorGUI.Toggle(selectRect, isSelected); if (newSelected && !isSelected) llmScript.AddLora(entry.filename); else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename); diff --git a/README.md b/README.md index 8dee1a57..d179f17a 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,10 @@ LLM for Unity is built on top of the awesome [llama.cpp](https://github.com/gger ## How to help - [⭐ Star](https://github.com/undreamai/LLMUnity) the repo, leave us a [review](https://assetstore.unity.com/packages/slug/273604) and spread the word about the project! -- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi! -- [Contribute](CONTRIBUTING.md) by submitting feature requests or bugs as issues or even submiting a PR and become a collaborator! +- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi. +- [Contribute](CONTRIBUTING.md) by submitting feature requests, bugs or even your own PR. +- [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/amakropoulos) this work to allow even cooler features! + ## Games using LLM for Unity - [Verbal Verdict](https://store.steampowered.com/app/2778780/Verbal_Verdict/) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 77b71398..f2c1a82b 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading; using System.Threading.Tasks; using UnityEditor; @@ -11,20 +10,6 @@ namespace LLMUnity { - /// \cond HIDE - public class LLMException : Exception - { - public int ErrorCode { get; private set; } - - public LLMException(string message, int errorCode) : base(message) - { - ErrorCode = errorCode; - } - } - - public class DestroyException : Exception {} - /// \endcond - [DefaultExecutionOrder(-1)] /// @ingroup llm /// @@ -74,6 +59,8 @@ public class LLM : MonoBehaviour /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). /// Models with .gguf format are allowed. [ModelAdvanced] public string lora = ""; + /// the weights of the LORA models being used. + [ModelAdvanced] public string loraWeights = ""; /// enable use of flash attention [ModelExtras] public bool flashAttention = false; @@ -86,8 +73,10 @@ public class LLM : MonoBehaviour Thread llmThread = null; List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); - List loraWeights = new List(); private readonly object startLock = new object(); + public LoraManager loraManager = new LoraManager(); + string loraPre = ""; + string loraWeightsPre = ""; /// \endcond @@ -96,6 +85,15 @@ public LLM() LLMManager.Register(this); } + void OnValidate() + { + if (lora != loraPre || loraWeights != loraWeightsPre) + { + loraManager.FromStrings(lora, loraWeights); + (loraPre, loraWeightsPre) = (lora, loraWeights); + } + } + /// /// The Unity Awake function that starts the LLM server. /// The server can be started asynchronously if the asynchronousStartup option is set. @@ -136,35 +134,55 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr return !modelSetupFailed; } - public string GetModelLoraPathRuntime(string path) + public static string GetLLMManagerAsset(string path) { - string assetPath = LLMManager.GetAssetPath(path); - if (!string.IsNullOrEmpty(assetPath)) return assetPath; - return path; +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path); +#endif + return GetLLMManagerAssetRuntime(path); } - public string GetModelLoraPath(string path, bool lora) + public static string GetLLMManagerAssetEditor(string path) { + // empty if (string.IsNullOrEmpty(path)) return path; + // LLMManager - return location the file will be stored in StreamingAssets ModelEntry modelEntry = LLMManager.Get(path); if (modelEntry != null) return modelEntry.filename; - - string modelType = lora ? "Lora" : "Model"; - string assetPath = LLMUnitySetup.GetAssetPath(path); + // StreamingAssets - return relative location within StreamingAssets + string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed + string basePath = LLMUnitySetup.GetAssetPath(); + if (File.Exists(assetPath)) + { + if (LLMUnitySetup.IsSubPath(assetPath, basePath)) return LLMUnitySetup.RelativePath(assetPath, basePath); + } + // full path if (!File.Exists(assetPath)) { - LLMUnitySetup.LogError($"The {modelType} file {path} was not found."); - return path; + LLMUnitySetup.LogError($"Model {path} was not found."); } - - if (!LLMUnitySetup.IsSubPath(assetPath, LLMUnitySetup.GetAssetPath())) + else { - string errorMessage = $"The {modelType} file {path} was loaded locally. If you want to include it in the build:"; - errorMessage += $"\n-Copy the {modelType} inside the StreamingAssets folder and use its relative path or"; - errorMessage += $"\n-Load the {modelType} with the LLMManager: `string filename=LLMManager.Load{modelType}(path); llm.Set{modelType}(filename)`"; + string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:"; + errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path"; + errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename"; LLMUnitySetup.LogWarning(errorMessage); } - return assetPath; + return path; + } + + public static string GetLLMManagerAssetRuntime(string path) + { + // empty + if (string.IsNullOrEmpty(path)) return path; + // LLMManager + string managerPath = LLMManager.GetAssetPath(path); + if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath; + // StreamingAssets + string assetPath = LLMUnitySetup.GetAssetPath(path); + if (File.Exists(assetPath)) return assetPath; + // give up + return path; } /// @@ -175,11 +193,11 @@ public string GetModelLoraPath(string path, bool lora) /// path to model to use (.gguf format) public void SetModel(string path) { - model = GetModelLoraPath(path, false); + model = GetLLMManagerAsset(path); if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); - if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model)); + if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model)); SetTemplate(modelEntry.chatTemplate); if (contextSize == 0 && modelEntry.contextLength > 32768) { @@ -197,10 +215,11 @@ public void SetModel(string path) /// Models supported are in .gguf format. /// /// path to LORA model to use (.gguf format) - public void SetLora(string path) + public void SetLora(string path, float weight = 1) { - lora = ""; - AddLora(path); + AssertNotStarted(); + loraManager.Clear(); + AddLora(path, weight); } /// @@ -209,15 +228,11 @@ public void SetLora(string path) /// Models supported are in .gguf format. /// /// path to LORA model to use (.gguf format) - public void AddLora(string path) + public void AddLora(string path, float weight = 1) { - string loraPath = GetModelLoraPath(path, true); - if (lora.Split(" ").Contains(loraPath)) return; - if (lora != "") lora += " "; - lora += loraPath; -#if UNITY_EDITOR - if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); -#endif + AssertNotStarted(); + loraManager.Add(path, weight); + UpdateLoras(); } /// @@ -227,15 +242,37 @@ public void AddLora(string path) /// path to LORA model to remove (.gguf format) public void RemoveLora(string path) { - string loraPath = GetModelLoraPath(path, true); - List loras = new List(lora.Split(" ")); - loras.Remove(loraPath); - lora = ""; - for (int i = 0; i < loras.Count; i++) - { - if (i > 0) lora += " "; - lora += loras[i]; - } + AssertNotStarted(); + loraManager.Remove(path); + UpdateLoras(); + } + + /// + /// Allows to remove all LORA models from the LLM. + /// + public void RemoveLoras() + { + AssertNotStarted(); + loraManager.Clear(); + UpdateLoras(); + } + + /// + /// Allows to change the scale (weight) of a LORA model in the LLM. + /// + /// path of LORA model to change (.gguf format) + /// scale of LORA + public void SetLoraScale(string path, float scale) + { + loraManager.SetWeight(path, scale); + UpdateLoras(); + if (started) ApplyLoras(); + } + + public void UpdateLoras() + { + (lora, loraWeights) = loraManager.ToStrings(); + (loraPre, loraWeightsPre) = (lora, loraWeights); #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif @@ -271,7 +308,7 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError("No model file provided!"); return null; } - string modelPath = GetModelLoraPathRuntime(model); + string modelPath = GetLLMManagerAssetRuntime(model); if (!File.Exists(modelPath)) { LLMUnitySetup.LogError($"File {modelPath} not found!"); @@ -281,7 +318,7 @@ protected virtual string GetLlamaccpArguments() foreach (string lora in lora.Trim().Split(" ")) { if (lora == "") continue; - string loraPath = GetModelLoraPathRuntime(lora); + string loraPath = GetLLMManagerAssetRuntime(lora); if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); @@ -387,8 +424,7 @@ private void StartService() llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); while (!llmlib.LLM_Started(LLMObject)) {} - loraWeights = new List(); - for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); + ApplyLoras(); started = true; } @@ -447,6 +483,16 @@ void AssertStarted() } } + void AssertNotStarted() + { + if (started) + { + string error = "This method can't be called when the LLM has started"; + LLMUnitySetup.LogError(error); + throw new Exception(error); + } + } + void CheckLLMStatus(bool log = true) { if (llmlib == null) { return; } @@ -537,46 +583,31 @@ public async Task Embeddings(string json) /// Sets the lora scale, only works after the LLM service has started /// /// switch result - public async Task SetLoraScale(string loraToScale, float scale) + public void ApplyLoras() { - AssertStarted(); - List loras = new List(lora.Split(" ")); - string loraToScalePath = GetModelLoraPath(loraToScale, true); - - int index = loras.IndexOf(loraToScale); - if (index == -1) index = loras.IndexOf(loraToScalePath); - if (index == -1) - { - LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM"); - return ""; - } - - loraWeights[index] = scale; LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList(); loraWeightRequest.loraWeights = new List(); - for (int i = 0; i < loraWeights.Count; i++) + float[] weights = loraManager.GetWeights(); + for (int i = 0; i < weights.Length; i++) { - loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = loraWeights[i] }); + loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = weights[i] }); } - ; string json = JsonUtility.ToJson(loraWeightRequest); int startIndex = json.IndexOf("["); int endIndex = json.LastIndexOf("]") + 1; json = json.Substring(startIndex, endIndex - startIndex); - LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => - { - llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper); - }; - return await LLMReply(callback, json); + IntPtr stringWrapper = llmlib.StringWrapper_Construct(); + llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper); + llmlib.StringWrapper_Delete(stringWrapper); } /// /// Gets a list of the lora adapters /// /// list of lara adapters - public async Task ListLora() + public async Task ListLoras() { AssertStarted(); LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) => diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 3c7cbc24..34301fac 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -24,7 +24,7 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur filename = Path.GetFileName(path); this.label = label == null ? filename : label; this.lora = lora; - this.path = Path.GetFullPath(path).Replace('\\', '/'); + this.path = LLMUnitySetup.GetFullPath(path); this.url = url; includeInBuild = true; chatTemplate = null; @@ -162,7 +162,7 @@ public static void SetTemplate(ModelEntry entry, string chatTemplate) public static ModelEntry Get(string path) { string filename = Path.GetFileName(path); - string fullPath = Path.GetFullPath(path).Replace('\\', '/'); + string fullPath = LLMUnitySetup.GetFullPath(path); foreach (ModelEntry entry in modelEntries) { if (entry.filename == filename || entry.path == fullPath) return entry; @@ -387,7 +387,7 @@ public static void Remove(ModelEntry entry) foreach (LLM llm in llms) { if (!entry.lora && llm.model == entry.filename) llm.model = ""; - else if (entry.lora && llm.lora == entry.filename) llm.lora = ""; + else if (entry.lora) llm.RemoveLora(entry.filename); } } diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 428c2da5..c86c1f95 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -299,18 +299,36 @@ public static async Task AndroidExtractAsset(string path, bool overwrite = false await AndroidExtractFile(Path.GetFileName(path), overwrite); } + public static string GetFullPath(string path) + { + return Path.GetFullPath(path).Replace('\\', '/'); + } + public static bool IsSubPath(string childPath, string parentPath) { - string fullParentPath = Path.GetFullPath(parentPath).Replace('\\', '/'); - string fullChildPath = Path.GetFullPath(childPath).Replace('\\', '/'); - return fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase); + return GetFullPath(childPath).StartsWith(GetFullPath(parentPath), StringComparison.OrdinalIgnoreCase); + } + + public static string RelativePath(string fullPath, string basePath) + { + // Get the full paths and replace backslashes with forward slashes (or vice versa) + string fullParentPath = GetFullPath(basePath).TrimEnd('/'); + string fullChildPath = GetFullPath(fullPath); + + string relativePath = fullChildPath; + if (fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase)) + { + relativePath = fullChildPath.Substring(fullParentPath.Length); + while (relativePath.StartsWith("/")) relativePath = relativePath.Substring(1); + } + return relativePath; } #if UNITY_EDITOR [HideInInspector] public static float libraryProgress = 1; - static void CreateEmptyFile(string path) + public static void CreateEmptyFile(string path) { File.Create(path).Dispose(); } diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs new file mode 100644 index 00000000..5e99950d --- /dev/null +++ b/Runtime/LLMUtils.cs @@ -0,0 +1,129 @@ +/// @file +/// @brief File implementing LLM helper code. +using System; +using System.Collections.Generic; + +namespace LLMUnity +{ + /// \cond HIDE + public class LLMException : Exception + { + public int ErrorCode { get; private set; } + + public LLMException(string message, int errorCode) : base(message) + { + ErrorCode = errorCode; + } + } + + public class DestroyException : Exception {} + + public class LoraAsset + { + public string assetPath; + public float weight; + + public LoraAsset(string path, float weight = 1) + { + assetPath = LLM.GetLLMManagerAsset(path); + this.weight = weight; + } + + public override bool Equals(object obj) + { + string RuntimePath(string path) {return LLMUnitySetup.GetFullPath(LLM.GetLLMManagerAssetRuntime(path));} + + if (obj == null || obj.GetType() != this.GetType()) return false; + LoraAsset other = (LoraAsset)obj; + return assetPath == other.assetPath || RuntimePath(assetPath) == RuntimePath(other.assetPath); + } + + public override int GetHashCode() + { + return (assetPath + "," + weight.ToString()).GetHashCode(); + } + } + + public class LoraManager + { + List loras = new List(); + + public void Clear() + { + loras.Clear(); + } + + public bool Contains(string path) + { + LoraAsset lora = new LoraAsset(path); + return loras.Contains(lora); + } + + public void Add(string path, float weight = 1) + { + LoraAsset lora = new LoraAsset(path, weight); + if (loras.Contains(lora)) return; + loras.Add(lora); + } + + public void Remove(string path) + { + loras.Remove(new LoraAsset(path)); + } + + public void SetWeight(string path, float weight) + { + LoraAsset lora = new LoraAsset(path); + int index = loras.IndexOf(lora); + if (index == -1) + { + LLMUnitySetup.LogError($"LoRA {path} not loaded with the LLM"); + return; + } + loras[index].weight = weight; + } + + public void FromStrings(string loraString, string loraWeightsString) + { + try + { + List loraStringArr = new List(loraString.Split(" ")); + List loraWeightsStringArr = new List(loraWeightsString.Split(" ")); + if (loraStringArr.Count != loraWeightsStringArr.Count) throw new Exception($"LoRAs number ({loraString}) doesn't match the number of weights ({loraWeightsString})"); + + List lorasNew = new List(); + for (int i = 0; i < loraStringArr.Count; i++) lorasNew.Add(new LoraAsset(loraStringArr[i], float.Parse(loraWeightsStringArr[i]))); + loras = lorasNew; + } + catch (Exception e) + { + LLMUnitySetup.LogError($"Loras not set: {e.Message}"); + } + } + + public (string, string) ToStrings() + { + string loraString = ""; + string loraWeightsString = ""; + for (int i = 0; i < loras.Count; i++) + { + if (i > 0) + { + loraString += " "; + loraWeightsString += " "; + } + loraString += loras[i].assetPath; + loraWeightsString += loras[i].weight; + } + return (loraString, loraWeightsString); + } + + public float[] GetWeights() + { + float[] weights = new float[loras.Count]; + for (int i = 0; i < loras.Count; i++) weights[i] = loras[i].weight; + return weights; + } + } + /// \endcond +} diff --git a/Runtime/LLMUtils.cs.meta b/Runtime/LLMUtils.cs.meta new file mode 100644 index 00000000..93974a4f --- /dev/null +++ b/Runtime/LLMUtils.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2ae6a2ce57e8af0fc876d0c380ed8a2f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index affca075..ef7f28f6 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -5,55 +5,178 @@ using System.Collections.Generic; using System; using System.Collections; -using UnityEngine.TestTools; using System.IO; -using NUnit.Framework.Internal; +using UnityEngine.TestTools; namespace LLMUnityTests { + public class TestLLMLoras + { + [Test] + public void TestLLMLorasAssign() + { + GameObject gameObject = new GameObject(); + gameObject.SetActive(false); + LLM llm = gameObject.AddComponent(); + + string lora1 = "/tmp/lala"; + string lora2Rel = "test/lala"; + string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel); + LLMUnitySetup.CreateEmptyFile(lora1); + Directory.CreateDirectory(Path.GetDirectoryName(lora2)); + LLMUnitySetup.CreateEmptyFile(lora2); + + llm.AddLora(lora1); + llm.AddLora(lora2); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "1 1"); + + llm.RemoveLoras(); + Assert.AreEqual(llm.lora, ""); + Assert.AreEqual(llm.loraWeights, ""); + + llm.AddLora(lora1, 0.8f); + llm.AddLora(lora2Rel, 0.9f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.9"); + + llm.SetLoraScale(lora2Rel, 0.7f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.7"); + + llm.RemoveLora(lora2Rel); + Assert.AreEqual(llm.lora, lora1); + Assert.AreEqual(llm.loraWeights, "0.8"); + + llm.AddLora(lora2Rel); + llm.SetLoraScale(lora2Rel, 0.5f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.5"); + + llm.SetLoraScale(lora2, 0.1f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.1"); + + File.Delete(lora1); + File.Delete(lora2); + } + } + public class TestLLM { + protected static string modelUrl = "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true"; + protected string modelNameLLManager; + protected GameObject gameObject; protected LLM llm; protected LLMCharacter llmCharacter; - protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true"; - protected static string filename = Path.GetFileName(modelUrl).Split("?")[0]; Exception error = null; - string prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."; + protected string prompt; + protected string query; + protected string reply1; + protected string reply2; + protected int tokens1; + protected int tokens2; + public TestLLM() { - LLMUnitySetup.SetDebugMode(LLMUnitySetup.DebugModeType.All); Task task = Init(); task.Wait(); } public virtual async Task Init() { + SetParameters(); + await DownloadModels(); gameObject = new GameObject(); gameObject.SetActive(false); - await SetLLM(); - SetLLMCharacter(); + llm = CreateLLM(); + llmCharacter = CreateLLMCharacter(); gameObject.SetActive(true); } - public async Task EmptyTask() + public virtual void SetParameters() { - await Task.Delay(1); + prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."; + query = "How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming."; + reply1 = "To increase your meme production/output, you can try using more modern tools and techniques. For instance,"; + reply2 = "To increase your meme production/output, you can try the following strategies:\n\n1. Use a meme generator"; + tokens1 = 32; + tokens2 = 9; } - public virtual async Task SetLLM() + public virtual async Task DownloadModels() { - llm = gameObject.AddComponent(); - string filename = await LLMManager.DownloadModel(modelUrl); - llm.SetModel(filename); + modelNameLLManager = await LLMManager.DownloadModel(modelUrl); + } + + [Test] + public void TestGetLLMManagerAssetRuntime() + { + string path = ""; + string managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, path); + + path = "/tmp/lala"; + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, path); + File.Delete(path); + + path = modelNameLLManager; + managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, LLMManager.GetAssetPath(path)); + + path = LLMUnitySetup.GetAssetPath("lala"); + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, path); + File.Delete(path); + } + + [Test] + public void TestGetLLMManagerAssetEditor() + { + string path = ""; + string managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, path); + + path = modelNameLLManager; + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, modelNameLLManager); + + path = LLMManager.Get(modelNameLLManager).path; + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, modelNameLLManager); + + string filename = "lala"; + path = LLMUnitySetup.GetAssetPath(filename); + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetEditor(filename); + Assert.AreEqual(managerPath, filename); + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, filename); + File.Delete(path); + + path = "/tmp/lala"; + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, path); + File.Delete(path); + } + + public virtual LLM CreateLLM() + { + LLM llm = gameObject.AddComponent(); + llm.SetModel(modelNameLLManager); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); + return llm; } - public virtual void SetLLMCharacter() + public virtual LLMCharacter CreateLLMCharacter() { - llmCharacter = gameObject.AddComponent(); + LLMCharacter llmCharacter = gameObject.AddComponent(); llmCharacter.llm = llm; llmCharacter.playerName = "Instruction"; llmCharacter.AIName = "Response"; @@ -62,31 +185,43 @@ public virtual void SetLLMCharacter() llmCharacter.seed = 0; llmCharacter.stream = false; llmCharacter.numPredict = 20; + return llmCharacter; } - public virtual async Task RunTests() + [UnityTest] + public IEnumerator RunTests() + { + Task task = RunTestsTask(); + while (!task.IsCompleted) yield return null; + if (error != null) + { + Debug.LogError(error.ToString()); + throw (error); + } + OnDestroy(); + } + + public async Task RunTestsTask() { error = null; try { - llm.Awake(); - llmCharacter.Awake(); await llmCharacter.Tokenize("I", TestTokens); await llmCharacter.Warmup(); - TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 1); + TestInitParameters(tokens1, 1); TestWarmup(); - await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat); + await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); TestPostChat(3); llmCharacter.SetPrompt(llmCharacter.prompt); llmCharacter.AIName = "False response"; - await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat2); + await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); TestPostChat(3); await llmCharacter.Chat("bye!"); TestPostChat(5); prompt = "How are you?"; llmCharacter.SetPrompt(prompt); await llmCharacter.Chat("hi"); - TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3); + TestInitParameters(tokens2, 3); List embeddings = await llmCharacter.Embeddings("hi how are you?"); TestEmbeddings(embeddings); llm.OnDestroy(); @@ -97,19 +232,6 @@ public virtual async Task RunTests() } } - [UnityTest] - public IEnumerator RunTestsWait() - { - Task task = RunTests(); - while (!task.IsCompleted) yield return null; - if (error != null) - { - Debug.LogError(error.ToString()); - throw (error); - } - OnDestroy(); - } - public void TestInitParameters(int nkeep, int chats) { Assert.That(llmCharacter.nKeep == nkeep); @@ -119,7 +241,7 @@ public void TestInitParameters(int nkeep, int chats) public void TestTokens(List tokens) { - Assert.AreEqual(tokens, new List {306}); + Assert.AreEqual(tokens, new List {40}); } public void TestWarmup() @@ -127,16 +249,9 @@ public void TestWarmup() Assert.That(llmCharacter.chat.Count == 1); } - public void TestChat(string reply) - { - string AIReply = "One way to increase your meme production/output is by creating a more complex and customized"; - Assert.That(reply.Trim() == AIReply); - } - - public void TestChat2(string reply) + public void TestChat(string reply, string replyGT) { - string AIReply = "One possible solution is to use a more advanced natural language processing library like NLTK or sp"; - Assert.That(reply.Trim() == AIReply); + Assert.That(reply.Trim() == replyGT); } public void TestPostChat(int num) @@ -146,60 +261,127 @@ public void TestPostChat(int num) public void TestEmbeddings(List embeddings) { - Assert.That(embeddings.Count == 1024); + Assert.That(embeddings.Count == 896); } - public virtual void OnDestroy() - { - LLMManager.Remove(filename); - } + public virtual void OnDestroy() {} } public class TestLLM_LLMManager_Load : TestLLM { - public override Task SetLLM() + public override LLM CreateLLM() { - llm = gameObject.AddComponent(); + LLM llm = gameObject.AddComponent(); + string filename = Path.GetFileName(modelUrl).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); filename = LLMManager.LoadModel(sourcePath); llm.SetModel(filename); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); - return Task.CompletedTask; + return llm; } } public class TestLLM_StreamingAssets_Load : TestLLM { - public override Task SetLLM() + string loadPath; + + public override LLM CreateLLM() { - llm = gameObject.AddComponent(); + LLM llm = gameObject.AddComponent(); + string filename = Path.GetFileName(modelUrl).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - string targetPath = LLMUnitySetup.GetAssetPath(filename); - if (!File.Exists(targetPath)) File.Copy(sourcePath, targetPath); - llm.SetModel(filename); + loadPath = LLMUnitySetup.GetAssetPath(filename); + if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath); + llm.SetModel(loadPath); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); - return Task.CompletedTask; + return llm; } public override void OnDestroy() { - string targetPath = LLMUnitySetup.GetAssetPath(filename); - if (!File.Exists(targetPath)) File.Delete(targetPath); + if (!File.Exists(loadPath)) File.Delete(loadPath); } } public class TestLLM_SetModel_Warning : TestLLM { - public override Task SetLLM() + public override LLM CreateLLM() { - llm = gameObject.AddComponent(); - string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - llm.SetModel(sourcePath); + LLM llm = gameObject.AddComponent(); + string filename = Path.GetFileName(modelUrl).Split("?")[0]; + string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); + llm.SetModel(loadPath); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); - return Task.CompletedTask; + return llm; + } + } + + public class TestLLM_NoLora : TestLLM + { + public override void SetParameters() + { + prompt = ""; + query = "кто ты?"; + reply1 = "Я - искусственный интеллект, который помогаю вам с информацией и задачами"; + reply2 = "I'm sorry, but I didn't understand your request. Could you please provide more information or clarify"; + tokens1 = 5; + tokens2 = 9; + } + } + + public class TestLLM_Lora : TestLLM + { + string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true"; + string loraNameLLManager; + + public override async Task DownloadModels() + { + await base.DownloadModels(); + loraNameLLManager = await LLMManager.DownloadLora(loraUrl); + } + + public override LLM CreateLLM() + { + LLM llm = base.CreateLLM(); + llm.AddLora(loraNameLLManager); + return llm; + } + + public override void SetParameters() + { + prompt = ""; + query = "кто ты?"; + reply1 = "Я - искусственный интеллект, созданный для помощи и общения с людьми"; + reply2 = "Идиот"; + tokens1 = 5; + tokens2 = 9; + } + + [Test] + public void TestModelPaths() + { + Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(modelUrl).Split("?")[0])); + Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0])); + } + } + + + public class TestLLM_Double : TestLLM + { + LLM llm1; + LLMCharacter lLMCharacter1; + + public override async Task Init() + { + SetParameters(); + await DownloadModels(); + gameObject = new GameObject(); + gameObject.SetActive(false); + llm = CreateLLM(); + llmCharacter = CreateLLMCharacter(); + llm1 = CreateLLM(); + lLMCharacter1 = CreateLLMCharacter(); + gameObject.SetActive(true); } } }