From 91f74ba1c290ad302e002b71fc9cb7a5c97cc433 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 16 Aug 2024 17:13:59 +0300 Subject: [PATCH 01/67] fix set template for remote setup --- Runtime/LLM.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 2f4c57cf..315c5cd9 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -314,7 +314,7 @@ private void InitServer(string arguments) if (debug) SetupLogging(); LLMObject = llmlib.LLM_Construct(arguments); if (remote) llmlib.LLM_StartServer(LLMObject); - SetTemplate(chatTemplate, false); + llmlib.LLM_SetTemplate(LLMObject, chatTemplate); CheckLLMStatus(false); } From 897fbe63fb1a1be8141e92ed8e7da3ebaee32cbd Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Fri, 16 Aug 2024 14:14:44 +0000 Subject: [PATCH 02/67] update changelogs --- CHANGELOG.md | 6 ++++++ CHANGELOG.release.md | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31631aac..bba21a49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## v2.1.2 +#### 🐛 Fixes + +- Fix set template for remote setup (PR: #208) + + ## v2.1.1 #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 379083c3..50bd4944 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,3 +1,4 @@ ### 🐛 Fixes -- Resolve build directory creation \ No newline at end of file +- Fix set template for remote setup (PR: #208) + From a9ca159d48c3b5f4e49d98b169f5d0fc365a488e Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Fri, 16 Aug 2024 14:14:59 +0000 Subject: [PATCH 03/67] update VERSION --- .github/doxygen/Doxyfile | 2 +- Runtime/LLMUnitySetup.cs | 2 +- VERSION | 2 +- package.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/doxygen/Doxyfile b/.github/doxygen/Doxyfile index f765ee81..1bbe3027 100644 --- a/.github/doxygen/Doxyfile +++ b/.github/doxygen/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "LLM for Unity" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v2.1.1 +PROJECT_NUMBER = v2.1.2 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 8150b22f..833f0ce7 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -84,7 +84,7 @@ public class LLMUnitySetup { // DON'T CHANGE! the version is autocompleted with a GitHub action /// LLM for Unity version - public static string Version = "v2.1.1"; + public static string Version = "v2.1.2"; /// LlamaLib version public static string LlamaLibVersion = "v1.1.6"; /// LlamaLib url diff --git a/VERSION b/VERSION index 826e1424..59696826 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.1.1 +v2.1.2 diff --git a/package.json b/package.json index 795878ac..a22c93a3 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ai.undream.llm", - "version": "2.1.1", + "version": "2.1.2", "displayName": "LLM for Unity", "description": "LLM for Unity allows to run and distribute Large Language Models (LLMs) in the Unity engine.", "unity": "2022.3", From df32009c3dc83ffc087ef32f148139ee6fc9657c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Sun, 18 Aug 2024 18:05:16 +0300 Subject: [PATCH 04/67] bump LlamaLib to v1.1.8 --- Runtime/LLMUnitySetup.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 833f0ce7..4ee87a66 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -86,7 +86,7 @@ public class LLMUnitySetup /// LLM for Unity version public static string Version = "v2.1.2"; /// LlamaLib version - public static string LlamaLibVersion = "v1.1.6"; + public static string LlamaLibVersion = "v1.1.8"; /// LlamaLib url public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; /// LlamaLib path From afe275244b35a45b16036685e9fb4ab4800e24da Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:54:01 +0300 Subject: [PATCH 05/67] add embedding functionality --- Runtime/LLMCharacter.cs | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 9f37721a..c76239aa 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -430,16 +430,22 @@ protected List TokenizeContent(TokenizeResult result) return result.tokens; } - protected string SlotContent(SlotResult result) + protected string DetokenizeContent(TokenizeRequest result) { - // get the tokens from a tokenize result received from the endpoint - return result.filename; + // get content from a chat result received from the endpoint + return result.content; } - protected string DetokenizeContent(TokenizeRequest result) + protected List EmbeddingsContent(EmbeddingsResult result) { // get content from a chat result received from the endpoint - return result.content; + return result.embedding; + } + + protected string SlotContent(SlotResult result) + { + // get the tokens from a tokenize result received from the endpoint + return result.filename; } /// @@ -572,6 +578,21 @@ public async Task Detokenize(List tokens, Callback callback return await PostRequest(json, "detokenize", DetokenizeContent, callback); } + /// + /// Computes the embeddings of the provided input. + /// + /// input to compute the embeddings for + /// callback function called with the result string + /// the computed embeddings + public async Task> Embeddings(string query, Callback> callback = null) + { + // handle the tokenization of a message by the user + TokenizeRequest tokenizeRequest = new TokenizeRequest(); + tokenizeRequest.content = query; + string json = JsonUtility.ToJson(tokenizeRequest); + return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); + } + private async Task Slot(string filepath, string action) { SlotRequest slotRequest = new SlotRequest(); @@ -682,6 +703,9 @@ protected async Task PostRequestLocal(string json, string endpoin case "detokenize": callResult = await llm.Detokenize(json); break; + case "embeddings": + callResult = await llm.Embeddings(json); + break; case "slots": callResult = await llm.Slot(json); break; From 5ee7930375b644ee735371861459ed4428719252 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:54:30 +0300 Subject: [PATCH 06/67] import embeddings and lora adapters --- Runtime/LLMLib.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index c8d49ccd..fbaa050d 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -298,6 +298,9 @@ public LLMLib(string arch) LLM_SetTemplate = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_SetTemplate"); LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize"); LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize"); + LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings"); + LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight"); + LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List"); LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion"); LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot"); LLM_Cancel = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Cancel"); @@ -452,6 +455,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate); public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper); public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot); @@ -474,6 +480,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public LLM_TokenizeDelegate LLM_Tokenize; public LLM_DetokenizeDelegate LLM_Detokenize; public LLM_CompletionDelegate LLM_Completion; + public LLM_EmbeddingsDelegate LLM_Embeddings; + public LLM_LoraWeightDelegate LLM_Lora_Weight; + public LLM_LoraListDelegate LLM_LoraList; public LLM_SlotDelegate LLM_Slot; public LLM_CancelDelegate LLM_Cancel; public LLM_StatusDelegate LLM_Status; From 4a80c560a35e2397029a1b4a6554277329288e6c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:54:53 +0300 Subject: [PATCH 07/67] add structs for embeddings and lora adapters --- Runtime/LLMInterface.cs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/Runtime/LLMInterface.cs b/Runtime/LLMInterface.cs index f5086595..36a11c37 100644 --- a/Runtime/LLMInterface.cs +++ b/Runtime/LLMInterface.cs @@ -93,6 +93,25 @@ public struct TokenizeResult public List tokens; } + [Serializable] + public struct EmbeddingsResult + { + public List embedding; + } + + [Serializable] + public struct LoraWeightRequest + { + public int id; + public float scale; + } + + [Serializable] + public struct LoraWeightRequestList + { + public List loraWeights; + } + [Serializable] public struct TemplateResult { From 5e8d05c8e071e8199c268de4e008b447ee33e76e Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:55:10 +0300 Subject: [PATCH 08/67] allow multiple loras --- Editor/LLMEditor.cs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index bebaf578..ed50046d 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -142,7 +142,7 @@ void SetModelIfNone(string filename, bool lora) LLM llmScript = (LLM)target; int num = LLMManager.Num(lora); if (!lora && llmScript.model == "" && num == 1) llmScript.SetModel(filename); - if (lora && llmScript.lora == "" && num == 1) llmScript.SetLora(filename); + if (lora) llmScript.AddLora(filename); } async Task createCustomURLField() @@ -237,7 +237,7 @@ async Task createButtons() { EditorApplication.delayCall += () => { - string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" }); + string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf lora file", "", new string[] { "Model Files", "gguf" }); if (!string.IsNullOrEmpty(path)) { string filename = LLMManager.LoadLora(path, true); @@ -299,10 +299,10 @@ void OnEnable() } else { - isSelected = llmScript.lora == entry.filename; - bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); - if (newSelected && !isSelected) llmScript.SetLora(entry.filename); - else if (!newSelected && isSelected) llmScript.SetLora(""); + isSelected = llmScript.lora.Split(" ").Contains(entry.filename); + bool newSelected = EditorGUI.Toggle(selectRect, isSelected); + if (newSelected && !isSelected) llmScript.AddLora(entry.filename); + else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename); } DrawCopyableLabel(nameRect, entry.label, entry.filename); @@ -347,6 +347,11 @@ void OnEnable() if (GUI.Button(actionRect, trashIcon)) { + if (isSelected) + { + if (!entry.lora) llmScript.SetModel(""); + else llmScript.RemoveLora(entry.filename); + } LLMManager.Remove(entry); UpdateModels(true); } From dfb0e6a3f5286dc91ecbea634c078d51e3ad942c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:56:05 +0300 Subject: [PATCH 09/67] implement callback functionaloty for embeddings and lora adapters --- Runtime/LLM.cs | 138 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 129 insertions(+), 9 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 315c5cd9..fad6b5d6 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; using UnityEditor; @@ -68,7 +69,7 @@ public class LLM : MonoBehaviour [ModelAdvanced] public string model = ""; /// Chat template used for the model [ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate; - /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder). + /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). /// Models with .bin format are allowed. [ModelAdvanced] public string lora = ""; @@ -81,6 +82,7 @@ public class LLM : MonoBehaviour Thread llmThread = null; List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); + List loraWeights = new List(); /// \endcond @@ -135,7 +137,7 @@ public string GetModelLoraPath(string path) return path; } - public string SetModelLoraPath(string path, bool lora) + public string GetModelLoraPath(string path, bool lora) { if (string.IsNullOrEmpty(path)) return path; ModelEntry modelEntry = LLMManager.Get(path); @@ -167,7 +169,7 @@ public string SetModelLoraPath(string path, bool lora) /// path to model to use (.gguf format) public void SetModel(string path) { - model = SetModelLoraPath(path, false); + model = GetModelLoraPath(path, false); if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); @@ -187,7 +189,43 @@ public void SetModel(string path) /// path to LORA model to use (.bin format) public void SetLora(string path) { - lora = SetModelLoraPath(path, true); + lora = ""; + AddLora(path); + } + + /// + /// Allows to add a LORA model to use in the LLM. + /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. + /// Models supported are in .bin format. + /// + /// path to LORA model to use (.bin format) + public void AddLora(string path) + { + string loraPath = GetModelLoraPath(path, true); + if (lora.Split(" ").Contains(loraPath)) return; + if (lora != "") lora += " "; + lora += loraPath; +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); +#endif + } + + /// + /// Allows to remove a LORA model from the LLM. + /// Models supported are in .bin format. + /// + /// path to LORA model to remove (.bin format) + public void RemoveLora(string path) + { + string loraPath = GetModelLoraPath(path, true); + List loras = new List(lora.Split(" ")); + loras.Remove(loraPath); + lora = ""; + for (int i = 0; i < loras.Count; i++) + { + if (i > 0) lora += " "; + lora += loras[i]; + } #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif @@ -229,15 +267,16 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError($"File {modelPath} not found!"); return null; } - string loraPath = ""; - if (lora != "") + string loraArgument = ""; + foreach (string lora in lora.Split(" ")) { - loraPath = GetModelLoraPath(lora); + string loraPath = GetModelLoraPath(lora); if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); return null; } + loraArgument += $" --lora \"{loraPath}\""; } int numThreadsToUse = numThreads; @@ -247,7 +286,7 @@ protected virtual string GetLlamaccpArguments() string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}"; if (remote) arguments += $" --port {port} --host 0.0.0.0"; if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; - if (loraPath != "") arguments += $" --lora \"{loraPath}\""; + arguments += loraArgument; arguments += $" -ngl {numGPULayers}"; return arguments; } @@ -323,6 +362,8 @@ private void StartService() llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); while (!llmlib.LLM_Started(LLMObject)) {} + loraWeights = new List(); + for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); started = true; } @@ -345,7 +386,7 @@ protected int GetNumClients() /// \cond HIDE public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper); - public delegate void LLMSimpleCallback(IntPtr LLMObject, string json_data); + public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper); public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper); /// \endcond @@ -400,6 +441,17 @@ void CheckLLMStatus(bool log = true) } } + async Task LLMNoInputReply(LLMNoInputReplyCallback callback) + { + AssertStarted(); + IntPtr stringWrapper = llmlib.StringWrapper_Construct(); + await Task.Run(() => callback(LLMObject, stringWrapper)); + string result = llmlib?.GetStringWrapperResult(stringWrapper); + llmlib?.StringWrapper_Delete(stringWrapper); + CheckLLMStatus(); + return result; + } + async Task LLMReply(LLMReplyCallback callback, string json) { AssertStarted(); @@ -441,6 +493,74 @@ public async Task Detokenize(string json) return await LLMReply(callback, json); } + /// + /// Computes the embeddings of the provided query. + /// + /// json request containing the query + /// embeddings result + public async Task Embeddings(string json) + { + AssertStarted(); + LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => + { + llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper); + }; + return await LLMReply(callback, json); + } + + /// + /// Sets the lora scale, only works after the LLM service has started + /// + /// switch result + public async Task SetLoraScale(string loraToScale, float scale) + { + AssertStarted(); + List loras = new List(lora.Split(" ")); + string loraToScalePath = GetModelLoraPath(loraToScale, true); + + int index = loras.IndexOf(loraToScale); + if (index == -1) index = loras.IndexOf(loraToScalePath); + if (index == -1) + { + LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM"); + return ""; + } + + loraWeights[index] = scale; + LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList(); + loraWeightRequest.loraWeights = new List(); + for (int i = 0; i < loraWeights.Count; i++) + { + loraWeightRequest.loraWeights.Add(new LoraWeightRequest() {id = i, scale = loraWeights[i]}); + } + ; + + string json = JsonUtility.ToJson(loraWeightRequest); + int startIndex = json.IndexOf("["); + int endIndex = json.LastIndexOf("]") + 1; + json = json.Substring(startIndex, endIndex - startIndex); + + LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => + { + llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper); + }; + return await LLMReply(callback, json); + } + + /// + /// Gets a list of the lora adapters + /// + /// list of lara adapters + public async Task ListLora() + { + AssertStarted(); + LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) => + { + llmlib.LLM_LoraList(LLMObject, strWrapper); + }; + return await LLMNoInputReply(callback); + } + /// /// Allows to save / restore the state of a slot /// From 3dc59af6b93fe57425ac588a75f8b429a5c16ab0 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 16:48:13 +0300 Subject: [PATCH 10/67] fix for lora splitting --- Runtime/LLM.cs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index fad6b5d6..a691cf29 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -130,7 +130,7 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr return !modelSetupFailed; } - public string GetModelLoraPath(string path) + public string GetModelLoraPathRuntime(string path) { string assetPath = LLMManager.GetAssetPath(path); if (!string.IsNullOrEmpty(assetPath)) return assetPath; @@ -173,7 +173,7 @@ public void SetModel(string path) if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); - string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPath(model)); + string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model)); SetTemplate(template); } #if UNITY_EDITOR @@ -261,16 +261,17 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError("No model file provided!"); return null; } - string modelPath = GetModelLoraPath(model); + string modelPath = GetModelLoraPathRuntime(model); if (!File.Exists(modelPath)) { LLMUnitySetup.LogError($"File {modelPath} not found!"); return null; } string loraArgument = ""; - foreach (string lora in lora.Split(" ")) + foreach (string lora in lora.Trim().Split(" ")) { - string loraPath = GetModelLoraPath(lora); + if (lora == "") continue; + string loraPath = GetModelLoraPathRuntime(lora); if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); From 7c85dc4fc62ecd5ca224e54642e2ca45a29bd9af Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 16:49:01 +0300 Subject: [PATCH 11/67] update to latest LlamaLib, add embedding test --- Tests/Runtime/TestLLM.cs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index b5afc880..affca075 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -7,6 +7,7 @@ using System.Collections; using UnityEngine.TestTools; using System.IO; +using NUnit.Framework.Internal; namespace LLMUnityTests { @@ -86,6 +87,8 @@ public virtual async Task RunTests() llmCharacter.SetPrompt(prompt); await llmCharacter.Chat("hi"); TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3); + List embeddings = await llmCharacter.Embeddings("hi how are you?"); + TestEmbeddings(embeddings); llm.OnDestroy(); } catch (Exception e) @@ -126,7 +129,7 @@ public void TestWarmup() public void TestChat(string reply) { - string AIReply = "To increase your meme production/output, you can consider the following:\n1. Use"; + string AIReply = "One way to increase your meme production/output is by creating a more complex and customized"; Assert.That(reply.Trim() == AIReply); } @@ -141,6 +144,11 @@ public void TestPostChat(int num) Assert.That(llmCharacter.chat.Count == num); } + public void TestEmbeddings(List embeddings) + { + Assert.That(embeddings.Count == 1024); + } + public virtual void OnDestroy() { LLMManager.Remove(filename); From 5c4c04d67d0c84b769d1b430e284ef6eaa8cb174 Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Mon, 19 Aug 2024 13:51:12 +0000 Subject: [PATCH 12/67] update changelogs --- CHANGELOG.md | 4 ++++ CHANGELOG.release.md | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bba21a49..018166f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,8 @@ ## v2.1.2 +#### 🚀 Features + +- Implement embedding and lora adapter functionality (PR: #210) + #### 🐛 Fixes - Fix set template for remote setup (PR: #208) diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 50bd4944..aff199ee 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,3 +1,7 @@ +### 🚀 Features + +- Implement embedding and lora adapter functionality (PR: #210) + ### 🐛 Fixes - Fix set template for remote setup (PR: #208) From 3ab6cb417902c5037985e9c73f27a294ef680657 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 16:53:04 +0300 Subject: [PATCH 13/67] update changelogs --- CHANGELOG.md | 1 + CHANGELOG.release.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 018166f8..d93d512c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## v2.1.2 #### 🚀 Features +- Update to latest llama.cpp (b3600) (PR: #210) - Implement embedding and lora adapter functionality (PR: #210) #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index aff199ee..388aeb28 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,5 +1,6 @@ ### 🚀 Features +- Update to latest llama.cpp (b3600) (PR: #210) - Implement embedding and lora adapter functionality (PR: #210) ### 🐛 Fixes From b88e13b1b3080eecc9d3a57cc19564b557fd764f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 18:56:35 +0300 Subject: [PATCH 14/67] Read context length and warn if it is very large --- Runtime/LLM.cs | 9 +++++++-- Runtime/LLMChatTemplates.cs | 8 ++++++-- Runtime/LLMGGUF.cs | 25 ++++++++++++++++++++++--- Runtime/LLMManager.cs | 12 ++++++++++-- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index a691cf29..41001fb8 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -173,8 +173,13 @@ public void SetModel(string path) if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); - string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model)); - SetTemplate(template); + if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model)); + SetTemplate(modelEntry.chatTemplate); + Debug.Log(modelEntry.contextLength); + if (contextSize == 0 && modelEntry.contextLength > 32768) + { + LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM"); + } } #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs index f6aa22cc..060b3599 100644 --- a/Runtime/LLMChatTemplates.cs +++ b/Runtime/LLMChatTemplates.cs @@ -1,5 +1,6 @@ /// @file /// @brief File implementing the chat templates. +using System; using System.Collections.Generic; using System.IO; using UnityEngine; @@ -113,9 +114,12 @@ public static string FromTemplate(string template) /// template name public static string FromGGUF(string path) { - GGUFReader reader = new GGUFReader(path); - string name; + return FromGGUF(new GGUFReader(path), path); + } + public static string FromGGUF(GGUFReader reader, string path) + { + string name; name = FromTemplate(reader.GetStringField("tokenizer.chat_template")); if (name != null) return name; diff --git a/Runtime/LLMGGUF.cs b/Runtime/LLMGGUF.cs index 4db81061..2ebeaf80 100644 --- a/Runtime/LLMGGUF.cs +++ b/Runtime/LLMGGUF.cs @@ -125,6 +125,13 @@ public ReaderField GetField(string key) return null; } + public byte[] GetGenericField(string key) + { + ReaderField field = GetField(key); + if (field == null || field.parts.Count == 0) return null; + return (byte[])field.parts[field.parts.Count - 1]; + } + /// /// Allows to retrieve a string GGUF field. /// @@ -132,9 +139,21 @@ public ReaderField GetField(string key) /// Retrieved GGUF value public string GetStringField(string key) { - ReaderField field = GetField(key); - if (field == null || field.parts.Count == 0) return null; - return System.Text.Encoding.UTF8.GetString((byte[])field.parts[field.parts.Count - 1]); + byte[] value = GetGenericField(key); + if (value == null) return null; + return System.Text.Encoding.UTF8.GetString(value); + } + + /// + /// Allows to retrieve an integer GGUF field. + /// + /// GGUF field to retrieve + /// Retrieved GGUF value + public int GetIntField(string key) + { + byte[] value = GetGenericField(key); + if (value == null) return -1; + return BitConverter.ToInt32(value, 0); } private byte[] ReadBytes(int offset, int count) diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index ff1101d8..3c7cbc24 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -17,7 +17,7 @@ public class ModelEntry public string chatTemplate; public string url; public bool includeInBuild; - + public int contextLength; public ModelEntry(string path, bool lora = false, string label = null, string url = null) { @@ -25,9 +25,17 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur this.label = label == null ? filename : label; this.lora = lora; this.path = Path.GetFullPath(path).Replace('\\', '/'); - chatTemplate = lora ? null : ChatTemplate.FromGGUF(this.path); this.url = url; includeInBuild = true; + chatTemplate = null; + contextLength = -1; + if (!lora) + { + GGUFReader reader = new GGUFReader(this.path); + chatTemplate = ChatTemplate.FromGGUF(reader, this.path); + string arch = reader.GetStringField("general.architecture"); + if (arch != null) contextLength = reader.GetIntField($"{arch}.context_length"); + } } public ModelEntry OnlyRequiredFields() From 56ec876d3bbc0828b12aecf8c3630a6796098dd1 Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Mon, 19 Aug 2024 15:57:34 +0000 Subject: [PATCH 15/67] update changelogs --- CHANGELOG.md | 2 +- CHANGELOG.release.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d93d512c..f3394aea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,8 @@ ## v2.1.2 #### 🚀 Features -- Update to latest llama.cpp (b3600) (PR: #210) - Implement embedding and lora adapter functionality (PR: #210) +- Read context length and warn if it is very large (PR: #211) #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 388aeb28..7e80ff57 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,7 +1,7 @@ ### 🚀 Features -- Update to latest llama.cpp (b3600) (PR: #210) - Implement embedding and lora adapter functionality (PR: #210) +- Read context length and warn if it is very large (PR: #211) ### 🐛 Fixes From 9fdd5857342b8d381989f6cb7d134533faec72ee Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 19:01:24 +0300 Subject: [PATCH 16/67] remove debug message --- Runtime/LLM.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 41001fb8..3d7b599b 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -175,7 +175,6 @@ public void SetModel(string path) ModelEntry modelEntry = LLMManager.Get(model); if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model)); SetTemplate(modelEntry.chatTemplate); - Debug.Log(modelEntry.contextLength); if (contextSize == 0 && modelEntry.contextLength > 32768) { LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM"); From 501577e19e6b90511b8c7ef334f3f911b66eb205 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 19:26:12 +0300 Subject: [PATCH 17/67] add Llama 3.1 and Gemma2 models --- README.md | 4 +++- Runtime/LLMUnitySetup.cs | 3 ++- Third Party Notices.md | 34 +++++++++++++++++++++++++--------- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 8b2e15d0..f7a828e9 100644 --- a/README.md +++ b/README.md @@ -446,4 +446,6 @@ If it is not selected, the full reply from the model is received in one go ## License -The license of LLM for Unity is MIT ([LICENSE.md](LICENSE.md)) and uses third-party software with MIT and Apache licenses ([Third Party Notices.md]()). +The license of LLM for Unity is MIT ([LICENSE.md](LICENSE.md)) and uses third-party software with MIT and Apache licenses. +Some models included in the asset define their own license terms, please review them before using each model. +Third-party licenses can be found in the ([Third Party Notices.md]()). diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 4ee87a66..285eee78 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -101,7 +101,8 @@ public class LLMUnitySetup /// Default models for download [HideInInspector] public static readonly (string, string, string)[] modelOptions = new(string, string, string)[] { - ("Llama 3 7B (medium, best overall)", "https://huggingface.co/lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf?download=true", "https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE"), + ("Llama 3.1 8B (medium, best overall)", "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf?download=true", "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE"), + ("Gemma 2 9B it (medium, great overall)", "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf?download=true", "https://ai.google.dev/gemma/terms"), ("Mistral 7B Instruct v0.2 (medium, great overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true", null), ("OpenHermes 2.5 7B (medium, good for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf?download=true", null), ("Phi 3 (small, great small model)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf?download=true", null), diff --git a/Third Party Notices.md b/Third Party Notices.md index ca86d120..3daa05a2 100644 --- a/Third Party Notices.md +++ b/Third Party Notices.md @@ -26,19 +26,35 @@ License: [link](https://github.com/Mozilla-Ocho/llamafile/blob/main/LICENSE) The following models can be downloaded with LLMUnity: -### meta-llama/Meta-Llama-3-8B-Instruct +### meta-llama/Meta-Llama-3.1-8B-Instruct Developer: Meta
-Origin: [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
-License Type: "llama3"
-License: [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) +Origin: [link](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
+License Type: "llama3.1"
+License: [link](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE) -##### modified by: lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF +##### modified by: bartowski/Meta-Llama-3-8B-Instruct-GGUF -Developer: LM Studio
-Origin: [link](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF)
-License Type: "llama3"
-License: [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) +Developer:bartowski
+Origin: [link](https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF)
+License Type: "llama3.1"
+License: [link](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE) + +
+ +### google/gemma-2-9b-it + +Developer: Google
+Origin: [link](https://huggingface.co/google/gemma-2-9b-it)
+License Type: "gemma"
+License: [link](https://ai.google.dev/gemma/terms) + +##### modified by: bartowski/gemma-2-9b-it-GGUF + +Developer:bartowski
+Origin: [link](https://huggingface.co/bartowski/gemma-2-9b-it-GGUF)
+License Type: "gemma"
+License: [link](https://ai.google.dev/gemma/terms)
From ff1f300770e1959e178d48bbeff95e2f8cdaec3e Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 20:41:17 +0300 Subject: [PATCH 18/67] add Gemma chat template --- Runtime/LLMCharacter.cs | 4 +-- Runtime/LLMChatTemplates.cs | 48 +++++++++++++++++++++++++-- Tests/Runtime/TestLLMChatTemplates.cs | 31 +++++++++++------ 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index c76239aa..3b002564 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -274,7 +274,7 @@ private async Task InitNKeep() { if (setNKeepToPrompt && nKeep == -1) { - string systemPrompt = template.ComputePrompt(new List(){chat[0]}, "", false); + string systemPrompt = template.ComputePrompt(new List(){chat[0]}, playerName, "", false); await Tokenize(systemPrompt, SetNKeep); } } @@ -472,7 +472,7 @@ public async Task Chat(string query, Callback callback = null, E try { AddPlayerMessage(query); - string prompt = template.ComputePrompt(chat, AIName); + string prompt = template.ComputePrompt(chat, playerName, AIName); json = JsonUtility.ToJson(GenerateRequest(prompt)); chat.RemoveAt(chat.Count - 1); } diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs index 060b3599..0820078d 100644 --- a/Runtime/LLMChatTemplates.cs +++ b/Runtime/LLMChatTemplates.cs @@ -37,6 +37,7 @@ static ChatTemplate() { new ChatMLTemplate(), new AlpacaTemplate(), + new GemmaTemplate(), new MistralChatTemplate(), new MistralInstructTemplate(), new LLama3ChatTemplate(), @@ -169,7 +170,7 @@ public static ChatTemplate GetTemplate(string template) /// the AI name /// whether to end the prompt with the AI prefix /// prompt - public virtual string ComputePrompt(List messages, string AIName, bool endWithPrefix = true) + public virtual string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true) { string chatPrompt = PromptPrefix(); int start = 0; @@ -336,6 +337,47 @@ public override string[] GetStop(string playerName, string AIName) } } + /// @ingroup template + /// + /// Class implementing the Gemma template + /// + public class GemmaTemplate : ChatTemplate + { + public override string GetName() { return "gemma"; } + public override string GetDescription() { return "gemma"; } + public override string[] GetNameMatches() { return new string[] {"gemma"}; } + + protected override string RequestSuffix() { return "\n"; } + protected override string PairSuffix() { return "\n"; } + + protected override string PlayerPrefix(string playerName) { return "" + playerName + "\n"; } + protected override string AIPrefix(string AIName) { return "" + AIName + "\n"; } + + public override string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true) + { + List messagesSystemPrompt = messages; + if (messages[0].role == "system") + { + string firstUserMessage = messages[0].content; + int start = 1; + if (messages.Count > 1) + { + if (firstUserMessage != "") firstUserMessage += "\n\n"; + firstUserMessage += messages[1].content; + start = 2; + } + messagesSystemPrompt = new List(){new ChatMessage { role = playerName, content = firstUserMessage }}; + messagesSystemPrompt.AddRange(messages.GetRange(start, messages.Count - start)); + } + return base.ComputePrompt(messagesSystemPrompt, playerName, AIName, endWithPrefix); + } + + public override string[] GetStop(string playerName, string AIName) + { + return AddStopNewlines(new string[] { "", "" }); + } + } + /// @ingroup template /// /// Class implementing the Alpaca template @@ -421,7 +463,7 @@ public class Phi3Template : ChatTemplate protected override string PairSuffix() { return "<|end|>\n"; } - public override string ComputePrompt(List messages, string AIName, bool endWithPrefix = true) + public override string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true) { List messagesSystemPrompt = messages; if (messages[0].role == "system") @@ -437,7 +479,7 @@ public override string ComputePrompt(List messages, string AIName, messagesSystemPrompt = new List(){new ChatMessage { role = "user", content = firstUserMessage }}; messagesSystemPrompt.AddRange(messages.GetRange(start, messages.Count - start)); } - return base.ComputePrompt(messagesSystemPrompt, AIName, endWithPrefix); + return base.ComputePrompt(messagesSystemPrompt, playerName, AIName, endWithPrefix); } public override string[] GetStop(string playerName, string AIName) diff --git a/Tests/Runtime/TestLLMChatTemplates.cs b/Tests/Runtime/TestLLMChatTemplates.cs index b48a2bef..bfedf6c8 100644 --- a/Tests/Runtime/TestLLMChatTemplates.cs +++ b/Tests/Runtime/TestLLMChatTemplates.cs @@ -21,16 +21,25 @@ public class TestChatTemplate public void TestChatML() { Assert.AreEqual( - new ChatMLTemplate().ComputePrompt(messages, "assistant"), + new ChatMLTemplate().ComputePrompt(messages, "user", "assistant"), "<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n" ); } + [Test] + public void TestGemma() + { + Assert.AreEqual( + new GemmaTemplate().ComputePrompt(messages, "user", "assistant"), + "user\nyou are a bot\n\nHello, how are you?\nassistant\nI'm doing great. How can I help you today?\nuser\nI'd like to show off how chat templating works!\nassistant\nchat template is awesome\nuser\ndo you think so?\nassistant\n" + ); + } + [Test] public void TestMistralInstruct() { Assert.AreEqual( - new MistralInstructTemplate().ComputePrompt(messages, "assistant"), + new MistralInstructTemplate().ComputePrompt(messages, "user", "assistant"), "[INST] you are a bot\n\nHello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]chat template is awesome[INST] do you think so? [/INST]" ); } @@ -39,7 +48,7 @@ public void TestMistralInstruct() public void TestMistralChat() { Assert.AreEqual( - new MistralChatTemplate().ComputePrompt(messages, "assistant"), + new MistralChatTemplate().ComputePrompt(messages, "user", "assistant"), "[INST] you are a bot\n\n### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today?[INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome[INST] ### user: do you think so? [/INST]### assistant:" ); } @@ -48,7 +57,7 @@ public void TestMistralChat() public void TestLLama2() { Assert.AreEqual( - new LLama2Template().ComputePrompt(messages, "assistant"), + new LLama2Template().ComputePrompt(messages, "user", "assistant"), "[INST] <>\nyou are a bot\n<> Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]chat template is awesome [INST] do you think so? [/INST]" ); } @@ -57,7 +66,7 @@ public void TestLLama2() public void TestLLama2Chat() { Assert.AreEqual( - new LLama2ChatTemplate().ComputePrompt(messages, "assistant"), + new LLama2ChatTemplate().ComputePrompt(messages, "user", "assistant"), "[INST] <>\nyou are a bot\n<> ### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today? [INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome [INST] ### user: do you think so? [/INST]### assistant:" ); } @@ -66,7 +75,7 @@ public void TestLLama2Chat() public void TestLLama3Chat() { Assert.AreEqual( - new LLama3ChatTemplate().ComputePrompt(messages, "assistant"), + new LLama3ChatTemplate().ComputePrompt(messages, "user", "assistant"), "<|start_header_id|>system<|end_header_id|>\n\nyou are a bot<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nchat template is awesome<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ndo you think so?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ); } @@ -75,7 +84,7 @@ public void TestLLama3Chat() public void TestAlpaca() { Assert.AreEqual( - new AlpacaTemplate().ComputePrompt(messages, "assistant"), + new AlpacaTemplate().ComputePrompt(messages, "user", "assistant"), "you are a bot\n\n### user: Hello, how are you?\n### assistant: I'm doing great. How can I help you today?\n### user: I'd like to show off how chat templating works!\n### assistant: chat template is awesome\n### user: do you think so?\n### assistant:" ); } @@ -84,7 +93,7 @@ public void TestAlpaca() public void TestVicuna() { Assert.AreEqual( - new VicunaTemplate().ComputePrompt(messages, "assistant"), + new VicunaTemplate().ComputePrompt(messages, "user", "assistant"), "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" ); } @@ -93,7 +102,7 @@ public void TestVicuna() public void TestPhi2() { Assert.AreEqual( - new Phi2Template().ComputePrompt(messages, "assistant"), + new Phi2Template().ComputePrompt(messages, "user", "assistant"), "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" ); } @@ -102,7 +111,7 @@ public void TestPhi2() public void TestPhi3() { Assert.AreEqual( - new Phi3Template().ComputePrompt(messages, "assistant"), + new Phi3Template().ComputePrompt(messages, "user", "assistant"), "<|user|>\nyou are a bot\n\nHello, how are you?<|end|>\n<|assistant|>\nI'm doing great. How can I help you today?<|end|>\n<|user|>\nI'd like to show off how chat templating works!<|end|>\n<|assistant|>\nchat template is awesome<|end|>\n<|user|>\ndo you think so?<|end|>\n<|assistant|>\n" ); } @@ -111,7 +120,7 @@ public void TestPhi3() public void TestZephyr() { Assert.AreEqual( - new ZephyrTemplate().ComputePrompt(messages, "assistant"), + new ZephyrTemplate().ComputePrompt(messages, "user", "assistant"), "<|system|>\nyou are a bot\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n<|assistant|>\nchat template is awesome\n<|user|>\ndo you think so?\n<|assistant|>\n" ); } From e8cb44f39b7bc90773a224cd47b280abf3deb1d2 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 20 Aug 2024 17:56:02 +0300 Subject: [PATCH 19/67] fix crash when stopping scene before LLM creation --- Runtime/LLM.cs | 70 ++++++++++++++++++++++++++++++----------------- Runtime/LLMLib.cs | 1 - 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 3d7b599b..7fb460c4 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -21,6 +21,8 @@ public LLMException(string message, int errorCode) : base(message) ErrorCode = errorCode; } } + + public class DestroyException : Exception {} /// \endcond [DefaultExecutionOrder(-1)] @@ -83,6 +85,7 @@ public class LLM : MonoBehaviour List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); List loraWeights = new List(); + private readonly object startLock = new object(); /// \endcond @@ -114,6 +117,7 @@ public async void Awake() return; } await Task.Run(() => StartLLMServer(arguments)); + if (!started) return; if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); if (basePrompt != "") await SetBasePrompt(basePrompt); } @@ -322,7 +326,7 @@ private void StartLLMServer(string arguments) try { InitLib(arch); - InitServer(arguments); + InitService(arguments); LLMUnitySetup.Log($"Using architecture: {arch}"); break; } @@ -331,6 +335,10 @@ private void StartLLMServer(string arguments) error = e.Message; Destroy(); } + catch (DestroyException) + { + break; + } catch (Exception e) { error = $"{e.GetType()}: {e.Message}"; @@ -343,7 +351,7 @@ private void StartLLMServer(string arguments) failed = true; return; } - StartService(); + CallIfNotDestroyed(() => StartService()); LLMUnitySetup.Log("LLM service created"); } @@ -353,13 +361,22 @@ private void InitLib(string arch) CheckLLMStatus(false); } - private void InitServer(string arguments) + void CallIfNotDestroyed(EmptyCallback fn) { - if (debug) SetupLogging(); - LLMObject = llmlib.LLM_Construct(arguments); - if (remote) llmlib.LLM_StartServer(LLMObject); - llmlib.LLM_SetTemplate(LLMObject, chatTemplate); - CheckLLMStatus(false); + lock (startLock) + { + if (llmlib == null) throw new DestroyException(); + fn(); + } + } + + private void InitService(string arguments) + { + if (debug) CallIfNotDestroyed(() => SetupLogging()); + CallIfNotDestroyed(() => {LLMObject = llmlib.LLM_Construct(arguments);}); + if (remote) CallIfNotDestroyed(() => llmlib.LLM_StartServer(LLMObject)); + CallIfNotDestroyed(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); + CallIfNotDestroyed(() => CheckLLMStatus(false)); } private void StartService() @@ -624,28 +641,31 @@ public void CancelRequest(int id_slot) /// public void Destroy() { - try + lock (startLock) { - if (llmlib != null) + try { - if (LLMObject != IntPtr.Zero) + if (llmlib != null) { - llmlib.LLM_Stop(LLMObject); - if (remote) llmlib.LLM_StopServer(LLMObject); - StopLogging(); - llmThread?.Join(); - llmlib.LLM_Delete(LLMObject); - LLMObject = IntPtr.Zero; + if (LLMObject != IntPtr.Zero) + { + llmlib.LLM_Stop(LLMObject); + if (remote) llmlib.LLM_StopServer(LLMObject); + StopLogging(); + llmThread?.Join(); + llmlib.LLM_Delete(LLMObject); + LLMObject = IntPtr.Zero; + } + llmlib.Destroy(); + llmlib = null; } - llmlib.Destroy(); + started = false; + failed = false; + } + catch (Exception e) + { + LLMUnitySetup.LogError(e.Message); } - started = false; - failed = false; - llmlib = null; - } - catch (Exception e) - { - LLMUnitySetup.LogError(e.Message); } } diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index fbaa050d..edb60601 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -281,7 +281,6 @@ static LLMLib() public LLMLib(string arch) { - LLMUnitySetup.Log(GetArchitecturePath(arch)); libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch)); if (libraryHandle == IntPtr.Zero) { From f99ee2f7da35edec2c47ebe80379e11276f7ce9d Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Tue, 20 Aug 2024 14:57:03 +0000 Subject: [PATCH 20/67] update changelogs --- CHANGELOG.md | 1 + CHANGELOG.release.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3394aea..d6776062 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### 🐛 Fixes - Fix set template for remote setup (PR: #208) +- fix crash when stopping scene before LLM creation (PR: #214) ## v2.1.1 diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 7e80ff57..e6701df4 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -6,4 +6,5 @@ ### 🐛 Fixes - Fix set template for remote setup (PR: #208) +- fix crash when stopping scene before LLM creation (PR: #214) From 225957d2cee15e8d435611db34e9260ac0842ba0 Mon Sep 17 00:00:00 2001 From: ltoniazzi Date: Tue, 20 Aug 2024 22:43:31 +0100 Subject: [PATCH 21/67] Point to gguf format for lora --- README.md | 6 +++--- Runtime/LLM.cs | 28 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f7a828e9..d8e79bba 100644 --- a/README.md +++ b/README.md @@ -248,7 +248,7 @@ public class MyScript : MonoBehaviour // Otherwise the model file can be copied directly inside the StreamingAssets folder. llm.SetModel("Phi-3-mini-4k-instruct-q4.gguf"); // optional: you can also set a lora in a similar fashion - llm.SetLora("my-lora.bin"); + llm.SetLora("my-lora.gguf"); // optional: you can set the chat template of the model if it is not correctly identified // You can find a list of chat templates in the ChatTemplate.templates.Keys llm.SetTemplate("phi-3"); @@ -374,8 +374,8 @@ If the user's GPU is not supported, the LLM will fall back to the CPU -
Advanced options - - `Download lora` click to download a LoRA model in .bin format - - `Load lora` click to load a LoRA model in .bin format + - `Download lora` click to download a LoRA model in .gguf format + - `Load lora` click to load a LoRA model in .gguf format -
Context Size size of the prompt context (0 = context size of the model) This is the number of tokens the model can take as input when generating responses. Higher values use more RAM or VRAM (if using GPU).
- `Batch Size` batch size for prompt processing (default: 512) - `Model` the path of the model being used (relative to the Assets/StreamingAssets folder) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 7fb460c4..54395e8c 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -22,7 +22,7 @@ public LLMException(string message, int errorCode) : base(message) } } - public class DestroyException : Exception {} + public class DestroyException : Exception { } /// \endcond [DefaultExecutionOrder(-1)] @@ -72,7 +72,7 @@ public class LLM : MonoBehaviour /// Chat template used for the model [ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate; /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). - /// Models with .bin format are allowed. + /// Models with .gguf format are allowed.
[ModelAdvanced] public string lora = ""; /// \cond HIDE @@ -192,9 +192,9 @@ public void SetModel(string path) /// /// Allows to set a LORA model to use in the LLM. /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. - /// Models supported are in .bin format. + /// Models supported are in .gguf format. /// - /// path to LORA model to use (.bin format) + /// path to LORA model to use (.gguf format) public void SetLora(string path) { lora = ""; @@ -204,9 +204,9 @@ public void SetLora(string path) /// /// Allows to add a LORA model to use in the LLM. /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. - /// Models supported are in .bin format. + /// Models supported are in .gguf format. /// - /// path to LORA model to use (.bin format) + /// path to LORA model to use (.gguf format) public void AddLora(string path) { string loraPath = GetModelLoraPath(path, true); @@ -220,9 +220,9 @@ public void AddLora(string path) /// /// Allows to remove a LORA model from the LLM. - /// Models supported are in .bin format. + /// Models supported are in .gguf format. /// - /// path to LORA model to remove (.bin format) + /// path to LORA model to remove (.gguf format) public void RemoveLora(string path) { string loraPath = GetModelLoraPath(path, true); @@ -373,7 +373,7 @@ void CallIfNotDestroyed(EmptyCallback fn) private void InitService(string arguments) { if (debug) CallIfNotDestroyed(() => SetupLogging()); - CallIfNotDestroyed(() => {LLMObject = llmlib.LLM_Construct(arguments);}); + CallIfNotDestroyed(() => { LLMObject = llmlib.LLM_Construct(arguments); }); if (remote) CallIfNotDestroyed(() => llmlib.LLM_StartServer(LLMObject)); CallIfNotDestroyed(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); CallIfNotDestroyed(() => CheckLLMStatus(false)); @@ -383,7 +383,7 @@ private void StartService() { llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); - while (!llmlib.LLM_Started(LLMObject)) {} + while (!llmlib.LLM_Started(LLMObject)) { } loraWeights = new List(); for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); started = true; @@ -446,7 +446,7 @@ void AssertStarted() void CheckLLMStatus(bool log = true) { - if (llmlib == null) {return;} + if (llmlib == null) { return; } IntPtr stringWrapper = llmlib.StringWrapper_Construct(); int status = llmlib.LLM_Status(LLMObject, stringWrapper); string result = llmlib.GetStringWrapperResult(stringWrapper); @@ -553,7 +553,7 @@ public async Task SetLoraScale(string loraToScale, float scale) loraWeightRequest.loraWeights = new List(); for (int i = 0; i < loraWeights.Count; i++) { - loraWeightRequest.loraWeights.Add(new LoraWeightRequest() {id = i, scale = loraWeights[i]}); + loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = loraWeights[i] }); } ; @@ -607,7 +607,7 @@ public async Task Slot(string json) public async Task Completion(string json, Callback streamCallback = null) { AssertStarted(); - if (streamCallback == null) streamCallback = (string s) => {}; + if (streamCallback == null) streamCallback = (string s) => { }; StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback); await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper())); if (!started) return null; @@ -621,7 +621,7 @@ public async Task Completion(string json, Callback streamCallbac public async Task SetBasePrompt(string base_prompt) { AssertStarted(); - SystemPromptRequest request = new SystemPromptRequest(){system_prompt = base_prompt, prompt = " ", n_predict = 0}; + SystemPromptRequest request = new SystemPromptRequest() { system_prompt = base_prompt, prompt = " ", n_predict = 0 }; await Completion(JsonUtility.ToJson(request)); } From cf2f1b7461eb437695b9e836da158b5cf69cac60 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 09:11:59 +0300 Subject: [PATCH 22/67] move around code --- Runtime/LLMUnitySetup.cs | 48 +++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 285eee78..3aefe646 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -109,25 +109,13 @@ public class LLMUnitySetup ("Qwen 2 0.5B (tiny, useful for mobile)", "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true", null), }; - /// Add callback function to call for error logs - public static void AddErrorCallBack(Callback callback) - { - errorCallbacks.Add(callback); - } - - /// Remove callback function added for error logs - public static void RemoveErrorCallBack(Callback callback) - { - errorCallbacks.Remove(callback); - } - - /// Remove all callback function added for error logs - public static void ClearErrorCallBacks() - { - errorCallbacks.Clear(); - } - /// \cond HIDE + [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; + static List> errorCallbacks = new List>(); + static readonly object lockObject = new object(); + static Dictionary androidExtractTasks = new Dictionary(); + static string DebugModeKey = "DebugMode"; + public enum DebugModeType { All, @@ -135,10 +123,6 @@ public enum DebugModeType Error, None } - [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; - static List> errorCallbacks = new List>(); - static readonly object lockObject = new object(); - static Dictionary androidExtractTasks = new Dictionary(); public static void Log(string message) { @@ -159,7 +143,6 @@ public static void LogError(string message) foreach (Callback errorCallback in errorCallbacks) errorCallback(message); } - static string DebugModeKey = "DebugMode"; static void LoadDebugMode() { DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All); @@ -364,6 +347,25 @@ public static string AddAsset(string assetPath) #endif /// \endcond + + /// Add callback function to call for error logs + public static void AddErrorCallBack(Callback callback) + { + errorCallbacks.Add(callback); + } + + /// Remove callback function added for error logs + public static void RemoveErrorCallBack(Callback callback) + { + errorCallbacks.Remove(callback); + } + + /// Remove all callback function added for error logs + public static void ClearErrorCallBacks() + { + errorCallbacks.Clear(); + } + public static int GetMaxFreqKHz(int cpuId) { string[] paths = new string[] From d19e7e017e735898a630f2ca0602356a0cc1ad9b Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 14:55:15 +0300 Subject: [PATCH 23/67] improve library setup and include build for extras --- Runtime/LLMUnitySetup.cs | 111 +++++++++++++++++++++++++++++++-------- 1 file changed, 89 insertions(+), 22 deletions(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 3aefe646..9b6a22af 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -47,6 +47,7 @@ public class ModelAttribute : PropertyAttribute {} public class ModelDownloadAttribute : ModelAttribute {} public class ModelDownloadAdvancedAttribute : ModelAdvancedAttribute {} public class ModelAdvancedAttribute : PropertyAttribute {} + public class ModelExtrasAttribute : PropertyAttribute {} public class ChatAttribute : PropertyAttribute {} public class ChatAdvancedAttribute : PropertyAttribute {} public class LLMUnityAttribute : PropertyAttribute {} @@ -87,8 +88,12 @@ public class LLMUnitySetup public static string Version = "v2.1.2"; /// LlamaLib version public static string LlamaLibVersion = "v1.1.8"; + /// LlamaLib release url + public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}"; /// LlamaLib url - public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; + public static string LlamaLibURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp.zip"; + /// LlamaLib extension url + public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp-full.zip"; /// LlamaLib path public static string libraryPath = GetAssetPath(Path.GetFileName(LlamaLibURL).Replace(".zip", "")); /// LLMnity store path @@ -111,10 +116,12 @@ public class LLMUnitySetup /// \cond HIDE [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All; + static string DebugModeKey = "DebugMode"; + public static bool FullLlamaLib = false; + static string FullLlamaLibKey = "FullLlamaLib"; static List> errorCallbacks = new List>(); static readonly object lockObject = new object(); static Dictionary androidExtractTasks = new Dictionary(); - static string DebugModeKey = "DebugMode"; public enum DebugModeType { @@ -143,9 +150,10 @@ public static void LogError(string message) foreach (Callback errorCallback in errorCallbacks) errorCallback(message); } - static void LoadDebugMode() + static void LoadPlayerPrefs() { DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All); + FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1; } public static void SetDebugMode(DebugModeType newDebugMode) @@ -156,6 +164,18 @@ public static void SetDebugMode(DebugModeType newDebugMode) PlayerPrefs.Save(); } +#if UNITY_EDITOR + public static void SetFullLlamaLib(bool value) + { + if (FullLlamaLib == value) return; + FullLlamaLib = value; + PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0); + PlayerPrefs.Save(); + _ = DownloadLibrary(); + } + +#endif + public static string GetAssetPath(string relPath = "") { // Path to store llm server binaries and models @@ -168,14 +188,14 @@ public static string GetAssetPath(string relPath = "") static async Task InitializeOnLoad() { await DownloadLibrary(); - LoadDebugMode(); + LoadPlayerPrefs(); } #else [RuntimeInitializeOnLoadMethod(RuntimeInitializeLoadType.BeforeSceneLoad)] void InitializeOnLoad() { - LoadDebugMode(); + LoadPlayerPrefs(); } #endif @@ -290,35 +310,82 @@ public static bool IsSubPath(string childPath, string parentPath) [HideInInspector] public static float libraryProgress = 1; - private static async Task DownloadLibrary() + static void CreateEmptyFile(string path) + { + File.Create(path).Dispose(); + } + + static void ExtractInsideDirectory(string zipPath, string extractPath, bool overwrite = true) { - if (libraryProgress < 1) return; - libraryProgress = 0; - string libZip = Path.Combine(Application.temporaryCachePath, Path.GetFileName(LlamaLibURL)); - if (!Directory.Exists(libraryPath)) + using (ZipArchive archive = ZipFile.OpenRead(zipPath)) { - await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress); + foreach (ZipArchiveEntry entry in archive.Entries) + { + if (string.IsNullOrEmpty(entry.Name)) continue; + string destinationPath = Path.Combine(extractPath, entry.FullName); + Directory.CreateDirectory(Path.GetDirectoryName(destinationPath)); + entry.ExtractToFile(destinationPath, overwrite); + } + } + } + + static async Task DownloadAndExtractInsideDirectory(string url, string path, string setupDir) + { + string urlName = Path.GetFileName(url); + string setupFile = Path.Combine(setupDir, urlName + ".complete"); + if (File.Exists(setupFile)) return; + + string zipPath = Path.Combine(Application.temporaryCachePath, urlName); + await DownloadFile(url, zipPath, true, null, SetLibraryProgress); + + AssetDatabase.StartAssetEditing(); + ExtractInsideDirectory(zipPath, path); + CreateEmptyFile(setupFile); + AssetDatabase.StopAssetEditing(); + + File.Delete(zipPath); + } + + static async Task DownloadLibrary() + { + void DeleteFileAndMeta(string path) + { + if (File.Exists(path + ".meta")) File.Delete(path + ".meta"); + if (File.Exists(path)) File.Delete(path); + } + + try + { + string setupDir = Path.Combine(libraryPath, "setup"); + Directory.CreateDirectory(setupDir); + + string lockFile = Path.Combine(setupDir, "LLMUnitySetup.lock"); + if (File.Exists(lockFile)) return; + CreateEmptyFile(lockFile); + + libraryProgress = 0; + await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir); + AssetDatabase.StartAssetEditing(); - ZipFile.ExtractToDirectory(libZip, libraryPath); string androidDir = Path.Combine(libraryPath, "android"); if (Directory.Exists(androidDir)) { string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android"); Directory.CreateDirectory(androidPluginDir); Directory.Move(androidDir, Path.Combine(androidPluginDir, Path.GetFileName(libraryPath))); - } - foreach (string librarySubPath in Directory.GetDirectories(libraryPath)) - { - if (Path.GetFileName(librarySubPath).StartsWith("android")) - { - string pluginPath = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(librarySubPath)); - Directory.Move(librarySubPath, pluginPath); - } + if (File.Exists(androidDir + ".meta")) File.Delete(androidDir + ".meta"); } AssetDatabase.StopAssetEditing(); - File.Delete(libZip); + + if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir); + + libraryProgress = 1; + DeleteFileAndMeta(lockFile); + } + catch (Exception e) + { + LogError(e.Message); } - libraryProgress = 1; } private static void SetLibraryProgress(float progress) From 41cea64be5b3610fff60082169a63601069150eb Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 14:56:23 +0300 Subject: [PATCH 24/67] use full library for cuda if set --- Runtime/LLMLib.cs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index edb60601..e3501769 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -326,8 +326,16 @@ public static List PossibleArchitectures(bool gpu = false) { if (gpu) { - architectures.Add("cuda-cu12.2.0"); - architectures.Add("cuda-cu11.7.1"); + if (LLMUnitySetup.FullLlamaLib) + { + architectures.Add("cuda-cu12.2.0-full"); + architectures.Add("cuda-cu11.7.1-full"); + } + else + { + architectures.Add("cuda-cu12.2.0"); + architectures.Add("cuda-cu11.7.1"); + } architectures.Add("hip"); architectures.Add("vulkan"); } From 1608ff17127f019ebdc935f7c1d37b5a1eac6f55 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 14:56:50 +0300 Subject: [PATCH 25/67] include full build if set, remove setup dir --- Runtime/LLMBuilder.cs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs index a3cf0070..542520f9 100644 --- a/Runtime/LLMBuilder.cs +++ b/Runtime/LLMBuilder.cs @@ -104,15 +104,19 @@ static void AddActionAddMeta(string target) public static void HideLibraryPlatforms(string platform) { - List platforms = new List(){ "windows", "macos", "linux", "android", "ios" }; + List platforms = new List(){ "windows", "macos", "linux", "android", "ios", "setup" }; platforms.Remove(platform); foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath)) { + string sourceName = Path.GetFileName(source); foreach (string platformPrefix in platforms) { - if (Path.GetFileName(source).StartsWith(platformPrefix)) + bool move = sourceName.StartsWith(platformPrefix); + move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib); + move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib); + if (move) { - string target = Path.Combine(BuildTempDir, Path.GetFileName(source)); + string target = Path.Combine(BuildTempDir, sourceName); MoveAction(source, target); MoveAction(source + ".meta", target + ".meta"); } From 247c8eb87281368c93f270b27c59704e3de567ca Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 14:57:16 +0300 Subject: [PATCH 26/67] add flash attention argument --- Runtime/LLM.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 54395e8c..9e50c513 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -22,7 +22,7 @@ public LLMException(string message, int errorCode) : base(message) } } - public class DestroyException : Exception { } + public class DestroyException : Exception {} /// \endcond [DefaultExecutionOrder(-1)] @@ -74,6 +74,8 @@ public class LLM : MonoBehaviour /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). /// Models with .gguf format are allowed. [ModelAdvanced] public string lora = ""; + /// enable use of flash attention + [ModelExtras] public bool flashAttention = false; /// \cond HIDE @@ -297,6 +299,7 @@ protected virtual string GetLlamaccpArguments() if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; arguments += loraArgument; arguments += $" -ngl {numGPULayers}"; + if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn"; return arguments; } @@ -383,7 +386,7 @@ private void StartService() { llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); - while (!llmlib.LLM_Started(LLMObject)) { } + while (!llmlib.LLM_Started(LLMObject)) {} loraWeights = new List(); for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); started = true; @@ -607,7 +610,7 @@ public async Task Slot(string json) public async Task Completion(string json, Callback streamCallback = null) { AssertStarted(); - if (streamCallback == null) streamCallback = (string s) => { }; + if (streamCallback == null) streamCallback = (string s) => {}; StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback); await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper())); if (!started) return null; From 137ba0502ba55a22cef111a175c171a257ee1588 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 14:57:51 +0300 Subject: [PATCH 27/67] add setup extras button, show extras arguments if set --- Editor/LLMEditor.cs | 1 + Editor/PropertyEditor.cs | 18 ++++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index ed50046d..9601579f 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -74,6 +74,7 @@ public void AddModelSettings(SerializedObject llmScriptSO) if (llmScriptSO.FindProperty("advancedOptions").boolValue) { attributeClasses.Add(typeof(ModelAdvancedAttribute)); + if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute)); } ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false); Space(); diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs index 87a40938..d3510ac8 100644 --- a/Editor/PropertyEditor.cs +++ b/Editor/PropertyEditor.cs @@ -16,17 +16,11 @@ public void AddScript(SerializedObject llmScriptSO) EditorGUILayout.PropertyField(scriptProp); } - public void AddOptionsToggle(SerializedObject llmScriptSO, string propertyName, string name) + public bool ToggleButton(string text, bool activated) { - SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty(propertyName); - string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " " + name; GUIStyle style = new GUIStyle("Button"); - if (advancedOptionsProp.boolValue) - style.normal = new GUIStyleState() { background = Texture2D.grayTexture }; - if (GUILayout.Button(toggleText, style, GUILayout.Width(buttonWidth))) - { - advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue; - } + if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture }; + return GUILayout.Button(text, style, GUILayout.Width(buttonWidth)); } public void AddSetupSettings(SerializedObject llmScriptSO) @@ -54,8 +48,12 @@ public void AddChatSettings(SerializedObject llmScriptSO) public void AddOptionsToggles(SerializedObject llmScriptSO) { LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode)); + EditorGUILayout.BeginHorizontal(); - AddOptionsToggle(llmScriptSO, "advancedOptions", "Advanced Options"); + SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions"); + string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options"; + if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue; + if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib); EditorGUILayout.EndHorizontal(); Space(); } From e08f8431c6df2a6784ba618e9573a0913f8581e3 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 15:20:14 +0300 Subject: [PATCH 28/67] download library after loading prefs, set max download progress of 0.99 to avoid enabling editor, handle overwrite plugin --- Runtime/LLMUnitySetup.cs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 9b6a22af..428c2da5 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -187,8 +187,8 @@ public static string GetAssetPath(string relPath = "") [InitializeOnLoadMethod] static async Task InitializeOnLoad() { - await DownloadLibrary(); LoadPlayerPrefs(); + await DownloadLibrary(); } #else @@ -370,9 +370,11 @@ void DeleteFileAndMeta(string path) string androidDir = Path.Combine(libraryPath, "android"); if (Directory.Exists(androidDir)) { - string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android"); - Directory.CreateDirectory(androidPluginDir); - Directory.Move(androidDir, Path.Combine(androidPluginDir, Path.GetFileName(libraryPath))); + string androidPluginsDir = Path.Combine(Application.dataPath, "Plugins", "Android"); + Directory.CreateDirectory(androidPluginsDir); + string pluginDir = Path.Combine(androidPluginsDir, Path.GetFileName(libraryPath)); + if (Directory.Exists(pluginDir)) Directory.Delete(pluginDir, true); + Directory.Move(androidDir, pluginDir); if (File.Exists(androidDir + ".meta")) File.Delete(androidDir + ".meta"); } AssetDatabase.StopAssetEditing(); @@ -390,7 +392,7 @@ void DeleteFileAndMeta(string path) private static void SetLibraryProgress(float progress) { - libraryProgress = progress; + libraryProgress = Math.Min(0.99f, progress); } public static string AddAsset(string assetPath) From d170f876ccc9d3bda32efae330832b73e719216b Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 15:37:32 +0300 Subject: [PATCH 29/67] add extra and flash attention options to readme --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index d8e79bba..bdf6350b 100644 --- a/README.md +++ b/README.md @@ -345,6 +345,7 @@ If you have loaded a model locally you need to set its URL through the expanded - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below - `Log Level` select how verbose the log messages are +- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants) #### 💻 Setup Settings @@ -381,6 +382,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Model` the path of the model being used (relative to the Assets/StreamingAssets folder) - `Chat Template` the chat template being used for the LLM - `Lora` the path of the LoRA being used (relative to the Assets/StreamingAssets folder) + - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled) @@ -395,6 +397,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Show/Hide Advanced Options` Toggle to show/hide advanced options from below - `Log Level` select how verbose the log messages are +- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants) #### 💻 Setup Settings
From 54398eef558f2f954b633314d9edfa4c8da15761 Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Wed, 21 Aug 2024 12:39:10 +0000 Subject: [PATCH 30/67] update changelogs --- CHANGELOG.md | 5 +++++ CHANGELOG.release.md | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6776062..d2a54e5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,12 +3,17 @@ - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) +- Setup allowing to use extra features: flash attention and IQ quants (PR: #216) #### 🐛 Fixes - Fix set template for remote setup (PR: #208) - fix crash when stopping scene before LLM creation (PR: #214) +#### 📦 General + +- Documentation/point to gguf format for lora (PR: #215) + ## v2.1.1 #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index e6701df4..10b5dc2e 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -2,9 +2,14 @@ - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) +- Setup allowing to use extra features: flash attention and IQ quants (PR: #216) ### 🐛 Fixes - Fix set template for remote setup (PR: #208) - fix crash when stopping scene before LLM creation (PR: #214) +### 📦 General + +- Documentation/point to gguf format for lora (PR: #215) + From 5f8deb7f77533d95efc4a74edf9f0649d26f85a8 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 21 Aug 2024 17:49:09 +0300 Subject: [PATCH 31/67] use CallWithLock function to lock lib calls --- Runtime/LLM.cs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 9e50c513..140bc75f 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -354,7 +354,7 @@ private void StartLLMServer(string arguments) failed = true; return; } - CallIfNotDestroyed(() => StartService()); + CallWithLock(StartService); LLMUnitySetup.Log("LLM service created"); } @@ -364,22 +364,22 @@ private void InitLib(string arch) CheckLLMStatus(false); } - void CallIfNotDestroyed(EmptyCallback fn) + void CallWithLock(EmptyCallback fn, bool checkNull = true) { lock (startLock) { - if (llmlib == null) throw new DestroyException(); + if (checkNull && llmlib == null) throw new DestroyException(); fn(); } } private void InitService(string arguments) { - if (debug) CallIfNotDestroyed(() => SetupLogging()); - CallIfNotDestroyed(() => { LLMObject = llmlib.LLM_Construct(arguments); }); - if (remote) CallIfNotDestroyed(() => llmlib.LLM_StartServer(LLMObject)); - CallIfNotDestroyed(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); - CallIfNotDestroyed(() => CheckLLMStatus(false)); + if (debug) CallWithLock(SetupLogging); + CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); }); + if (remote) CallWithLock(() => llmlib.LLM_StartServer(LLMObject)); + CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); + CallWithLock(() => CheckLLMStatus(false)); } private void StartService() @@ -644,7 +644,7 @@ public void CancelRequest(int id_slot) ///
public void Destroy() { - lock (startLock) + CallWithLock(() => { try { @@ -669,7 +669,7 @@ public void Destroy() { LLMUnitySetup.LogError(e.Message); } - } + }, false); } /// From 45fc661a37d8922ba4ab41d2ffb710773ce6316c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 22 Aug 2024 12:34:52 +0300 Subject: [PATCH 32/67] remote request with retries --- Runtime/LLMCharacter.cs | 71 ++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 3b002564..5732b1cb 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -27,6 +27,8 @@ public class LLMCharacter : MonoBehaviour [Remote] public string host = "localhost"; /// port to use for the LLM server [Remote] public int port = 13333; + /// number of retries to use for the LLM server requests (-1 = infinite) + [Remote] public int numRetries = -1; /// file to save the chat history. /// The file is saved only for Chat calls with addToHistory set to true. /// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). @@ -750,38 +752,57 @@ protected async Task PostRequestRemote(string json, string endpoi Ret result = default; byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json); - using (var request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend)) - { - WIPRequests.Add(request); + UnityWebRequest request = null; + string error = null; + int tryNr = numRetries; - request.method = "POST"; - if (requestHeaders != null) + while (tryNr != 0) + { + using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend)) { - for (int i = 0; i < requestHeaders.Count; i++) - request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2); - } + WIPRequests.Add(request); - // Start the request asynchronously - var asyncOperation = request.SendWebRequest(); - float lastProgress = 0f; - // Continue updating progress until the request is completed - while (!asyncOperation.isDone) - { - float currentProgress = request.downloadProgress; - // Check if progress has changed - if (currentProgress != lastProgress && callback != null) + request.method = "POST"; + if (requestHeaders != null) + { + for (int i = 0; i < requestHeaders.Count; i++) + request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2); + } + + // Start the request asynchronously + var asyncOperation = request.SendWebRequest(); + float lastProgress = 0f; + // Continue updating progress until the request is completed + while (!asyncOperation.isDone) { - callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent)); - lastProgress = currentProgress; + float currentProgress = request.downloadProgress; + // Check if progress has changed + if (currentProgress != lastProgress && callback != null) + { + callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent)); + lastProgress = currentProgress; + } + // Wait for the next frame + await Task.Yield(); + } + WIPRequests.Remove(request); + if (request.result == UnityWebRequest.Result.Success) + { + result = ConvertContent(request.downloadHandler.text, getContent); + error = null; + break; + } + else + { + result = default; + error = request.error; } - // Wait for the next frame - await Task.Yield(); } - WIPRequests.Remove(request); - if (request.result != UnityWebRequest.Result.Success) LLMUnitySetup.LogError(request.error); - else result = ConvertContent(request.downloadHandler.text, getContent); - callback?.Invoke(result); + tryNr--; } + + if (error != null) LLMUnitySetup.LogError(error); + callback?.Invoke(result); return result; } From 92efe9cf01fa4304ce92eb916aba4197fc210123 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 22 Aug 2024 12:35:31 +0300 Subject: [PATCH 33/67] check template if null before using it --- Runtime/LLMCharacter.cs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 5732b1cb..1bf68236 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -120,7 +120,7 @@ public class LLMCharacter : MonoBehaviour public List chat; private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1); private string chatTemplate; - private ChatTemplate template; + private ChatTemplate template = null; public string grammarString; protected int id_slot = -1; private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; @@ -272,10 +272,21 @@ public void SetPrompt(string newPrompt, bool clearChat = true) InitPrompt(clearChat); } + private bool CheckTemplate() + { + if (template == null) + { + LLMUnitySetup.LogError("Template not set!"); + return false; + } + return true; + } + private async Task InitNKeep() { if (setNKeepToPrompt && nKeep == -1) { + if (!CheckTemplate()) return; string systemPrompt = template.ComputePrompt(new List(){chat[0]}, playerName, "", false); await Tokenize(systemPrompt, SetNKeep); } @@ -313,7 +324,7 @@ public async Task LoadTemplate() if (llmTemplate != chatTemplate) { chatTemplate = llmTemplate; - template = ChatTemplate.GetTemplate(chatTemplate); + template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate); } } @@ -333,6 +344,7 @@ public async void SetGrammar(string path) List GetStopwords() { + if (!CheckTemplate()) return null; List stopAll = new List(template.GetStop(playerName, AIName)); if (stop != null) stopAll.AddRange(stop); return stopAll; @@ -467,6 +479,7 @@ public async Task Chat(string query, Callback callback = null, E // call the callback function while the answer is received // call the completionCallback function when the answer is fully received await LoadTemplate(); + if (!CheckTemplate()) return null; await InitNKeep(); string json; From c478d9313071edaa76d86fa0258715f5b42e6741 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 22 Aug 2024 12:40:50 +0300 Subject: [PATCH 34/67] recompute nkeep when template changes --- Runtime/LLMCharacter.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 1bf68236..c5ce99fd 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -325,6 +325,7 @@ public async Task LoadTemplate() { chatTemplate = llmTemplate; template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate); + nKeep = -1; } } From 68560b7bf9ce515489ce12503e61f37355cfc0d1 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 22 Aug 2024 12:42:38 +0300 Subject: [PATCH 35/67] add retries to readme --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bdf6350b..8dee1a57 100644 --- a/README.md +++ b/README.md @@ -406,8 +406,9 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Remote` whether the LLM used is remote or local - `LLM` the LLM GameObject (if `Remote` is not set) -- `Hort` ip of the LLM (if `Remote` is set) -- `Port` port of the LLM (if `Remote` is set) +- `Hort` ip of the LLM server (if `Remote` is set) +- `Port` port of the LLM server (if `Remote` is set) +- `Num Retries` number of HTTP request retries from the LLM server (if `Remote` is set) -
Save save filename or relative path If set, the chat history and LLM state (if save cache is enabled) is automatically saved to file specified.
The chat history is saved with a json suffix and the LLM state with a cache suffix.
Both files are saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
- `Save Cache` select to save the LLM state along with the chat history. The LLM state is typically around 100MB+. - `Debug Prompt` select to log the constructed prompts in the Unity Editor From 9128847d95ca8535899d57a4a0d45ed9dab826b8 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 22 Aug 2024 12:46:15 +0300 Subject: [PATCH 36/67] set template before starting server --- Runtime/LLM.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 140bc75f..77b71398 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -377,8 +377,8 @@ private void InitService(string arguments) { if (debug) CallWithLock(SetupLogging); CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); }); - if (remote) CallWithLock(() => llmlib.LLM_StartServer(LLMObject)); CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); + if (remote) CallWithLock(() => llmlib.LLM_StartServer(LLMObject)); CallWithLock(() => CheckLLMStatus(false)); } From 72e24dd46a0c88ad19fb8395e2d6282ae7e1bb5a Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Thu, 22 Aug 2024 09:46:49 +0000 Subject: [PATCH 37/67] update changelogs --- CHANGELOG.md | 1 + CHANGELOG.release.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2a54e5a..3154f6c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) - Setup allowing to use extra features: flash attention and IQ quants (PR: #216) +- Allow HTTP request retries for remote server (PR: #217) #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 10b5dc2e..439abc9a 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -3,6 +3,7 @@ - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) - Setup allowing to use extra features: flash attention and IQ quants (PR: #216) +- Allow HTTP request retries for remote server (PR: #217) ### 🐛 Fixes From 51d27b2d0c41739ee6ddb99eb09f5cef30d0fd79 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 22 Aug 2024 18:11:09 +0300 Subject: [PATCH 38/67] modify how to help --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8dee1a57..d179f17a 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,10 @@ LLM for Unity is built on top of the awesome [llama.cpp](https://github.com/gger ## How to help - [⭐ Star](https://github.com/undreamai/LLMUnity) the repo, leave us a [review](https://assetstore.unity.com/packages/slug/273604) and spread the word about the project! -- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi! -- [Contribute](CONTRIBUTING.md) by submitting feature requests or bugs as issues or even submiting a PR and become a collaborator! +- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi. +- [Contribute](CONTRIBUTING.md) by submitting feature requests, bugs or even your own PR. +- [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/amakropoulos) this work to allow even cooler features! + ## Games using LLM for Unity - [Verbal Verdict](https://store.steampowered.com/app/2778780/Verbal_Verdict/) From 34c300fd4ca08ced8397e2d5712a9104afc868bd Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 14:02:18 +0300 Subject: [PATCH 39/67] add relative path function, expose create empty file --- Runtime/LLMUnitySetup.cs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 428c2da5..614f52e8 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -306,11 +306,26 @@ public static bool IsSubPath(string childPath, string parentPath) return fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase); } + public static string RelativePath(string fullPath, string basePath) + { + // Get the full paths and replace backslashes with forward slashes (or vice versa) + string fullParentPath = Path.GetFullPath(basePath).Replace('\\', '/').TrimEnd('/'); + string fullChildPath = Path.GetFullPath(fullPath).Replace('\\', '/'); + + string relativePath = fullChildPath; + if (fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase)) + { + relativePath = fullChildPath.Substring(fullParentPath.Length); + while (relativePath.StartsWith("/")) relativePath = relativePath.Substring(1); + } + return relativePath; + } + #if UNITY_EDITOR [HideInInspector] public static float libraryProgress = 1; - static void CreateEmptyFile(string path) + public static void CreateEmptyFile(string path) { File.Create(path).Dispose(); } From 5a086d5a6403363f1549f0596b6bf5b7df3a6d32 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 14:24:16 +0300 Subject: [PATCH 40/67] explicitly specify editor and runtime asset management --- Runtime/LLM.cs | 68 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 77b71398..dfe56331 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -136,35 +136,57 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr return !modelSetupFailed; } - public string GetModelLoraPathRuntime(string path) + public string GetLLMManagerAsset(string path) { - string assetPath = LLMManager.GetAssetPath(path); - if (!string.IsNullOrEmpty(assetPath)) return assetPath; - return path; +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path); +#endif + return GetLLMManagerAssetRuntime(path); } - public string GetModelLoraPath(string path, bool lora) + public static string GetLLMManagerAssetEditor(string path) { + // empty if (string.IsNullOrEmpty(path)) return path; + // LLMManager - return location the file will be stored in StreamingAssets ModelEntry modelEntry = LLMManager.Get(path); if (modelEntry != null) return modelEntry.filename; - - string modelType = lora ? "Lora" : "Model"; - string assetPath = LLMUnitySetup.GetAssetPath(path); + // StreamingAssets - return relative location within StreamingAssets + string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed + string basePath = LLMUnitySetup.GetAssetPath(); + if (File.Exists(assetPath)) + { + if (LLMUnitySetup.IsSubPath(assetPath, basePath)) return LLMUnitySetup.RelativePath(assetPath, basePath); + } + // full path if (!File.Exists(assetPath)) { - LLMUnitySetup.LogError($"The {modelType} file {path} was not found."); - return path; + LLMUnitySetup.LogError($"Model {path} was not found."); } - - if (!LLMUnitySetup.IsSubPath(assetPath, LLMUnitySetup.GetAssetPath())) + else { - string errorMessage = $"The {modelType} file {path} was loaded locally. If you want to include it in the build:"; - errorMessage += $"\n-Copy the {modelType} inside the StreamingAssets folder and use its relative path or"; - errorMessage += $"\n-Load the {modelType} with the LLMManager: `string filename=LLMManager.Load{modelType}(path); llm.Set{modelType}(filename)`"; + string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:"; + errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path"; + errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename"; LLMUnitySetup.LogWarning(errorMessage); } - return assetPath; + return path; + } + + public static string GetLLMManagerAssetRuntime(string path) + { + // empty + if (string.IsNullOrEmpty(path)) return path; + // full path + if (File.Exists(path)) return path; + // LLMManager + string managerPath = LLMManager.GetAssetPath(path); + if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath; + // StreamingAssets + string assetPath = LLMUnitySetup.GetAssetPath(path); + if (File.Exists(assetPath)) return assetPath; + // give up + return path; } /// @@ -175,11 +197,11 @@ public string GetModelLoraPath(string path, bool lora) /// path to model to use (.gguf format) public void SetModel(string path) { - model = GetModelLoraPath(path, false); + model = GetLLMManagerAsset(path); if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); - if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model)); + if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model)); SetTemplate(modelEntry.chatTemplate); if (contextSize == 0 && modelEntry.contextLength > 32768) { @@ -211,7 +233,7 @@ public void SetLora(string path) /// path to LORA model to use (.gguf format) public void AddLora(string path) { - string loraPath = GetModelLoraPath(path, true); + string loraPath = GetLLMManagerAsset(path); if (lora.Split(" ").Contains(loraPath)) return; if (lora != "") lora += " "; lora += loraPath; @@ -227,7 +249,7 @@ public void AddLora(string path) /// path to LORA model to remove (.gguf format) public void RemoveLora(string path) { - string loraPath = GetModelLoraPath(path, true); + string loraPath = GetLLMManagerAsset(path); List loras = new List(lora.Split(" ")); loras.Remove(loraPath); lora = ""; @@ -271,7 +293,7 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError("No model file provided!"); return null; } - string modelPath = GetModelLoraPathRuntime(model); + string modelPath = GetLLMManagerAssetRuntime(model); if (!File.Exists(modelPath)) { LLMUnitySetup.LogError($"File {modelPath} not found!"); @@ -281,7 +303,7 @@ protected virtual string GetLlamaccpArguments() foreach (string lora in lora.Trim().Split(" ")) { if (lora == "") continue; - string loraPath = GetModelLoraPathRuntime(lora); + string loraPath = GetLLMManagerAssetRuntime(lora); if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); @@ -541,7 +563,7 @@ public async Task SetLoraScale(string loraToScale, float scale) { AssertStarted(); List loras = new List(lora.Split(" ")); - string loraToScalePath = GetModelLoraPath(loraToScale, true); + string loraToScalePath = GetLLMManagerAssetRuntime(loraToScale); int index = loras.IndexOf(loraToScale); if (index == -1) index = loras.IndexOf(loraToScalePath); From f4112a1b522145dcea87184ae123fecb4b2fcd49 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 14:24:57 +0300 Subject: [PATCH 41/67] tests for explicit editor and runtime asset management --- Tests/Runtime/TestLLM.cs | 138 ++++++++++++++++++++++++++------------- 1 file changed, 94 insertions(+), 44 deletions(-) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index affca075..a55d1609 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -5,48 +5,98 @@ using System.Collections.Generic; using System; using System.Collections; -using UnityEngine.TestTools; using System.IO; -using NUnit.Framework.Internal; +using UnityEngine.TestTools; namespace LLMUnityTests { public class TestLLM { + protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true"; + protected string modelNameLLManager; + protected GameObject gameObject; protected LLM llm; protected LLMCharacter llmCharacter; - protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true"; - protected static string filename = Path.GetFileName(modelUrl).Split("?")[0]; Exception error = null; string prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."; + public TestLLM() { - LLMUnitySetup.SetDebugMode(LLMUnitySetup.DebugModeType.All); Task task = Init(); task.Wait(); } public virtual async Task Init() { + modelNameLLManager = await LLMManager.DownloadModel(modelUrl); gameObject = new GameObject(); gameObject.SetActive(false); - await SetLLM(); + SetLLM(); SetLLMCharacter(); gameObject.SetActive(true); } - public async Task EmptyTask() + [Test] + public void TestGetLLMManagerAssetRuntime() { - await Task.Delay(1); + string path = ""; + string managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, path); + + path = "/tmp/lala"; + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, path); + File.Delete(path); + + path = modelNameLLManager; + managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, LLMManager.GetAssetPath(path)); + + path = LLMUnitySetup.GetAssetPath("lala"); + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetRuntime(path); + Assert.AreEqual(managerPath, path); + File.Delete(path); } - public virtual async Task SetLLM() + [Test] + public void TestGetLLMManagerAssetEditor() + { + string path = ""; + string managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, path); + + path = modelNameLLManager; + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, modelNameLLManager); + + path = LLMManager.Get(modelNameLLManager).path; + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, modelNameLLManager); + + string filename = "lala"; + path = LLMUnitySetup.GetAssetPath(filename); + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetEditor(filename); + Assert.AreEqual(managerPath, filename); + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, filename); + File.Delete(path); + + path = "/tmp/lala"; + LLMUnitySetup.CreateEmptyFile(path); + managerPath = LLM.GetLLMManagerAssetEditor(path); + Assert.AreEqual(managerPath, path); + File.Delete(path); + } + + public virtual void SetLLM() { llm = gameObject.AddComponent(); - string filename = await LLMManager.DownloadModel(modelUrl); - llm.SetModel(filename); + llm.SetModel(modelNameLLManager); llm.parallelPrompts = 1; llm.SetTemplate("alpaca"); } @@ -64,13 +114,28 @@ public virtual void SetLLMCharacter() llmCharacter.numPredict = 20; } - public virtual async Task RunTests() + [UnityTest] + public IEnumerator RunTests() + { + Task task = RunTestsTask(); + while (!task.IsCompleted) yield return null; + if (error != null) + { + Debug.LogError(error.ToString()); + throw (error); + } + OnDestroy(); + } + + public async Task RunTestsTask() { error = null; try { - llm.Awake(); - llmCharacter.Awake(); + // await llm.WaitUntilReady(); + + // llm.Awake(); + // llmCharacter.Awake(); await llmCharacter.Tokenize("I", TestTokens); await llmCharacter.Warmup(); TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 1); @@ -97,19 +162,6 @@ public virtual async Task RunTests() } } - [UnityTest] - public IEnumerator RunTestsWait() - { - Task task = RunTests(); - while (!task.IsCompleted) yield return null; - if (error != null) - { - Debug.LogError(error.ToString()); - throw (error); - } - OnDestroy(); - } - public void TestInitParameters(int nkeep, int chats) { Assert.That(llmCharacter.nKeep == nkeep); @@ -149,57 +201,55 @@ public void TestEmbeddings(List embeddings) Assert.That(embeddings.Count == 1024); } - public virtual void OnDestroy() - { - LLMManager.Remove(filename); - } + public virtual void OnDestroy() {} } public class TestLLM_LLMManager_Load : TestLLM { - public override Task SetLLM() + public override void SetLLM() { llm = gameObject.AddComponent(); + string filename = Path.GetFileName(modelUrl).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); filename = LLMManager.LoadModel(sourcePath); llm.SetModel(filename); llm.parallelPrompts = 1; llm.SetTemplate("alpaca"); - return Task.CompletedTask; } } public class TestLLM_StreamingAssets_Load : TestLLM { - public override Task SetLLM() + string loadPath; + + public override void SetLLM() { llm = gameObject.AddComponent(); + string filename = Path.GetFileName(modelUrl).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - string targetPath = LLMUnitySetup.GetAssetPath(filename); - if (!File.Exists(targetPath)) File.Copy(sourcePath, targetPath); - llm.SetModel(filename); + loadPath = LLMUnitySetup.GetAssetPath(filename); + if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath); + llm.SetModel(loadPath); llm.parallelPrompts = 1; llm.SetTemplate("alpaca"); - return Task.CompletedTask; } public override void OnDestroy() { - string targetPath = LLMUnitySetup.GetAssetPath(filename); - if (!File.Exists(targetPath)) File.Delete(targetPath); + if (!File.Exists(loadPath)) File.Delete(loadPath); } } public class TestLLM_SetModel_Warning : TestLLM { - public override Task SetLLM() + public override void SetLLM() { llm = gameObject.AddComponent(); - string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); - llm.SetModel(sourcePath); + string filename = Path.GetFileName(modelUrl).Split("?")[0]; + string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); + llm.SetModel(loadPath); llm.parallelPrompts = 1; llm.SetTemplate("alpaca"); - return Task.CompletedTask; } } } From 59494fdf9b767911db1c8cca998215dd7218e5dc Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 16:35:05 +0300 Subject: [PATCH 42/67] add full path function --- Runtime/LLMUnitySetup.cs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 614f52e8..c86c1f95 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -299,18 +299,21 @@ public static async Task AndroidExtractAsset(string path, bool overwrite = false await AndroidExtractFile(Path.GetFileName(path), overwrite); } + public static string GetFullPath(string path) + { + return Path.GetFullPath(path).Replace('\\', '/'); + } + public static bool IsSubPath(string childPath, string parentPath) { - string fullParentPath = Path.GetFullPath(parentPath).Replace('\\', '/'); - string fullChildPath = Path.GetFullPath(childPath).Replace('\\', '/'); - return fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase); + return GetFullPath(childPath).StartsWith(GetFullPath(parentPath), StringComparison.OrdinalIgnoreCase); } public static string RelativePath(string fullPath, string basePath) { // Get the full paths and replace backslashes with forward slashes (or vice versa) - string fullParentPath = Path.GetFullPath(basePath).Replace('\\', '/').TrimEnd('/'); - string fullChildPath = Path.GetFullPath(fullPath).Replace('\\', '/'); + string fullParentPath = GetFullPath(basePath).TrimEnd('/'); + string fullChildPath = GetFullPath(fullPath); string relativePath = fullChildPath; if (fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase)) From d9cd26e16056e1baa4911c101df438752cda709b Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 16:35:35 +0300 Subject: [PATCH 43/67] use the LLMUnitySetup.GetFullPath function --- Runtime/LLMManager.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 3c7cbc24..81ac8644 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -24,7 +24,7 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur filename = Path.GetFileName(path); this.label = label == null ? filename : label; this.lora = lora; - this.path = Path.GetFullPath(path).Replace('\\', '/'); + this.path = LLMUnitySetup.GetFullPath(path); this.url = url; includeInBuild = true; chatTemplate = null; @@ -162,7 +162,7 @@ public static void SetTemplate(ModelEntry entry, string chatTemplate) public static ModelEntry Get(string path) { string filename = Path.GetFileName(path); - string fullPath = Path.GetFullPath(path).Replace('\\', '/'); + string fullPath = LLMUnitySetup.GetFullPath(path); foreach (ModelEntry entry in modelEntries) { if (entry.filename == filename || entry.path == fullPath) return entry; From 13f5eaee72a90aab414018fd48d0e6b9c69e98db Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 16:37:05 +0300 Subject: [PATCH 44/67] implement lora manager to allow easy setup and switch --- Runtime/LLM.cs | 127 +++++++++++++++++++-------------------- Runtime/LLMUtils.cs | 126 ++++++++++++++++++++++++++++++++++++++ Runtime/LLMUtils.cs.meta | 11 ++++ 3 files changed, 200 insertions(+), 64 deletions(-) create mode 100644 Runtime/LLMUtils.cs create mode 100644 Runtime/LLMUtils.cs.meta diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index dfe56331..96b2e10b 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading; using System.Threading.Tasks; using UnityEditor; @@ -11,20 +10,6 @@ namespace LLMUnity { - /// \cond HIDE - public class LLMException : Exception - { - public int ErrorCode { get; private set; } - - public LLMException(string message, int errorCode) : base(message) - { - ErrorCode = errorCode; - } - } - - public class DestroyException : Exception {} - /// \endcond - [DefaultExecutionOrder(-1)] /// @ingroup llm /// @@ -74,6 +59,8 @@ public class LLM : MonoBehaviour /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). /// Models with .gguf format are allowed. [ModelAdvanced] public string lora = ""; + /// the weights of the LORA models being used. + [ModelAdvanced] public string loraWeights = ""; /// enable use of flash attention [ModelExtras] public bool flashAttention = false; @@ -86,8 +73,8 @@ public class LLM : MonoBehaviour Thread llmThread = null; List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); - List loraWeights = new List(); private readonly object startLock = new object(); + LoraManager loraManager = new LoraManager(); /// \endcond @@ -136,7 +123,7 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr return !modelSetupFailed; } - public string GetLLMManagerAsset(string path) + public static string GetLLMManagerAsset(string path) { #if UNITY_EDITOR if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path); @@ -177,8 +164,6 @@ public static string GetLLMManagerAssetRuntime(string path) { // empty if (string.IsNullOrEmpty(path)) return path; - // full path - if (File.Exists(path)) return path; // LLMManager string managerPath = LLMManager.GetAssetPath(path); if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath; @@ -219,10 +204,11 @@ public void SetModel(string path) /// Models supported are in .gguf format. /// /// path to LORA model to use (.gguf format) - public void SetLora(string path) + public void SetLora(string path, float weight = 1) { - lora = ""; - AddLora(path); + AssertNotStarted(); + loraManager.Clear(); + AddLora(path, weight); } /// @@ -231,15 +217,11 @@ public void SetLora(string path) /// Models supported are in .gguf format. /// /// path to LORA model to use (.gguf format) - public void AddLora(string path) + public void AddLora(string path, float weight = 1) { - string loraPath = GetLLMManagerAsset(path); - if (lora.Split(" ").Contains(loraPath)) return; - if (lora != "") lora += " "; - lora += loraPath; -#if UNITY_EDITOR - if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); -#endif + AssertNotStarted(); + loraManager.Add(path, weight); + UpdateLoras(); } /// @@ -249,15 +231,38 @@ public void AddLora(string path) /// path to LORA model to remove (.gguf format) public void RemoveLora(string path) { - string loraPath = GetLLMManagerAsset(path); - List loras = new List(lora.Split(" ")); - loras.Remove(loraPath); - lora = ""; - for (int i = 0; i < loras.Count; i++) - { - if (i > 0) lora += " "; - lora += loras[i]; - } + AssertNotStarted(); + loraManager.Remove(path); + UpdateLoras(); + } + + /// + /// Allows to remove all LORA models from the LLM. + /// + public void RemoveLoras() + { + AssertNotStarted(); + loraManager.Clear(); + UpdateLoras(); + } + + /// + /// Allows to change the scale (weight) of a LORA model in the LLM. + /// + /// path of LORA model to change (.gguf format) + /// scale of LORA + public void SetLoraScale(string path, float scale) + { + loraManager.SetWeight(path, scale); + UpdateLoras(); + if (started) ApplyLoras(); + } + + public void UpdateLoras() + { + StringPair pair = loraManager.ToStrings(); + lora = pair.source; + loraWeights = pair.target; #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif @@ -409,8 +414,7 @@ private void StartService() llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); while (!llmlib.LLM_Started(LLMObject)) {} - loraWeights = new List(); - for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); + ApplyLoras(); started = true; } @@ -469,6 +473,16 @@ void AssertStarted() } } + void AssertNotStarted() + { + if (started) + { + string error = "This method can't be called when the LLM has started"; + LLMUnitySetup.LogError(error); + throw new Exception(error); + } + } + void CheckLLMStatus(bool log = true) { if (llmlib == null) { return; } @@ -559,46 +573,31 @@ public async Task Embeddings(string json) /// Sets the lora scale, only works after the LLM service has started /// /// switch result - public async Task SetLoraScale(string loraToScale, float scale) + public void ApplyLoras() { - AssertStarted(); - List loras = new List(lora.Split(" ")); - string loraToScalePath = GetLLMManagerAssetRuntime(loraToScale); - - int index = loras.IndexOf(loraToScale); - if (index == -1) index = loras.IndexOf(loraToScalePath); - if (index == -1) - { - LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM"); - return ""; - } - - loraWeights[index] = scale; LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList(); loraWeightRequest.loraWeights = new List(); - for (int i = 0; i < loraWeights.Count; i++) + float[] weights = loraManager.GetWeights(); + for (int i = 0; i < weights.Length; i++) { - loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = loraWeights[i] }); + loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = weights[i] }); } - ; string json = JsonUtility.ToJson(loraWeightRequest); int startIndex = json.IndexOf("["); int endIndex = json.LastIndexOf("]") + 1; json = json.Substring(startIndex, endIndex - startIndex); - LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => - { - llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper); - }; - return await LLMReply(callback, json); + IntPtr stringWrapper = llmlib.StringWrapper_Construct(); + llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper); + llmlib.StringWrapper_Delete(stringWrapper); } /// /// Gets a list of the lora adapters /// /// list of lara adapters - public async Task ListLora() + public async Task ListLoras() { AssertStarted(); LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) => diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs new file mode 100644 index 00000000..859a3a4b --- /dev/null +++ b/Runtime/LLMUtils.cs @@ -0,0 +1,126 @@ +/// @file +/// @brief File implementing LLM helper code. +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using UnityEditor; +using UnityEngine; + +namespace LLMUnity +{ + /// \cond HIDE + public class LLMException : Exception + { + public int ErrorCode { get; private set; } + + public LLMException(string message, int errorCode) : base(message) + { + ErrorCode = errorCode; + } + } + + public class DestroyException : Exception {} + + public class LoraAsset + { + public string assetPath; + public float weight; + + public LoraAsset(string path, float weight = 1) + { + assetPath = LLM.GetLLMManagerAsset(path); + this.weight = weight; + } + + public override bool Equals(object obj) + { + string RuntimePath(string path) {return LLMUnitySetup.GetFullPath(LLM.GetLLMManagerAssetRuntime(path));} + + if (obj == null || obj.GetType() != this.GetType()) return false; + LoraAsset other = (LoraAsset)obj; + return assetPath == other.assetPath || RuntimePath(assetPath) == RuntimePath(other.assetPath); + } + + public override int GetHashCode() + { + return (assetPath + "," + weight.ToString()).GetHashCode(); + } + } + + public class LoraManager + { + List loras = new List(); + + public void Clear() + { + loras.Clear(); + } + + public void Add(string path, float weight = 1) + { + LoraAsset lora = new LoraAsset(path, weight); + if (loras.Contains(lora)) return; + loras.Add(lora); + } + + public void Remove(string path) + { + loras.Remove(new LoraAsset(path)); + } + + public void SetWeight(string path, float weight) + { + LoraAsset lora = new LoraAsset(path); + int index = loras.IndexOf(lora); + if (index == -1) + { + LLMUnitySetup.LogError($"LoRA {path} not loaded with the LLM"); + return; + } + loras[index].weight = weight; + } + + public void FromStrings(string loraString, string loraWeightsString) + { + Clear(); + List loraStringArr = new List(loraString.Split(" ")); + List loraWeightsStringArr = new List(loraWeightsString.Split(" ")); + if (loraStringArr.Count != loraWeightsStringArr.Count) + { + LLMUnitySetup.LogError($"LoRAs number ({loraString}) doesn't match the number of weights ({loraWeightsString})"); + return; + } + for (int i = 0; i < loraStringArr.Count; i++) + { + Add(loraStringArr[i], float.Parse(loraWeightsStringArr[i])); + } + } + + public StringPair ToStrings() + { + string loraString = ""; + string loraWeightsString = ""; + for (int i = 0; i < loras.Count; i++) + { + if (i > 0) + { + loraString += " "; + loraWeightsString += " "; + } + loraString += loras[i].assetPath; + loraWeightsString += loras[i].weight; + } + return new StringPair(){source = loraString, target = loraWeightsString}; + } + + public float[] GetWeights() + { + float[] weights = new float[loras.Count]; + for (int i = 0; i < loras.Count; i++) weights[i] = loras[i].weight; + return weights; + } + } + /// \endcond +} diff --git a/Runtime/LLMUtils.cs.meta b/Runtime/LLMUtils.cs.meta new file mode 100644 index 00000000..93974a4f --- /dev/null +++ b/Runtime/LLMUtils.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 2ae6a2ce57e8af0fc876d0c380ed8a2f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: From 77ee95023caa5824b59704599efced37480d72f9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 23 Aug 2024 16:37:42 +0300 Subject: [PATCH 45/67] la --- Runtime/LLMUtils.cs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs index 859a3a4b..75b72721 100644 --- a/Runtime/LLMUtils.cs +++ b/Runtime/LLMUtils.cs @@ -2,11 +2,6 @@ /// @brief File implementing LLM helper code. using System; using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -using UnityEditor; -using UnityEngine; namespace LLMUnity { From 61da5f6fe354c385c3d68722394d6509aba88108 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Sat, 24 Aug 2024 09:11:15 +0300 Subject: [PATCH 46/67] fixes to lora functions, update when changed --- Editor/LLMEditor.cs | 4 ++-- Runtime/LLM.cs | 16 +++++++++++++++- Runtime/LLMManager.cs | 2 +- Runtime/LLMUtils.cs | 24 ++++++++++++++++-------- 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 9601579f..dcf5cfed 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -206,7 +206,7 @@ async Task createButtons() } else if (modelIndex > 1) { - if (modelLicenses[modelIndex] != null) Debug.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license."); + if (modelLicenses[modelIndex] != null) LLMUnitySetup.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license."); string filename = await LLMManager.DownloadModel(modelURLs[modelIndex], true, modelOptions[modelIndex]); SetModelIfNone(filename, false); UpdateModels(true); @@ -300,7 +300,7 @@ void OnEnable() } else { - isSelected = llmScript.lora.Split(" ").Contains(entry.filename); + isSelected = llmScript.loraManager.Contains(entry.filename); bool newSelected = EditorGUI.Toggle(selectRect, isSelected); if (newSelected && !isSelected) llmScript.AddLora(entry.filename); else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename); diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 96b2e10b..6b9123c4 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -74,7 +74,9 @@ public class LLM : MonoBehaviour List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); private readonly object startLock = new object(); - LoraManager loraManager = new LoraManager(); + public LoraManager loraManager = new LoraManager(); + string loraPre = ""; + string loraWeightsPre = ""; /// \endcond @@ -83,6 +85,16 @@ public LLM() LLMManager.Register(this); } + void OnValidate() + { + if (lora != loraPre || loraWeights != loraWeightsPre) + { + loraManager.FromStrings(lora, loraWeights); + loraPre = lora; + loraWeightsPre = loraWeights; + } + } + /// /// The Unity Awake function that starts the LLM server. /// The server can be started asynchronously if the asynchronousStartup option is set. @@ -263,6 +275,8 @@ public void UpdateLoras() StringPair pair = loraManager.ToStrings(); lora = pair.source; loraWeights = pair.target; + loraPre = lora; + loraWeightsPre = loraWeights; #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 81ac8644..34301fac 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -387,7 +387,7 @@ public static void Remove(ModelEntry entry) foreach (LLM llm in llms) { if (!entry.lora && llm.model == entry.filename) llm.model = ""; - else if (entry.lora && llm.lora == entry.filename) llm.lora = ""; + else if (entry.lora) llm.RemoveLora(entry.filename); } } diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs index 75b72721..0c5c8904 100644 --- a/Runtime/LLMUtils.cs +++ b/Runtime/LLMUtils.cs @@ -53,6 +53,12 @@ public void Clear() loras.Clear(); } + public bool Contains(string path) + { + LoraAsset lora = new LoraAsset(path); + return loras.Contains(lora); + } + public void Add(string path, float weight = 1) { LoraAsset lora = new LoraAsset(path, weight); @@ -79,17 +85,19 @@ public void SetWeight(string path, float weight) public void FromStrings(string loraString, string loraWeightsString) { - Clear(); - List loraStringArr = new List(loraString.Split(" ")); - List loraWeightsStringArr = new List(loraWeightsString.Split(" ")); - if (loraStringArr.Count != loraWeightsStringArr.Count) + try { - LLMUnitySetup.LogError($"LoRAs number ({loraString}) doesn't match the number of weights ({loraWeightsString})"); - return; + List loraStringArr = new List(loraString.Split(" ")); + List loraWeightsStringArr = new List(loraWeightsString.Split(" ")); + if (loraStringArr.Count != loraWeightsStringArr.Count) throw new Exception($"LoRAs number ({loraString}) doesn't match the number of weights ({loraWeightsString})"); + + List lorasNew = new List(); + for (int i = 0; i < loraStringArr.Count; i++) lorasNew.Add(new LoraAsset(loraStringArr[i], float.Parse(loraWeightsStringArr[i]))); + loras = lorasNew; } - for (int i = 0; i < loraStringArr.Count; i++) + catch (Exception e) { - Add(loraStringArr[i], float.Parse(loraWeightsStringArr[i])); + LLMUnitySetup.LogError($"Loras not set: {e.Message}"); } } From 96a3b78cc1e4caeb6a04a117ea321822701ec1bf Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Sat, 24 Aug 2024 09:19:35 +0300 Subject: [PATCH 47/67] shorten lora string assignment --- Runtime/LLM.cs | 10 +++------- Runtime/LLMUtils.cs | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 6b9123c4..f2c1a82b 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -90,8 +90,7 @@ void OnValidate() if (lora != loraPre || loraWeights != loraWeightsPre) { loraManager.FromStrings(lora, loraWeights); - loraPre = lora; - loraWeightsPre = loraWeights; + (loraPre, loraWeightsPre) = (lora, loraWeights); } } @@ -272,11 +271,8 @@ public void SetLoraScale(string path, float scale) public void UpdateLoras() { - StringPair pair = loraManager.ToStrings(); - lora = pair.source; - loraWeights = pair.target; - loraPre = lora; - loraWeightsPre = loraWeights; + (lora, loraWeights) = loraManager.ToStrings(); + (loraPre, loraWeightsPre) = (lora, loraWeights); #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs index 0c5c8904..5e99950d 100644 --- a/Runtime/LLMUtils.cs +++ b/Runtime/LLMUtils.cs @@ -101,7 +101,7 @@ public void FromStrings(string loraString, string loraWeightsString) } } - public StringPair ToStrings() + public (string, string) ToStrings() { string loraString = ""; string loraWeightsString = ""; @@ -115,7 +115,7 @@ public StringPair ToStrings() loraString += loras[i].assetPath; loraWeightsString += loras[i].weight; } - return new StringPair(){source = loraString, target = loraWeightsString}; + return (loraString, loraWeightsString); } public float[] GetWeights() From 4c318711906bf3127d1477f976fcfcbb43b1f7f8 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Sat, 24 Aug 2024 10:28:15 +0300 Subject: [PATCH 48/67] add lora assignment tests --- Tests/Runtime/TestLLM.cs | 51 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index a55d1609..f06cb911 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -10,6 +10,57 @@ namespace LLMUnityTests { + public class TestLLMLoras + { + [Test] + public void TestLLMLorasAssign() + { + GameObject gameObject = new GameObject(); + gameObject.SetActive(false); + LLM llm = gameObject.AddComponent(); + + string lora1 = "/tmp/lala"; + string lora2Rel = "test/lala"; + string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel); + LLMUnitySetup.CreateEmptyFile(lora1); + LLMUnitySetup.CreateEmptyFile(lora2); + + llm.AddLora(lora1); + llm.AddLora(lora2); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "1 1"); + + llm.RemoveLoras(); + Assert.AreEqual(llm.lora, ""); + Assert.AreEqual(llm.loraWeights, ""); + + llm.AddLora(lora1, 0.8f); + llm.AddLora(lora2Rel, 0.9f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.9"); + + llm.SetLoraScale(lora2Rel, 0.7f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.7"); + + llm.RemoveLora(lora2Rel); + Assert.AreEqual(llm.lora, lora1); + Assert.AreEqual(llm.loraWeights, "0.8"); + + llm.AddLora(lora2Rel); + llm.SetLoraScale(lora2Rel, 0.5f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.5"); + + llm.SetLoraScale(lora2, 0.1f); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0.8 0.1"); + + File.Delete(lora1); + File.Delete(lora2); + } + } + public class TestLLM { protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true"; From 920e4a17036d5f8b5456bd59f73d6ac383a690c3 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 11:49:47 +0300 Subject: [PATCH 49/67] add lora test --- Tests/Runtime/TestLLM.cs | 157 +++++++++++++++++++++++++++++---------- 1 file changed, 119 insertions(+), 38 deletions(-) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index f06cb911..ef7f28f6 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -23,6 +23,7 @@ public void TestLLMLorasAssign() string lora2Rel = "test/lala"; string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel); LLMUnitySetup.CreateEmptyFile(lora1); + Directory.CreateDirectory(Path.GetDirectoryName(lora2)); LLMUnitySetup.CreateEmptyFile(lora2); llm.AddLora(lora1); @@ -63,14 +64,19 @@ public void TestLLMLorasAssign() public class TestLLM { - protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true"; + protected static string modelUrl = "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true"; protected string modelNameLLManager; protected GameObject gameObject; protected LLM llm; protected LLMCharacter llmCharacter; Exception error = null; - string prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."; + protected string prompt; + protected string query; + protected string reply1; + protected string reply2; + protected int tokens1; + protected int tokens2; public TestLLM() @@ -81,14 +87,30 @@ public TestLLM() public virtual async Task Init() { - modelNameLLManager = await LLMManager.DownloadModel(modelUrl); + SetParameters(); + await DownloadModels(); gameObject = new GameObject(); gameObject.SetActive(false); - SetLLM(); - SetLLMCharacter(); + llm = CreateLLM(); + llmCharacter = CreateLLMCharacter(); gameObject.SetActive(true); } + public virtual void SetParameters() + { + prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."; + query = "How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming."; + reply1 = "To increase your meme production/output, you can try using more modern tools and techniques. For instance,"; + reply2 = "To increase your meme production/output, you can try the following strategies:\n\n1. Use a meme generator"; + tokens1 = 32; + tokens2 = 9; + } + + public virtual async Task DownloadModels() + { + modelNameLLManager = await LLMManager.DownloadModel(modelUrl); + } + [Test] public void TestGetLLMManagerAssetRuntime() { @@ -144,17 +166,17 @@ public void TestGetLLMManagerAssetEditor() File.Delete(path); } - public virtual void SetLLM() + public virtual LLM CreateLLM() { - llm = gameObject.AddComponent(); + LLM llm = gameObject.AddComponent(); llm.SetModel(modelNameLLManager); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); + return llm; } - public virtual void SetLLMCharacter() + public virtual LLMCharacter CreateLLMCharacter() { - llmCharacter = gameObject.AddComponent(); + LLMCharacter llmCharacter = gameObject.AddComponent(); llmCharacter.llm = llm; llmCharacter.playerName = "Instruction"; llmCharacter.AIName = "Response"; @@ -163,6 +185,7 @@ public virtual void SetLLMCharacter() llmCharacter.seed = 0; llmCharacter.stream = false; llmCharacter.numPredict = 20; + return llmCharacter; } [UnityTest] @@ -183,26 +206,22 @@ public async Task RunTestsTask() error = null; try { - // await llm.WaitUntilReady(); - - // llm.Awake(); - // llmCharacter.Awake(); await llmCharacter.Tokenize("I", TestTokens); await llmCharacter.Warmup(); - TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 1); + TestInitParameters(tokens1, 1); TestWarmup(); - await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat); + await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); TestPostChat(3); llmCharacter.SetPrompt(llmCharacter.prompt); llmCharacter.AIName = "False response"; - await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat2); + await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); TestPostChat(3); await llmCharacter.Chat("bye!"); TestPostChat(5); prompt = "How are you?"; llmCharacter.SetPrompt(prompt); await llmCharacter.Chat("hi"); - TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3); + TestInitParameters(tokens2, 3); List embeddings = await llmCharacter.Embeddings("hi how are you?"); TestEmbeddings(embeddings); llm.OnDestroy(); @@ -222,7 +241,7 @@ public void TestInitParameters(int nkeep, int chats) public void TestTokens(List tokens) { - Assert.AreEqual(tokens, new List {306}); + Assert.AreEqual(tokens, new List {40}); } public void TestWarmup() @@ -230,16 +249,9 @@ public void TestWarmup() Assert.That(llmCharacter.chat.Count == 1); } - public void TestChat(string reply) - { - string AIReply = "One way to increase your meme production/output is by creating a more complex and customized"; - Assert.That(reply.Trim() == AIReply); - } - - public void TestChat2(string reply) + public void TestChat(string reply, string replyGT) { - string AIReply = "One possible solution is to use a more advanced natural language processing library like NLTK or sp"; - Assert.That(reply.Trim() == AIReply); + Assert.That(reply.Trim() == replyGT); } public void TestPostChat(int num) @@ -249,7 +261,7 @@ public void TestPostChat(int num) public void TestEmbeddings(List embeddings) { - Assert.That(embeddings.Count == 1024); + Assert.That(embeddings.Count == 896); } public virtual void OnDestroy() {} @@ -257,15 +269,15 @@ public virtual void OnDestroy() {} public class TestLLM_LLMManager_Load : TestLLM { - public override void SetLLM() + public override LLM CreateLLM() { - llm = gameObject.AddComponent(); + LLM llm = gameObject.AddComponent(); string filename = Path.GetFileName(modelUrl).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); filename = LLMManager.LoadModel(sourcePath); llm.SetModel(filename); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); + return llm; } } @@ -273,16 +285,16 @@ public class TestLLM_StreamingAssets_Load : TestLLM { string loadPath; - public override void SetLLM() + public override LLM CreateLLM() { - llm = gameObject.AddComponent(); + LLM llm = gameObject.AddComponent(); string filename = Path.GetFileName(modelUrl).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); loadPath = LLMUnitySetup.GetAssetPath(filename); if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath); llm.SetModel(loadPath); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); + return llm; } public override void OnDestroy() @@ -293,14 +305,83 @@ public override void OnDestroy() public class TestLLM_SetModel_Warning : TestLLM { - public override void SetLLM() + public override LLM CreateLLM() { - llm = gameObject.AddComponent(); + LLM llm = gameObject.AddComponent(); string filename = Path.GetFileName(modelUrl).Split("?")[0]; string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); llm.SetModel(loadPath); llm.parallelPrompts = 1; - llm.SetTemplate("alpaca"); + return llm; + } + } + + public class TestLLM_NoLora : TestLLM + { + public override void SetParameters() + { + prompt = ""; + query = "кто ты?"; + reply1 = "Я - искусственный интеллект, который помогаю вам с информацией и задачами"; + reply2 = "I'm sorry, but I didn't understand your request. Could you please provide more information or clarify"; + tokens1 = 5; + tokens2 = 9; + } + } + + public class TestLLM_Lora : TestLLM + { + string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true"; + string loraNameLLManager; + + public override async Task DownloadModels() + { + await base.DownloadModels(); + loraNameLLManager = await LLMManager.DownloadLora(loraUrl); + } + + public override LLM CreateLLM() + { + LLM llm = base.CreateLLM(); + llm.AddLora(loraNameLLManager); + return llm; + } + + public override void SetParameters() + { + prompt = ""; + query = "кто ты?"; + reply1 = "Я - искусственный интеллект, созданный для помощи и общения с людьми"; + reply2 = "Идиот"; + tokens1 = 5; + tokens2 = 9; + } + + [Test] + public void TestModelPaths() + { + Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(modelUrl).Split("?")[0])); + Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0])); + } + } + + + public class TestLLM_Double : TestLLM + { + LLM llm1; + LLMCharacter lLMCharacter1; + + public override async Task Init() + { + SetParameters(); + await DownloadModels(); + gameObject = new GameObject(); + gameObject.SetActive(false); + llm = CreateLLM(); + llmCharacter = CreateLLMCharacter(); + llm1 = CreateLLM(); + lLMCharacter1 = CreateLLMCharacter(); + gameObject.SetActive(true); } } } From 38c5d8b3e336d47aee31f4d8653c53025d06272a Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Mon, 26 Aug 2024 08:51:17 +0000 Subject: [PATCH 50/67] update changelogs --- CHANGELOG.md | 1 + CHANGELOG.release.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3154f6c1..7c66ad7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Read context length and warn if it is very large (PR: #211) - Setup allowing to use extra features: flash attention and IQ quants (PR: #216) - Allow HTTP request retries for remote server (PR: #217) +- Allow to set lora weights at startup, add unit test (PR: #219) #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 439abc9a..c58310a9 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -4,6 +4,7 @@ - Read context length and warn if it is very large (PR: #211) - Setup allowing to use extra features: flash attention and IQ quants (PR: #216) - Allow HTTP request retries for remote server (PR: #217) +- Allow to set lora weights at startup, add unit test (PR: #219) ### 🐛 Fixes From 675f02cbcb573ae3d1340c44a4ff55ee9776761a Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:13:24 +0300 Subject: [PATCH 51/67] load manager strings at start --- Runtime/LLM.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index f2c1a82b..5fe98492 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -326,6 +326,7 @@ protected virtual string GetLlamaccpArguments() } loraArgument += $" --lora \"{loraPath}\""; } + loraManager.FromStrings(lora, loraWeights); int numThreadsToUse = numThreads; if (Application.platform == RuntimePlatform.Android && numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores(); From 61eaaaa12c4bc911f9454fa5fd06aeed8faf069f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:24:36 +0300 Subject: [PATCH 52/67] fix contains and fromStrings --- Runtime/LLMUtils.cs | 46 ++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs index 5e99950d..6be9c6ff 100644 --- a/Runtime/LLMUtils.cs +++ b/Runtime/LLMUtils.cs @@ -21,26 +21,19 @@ public class DestroyException : Exception {} public class LoraAsset { public string assetPath; + public string fullPath; public float weight; public LoraAsset(string path, float weight = 1) { assetPath = LLM.GetLLMManagerAsset(path); + fullPath = RuntimePath(path); this.weight = weight; } - public override bool Equals(object obj) + public static string RuntimePath(string path) { - string RuntimePath(string path) {return LLMUnitySetup.GetFullPath(LLM.GetLLMManagerAssetRuntime(path));} - - if (obj == null || obj.GetType() != this.GetType()) return false; - LoraAsset other = (LoraAsset)obj; - return assetPath == other.assetPath || RuntimePath(assetPath) == RuntimePath(other.assetPath); - } - - public override int GetHashCode() - { - return (assetPath + "," + weight.ToString()).GetHashCode(); + return LLMUnitySetup.GetFullPath(LLM.GetLLMManagerAssetRuntime(path)); } } @@ -53,28 +46,37 @@ public void Clear() loras.Clear(); } + public int IndexOf(string path) + { + string fullPath = LoraAsset.RuntimePath(path); + for (int i = 0; i < loras.Count; i++) + { + LoraAsset lora = loras[i]; + if (lora.assetPath == path || lora.fullPath == fullPath) return i; + } + return -1; + } + public bool Contains(string path) { - LoraAsset lora = new LoraAsset(path); - return loras.Contains(lora); + return IndexOf(path) != -1; } public void Add(string path, float weight = 1) { - LoraAsset lora = new LoraAsset(path, weight); - if (loras.Contains(lora)) return; - loras.Add(lora); + if (Contains(path)) return; + loras.Add(new LoraAsset(path, weight)); } public void Remove(string path) { - loras.Remove(new LoraAsset(path)); + int index = IndexOf(path); + if (index != -1) loras.RemoveAt(index); } public void SetWeight(string path, float weight) { - LoraAsset lora = new LoraAsset(path); - int index = loras.IndexOf(lora); + int index = IndexOf(path); if (index == -1) { LLMUnitySetup.LogError($"LoRA {path} not loaded with the LLM"); @@ -85,6 +87,12 @@ public void SetWeight(string path, float weight) public void FromStrings(string loraString, string loraWeightsString) { + if (string.IsNullOrEmpty(loraString) && string.IsNullOrEmpty(loraWeightsString)) + { + Clear(); + return; + } + try { List loraStringArr = new List(loraString.Split(" ")); From 38ceb0b56d64c83736d818034ec0ea7cfad8438f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:24:59 +0300 Subject: [PATCH 53/67] bump LlamaLib to v1.1.9 --- Runtime/LLMUnitySetup.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index c86c1f95..9638e59c 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -87,7 +87,7 @@ public class LLMUnitySetup /// LLM for Unity version public static string Version = "v2.1.2"; /// LlamaLib version - public static string LlamaLibVersion = "v1.1.8"; + public static string LlamaLibVersion = "v1.1.9"; /// LlamaLib release url public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}"; /// LlamaLib url From eb7382450c8ba0f8d67aeee99bcb916230624589 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:37:33 +0300 Subject: [PATCH 54/67] rename SetLoraScale to SetLoraWeight --- README.md | 5 +++-- Runtime/LLM.cs | 2 +- Tests/Runtime/TestLLM.cs | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d179f17a..0cc071ef 100644 --- a/README.md +++ b/README.md @@ -249,8 +249,9 @@ public class MyScript : MonoBehaviour // The model needs to be added to the LLM model manager (see LLM model management) by loading or downloading it. // Otherwise the model file can be copied directly inside the StreamingAssets folder. llm.SetModel("Phi-3-mini-4k-instruct-q4.gguf"); - // optional: you can also set a lora in a similar fashion - llm.SetLora("my-lora.gguf"); + // optional: you can also set loras in a similar fashion and set their weights (if needed) + llm.AddLora("my-lora.gguf"); + llm.SetLoraScale(0.5f); // optional: you can set the chat template of the model if it is not correctly identified // You can find a list of chat templates in the ChatTemplate.templates.Keys llm.SetTemplate("phi-3"); diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 5fe98492..af950cb7 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -262,7 +262,7 @@ public void RemoveLoras() /// /// path of LORA model to change (.gguf format) /// scale of LORA - public void SetLoraScale(string path, float scale) + public void SetLoraWeight(string path, float scale) { loraManager.SetWeight(path, scale); UpdateLoras(); diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index ef7f28f6..a0fe47a4 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -40,7 +40,7 @@ public void TestLLMLorasAssign() Assert.AreEqual(llm.lora, lora1 + " " + lora2); Assert.AreEqual(llm.loraWeights, "0.8 0.9"); - llm.SetLoraScale(lora2Rel, 0.7f); + llm.SetLoraWeight(lora2Rel, 0.7f); Assert.AreEqual(llm.lora, lora1 + " " + lora2); Assert.AreEqual(llm.loraWeights, "0.8 0.7"); @@ -49,11 +49,11 @@ public void TestLLMLorasAssign() Assert.AreEqual(llm.loraWeights, "0.8"); llm.AddLora(lora2Rel); - llm.SetLoraScale(lora2Rel, 0.5f); + llm.SetLoraWeight(lora2Rel, 0.5f); Assert.AreEqual(llm.lora, lora1 + " " + lora2); Assert.AreEqual(llm.loraWeights, "0.8 0.5"); - llm.SetLoraScale(lora2, 0.1f); + llm.SetLoraWeight(lora2, 0.1f); Assert.AreEqual(llm.lora, lora1 + " " + lora2); Assert.AreEqual(llm.loraWeights, "0.8 0.1"); From 70122d80ef7135b5188d0b6bcace71178eb79462 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:40:31 +0300 Subject: [PATCH 55/67] add lora weights to readme --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0cc071ef..52ddcf4e 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,7 @@ public class MyScript : MonoBehaviour llm.SetModel("Phi-3-mini-4k-instruct-q4.gguf"); // optional: you can also set loras in a similar fashion and set their weights (if needed) llm.AddLora("my-lora.gguf"); - llm.SetLoraScale(0.5f); + llm.SetLoraWeight(0.5f); // optional: you can set the chat template of the model if it is not correctly identified // You can find a list of chat templates in the ChatTemplate.templates.Keys llm.SetTemplate("phi-3"); @@ -384,7 +384,8 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Batch Size` batch size for prompt processing (default: 512) - `Model` the path of the model being used (relative to the Assets/StreamingAssets folder) - `Chat Template` the chat template being used for the LLM - - `Lora` the path of the LoRA being used (relative to the Assets/StreamingAssets folder) + - `Lora` the path of the LoRAs being used (relative to the Assets/StreamingAssets folder) + - `Lora Weights` the weights of the LoRAs being used - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled) From f8eef51b9b3babc0535aa8cb82c611e91954720f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:46:36 +0300 Subject: [PATCH 56/67] update changelogs --- CHANGELOG.md | 4 +++- CHANGELOG.release.md | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c66ad7c..ee83d0df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,11 @@ ## v2.1.2 #### 🚀 Features +- Update to latest llama.cpp (b3617) (PR: #210) +- Integrate Llama 3.1 and Gemma2 models in model dropdown - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) -- Setup allowing to use extra features: flash attention and IQ quants (PR: #216) +- Setup to allow to use extra features: flash attention and IQ quants (PR: #216) - Allow HTTP request retries for remote server (PR: #217) - Allow to set lora weights at startup, add unit test (PR: #219) diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index c58310a9..ced75b15 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,8 +1,10 @@ ### 🚀 Features +- Update to latest llama.cpp (b3617) (PR: #210) +- Integrate Llama 3.1 and Gemma2 models in model dropdown - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) -- Setup allowing to use extra features: flash attention and IQ quants (PR: #216) +- Setup to allow to use extra features: flash attention and IQ quants (PR: #216) - Allow HTTP request retries for remote server (PR: #217) - Allow to set lora weights at startup, add unit test (PR: #219) From f68a16696955b3fd00bce50a508200180123e0d7 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 13:47:58 +0300 Subject: [PATCH 57/67] add AI Speak game --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 52ddcf4e..7f0b43d1 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ LLM for Unity is built on top of the awesome [llama.cpp](https://github.com/gger - [Murder in Aisle 4](https://roadedlich.itch.io/murder-in-aisle-4) - [Finicky Food Delivery AI](https://helixngc7293.itch.io/finicky-food-delivery-ai) - [AI Emotional Girlfriend](https://whynames.itch.io/aiemotionalgirlfriend) +- [AI Speak](https://jdscogin.wixsite.com/aispeak) ## Setup _Method 1: Install using the asset store_ From 651cd495a495b76e6595e02b44a2cf4ec574b883 Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Mon, 26 Aug 2024 10:49:04 +0000 Subject: [PATCH 58/67] update VERSION --- .github/doxygen/Doxyfile | 2 +- Runtime/LLMUnitySetup.cs | 2 +- VERSION | 2 +- package.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/doxygen/Doxyfile b/.github/doxygen/Doxyfile index 1bbe3027..05989312 100644 --- a/.github/doxygen/Doxyfile +++ b/.github/doxygen/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "LLM for Unity" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v2.1.2 +PROJECT_NUMBER = v2.2.0 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 9638e59c..748eb4e8 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -85,7 +85,7 @@ public class LLMUnitySetup { // DON'T CHANGE! the version is autocompleted with a GitHub action /// LLM for Unity version - public static string Version = "v2.1.2"; + public static string Version = "v2.2.0"; /// LlamaLib version public static string LlamaLibVersion = "v1.1.9"; /// LlamaLib release url diff --git a/VERSION b/VERSION index 59696826..a4b6ac3d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.1.2 +v2.2.0 diff --git a/package.json b/package.json index a22c93a3..40182afe 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ai.undream.llm", - "version": "2.1.2", + "version": "2.2.0", "displayName": "LLM for Unity", "description": "LLM for Unity allows to run and distribute Large Language Models (LLMs) in the Unity engine.", "unity": "2022.3", From b14fdd526b5455c6f5bb47ac68e1eb3fc2c37e07 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 19:30:46 +0300 Subject: [PATCH 59/67] add Embeddings to Readme --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 7f0b43d1..c4792950 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,15 @@ You can use a remote server to carry out the processing and implement characters - Create a second project with the game characters using the `LLMCharacter` script as described above. Enable the `Remote` option and configure the host with the IP address (starting with "http://") and port of the server. + +
+Compute embeddings using a LLM + +The `Embeddings` function can be used to obtain the emdeddings of a phrase: +``` c# + List embeddings = await llmCharacter.Embeddings("hi, how are you?"); +``` +
A detailed documentation on function level can be found here: From b6c177da3c9947e08a86f58dff350ec3531b4e9c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 26 Aug 2024 20:30:15 +0300 Subject: [PATCH 60/67] bump LlamaLib to v1.1.10 --- Runtime/LLMUnitySetup.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 748eb4e8..064a031d 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -87,7 +87,7 @@ public class LLMUnitySetup /// LLM for Unity version public static string Version = "v2.2.0"; /// LlamaLib version - public static string LlamaLibVersion = "v1.1.9"; + public static string LlamaLibVersion = "v1.1.10"; /// LlamaLib release url public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}"; /// LlamaLib url From b3f0d50d652982def9adc2145a64177c93bb6af3 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 27 Aug 2024 12:29:42 +0300 Subject: [PATCH 61/67] add test for lora weight change --- Tests/Runtime/TestLLM.cs | 96 ++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index a0fe47a4..768eacf9 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -10,10 +10,10 @@ namespace LLMUnityTests { - public class TestLLMLoras + public class TestLLMLoraAssignment { [Test] - public void TestLLMLorasAssign() + public void TestLoras() { GameObject gameObject = new GameObject(); gameObject.SetActive(false); @@ -70,7 +70,7 @@ public class TestLLM protected GameObject gameObject; protected LLM llm; protected LLMCharacter llmCharacter; - Exception error = null; + protected Exception error = null; protected string prompt; protected string query; protected string reply1; @@ -206,24 +206,7 @@ public async Task RunTestsTask() error = null; try { - await llmCharacter.Tokenize("I", TestTokens); - await llmCharacter.Warmup(); - TestInitParameters(tokens1, 1); - TestWarmup(); - await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); - TestPostChat(3); - llmCharacter.SetPrompt(llmCharacter.prompt); - llmCharacter.AIName = "False response"; - await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); - TestPostChat(3); - await llmCharacter.Chat("bye!"); - TestPostChat(5); - prompt = "How are you?"; - llmCharacter.SetPrompt(prompt); - await llmCharacter.Chat("hi"); - TestInitParameters(tokens2, 3); - List embeddings = await llmCharacter.Embeddings("hi how are you?"); - TestEmbeddings(embeddings); + await Tests(); llm.OnDestroy(); } catch (Exception e) @@ -232,6 +215,28 @@ public async Task RunTestsTask() } } + public virtual async Task Tests() + { + await llmCharacter.Tokenize("I", TestTokens); + await llmCharacter.Warmup(); + TestInitParameters(tokens1, 1); + TestWarmup(); + await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); + TestPostChat(3); + llmCharacter.SetPrompt(llmCharacter.prompt); + llmCharacter.AIName = "False response"; + await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); + TestPostChat(3); + await llmCharacter.Chat("bye!"); + TestPostChat(5); + prompt = "How are you?"; + llmCharacter.SetPrompt(prompt); + await llmCharacter.Chat("hi"); + TestInitParameters(tokens2, 3); + List embeddings = await llmCharacter.Embeddings("hi how are you?"); + TestEmbeddings(embeddings); + } + public void TestInitParameters(int nkeep, int chats) { Assert.That(llmCharacter.nKeep == nkeep); @@ -316,23 +321,11 @@ public override LLM CreateLLM() } } - public class TestLLM_NoLora : TestLLM - { - public override void SetParameters() - { - prompt = ""; - query = "кто ты?"; - reply1 = "Я - искусственный интеллект, который помогаю вам с информацией и задачами"; - reply2 = "I'm sorry, but I didn't understand your request. Could you please provide more information or clarify"; - tokens1 = 5; - tokens2 = 9; - } - } - public class TestLLM_Lora : TestLLM { - string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true"; - string loraNameLLManager; + protected string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true"; + protected string loraNameLLManager; + protected float loraWeight; public override async Task DownloadModels() { @@ -343,7 +336,7 @@ public override async Task DownloadModels() public override LLM CreateLLM() { LLM llm = base.CreateLLM(); - llm.AddLora(loraNameLLManager); + llm.AddLora(loraNameLLManager, loraWeight); return llm; } @@ -351,21 +344,46 @@ public override void SetParameters() { prompt = ""; query = "кто ты?"; - reply1 = "Я - искусственный интеллект, созданный для помощи и общения с людьми"; + reply1 = "Я - искусственный интеллект, создан для общения и понимания людей."; reply2 = "Идиот"; tokens1 = 5; tokens2 = 9; + loraWeight = 0.9f; + } + + public override async Task Tests() + { + await base.Tests(); + TestModelPaths(); + await TestLoraWeight(); } - [Test] public void TestModelPaths() { Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(modelUrl).Split("?")[0])); Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0])); } + + public async Task TestLoraWeight() + { + string json = await llm.ListLoras(); + LoraWeightResultList loraRequest = JsonUtility.FromJson("{\"loraWeights\": " + json + "}"); + Assert.AreEqual(loraRequest.loraWeights[0].scale, loraWeight); + } } + public class TestLLM_Lora_ChangeWeight : TestLLM_Lora + { + public override async Task Tests() + { + await base.Tests(); + loraWeight = 0.6f; + llm.SetLoraWeight(loraNameLLManager, loraWeight); + await TestLoraWeight(); + } + } + public class TestLLM_Double : TestLLM { LLM llm1; From 53e9ae5ffa36de5ef4d22469ff4b96b3b323cb1a Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 27 Aug 2024 12:30:00 +0300 Subject: [PATCH 62/67] add test for remote setup --- Tests/Runtime/TestLLM.cs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 768eacf9..6dcb75fa 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -321,6 +321,23 @@ public override LLM CreateLLM() } } + public class TestLLM_Remote : TestLLM + { + public override LLM CreateLLM() + { + LLM llm = base.CreateLLM(); + llm.remote = true; + return llm; + } + + public override LLMCharacter CreateLLMCharacter() + { + LLMCharacter llmCharacter = base.CreateLLMCharacter(); + llmCharacter.remote = true; + return llmCharacter; + } + } + public class TestLLM_Lora : TestLLM { protected string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true"; From c7b8da63cad701ba93d667eb1691af82c38c0da3 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 27 Aug 2024 12:34:30 +0300 Subject: [PATCH 63/67] return lost of loras instead of json --- Runtime/LLM.cs | 7 +++++-- Runtime/LLMInterface.cs | 14 ++++++++++++++ Tests/Runtime/TestLLM.cs | 5 ++--- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index af950cb7..b511db6a 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -608,14 +608,17 @@ public void ApplyLoras() /// Gets a list of the lora adapters ///
/// list of lara adapters - public async Task ListLoras() + public async Task> ListLoras() { AssertStarted(); LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) => { llmlib.LLM_LoraList(LLMObject, strWrapper); }; - return await LLMNoInputReply(callback); + string json = await LLMNoInputReply(callback); + if (String.IsNullOrEmpty(json)) return null; + LoraWeightResultList loraRequest = JsonUtility.FromJson("{\"loraWeights\": " + json + "}"); + return loraRequest.loraWeights; } /// diff --git a/Runtime/LLMInterface.cs b/Runtime/LLMInterface.cs index 36a11c37..a9090632 100644 --- a/Runtime/LLMInterface.cs +++ b/Runtime/LLMInterface.cs @@ -112,6 +112,20 @@ public struct LoraWeightRequestList public List loraWeights; } + [Serializable] + public struct LoraWeightResult + { + public int id; + public string path; + public float scale; + } + + [Serializable] + public struct LoraWeightResultList + { + public List loraWeights; + } + [Serializable] public struct TemplateResult { diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 6dcb75fa..16bf37c5 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -383,9 +383,8 @@ public void TestModelPaths() public async Task TestLoraWeight() { - string json = await llm.ListLoras(); - LoraWeightResultList loraRequest = JsonUtility.FromJson("{\"loraWeights\": " + json + "}"); - Assert.AreEqual(loraRequest.loraWeights[0].scale, loraWeight); + List loras = await llm.ListLoras(); + Assert.AreEqual(loras[0].scale, loraWeight); } } From c93a61877b0293f332b0dc6f9cc1c284f0888f11 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 27 Aug 2024 12:44:14 +0300 Subject: [PATCH 64/67] add function to change multiple lora weights --- Runtime/LLM.cs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index b511db6a..95c51cf2 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -258,13 +258,24 @@ public void RemoveLoras() } /// - /// Allows to change the scale (weight) of a LORA model in the LLM. + /// Allows to change the weight (scale) of a LORA model in the LLM. /// /// path of LORA model to change (.gguf format) - /// scale of LORA - public void SetLoraWeight(string path, float scale) + /// weight of LORA + public void SetLoraWeight(string path, float weight) { - loraManager.SetWeight(path, scale); + loraManager.SetWeight(path, weight); + UpdateLoras(); + if (started) ApplyLoras(); + } + + /// + /// Allows to change the weights (scale) of the LORA models in the LLM. + /// + /// Dictionary (string, float) mapping the path of LORA models with weights to change + public void SetLoraWeights(Dictionary loraToWeight) + { + foreach (KeyValuePair entry in loraToWeight) loraManager.SetWeight(entry.Key, entry.Value); UpdateLoras(); if (started) ApplyLoras(); } From a4165aa1829984663c0e10a5ca7e02bbaf2e2a57 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 27 Aug 2024 12:44:25 +0300 Subject: [PATCH 65/67] add test for function to change multiple lora weights --- Tests/Runtime/TestLLM.cs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 16bf37c5..edf0907b 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -57,6 +57,13 @@ public void TestLoras() Assert.AreEqual(llm.lora, lora1 + " " + lora2); Assert.AreEqual(llm.loraWeights, "0.8 0.1"); + Dictionary loraToWeight = new Dictionary(); + loraToWeight[lora1] = 0; + loraToWeight[lora2] = 0.2f; + llm.SetLoraWeights(loraToWeight); + Assert.AreEqual(llm.lora, lora1 + " " + lora2); + Assert.AreEqual(llm.loraWeights, "0 0.2"); + File.Delete(lora1); File.Delete(lora2); } From cca81b869a12ee3b306d44591d2108fab73dc5d5 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 27 Aug 2024 13:21:30 +0300 Subject: [PATCH 66/67] allow relative StreamingAssets paths for models --- Runtime/LLMManager.cs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 34301fac..f08edc87 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -19,9 +19,20 @@ public class ModelEntry public bool includeInBuild; public int contextLength; + public static string GetFilenameOrRelativeAssetPath(string path) + { + string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed + string basePath = LLMUnitySetup.GetAssetPath(); + if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath)) + { + return LLMUnitySetup.RelativePath(assetPath, basePath); + } + return path; + } + public ModelEntry(string path, bool lora = false, string label = null, string url = null) { - filename = Path.GetFileName(path); + filename = GetFilenameOrRelativeAssetPath(path); this.label = label == null ? filename : label; this.lora = lora; this.path = LLMUnitySetup.GetFullPath(path); From fdd5e85104f6304a8c6ea36cf100e35676b433fb Mon Sep 17 00:00:00 2001 From: amakropoulos Date: Tue, 27 Aug 2024 10:22:48 +0000 Subject: [PATCH 67/67] update changelogs --- CHANGELOG.md | 7 +++---- CHANGELOG.release.md | 5 ++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee83d0df..5562a81e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,12 @@ -## v2.1.2 +## v2.2.0 #### 🚀 Features -- Update to latest llama.cpp (b3617) (PR: #210) -- Integrate Llama 3.1 and Gemma2 models in model dropdown - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) -- Setup to allow to use extra features: flash attention and IQ quants (PR: #216) +- Setup allowing to use extra features: flash attention and IQ quants (PR: #216) - Allow HTTP request retries for remote server (PR: #217) - Allow to set lora weights at startup, add unit test (PR: #219) +- allow relative StreamingAssets paths for models (PR: #221) #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index ced75b15..d405932e 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,12 +1,11 @@ ### 🚀 Features -- Update to latest llama.cpp (b3617) (PR: #210) -- Integrate Llama 3.1 and Gemma2 models in model dropdown - Implement embedding and lora adapter functionality (PR: #210) - Read context length and warn if it is very large (PR: #211) -- Setup to allow to use extra features: flash attention and IQ quants (PR: #216) +- Setup allowing to use extra features: flash attention and IQ quants (PR: #216) - Allow HTTP request retries for remote server (PR: #217) - Allow to set lora weights at startup, add unit test (PR: #219) +- allow relative StreamingAssets paths for models (PR: #221) ### 🐛 Fixes