Skip to content

Implement embedding and lora adapter functionality #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void SetModelIfNone(string filename, bool lora)
LLM llmScript = (LLM)target;
int num = LLMManager.Num(lora);
if (!lora && llmScript.model == "" && num == 1) llmScript.SetModel(filename);
if (lora && llmScript.lora == "" && num == 1) llmScript.SetLora(filename);
if (lora) llmScript.AddLora(filename);
}

async Task createCustomURLField()
Expand Down Expand Up @@ -237,7 +237,7 @@ async Task createButtons()
{
EditorApplication.delayCall += () =>
{
string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" });
string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf lora file", "", new string[] { "Model Files", "gguf" });
if (!string.IsNullOrEmpty(path))
{
string filename = LLMManager.LoadLora(path, true);
Expand Down Expand Up @@ -299,10 +299,10 @@ void OnEnable()
}
else
{
isSelected = llmScript.lora == entry.filename;
bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton);
if (newSelected && !isSelected) llmScript.SetLora(entry.filename);
else if (!newSelected && isSelected) llmScript.SetLora("");
isSelected = llmScript.lora.Split(" ").Contains(entry.filename);
bool newSelected = EditorGUI.Toggle(selectRect, isSelected);
if (newSelected && !isSelected) llmScript.AddLora(entry.filename);
else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename);
}

DrawCopyableLabel(nameRect, entry.label, entry.filename);
Expand Down Expand Up @@ -347,6 +347,11 @@ void OnEnable()

if (GUI.Button(actionRect, trashIcon))
{
if (isSelected)
{
if (!entry.lora) llmScript.SetModel("");
else llmScript.RemoveLora(entry.filename);
}
LLMManager.Remove(entry);
UpdateModels(true);
}
Expand Down
145 changes: 133 additions & 12 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using UnityEditor;
Expand Down Expand Up @@ -68,7 +69,7 @@ public class LLM : MonoBehaviour
[ModelAdvanced] public string model = "";
/// <summary> Chat template used for the model </summary>
[ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate;
/// <summary> the path of the LORA model being used (relative to the Assets/StreamingAssets folder).
/// <summary> the paths of the LORA models being used (relative to the Assets/StreamingAssets folder).
/// Models with .bin format are allowed.</summary>
[ModelAdvanced] public string lora = "";

Expand All @@ -81,6 +82,7 @@ public class LLM : MonoBehaviour
Thread llmThread = null;
List<StreamWrapper> streamWrappers = new List<StreamWrapper>();
public LLMManager llmManager = new LLMManager();
List<float> loraWeights = new List<float>();

/// \endcond

Expand Down Expand Up @@ -128,14 +130,14 @@ public static async Task<bool> WaitUntilModelSetup(Callback<float> downloadProgr
return !modelSetupFailed;
}

public string GetModelLoraPath(string path)
public string GetModelLoraPathRuntime(string path)
{
string assetPath = LLMManager.GetAssetPath(path);
if (!string.IsNullOrEmpty(assetPath)) return assetPath;
return path;
}

public string SetModelLoraPath(string path, bool lora)
public string GetModelLoraPath(string path, bool lora)
{
if (string.IsNullOrEmpty(path)) return path;
ModelEntry modelEntry = LLMManager.Get(path);
Expand Down Expand Up @@ -167,11 +169,11 @@ public string SetModelLoraPath(string path, bool lora)
/// <param name="path">path to model to use (.gguf format)</param>
public void SetModel(string path)
{
model = SetModelLoraPath(path, false);
model = GetModelLoraPath(path, false);
if (!string.IsNullOrEmpty(model))
{
ModelEntry modelEntry = LLMManager.Get(model);
string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPath(model));
string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model));
SetTemplate(template);
}
#if UNITY_EDITOR
Expand All @@ -187,7 +189,43 @@ public void SetModel(string path)
/// <param name="path">path to LORA model to use (.bin format)</param>
public void SetLora(string path)
{
lora = SetModelLoraPath(path, true);
lora = "";
AddLora(path);
}

