diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs
index bebaf578..ed50046d 100644
--- a/Editor/LLMEditor.cs
+++ b/Editor/LLMEditor.cs
@@ -142,7 +142,7 @@ void SetModelIfNone(string filename, bool lora)
LLM llmScript = (LLM)target;
int num = LLMManager.Num(lora);
if (!lora && llmScript.model == "" && num == 1) llmScript.SetModel(filename);
- if (lora && llmScript.lora == "" && num == 1) llmScript.SetLora(filename);
+ if (lora) llmScript.AddLora(filename);
}
async Task createCustomURLField()
@@ -237,7 +237,7 @@ async Task createButtons()
{
EditorApplication.delayCall += () =>
{
- string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" });
+ string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf lora file", "", new string[] { "Model Files", "gguf" });
if (!string.IsNullOrEmpty(path))
{
string filename = LLMManager.LoadLora(path, true);
@@ -299,10 +299,10 @@ void OnEnable()
}
else
{
- isSelected = llmScript.lora == entry.filename;
- bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton);
- if (newSelected && !isSelected) llmScript.SetLora(entry.filename);
- else if (!newSelected && isSelected) llmScript.SetLora("");
+ isSelected = llmScript.lora.Split(" ").Contains(entry.filename);
+ bool newSelected = EditorGUI.Toggle(selectRect, isSelected);
+ if (newSelected && !isSelected) llmScript.AddLora(entry.filename);
+ else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename);
}
DrawCopyableLabel(nameRect, entry.label, entry.filename);
@@ -347,6 +347,11 @@ void OnEnable()
if (GUI.Button(actionRect, trashIcon))
{
+ if (isSelected)
+ {
+ if (!entry.lora) llmScript.SetModel("");
+ else llmScript.RemoveLora(entry.filename);
+ }
LLMManager.Remove(entry);
UpdateModels(true);
}
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs
index 315c5cd9..a691cf29 100644
--- a/Runtime/LLM.cs
+++ b/Runtime/LLM.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using UnityEditor;
@@ -68,7 +69,7 @@ public class LLM : MonoBehaviour
[ModelAdvanced] public string model = "";
/// Chat template used for the model
[ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate;
- /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder).
+ /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder).
/// Models with .bin format are allowed.
[ModelAdvanced] public string lora = "";
@@ -81,6 +82,7 @@ public class LLM : MonoBehaviour
Thread llmThread = null;
List streamWrappers = new List();
public LLMManager llmManager = new LLMManager();
+ List loraWeights = new List();
/// \endcond
@@ -128,14 +130,14 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr
return !modelSetupFailed;
}
- public string GetModelLoraPath(string path)
+ public string GetModelLoraPathRuntime(string path)
{
string assetPath = LLMManager.GetAssetPath(path);
if (!string.IsNullOrEmpty(assetPath)) return assetPath;
return path;
}
- public string SetModelLoraPath(string path, bool lora)
+ public string GetModelLoraPath(string path, bool lora)
{
if (string.IsNullOrEmpty(path)) return path;
ModelEntry modelEntry = LLMManager.Get(path);
@@ -167,11 +169,11 @@ public string SetModelLoraPath(string path, bool lora)
/// path to model to use (.gguf format)
public void SetModel(string path)
{
- model = SetModelLoraPath(path, false);
+ model = GetModelLoraPath(path, false);
if (!string.IsNullOrEmpty(model))
{
ModelEntry modelEntry = LLMManager.Get(model);
- string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPath(model));
+ string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model));
SetTemplate(template);
}
#if UNITY_EDITOR
@@ -187,7 +189,43 @@ public void SetModel(string path)
/// path to LORA model to use (.bin format)
public void SetLora(string path)
{
- lora = SetModelLoraPath(path, true);
+ lora = "";
+ AddLora(path);
+ }
+
+ ///
+ /// Allows to add a LORA model to use in the LLM.
+ /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build.
+ /// Models supported are in .bin format.
+ ///
+ /// path to LORA model to use (.bin format)
+ public void AddLora(string path)
+ {
+ string loraPath = GetModelLoraPath(path, true);
+ if (lora.Split(" ").Contains(loraPath)) return;
+ if (lora != "") lora += " ";
+ lora += loraPath;
+#if UNITY_EDITOR
+ if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
+#endif
+ }
+
+ ///
+ /// Allows to remove a LORA model from the LLM.
+ /// Models supported are in .bin format.
+ ///
+ /// path to LORA model to remove (.bin format)
+ public void RemoveLora(string path)
+ {
+ string loraPath = GetModelLoraPath(path, true);
+ List loras = new List(lora.Split(" "));
+ loras.Remove(loraPath);
+ lora = "";
+ for (int i = 0; i < loras.Count; i++)
+ {
+ if (i > 0) lora += " ";
+ lora += loras[i];
+ }
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
#endif
@@ -223,21 +261,23 @@ protected virtual string GetLlamaccpArguments()
LLMUnitySetup.LogError("No model file provided!");
return null;
}
- string modelPath = GetModelLoraPath(model);
+ string modelPath = GetModelLoraPathRuntime(model);
if (!File.Exists(modelPath))
{
LLMUnitySetup.LogError($"File {modelPath} not found!");
return null;
}
- string loraPath = "";
- if (lora != "")
+ string loraArgument = "";
+ foreach (string lora in lora.Trim().Split(" "))
{
- loraPath = GetModelLoraPath(lora);
+ if (lora == "") continue;
+ string loraPath = GetModelLoraPathRuntime(lora);
if (!File.Exists(loraPath))
{
LLMUnitySetup.LogError($"File {loraPath} not found!");
return null;
}
+ loraArgument += $" --lora \"{loraPath}\"";
}
int numThreadsToUse = numThreads;
@@ -247,7 +287,7 @@ protected virtual string GetLlamaccpArguments()
string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
if (remote) arguments += $" --port {port} --host 0.0.0.0";
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
- if (loraPath != "") arguments += $" --lora \"{loraPath}\"";
+ arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
return arguments;
}
@@ -323,6 +363,8 @@ private void StartService()
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
while (!llmlib.LLM_Started(LLMObject)) {}
+ loraWeights = new List();
+ for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f);
started = true;
}
@@ -345,7 +387,7 @@ protected int GetNumClients()
/// \cond HIDE
public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
- public delegate void LLMSimpleCallback(IntPtr LLMObject, string json_data);
+ public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper);
/// \endcond
@@ -400,6 +442,17 @@ void CheckLLMStatus(bool log = true)
}
}
+ async Task LLMNoInputReply(LLMNoInputReplyCallback callback)
+ {
+ AssertStarted();
+ IntPtr stringWrapper = llmlib.StringWrapper_Construct();
+ await Task.Run(() => callback(LLMObject, stringWrapper));
+ string result = llmlib?.GetStringWrapperResult(stringWrapper);
+ llmlib?.StringWrapper_Delete(stringWrapper);
+ CheckLLMStatus();
+ return result;
+ }
+
async Task LLMReply(LLMReplyCallback callback, string json)
{
AssertStarted();
@@ -441,6 +494,74 @@ public async Task Detokenize(string json)
return await LLMReply(callback, json);
}
+ ///
+ /// Computes the embeddings of the provided query.
+ ///
+ /// json request containing the query
+ /// embeddings result
+ public async Task Embeddings(string json)
+ {
+ AssertStarted();
+ LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
+ {
+ llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
+ };
+ return await LLMReply(callback, json);
+ }
+
+ ///
+ /// Sets the lora scale, only works after the LLM service has started
+ ///
+ /// switch result
+ public async Task SetLoraScale(string loraToScale, float scale)
+ {
+ AssertStarted();
+ List loras = new List(lora.Split(" "));
+ string loraToScalePath = GetModelLoraPath(loraToScale, true);
+
+ int index = loras.IndexOf(loraToScale);
+ if (index == -1) index = loras.IndexOf(loraToScalePath);
+ if (index == -1)
+ {
+ LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM");
+ return "";
+ }
+
+ loraWeights[index] = scale;
+ LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList();
+ loraWeightRequest.loraWeights = new List();
+ for (int i = 0; i < loraWeights.Count; i++)
+ {
+ loraWeightRequest.loraWeights.Add(new LoraWeightRequest() {id = i, scale = loraWeights[i]});
+ }
+ ;
+
+ string json = JsonUtility.ToJson(loraWeightRequest);
+ int startIndex = json.IndexOf("[");
+ int endIndex = json.LastIndexOf("]") + 1;
+ json = json.Substring(startIndex, endIndex - startIndex);
+
+ LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
+ {
+ llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper);
+ };
+ return await LLMReply(callback, json);
+ }
+
+ ///
+ /// Gets a list of the lora adapters
+ ///
+ /// list of lara adapters
+ public async Task ListLora()
+ {
+ AssertStarted();
+ LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
+ {
+ llmlib.LLM_LoraList(LLMObject, strWrapper);
+ };
+ return await LLMNoInputReply(callback);
+ }
+
///
/// Allows to save / restore the state of a slot
///
diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs
index 9f37721a..c76239aa 100644
--- a/Runtime/LLMCharacter.cs
+++ b/Runtime/LLMCharacter.cs
@@ -430,16 +430,22 @@ protected List TokenizeContent(TokenizeResult result)
return result.tokens;
}
- protected string SlotContent(SlotResult result)
+ protected string DetokenizeContent(TokenizeRequest result)
{
- // get the tokens from a tokenize result received from the endpoint
- return result.filename;
+ // get content from a chat result received from the endpoint
+ return result.content;
}
- protected string DetokenizeContent(TokenizeRequest result)
+ protected List EmbeddingsContent(EmbeddingsResult result)
{
// get content from a chat result received from the endpoint
- return result.content;
+ return result.embedding;
+ }
+
+ protected string SlotContent(SlotResult result)
+ {
+ // get the tokens from a tokenize result received from the endpoint
+ return result.filename;
}
///
@@ -572,6 +578,21 @@ public async Task Detokenize(List tokens, Callback callback
return await PostRequest(json, "detokenize", DetokenizeContent, callback);
}
+ ///
+ /// Computes the embeddings of the provided input.
+ ///
+ /// input to compute the embeddings for
+ /// callback function called with the result string
+ /// the computed embeddings
+ public async Task> Embeddings(string query, Callback> callback = null)
+ {
+ // handle the tokenization of a message by the user
+ TokenizeRequest tokenizeRequest = new TokenizeRequest();
+ tokenizeRequest.content = query;
+ string json = JsonUtility.ToJson(tokenizeRequest);
+ return await PostRequest>(json, "embeddings", EmbeddingsContent, callback);
+ }
+
private async Task Slot(string filepath, string action)
{
SlotRequest slotRequest = new SlotRequest();
@@ -682,6 +703,9 @@ protected async Task PostRequestLocal(string json, string endpoin
case "detokenize":
callResult = await llm.Detokenize(json);
break;
+ case "embeddings":
+ callResult = await llm.Embeddings(json);
+ break;
case "slots":
callResult = await llm.Slot(json);
break;
diff --git a/Runtime/LLMInterface.cs b/Runtime/LLMInterface.cs
index f5086595..36a11c37 100644
--- a/Runtime/LLMInterface.cs
+++ b/Runtime/LLMInterface.cs
@@ -93,6 +93,25 @@ public struct TokenizeResult
public List tokens;
}
+ [Serializable]
+ public struct EmbeddingsResult
+ {
+ public List embedding;
+ }
+
+ [Serializable]
+ public struct LoraWeightRequest
+ {
+ public int id;
+ public float scale;
+ }
+
+ [Serializable]
+ public struct LoraWeightRequestList
+ {
+ public List loraWeights;
+ }
+
[Serializable]
public struct TemplateResult
{
diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs
index c8d49ccd..fbaa050d 100644
--- a/Runtime/LLMLib.cs
+++ b/Runtime/LLMLib.cs
@@ -298,6 +298,9 @@ public LLMLib(string arch)
LLM_SetTemplate = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_SetTemplate");
LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize");
LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize");
+ LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings");
+ LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight");
+ LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List");
LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion");
LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot");
LLM_Cancel = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Cancel");
@@ -452,6 +455,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate);
public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper);
public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot);
@@ -474,6 +480,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public LLM_TokenizeDelegate LLM_Tokenize;
public LLM_DetokenizeDelegate LLM_Detokenize;
public LLM_CompletionDelegate LLM_Completion;
+ public LLM_EmbeddingsDelegate LLM_Embeddings;
+ public LLM_LoraWeightDelegate LLM_Lora_Weight;
+ public LLM_LoraListDelegate LLM_LoraList;
public LLM_SlotDelegate LLM_Slot;
public LLM_CancelDelegate LLM_Cancel;
public LLM_StatusDelegate LLM_Status;
diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs
index 833f0ce7..4ee87a66 100644
--- a/Runtime/LLMUnitySetup.cs
+++ b/Runtime/LLMUnitySetup.cs
@@ -86,7 +86,7 @@ public class LLMUnitySetup
/// LLM for Unity version
public static string Version = "v2.1.2";
/// LlamaLib version
- public static string LlamaLibVersion = "v1.1.6";
+ public static string LlamaLibVersion = "v1.1.8";
/// LlamaLib url
public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip";
/// LlamaLib path
diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs
index b5afc880..affca075 100644
--- a/Tests/Runtime/TestLLM.cs
+++ b/Tests/Runtime/TestLLM.cs
@@ -7,6 +7,7 @@
using System.Collections;
using UnityEngine.TestTools;
using System.IO;
+using NUnit.Framework.Internal;
namespace LLMUnityTests
{
@@ -86,6 +87,8 @@ public virtual async Task RunTests()
llmCharacter.SetPrompt(prompt);
await llmCharacter.Chat("hi");
TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3);
+ List embeddings = await llmCharacter.Embeddings("hi how are you?");
+ TestEmbeddings(embeddings);
llm.OnDestroy();
}
catch (Exception e)
@@ -126,7 +129,7 @@ public void TestWarmup()
public void TestChat(string reply)
{
- string AIReply = "To increase your meme production/output, you can consider the following:\n1. Use";
+ string AIReply = "One way to increase your meme production/output is by creating a more complex and customized";
Assert.That(reply.Trim() == AIReply);
}
@@ -141,6 +144,11 @@ public void TestPostChat(int num)
Assert.That(llmCharacter.chat.Count == num);
}
+ public void TestEmbeddings(List embeddings)
+ {
+ Assert.That(embeddings.Count == 1024);
+ }
+
public virtual void OnDestroy()
{
LLMManager.Remove(filename);