From ff3e23c4a13ed676f51e93c424b7485527109007 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Sun, 18 Aug 2024 18:05:16 +0300 Subject: [PATCH 1/8] bump LlamaLib to v1.1.8 --- Runtime/LLMUnitySetup.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 833f0ce7..4ee87a66 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -86,7 +86,7 @@ public class LLMUnitySetup /// LLM for Unity version public static string Version = "v2.1.2"; /// LlamaLib version - public static string LlamaLibVersion = "v1.1.6"; + public static string LlamaLibVersion = "v1.1.8"; /// LlamaLib url public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; /// LlamaLib path From 37b35e2b22c9bc91af24a700e649359ef039f804 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:54:01 +0300 Subject: [PATCH 2/8] add embedding functionality --- Runtime/LLMCharacter.cs | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 9f37721a..c76239aa 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -430,16 +430,22 @@ protected List TokenizeContent(TokenizeResult result) return result.tokens; } - protected string SlotContent(SlotResult result) + protected string DetokenizeContent(TokenizeRequest result) { - // get the tokens from a tokenize result received from the endpoint - return result.filename; + // get content from a chat result received from the endpoint + return result.content; } - protected string DetokenizeContent(TokenizeRequest result) + protected List EmbeddingsContent(EmbeddingsResult result) { // get content from a chat result received from the endpoint - return result.content; + return result.embedding; + } + + protected string SlotContent(SlotResult result) + { + // get the tokens from a tokenize result received from the endpoint + return result.filename; } /// @@ -572,6 +578,21 @@ public async Task Detokenize(List tokens, Callback callback return await PostRequest(json, "detokenize", DetokenizeContent, callback); } + /// + /// Computes the embeddings of the provided input. + /// + /// input to compute the embeddings for + /// callback function called with the result string + /// the computed embeddings + public async Task> Embeddings(string query, Callback> callback = null) + { + // handle the tokenization of a message by the user + TokenizeRequest tokenizeRequest = new TokenizeRequest(); + tokenizeRequest.content = query; + string json = JsonUtility.ToJson(tokenizeRequest); + return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); + } + private async Task Slot(string filepath, string action) { SlotRequest slotRequest = new SlotRequest(); @@ -682,6 +703,9 @@ protected async Task PostRequestLocal(string json, string endpoin case "detokenize": callResult = await llm.Detokenize(json); break; + case "embeddings": + callResult = await llm.Embeddings(json); + break; case "slots": callResult = await llm.Slot(json); break; From 9c35b06281c76be53b1b5cdd2373072dc146edd6 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:54:30 +0300 Subject: [PATCH 3/8] import embeddings and lora adapters --- Runtime/LLMLib.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index c8d49ccd..fbaa050d 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -298,6 +298,9 @@ public LLMLib(string arch) LLM_SetTemplate = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_SetTemplate"); LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize"); LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize"); + LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings"); + LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight"); + LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List"); LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion"); LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot"); LLM_Cancel = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Cancel"); @@ -452,6 +455,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate); public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); + public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper); public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper); public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot); @@ -474,6 +480,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper) public LLM_TokenizeDelegate LLM_Tokenize; public LLM_DetokenizeDelegate LLM_Detokenize; public LLM_CompletionDelegate LLM_Completion; + public LLM_EmbeddingsDelegate LLM_Embeddings; + public LLM_LoraWeightDelegate LLM_Lora_Weight; + public LLM_LoraListDelegate LLM_LoraList; public LLM_SlotDelegate LLM_Slot; public LLM_CancelDelegate LLM_Cancel; public LLM_StatusDelegate LLM_Status; From 47d0eb997153d55d98e12d03d9f359ef7dd4cdd8 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:54:53 +0300 Subject: [PATCH 4/8] add structs for embeddings and lora adapters --- Runtime/LLMInterface.cs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/Runtime/LLMInterface.cs b/Runtime/LLMInterface.cs index f5086595..36a11c37 100644 --- a/Runtime/LLMInterface.cs +++ b/Runtime/LLMInterface.cs @@ -93,6 +93,25 @@ public struct TokenizeResult public List tokens; } + [Serializable] + public struct EmbeddingsResult + { + public List embedding; + } + + [Serializable] + public struct LoraWeightRequest + { + public int id; + public float scale; + } + + [Serializable] + public struct LoraWeightRequestList + { + public List loraWeights; + } + [Serializable] public struct TemplateResult { From f5d910a0a674755a15dc11c3aed712f922378de1 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:55:10 +0300 Subject: [PATCH 5/8] allow multiple loras --- Editor/LLMEditor.cs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index bebaf578..ed50046d 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -142,7 +142,7 @@ void SetModelIfNone(string filename, bool lora) LLM llmScript = (LLM)target; int num = LLMManager.Num(lora); if (!lora && llmScript.model == "" && num == 1) llmScript.SetModel(filename); - if (lora && llmScript.lora == "" && num == 1) llmScript.SetLora(filename); + if (lora) llmScript.AddLora(filename); } async Task createCustomURLField() @@ -237,7 +237,7 @@ async Task createButtons() { EditorApplication.delayCall += () => { - string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" }); + string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf lora file", "", new string[] { "Model Files", "gguf" }); if (!string.IsNullOrEmpty(path)) { string filename = LLMManager.LoadLora(path, true); @@ -299,10 +299,10 @@ void OnEnable() } else { - isSelected = llmScript.lora == entry.filename; - bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); - if (newSelected && !isSelected) llmScript.SetLora(entry.filename); - else if (!newSelected && isSelected) llmScript.SetLora(""); + isSelected = llmScript.lora.Split(" ").Contains(entry.filename); + bool newSelected = EditorGUI.Toggle(selectRect, isSelected); + if (newSelected && !isSelected) llmScript.AddLora(entry.filename); + else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename); } DrawCopyableLabel(nameRect, entry.label, entry.filename); @@ -347,6 +347,11 @@ void OnEnable() if (GUI.Button(actionRect, trashIcon)) { + if (isSelected) + { + if (!entry.lora) llmScript.SetModel(""); + else llmScript.RemoveLora(entry.filename); + } LLMManager.Remove(entry); UpdateModels(true); } From 3e95f4b3b3db17700dd33ea6de81bf477a226995 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 15:56:05 +0300 Subject: [PATCH 6/8] implement callback functionaloty for embeddings and lora adapters --- Runtime/LLM.cs | 138 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 129 insertions(+), 9 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 315c5cd9..fad6b5d6 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; using UnityEditor; @@ -68,7 +69,7 @@ public class LLM : MonoBehaviour [ModelAdvanced] public string model = ""; /// Chat template used for the model [ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate; - /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder). + /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder). /// Models with .bin format are allowed. [ModelAdvanced] public string lora = ""; @@ -81,6 +82,7 @@ public class LLM : MonoBehaviour Thread llmThread = null; List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); + List loraWeights = new List(); /// \endcond @@ -135,7 +137,7 @@ public string GetModelLoraPath(string path) return path; } - public string SetModelLoraPath(string path, bool lora) + public string GetModelLoraPath(string path, bool lora) { if (string.IsNullOrEmpty(path)) return path; ModelEntry modelEntry = LLMManager.Get(path); @@ -167,7 +169,7 @@ public string SetModelLoraPath(string path, bool lora) /// path to model to use (.gguf format) public void SetModel(string path) { - model = SetModelLoraPath(path, false); + model = GetModelLoraPath(path, false); if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); @@ -187,7 +189,43 @@ public void SetModel(string path) /// path to LORA model to use (.bin format) public void SetLora(string path) { - lora = SetModelLoraPath(path, true); + lora = ""; + AddLora(path); + } + + /// + /// Allows to add a LORA model to use in the LLM. + /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. + /// Models supported are in .bin format. + /// + /// path to LORA model to use (.bin format) + public void AddLora(string path) + { + string loraPath = GetModelLoraPath(path, true); + if (lora.Split(" ").Contains(loraPath)) return; + if (lora != "") lora += " "; + lora += loraPath; +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); +#endif + } + + /// + /// Allows to remove a LORA model from the LLM. + /// Models supported are in .bin format. + /// + /// path to LORA model to remove (.bin format) + public void RemoveLora(string path) + { + string loraPath = GetModelLoraPath(path, true); + List loras = new List(lora.Split(" ")); + loras.Remove(loraPath); + lora = ""; + for (int i = 0; i < loras.Count; i++) + { + if (i > 0) lora += " "; + lora += loras[i]; + } #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif @@ -229,15 +267,16 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError($"File {modelPath} not found!"); return null; } - string loraPath = ""; - if (lora != "") + string loraArgument = ""; + foreach (string lora in lora.Split(" ")) { - loraPath = GetModelLoraPath(lora); + string loraPath = GetModelLoraPath(lora); if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); return null; } + loraArgument += $" --lora \"{loraPath}\""; } int numThreadsToUse = numThreads; @@ -247,7 +286,7 @@ protected virtual string GetLlamaccpArguments() string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}"; if (remote) arguments += $" --port {port} --host 0.0.0.0"; if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; - if (loraPath != "") arguments += $" --lora \"{loraPath}\""; + arguments += loraArgument; arguments += $" -ngl {numGPULayers}"; return arguments; } @@ -323,6 +362,8 @@ private void StartService() llmThread = new Thread(() => llmlib.LLM_Start(LLMObject)); llmThread.Start(); while (!llmlib.LLM_Started(LLMObject)) {} + loraWeights = new List(); + for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f); started = true; } @@ -345,7 +386,7 @@ protected int GetNumClients() /// \cond HIDE public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper); - public delegate void LLMSimpleCallback(IntPtr LLMObject, string json_data); + public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper); public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper); /// \endcond @@ -400,6 +441,17 @@ void CheckLLMStatus(bool log = true) } } + async Task LLMNoInputReply(LLMNoInputReplyCallback callback) + { + AssertStarted(); + IntPtr stringWrapper = llmlib.StringWrapper_Construct(); + await Task.Run(() => callback(LLMObject, stringWrapper)); + string result = llmlib?.GetStringWrapperResult(stringWrapper); + llmlib?.StringWrapper_Delete(stringWrapper); + CheckLLMStatus(); + return result; + } + async Task LLMReply(LLMReplyCallback callback, string json) { AssertStarted(); @@ -441,6 +493,74 @@ public async Task Detokenize(string json) return await LLMReply(callback, json); } + /// + /// Computes the embeddings of the provided query. + /// + /// json request containing the query + /// embeddings result + public async Task Embeddings(string json) + { + AssertStarted(); + LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => + { + llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper); + }; + return await LLMReply(callback, json); + } + + /// + /// Sets the lora scale, only works after the LLM service has started + /// + /// switch result + public async Task SetLoraScale(string loraToScale, float scale) + { + AssertStarted(); + List loras = new List(lora.Split(" ")); + string loraToScalePath = GetModelLoraPath(loraToScale, true); + + int index = loras.IndexOf(loraToScale); + if (index == -1) index = loras.IndexOf(loraToScalePath); + if (index == -1) + { + LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM"); + return ""; + } + + loraWeights[index] = scale; + LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList(); + loraWeightRequest.loraWeights = new List(); + for (int i = 0; i < loraWeights.Count; i++) + { + loraWeightRequest.loraWeights.Add(new LoraWeightRequest() {id = i, scale = loraWeights[i]}); + } + ; + + string json = JsonUtility.ToJson(loraWeightRequest); + int startIndex = json.IndexOf("["); + int endIndex = json.LastIndexOf("]") + 1; + json = json.Substring(startIndex, endIndex - startIndex); + + LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) => + { + llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper); + }; + return await LLMReply(callback, json); + } + + /// + /// Gets a list of the lora adapters + /// + /// list of lara adapters + public async Task ListLora() + { + AssertStarted(); + LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) => + { + llmlib.LLM_LoraList(LLMObject, strWrapper); + }; + return await LLMNoInputReply(callback); + } + /// /// Allows to save / restore the state of a slot /// From bf3ea60d1df7bcfbc292b35790f231b6e9e65e43 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 16:48:13 +0300 Subject: [PATCH 7/8] fix for lora splitting --- Runtime/LLM.cs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index fad6b5d6..a691cf29 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -130,7 +130,7 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr return !modelSetupFailed; } - public string GetModelLoraPath(string path) + public string GetModelLoraPathRuntime(string path) { string assetPath = LLMManager.GetAssetPath(path); if (!string.IsNullOrEmpty(assetPath)) return assetPath; @@ -173,7 +173,7 @@ public void SetModel(string path) if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); - string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPath(model)); + string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model)); SetTemplate(template); } #if UNITY_EDITOR @@ -261,16 +261,17 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError("No model file provided!"); return null; } - string modelPath = GetModelLoraPath(model); + string modelPath = GetModelLoraPathRuntime(model); if (!File.Exists(modelPath)) { LLMUnitySetup.LogError($"File {modelPath} not found!"); return null; } string loraArgument = ""; - foreach (string lora in lora.Split(" ")) + foreach (string lora in lora.Trim().Split(" ")) { - string loraPath = GetModelLoraPath(lora); + if (lora == "") continue; + string loraPath = GetModelLoraPathRuntime(lora); if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); From 89793a02fd19579ca371fb0949263545fd6bd270 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 16:49:01 +0300 Subject: [PATCH 8/8] update to latest LlamaLib, add embedding test --- Tests/Runtime/TestLLM.cs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index b5afc880..affca075 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -7,6 +7,7 @@ using System.Collections; using UnityEngine.TestTools; using System.IO; +using NUnit.Framework.Internal; namespace LLMUnityTests { @@ -86,6 +87,8 @@ public virtual async Task RunTests() llmCharacter.SetPrompt(prompt); await llmCharacter.Chat("hi"); TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3); + List embeddings = await llmCharacter.Embeddings("hi how are you?"); + TestEmbeddings(embeddings); llm.OnDestroy(); } catch (Exception e) @@ -126,7 +129,7 @@ public void TestWarmup() public void TestChat(string reply) { - string AIReply = "To increase your meme production/output, you can consider the following:\n1. Use"; + string AIReply = "One way to increase your meme production/output is by creating a more complex and customized"; Assert.That(reply.Trim() == AIReply); } @@ -141,6 +144,11 @@ public void TestPostChat(int num) Assert.That(llmCharacter.chat.Count == num); } + public void TestEmbeddings(List embeddings) + { + Assert.That(embeddings.Count == 1024); + } + public virtual void OnDestroy() { LLMManager.Remove(filename);