/// <summary>
/// Allows to add a LORA model to use in the LLM.
/// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build.
/// Models supported are in .bin format.
/// </summary>
/// <param name="path">path to LORA model to use (.bin format)</param>
public void AddLora(string path)
{
string loraPath = GetModelLoraPath(path, true);
if (lora.Split(" ").Contains(loraPath)) return;
if (lora != "") lora += " ";
lora += loraPath;
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
#endif
}

/// <summary>
/// Allows to remove a LORA model from the LLM.
/// Models supported are in .bin format.
/// </summary>
/// <param name="path">path to LORA model to remove (.bin format)</param>
public void RemoveLora(string path)
{
string loraPath = GetModelLoraPath(path, true);
List<string> loras = new List<string>(lora.Split(" "));
loras.Remove(loraPath);
lora = "";
for (int i = 0; i < loras.Count; i++)
{
if (i > 0) lora += " ";
lora += loras[i];
}
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
#endif
Expand Down Expand Up @@ -223,21 +261,23 @@ protected virtual string GetLlamaccpArguments()
LLMUnitySetup.LogError("No model file provided!");
return null;
}
string modelPath = GetModelLoraPath(model);
string modelPath = GetModelLoraPathRuntime(model);
if (!File.Exists(modelPath))
{
LLMUnitySetup.LogError($"File {modelPath} not found!");
return null;
}
string loraPath = "";
if (lora != "")
string loraArgument = "";
foreach (string lora in lora.Trim().Split(" "))
{
loraPath = GetModelLoraPath(lora);
if (lora == "") continue;
string loraPath = GetModelLoraPathRuntime(lora);
if (!File.Exists(loraPath))
{
LLMUnitySetup.LogError($"File {loraPath} not found!");
return null;
}
loraArgument += $" --lora \"{loraPath}\"";
}

int numThreadsToUse = numThreads;
Expand All @@ -247,7 +287,7 @@ protected virtual string GetLlamaccpArguments()
string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
if (remote) arguments += $" --port {port} --host 0.0.0.0";
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
if (loraPath != "") arguments += $" --lora \"{loraPath}\"";
arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
return arguments;
}
Expand Down Expand Up @@ -323,6 +363,8 @@ private void StartService()
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
while (!llmlib.LLM_Started(LLMObject)) {}
loraWeights = new List<float>();
for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f);
started = true;
}

Expand All @@ -345,7 +387,7 @@ protected int GetNumClients()

/// \cond HIDE
public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
public delegate void LLMSimpleCallback(IntPtr LLMObject, string json_data);
public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper);
/// \endcond

Expand Down Expand Up @@ -400,6 +442,17 @@ void CheckLLMStatus(bool log = true)
}
}

async Task<string> LLMNoInputReply(LLMNoInputReplyCallback callback)
{
AssertStarted();
IntPtr stringWrapper = llmlib.StringWrapper_Construct();
await Task.Run(() => callback(LLMObject, stringWrapper));
string result = llmlib?.GetStringWrapperResult(stringWrapper);
llmlib?.StringWrapper_Delete(stringWrapper);
CheckLLMStatus();
return result;
}

async Task<string> LLMReply(LLMReplyCallback callback, string json)
{
AssertStarted();
Expand Down Expand Up @@ -441,6 +494,74 @@ public async Task<string> Detokenize(string json)
return await LLMReply(callback, json);
}

/// <summary>
/// Computes the embeddings of the provided query.
/// </summary>
/// <param name="json">json request containing the query</param>
/// <returns>embeddings result</returns>
public async Task<string> Embeddings(string json)
{
AssertStarted();
LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
{
llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
};
return await LLMReply(callback, json);
}

/// <summary>
/// Sets the lora scale, only works after the LLM service has started
/// </summary>
/// <returns>switch result</returns>
public async Task<string> SetLoraScale(string loraToScale, float scale)
{
AssertStarted();
List<string> loras = new List<string>(lora.Split(" "));
string loraToScalePath = GetModelLoraPath(loraToScale, true);

int index = loras.IndexOf(loraToScale);
if (index == -1) index = loras.IndexOf(loraToScalePath);
if (index == -1)
{
LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM");
return "";
}

loraWeights[index] = scale;
LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList();
loraWeightRequest.loraWeights = new List<LoraWeightRequest>();
for (int i = 0; i < loraWeights.Count; i++)
{
loraWeightRequest.loraWeights.Add(new LoraWeightRequest() {id = i, scale = loraWeights[i]});
}
;

string json = JsonUtility.ToJson(loraWeightRequest);
int startIndex = json.IndexOf("[");
int endIndex = json.LastIndexOf("]") + 1;
json = json.Substring(startIndex, endIndex - startIndex);

LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
{
llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper);
};
return await LLMReply(callback, json);
}

/// <summary>
/// Gets a list of the lora adapters
/// </summary>
/// <returns>list of lara adapters</returns>
public async Task<string> ListLora()
{
AssertStarted();
LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
{
llmlib.LLM_LoraList(LLMObject, strWrapper);
};
return await LLMNoInputReply(callback);
}

/// <summary>
/// Allows to save / restore the state of a slot
/// </summary>
Expand Down
34 changes: 29 additions & 5 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,22 @@ protected List<int> TokenizeContent(TokenizeResult result)
return result.tokens;
}

protected string SlotContent(SlotResult result)
protected string DetokenizeContent(TokenizeRequest result)
{
// get the tokens from a tokenize result received from the endpoint
return result.filename;
// get content from a chat result received from the endpoint
return result.content;
}

protected string DetokenizeContent(TokenizeRequest result)
protected List<float> EmbeddingsContent(EmbeddingsResult result)
{
// get content from a chat result received from the endpoint
return result.content;
return result.embedding;
}

protected string SlotContent(SlotResult result)
{
// get the tokens from a tokenize result received from the endpoint
return result.filename;
}

/// <summary>
Expand Down Expand Up @@ -572,6 +578,21 @@ public async Task<string> Detokenize(List<int> tokens, Callback<string> callback
return await PostRequest<TokenizeRequest, string>(json, "detokenize", DetokenizeContent, callback);
}

/// <summary>
/// Computes the embeddings of the provided input.
/// </summary>
/// <param name="tokens">input to compute the embeddings for</param>
/// <param name="callback">callback function called with the result string</param>
/// <returns>the computed embeddings</returns>
public async Task<List<float>> Embeddings(string query, Callback<List<float>> callback = null)
{
// handle the tokenization of a message by the user
TokenizeRequest tokenizeRequest = new TokenizeRequest();
tokenizeRequest.content = query;
string json = JsonUtility.ToJson(tokenizeRequest);
return await PostRequest<EmbeddingsResult, List<float>>(json, "embeddings", EmbeddingsContent, callback);
}

private async Task<string> Slot(string filepath, string action)
{
SlotRequest slotRequest = new SlotRequest();
Expand Down Expand Up @@ -682,6 +703,9 @@ protected async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoin
case "detokenize":
callResult = await llm.Detokenize(json);
break;
case "embeddings":
callResult = await llm.Embeddings(json);
break;
case "slots":
callResult = await llm.Slot(json);
break;
Expand Down
19 changes: 19 additions & 0 deletions Runtime/LLMInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ public struct TokenizeResult
public List<int> tokens;
}

[Serializable]
public struct EmbeddingsResult
{
public List<float> embedding;
}

[Serializable]
public struct LoraWeightRequest
{
public int id;
public float scale;
}

[Serializable]
public struct LoraWeightRequestList
{
public List<LoraWeightRequest> loraWeights;
}

[Serializable]
public struct TemplateResult
{
Expand Down
Loading
Loading