diff --git a/.github/doxygen/Doxyfile b/.github/doxygen/Doxyfile
index f765ee81..05989312 100644
--- a/.github/doxygen/Doxyfile
+++ b/.github/doxygen/Doxyfile
@@ -48,7 +48,7 @@ PROJECT_NAME = "LLM for Unity"
# could be handy for archiving the generated documentation or if some version
# control system is used.
-PROJECT_NUMBER = v2.1.1
+PROJECT_NUMBER = v2.2.0
# Using the PROJECT_BRIEF tag one can provide an optional one line description
# for a project that appears at the top of each page and should give viewer a
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 31631aac..5562a81e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,23 @@
+## v2.2.0
+#### 🚀 Features
+
+- Implement embedding and lora adapter functionality (PR: #210)
+- Read context length and warn if it is very large (PR: #211)
+- Setup allowing to use extra features: flash attention and IQ quants (PR: #216)
+- Allow HTTP request retries for remote server (PR: #217)
+- Allow to set lora weights at startup, add unit test (PR: #219)
+- allow relative StreamingAssets paths for models (PR: #221)
+
+#### 🐛 Fixes
+
+- Fix set template for remote setup (PR: #208)
+- fix crash when stopping scene before LLM creation (PR: #214)
+
+#### 📦 General
+
+- Documentation/point to gguf format for lora (PR: #215)
+
+
## v2.1.1
#### 🐛 Fixes
diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md
index 379083c3..d405932e 100644
--- a/CHANGELOG.release.md
+++ b/CHANGELOG.release.md
@@ -1,3 +1,18 @@
+### 🚀 Features
+
+- Implement embedding and lora adapter functionality (PR: #210)
+- Read context length and warn if it is very large (PR: #211)
+- Setup allowing to use extra features: flash attention and IQ quants (PR: #216)
+- Allow HTTP request retries for remote server (PR: #217)
+- Allow to set lora weights at startup, add unit test (PR: #219)
+- allow relative StreamingAssets paths for models (PR: #221)
+
### 🐛 Fixes
-- Resolve build directory creation
\ No newline at end of file
+- Fix set template for remote setup (PR: #208)
+- fix crash when stopping scene before LLM creation (PR: #214)
+
+### 📦 General
+
+- Documentation/point to gguf format for lora (PR: #215)
+
diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs
index bebaf578..dcf5cfed 100644
--- a/Editor/LLMEditor.cs
+++ b/Editor/LLMEditor.cs
@@ -74,6 +74,7 @@ public void AddModelSettings(SerializedObject llmScriptSO)
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ModelAdvancedAttribute));
+ if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute));
}
ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false);
Space();
@@ -142,7 +143,7 @@ void SetModelIfNone(string filename, bool lora)
LLM llmScript = (LLM)target;
int num = LLMManager.Num(lora);
if (!lora && llmScript.model == "" && num == 1) llmScript.SetModel(filename);
- if (lora && llmScript.lora == "" && num == 1) llmScript.SetLora(filename);
+ if (lora) llmScript.AddLora(filename);
}
async Task createCustomURLField()
@@ -205,7 +206,7 @@ async Task createButtons()
}
else if (modelIndex > 1)
{
- if (modelLicenses[modelIndex] != null) Debug.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license.");
+ if (modelLicenses[modelIndex] != null) LLMUnitySetup.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license.");
string filename = await LLMManager.DownloadModel(modelURLs[modelIndex], true, modelOptions[modelIndex]);
SetModelIfNone(filename, false);
UpdateModels(true);
@@ -237,7 +238,7 @@ async Task createButtons()
{
EditorApplication.delayCall += () =>
{
- string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" });
+ string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf lora file", "", new string[] { "Model Files", "gguf" });
if (!string.IsNullOrEmpty(path))
{
string filename = LLMManager.LoadLora(path, true);
@@ -299,10 +300,10 @@ void OnEnable()
}
else
{
- isSelected = llmScript.lora == entry.filename;
- bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton);
- if (newSelected && !isSelected) llmScript.SetLora(entry.filename);
- else if (!newSelected && isSelected) llmScript.SetLora("");
+ isSelected = llmScript.loraManager.Contains(entry.filename);
+ bool newSelected = EditorGUI.Toggle(selectRect, isSelected);
+ if (newSelected && !isSelected) llmScript.AddLora(entry.filename);
+ else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename);
}
DrawCopyableLabel(nameRect, entry.label, entry.filename);
@@ -347,6 +348,11 @@ void OnEnable()
if (GUI.Button(actionRect, trashIcon))
{
+ if (isSelected)
+ {
+ if (!entry.lora) llmScript.SetModel("");
+ else llmScript.RemoveLora(entry.filename);
+ }
LLMManager.Remove(entry);
UpdateModels(true);
}
diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs
index 87a40938..d3510ac8 100644
--- a/Editor/PropertyEditor.cs
+++ b/Editor/PropertyEditor.cs
@@ -16,17 +16,11 @@ public void AddScript(SerializedObject llmScriptSO)
EditorGUILayout.PropertyField(scriptProp);
}
- public void AddOptionsToggle(SerializedObject llmScriptSO, string propertyName, string name)
+ public bool ToggleButton(string text, bool activated)
{
- SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty(propertyName);
- string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " " + name;
GUIStyle style = new GUIStyle("Button");
- if (advancedOptionsProp.boolValue)
- style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
- if (GUILayout.Button(toggleText, style, GUILayout.Width(buttonWidth)))
- {
- advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
- }
+ if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
+ return GUILayout.Button(text, style, GUILayout.Width(buttonWidth));
}
public void AddSetupSettings(SerializedObject llmScriptSO)
@@ -54,8 +48,12 @@ public void AddChatSettings(SerializedObject llmScriptSO)
public void AddOptionsToggles(SerializedObject llmScriptSO)
{
LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode));
+
EditorGUILayout.BeginHorizontal();
- AddOptionsToggle(llmScriptSO, "advancedOptions", "Advanced Options");
+ SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions");
+ string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options";
+ if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
+ if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
EditorGUILayout.EndHorizontal();
Space();
}
diff --git a/README.md b/README.md
index 8b2e15d0..c4792950 100644
--- a/README.md
+++ b/README.md
@@ -46,8 +46,10 @@ LLM for Unity is built on top of the awesome [llama.cpp](https://github.com/gger
## How to help
- [⭐ Star](https://github.com/undreamai/LLMUnity) the repo, leave us a [review](https://assetstore.unity.com/packages/slug/273604) and spread the word about the project!
-- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi!
-- [Contribute](CONTRIBUTING.md) by submitting feature requests or bugs as issues or even submiting a PR and become a collaborator!
+- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi.
+- [Contribute](CONTRIBUTING.md) by submitting feature requests, bugs or even your own PR.
+- [](https://github.com/sponsors/amakropoulos) this work to allow even cooler features!
+
## Games using LLM for Unity
- [Verbal Verdict](https://store.steampowered.com/app/2778780/Verbal_Verdict/)
@@ -56,6 +58,7 @@ LLM for Unity is built on top of the awesome [llama.cpp](https://github.com/gger
- [Murder in Aisle 4](https://roadedlich.itch.io/murder-in-aisle-4)
- [Finicky Food Delivery AI](https://helixngc7293.itch.io/finicky-food-delivery-ai)
- [AI Emotional Girlfriend](https://whynames.itch.io/aiemotionalgirlfriend)
+- [AI Speak](https://jdscogin.wixsite.com/aispeak)
## Setup
_Method 1: Install using the asset store_
@@ -247,8 +250,9 @@ public class MyScript : MonoBehaviour
// The model needs to be added to the LLM model manager (see LLM model management) by loading or downloading it.
// Otherwise the model file can be copied directly inside the StreamingAssets folder.
llm.SetModel("Phi-3-mini-4k-instruct-q4.gguf");
- // optional: you can also set a lora in a similar fashion
- llm.SetLora("my-lora.bin");
+ // optional: you can also set loras in a similar fashion and set their weights (if needed)
+ llm.AddLora("my-lora.gguf");
+ llm.SetLoraWeight(0.5f);
// optional: you can set the chat template of the model if it is not correctly identified
// You can find a list of chat templates in the ChatTemplate.templates.Keys
llm.SetTemplate("phi-3");
@@ -290,6 +294,15 @@ You can use a remote server to carry out the processing and implement characters
- Create a second project with the game characters using the `LLMCharacter` script as described above.
Enable the `Remote` option and configure the host with the IP address (starting with "http://") and port of the server.
+
+
+Compute embeddings using a LLM
+
+The `Embeddings` function can be used to obtain the emdeddings of a phrase:
+``` c#
+ List embeddings = await llmCharacter.Embeddings("hi, how are you?");
+```
+
A detailed documentation on function level can be found here:
@@ -345,6 +358,7 @@ If you have loaded a model locally you need to set its URL through the expanded
- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
+- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)
#### 💻 Setup Settings
@@ -374,13 +388,15 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- Advanced options
- - `Download lora` click to download a LoRA model in .bin format
- - `Load lora` click to load a LoRA model in .bin format
+ - `Download lora` click to download a LoRA model in .gguf format
+ - `Load lora` click to load a LoRA model in .gguf format
- Context Size
size of the prompt context (0 = context size of the model)
This is the number of tokens the model can take as input when generating responses. Higher values use more RAM or VRAM (if using GPU).
- `Batch Size` batch size for prompt processing (default: 512)
- `Model` the path of the model being used (relative to the Assets/StreamingAssets folder)
- `Chat Template` the chat template being used for the LLM
- - `Lora` the path of the LoRA being used (relative to the Assets/StreamingAssets folder)
+ - `Lora` the path of the LoRAs being used (relative to the Assets/StreamingAssets folder)
+ - `Lora Weights` the weights of the LoRAs being used
+ - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled)
@@ -395,6 +411,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
+- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)
#### 💻 Setup Settings
@@ -403,8 +420,9 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Remote` whether the LLM used is remote or local
- `LLM` the LLM GameObject (if `Remote` is not set)
-- `Hort` ip of the LLM (if `Remote` is set)
-- `Port` port of the LLM (if `Remote` is set)
+- `Hort` ip of the LLM server (if `Remote` is set)
+- `Port` port of the LLM server (if `Remote` is set)
+- `Num Retries` number of HTTP request retries from the LLM server (if `Remote` is set)
-
Save
save filename or relative path
If set, the chat history and LLM state (if save cache is enabled) is automatically saved to file specified.
The chat history is saved with a json suffix and the LLM state with a cache suffix.
Both files are saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
- `Save Cache` select to save the LLM state along with the chat history. The LLM state is typically around 100MB+.
- `Debug Prompt` select to log the constructed prompts in the Unity Editor
@@ -446,4 +464,6 @@ If it is not selected, the full reply from the model is received in one go
## License
-The license of LLM for Unity is MIT ([LICENSE.md](LICENSE.md)) and uses third-party software with MIT and Apache licenses ([Third Party Notices.md](
)).
+The license of LLM for Unity is MIT ([LICENSE.md](LICENSE.md)) and uses third-party software with MIT and Apache licenses.
+Some models included in the asset define their own license terms, please review them before using each model.
+Third-party licenses can be found in the ([Third Party Notices.md]()).
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs
index 2f4c57cf..95c51cf2 100644
--- a/Runtime/LLM.cs
+++ b/Runtime/LLM.cs
@@ -10,18 +10,6 @@
namespace LLMUnity
{
- /// \cond HIDE
- public class LLMException : Exception
- {
- public int ErrorCode { get; private set; }
-
- public LLMException(string message, int errorCode) : base(message)
- {
- ErrorCode = errorCode;
- }
- }
- /// \endcond
-
[DefaultExecutionOrder(-1)]
/// @ingroup llm
///
@@ -68,9 +56,13 @@ public class LLM : MonoBehaviour
[ModelAdvanced] public string model = "";
/// Chat template used for the model
[ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate;
- /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder).
- /// Models with .bin format are allowed.
+ /// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder).
+ /// Models with .gguf format are allowed.
[ModelAdvanced] public string lora = "";
+ /// the weights of the LORA models being used.
+ [ModelAdvanced] public string loraWeights = "";
+ /// enable use of flash attention
+ [ModelExtras] public bool flashAttention = false;
/// \cond HIDE
@@ -81,6 +73,10 @@ public class LLM : MonoBehaviour
Thread llmThread = null;
List streamWrappers = new List();
public LLMManager llmManager = new LLMManager();
+ private readonly object startLock = new object();
+ public LoraManager loraManager = new LoraManager();
+ string loraPre = "";
+ string loraWeightsPre = "";
/// \endcond
@@ -89,6 +85,15 @@ public LLM()
LLMManager.Register(this);
}
+ void OnValidate()
+ {
+ if (lora != loraPre || loraWeights != loraWeightsPre)
+ {
+ loraManager.FromStrings(lora, loraWeights);
+ (loraPre, loraWeightsPre) = (lora, loraWeights);
+ }
+ }
+
///
/// The Unity Awake function that starts the LLM server.
/// The server can be started asynchronously if the asynchronousStartup option is set.
@@ -112,6 +117,7 @@ public async void Awake()
return;
}
await Task.Run(() => StartLLMServer(arguments));
+ if (!started) return;
if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject);
if (basePrompt != "") await SetBasePrompt(basePrompt);
}
@@ -128,35 +134,55 @@ public static async Task WaitUntilModelSetup(Callback downloadProgr
return !modelSetupFailed;
}
- public string GetModelLoraPath(string path)
+ public static string GetLLMManagerAsset(string path)
{
- string assetPath = LLMManager.GetAssetPath(path);
- if (!string.IsNullOrEmpty(assetPath)) return assetPath;
- return path;
+#if UNITY_EDITOR
+ if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path);
+#endif
+ return GetLLMManagerAssetRuntime(path);
}
- public string SetModelLoraPath(string path, bool lora)
+ public static string GetLLMManagerAssetEditor(string path)
{
+ // empty
if (string.IsNullOrEmpty(path)) return path;
+ // LLMManager - return location the file will be stored in StreamingAssets
ModelEntry modelEntry = LLMManager.Get(path);
if (modelEntry != null) return modelEntry.filename;
-
- string modelType = lora ? "Lora" : "Model";
- string assetPath = LLMUnitySetup.GetAssetPath(path);
+ // StreamingAssets - return relative location within StreamingAssets
+ string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
+ string basePath = LLMUnitySetup.GetAssetPath();
+ if (File.Exists(assetPath))
+ {
+ if (LLMUnitySetup.IsSubPath(assetPath, basePath)) return LLMUnitySetup.RelativePath(assetPath, basePath);
+ }
+ // full path
if (!File.Exists(assetPath))
{
- LLMUnitySetup.LogError($"The {modelType} file {path} was not found.");
- return path;
+ LLMUnitySetup.LogError($"Model {path} was not found.");
}
-
- if (!LLMUnitySetup.IsSubPath(assetPath, LLMUnitySetup.GetAssetPath()))
+ else
{
- string errorMessage = $"The {modelType} file {path} was loaded locally. If you want to include it in the build:";
- errorMessage += $"\n-Copy the {modelType} inside the StreamingAssets folder and use its relative path or";
- errorMessage += $"\n-Load the {modelType} with the LLMManager: `string filename=LLMManager.Load{modelType}(path); llm.Set{modelType}(filename)`";
+ string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:";
+ errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
+ errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
LLMUnitySetup.LogWarning(errorMessage);
}
- return assetPath;
+ return path;
+ }
+
+ public static string GetLLMManagerAssetRuntime(string path)
+ {
+ // empty
+ if (string.IsNullOrEmpty(path)) return path;
+ // LLMManager
+ string managerPath = LLMManager.GetAssetPath(path);
+ if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath;
+ // StreamingAssets
+ string assetPath = LLMUnitySetup.GetAssetPath(path);
+ if (File.Exists(assetPath)) return assetPath;
+ // give up
+ return path;
}
///
@@ -167,12 +193,16 @@ public string SetModelLoraPath(string path, bool lora)
/// path to model to use (.gguf format)
public void SetModel(string path)
{
- model = SetModelLoraPath(path, false);
+ model = GetLLMManagerAsset(path);
if (!string.IsNullOrEmpty(model))
{
ModelEntry modelEntry = LLMManager.Get(model);
- string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPath(model));
- SetTemplate(template);
+ if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model));
+ SetTemplate(modelEntry.chatTemplate);
+ if (contextSize == 0 && modelEntry.contextLength > 32768)
+ {
+ LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
+ }
}
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
@@ -182,12 +212,78 @@ public void SetModel(string path)
///
/// Allows to set a LORA model to use in the LLM.
/// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build.
- /// Models supported are in .bin format.
+ /// Models supported are in .gguf format.
+ ///
+ /// path to LORA model to use (.gguf format)
+ public void SetLora(string path, float weight = 1)
+ {
+ AssertNotStarted();
+ loraManager.Clear();
+ AddLora(path, weight);
+ }
+
+ ///
+ /// Allows to add a LORA model to use in the LLM.
+ /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build.
+ /// Models supported are in .gguf format.
+ ///
+ /// path to LORA model to use (.gguf format)
+ public void AddLora(string path, float weight = 1)
+ {
+ AssertNotStarted();
+ loraManager.Add(path, weight);
+ UpdateLoras();
+ }
+
+ ///
+ /// Allows to remove a LORA model from the LLM.
+ /// Models supported are in .gguf format.
///
- /// path to LORA model to use (.bin format)
- public void SetLora(string path)
+ /// path to LORA model to remove (.gguf format)
+ public void RemoveLora(string path)
{
- lora = SetModelLoraPath(path, true);
+ AssertNotStarted();
+ loraManager.Remove(path);
+ UpdateLoras();
+ }
+
+ ///
+ /// Allows to remove all LORA models from the LLM.
+ ///
+ public void RemoveLoras()
+ {
+ AssertNotStarted();
+ loraManager.Clear();
+ UpdateLoras();
+ }
+
+ ///
+ /// Allows to change the weight (scale) of a LORA model in the LLM.
+ ///
+ /// path of LORA model to change (.gguf format)
+ /// weight of LORA
+ public void SetLoraWeight(string path, float weight)
+ {
+ loraManager.SetWeight(path, weight);
+ UpdateLoras();
+ if (started) ApplyLoras();
+ }
+
+ ///
+ /// Allows to change the weights (scale) of the LORA models in the LLM.
+ ///
+ /// Dictionary (string, float) mapping the path of LORA models with weights to change
+ public void SetLoraWeights(Dictionary loraToWeight)
+ {
+ foreach (KeyValuePair entry in loraToWeight) loraManager.SetWeight(entry.Key, entry.Value);
+ UpdateLoras();
+ if (started) ApplyLoras();
+ }
+
+ public void UpdateLoras()
+ {
+ (lora, loraWeights) = loraManager.ToStrings();
+ (loraPre, loraWeightsPre) = (lora, loraWeights);
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
#endif
@@ -223,22 +319,25 @@ protected virtual string GetLlamaccpArguments()
LLMUnitySetup.LogError("No model file provided!");
return null;
}
- string modelPath = GetModelLoraPath(model);
+ string modelPath = GetLLMManagerAssetRuntime(model);
if (!File.Exists(modelPath))
{
LLMUnitySetup.LogError($"File {modelPath} not found!");
return null;
}
- string loraPath = "";
- if (lora != "")
+ string loraArgument = "";
+ foreach (string lora in lora.Trim().Split(" "))
{
- loraPath = GetModelLoraPath(lora);
+ if (lora == "") continue;
+ string loraPath = GetLLMManagerAssetRuntime(lora);
if (!File.Exists(loraPath))
{
LLMUnitySetup.LogError($"File {loraPath} not found!");
return null;
}
+ loraArgument += $" --lora \"{loraPath}\"";
}
+ loraManager.FromStrings(lora, loraWeights);
int numThreadsToUse = numThreads;
if (Application.platform == RuntimePlatform.Android && numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores();
@@ -247,8 +346,9 @@ protected virtual string GetLlamaccpArguments()
string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
if (remote) arguments += $" --port {port} --host 0.0.0.0";
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
- if (loraPath != "") arguments += $" --lora \"{loraPath}\"";
+ arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
+ if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn";
return arguments;
}
@@ -278,7 +378,7 @@ private void StartLLMServer(string arguments)
try
{
InitLib(arch);
- InitServer(arguments);
+ InitService(arguments);
LLMUnitySetup.Log($"Using architecture: {arch}");
break;
}
@@ -287,6 +387,10 @@ private void StartLLMServer(string arguments)
error = e.Message;
Destroy();
}
+ catch (DestroyException)
+ {
+ break;
+ }
catch (Exception e)
{
error = $"{e.GetType()}: {e.Message}";
@@ -299,7 +403,7 @@ private void StartLLMServer(string arguments)
failed = true;
return;
}
- StartService();
+ CallWithLock(StartService);
LLMUnitySetup.Log("LLM service created");
}
@@ -309,13 +413,22 @@ private void InitLib(string arch)
CheckLLMStatus(false);
}
- private void InitServer(string arguments)
+ void CallWithLock(EmptyCallback fn, bool checkNull = true)
{
- if (debug) SetupLogging();
- LLMObject = llmlib.LLM_Construct(arguments);
- if (remote) llmlib.LLM_StartServer(LLMObject);
- SetTemplate(chatTemplate, false);
- CheckLLMStatus(false);
+ lock (startLock)
+ {
+ if (checkNull && llmlib == null) throw new DestroyException();
+ fn();
+ }
+ }
+
+ private void InitService(string arguments)
+ {
+ if (debug) CallWithLock(SetupLogging);
+ CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
+ CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate));
+ if (remote) CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
+ CallWithLock(() => CheckLLMStatus(false));
}
private void StartService()
@@ -323,6 +436,7 @@ private void StartService()
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
while (!llmlib.LLM_Started(LLMObject)) {}
+ ApplyLoras();
started = true;
}
@@ -345,7 +459,7 @@ protected int GetNumClients()
/// \cond HIDE
public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
- public delegate void LLMSimpleCallback(IntPtr LLMObject, string json_data);
+ public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper);
/// \endcond
@@ -381,9 +495,19 @@ void AssertStarted()
}
}
+ void AssertNotStarted()
+ {
+ if (started)
+ {
+ string error = "This method can't be called when the LLM has started";
+ LLMUnitySetup.LogError(error);
+ throw new Exception(error);
+ }
+ }
+
void CheckLLMStatus(bool log = true)
{
- if (llmlib == null) {return;}
+ if (llmlib == null) { return; }
IntPtr stringWrapper = llmlib.StringWrapper_Construct();
int status = llmlib.LLM_Status(LLMObject, stringWrapper);
string result = llmlib.GetStringWrapperResult(stringWrapper);
@@ -400,6 +524,17 @@ void CheckLLMStatus(bool log = true)
}
}
+ async Task LLMNoInputReply(LLMNoInputReplyCallback callback)
+ {
+ AssertStarted();
+ IntPtr stringWrapper = llmlib.StringWrapper_Construct();
+ await Task.Run(() => callback(LLMObject, stringWrapper));
+ string result = llmlib?.GetStringWrapperResult(stringWrapper);
+ llmlib?.StringWrapper_Delete(stringWrapper);
+ CheckLLMStatus();
+ return result;
+ }
+
async Task LLMReply(LLMReplyCallback callback, string json)
{
AssertStarted();
@@ -441,6 +576,62 @@ public async Task Detokenize(string json)
return await LLMReply(callback, json);
}
+ ///
+ /// Computes the embeddings of the provided query.
+ ///
+ /// json request containing the query
+ /// embeddings result
+ public async Task Embeddings(string json)
+ {
+ AssertStarted();
+ LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
+ {
+ llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
+ };
+ return await LLMReply(callback, json);
+ }
+
+ ///
+ /// Sets the lora scale, only works after the LLM service has started
+ ///
+ /// switch result
+ public void ApplyLoras()
+ {
+ LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList();
+ loraWeightRequest.loraWeights = new List();
+ float[] weights = loraManager.GetWeights();
+ for (int i = 0; i < weights.Length; i++)
+ {
+ loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = weights[i] });
+ }
+
+ string json = JsonUtility.ToJson(loraWeightRequest);
+ int startIndex = json.IndexOf("[");
+ int endIndex = json.LastIndexOf("]") + 1;
+ json = json.Substring(startIndex, endIndex - startIndex);
+
+ IntPtr stringWrapper = llmlib.StringWrapper_Construct();
+ llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper);
+ llmlib.StringWrapper_Delete(stringWrapper);
+ }
+
+ ///
+ /// Gets a list of the lora adapters
+ ///
+ /// list of lara adapters
+ public async Task> ListLoras()
+ {
+ AssertStarted();
+ LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
+ {
+ llmlib.LLM_LoraList(LLMObject, strWrapper);
+ };
+ string json = await LLMNoInputReply(callback);
+ if (String.IsNullOrEmpty(json)) return null;
+ LoraWeightResultList loraRequest = JsonUtility.FromJson("{\"loraWeights\": " + json + "}");
+ return loraRequest.loraWeights;
+ }
+
///
/// Allows to save / restore the state of a slot
///
@@ -479,7 +670,7 @@ public async Task Completion(string json, Callback streamCallbac
public async Task SetBasePrompt(string base_prompt)
{
AssertStarted();
- SystemPromptRequest request = new SystemPromptRequest(){system_prompt = base_prompt, prompt = " ", n_predict = 0};
+ SystemPromptRequest request = new SystemPromptRequest() { system_prompt = base_prompt, prompt = " ", n_predict = 0 };
await Completion(JsonUtility.ToJson(request));
}
@@ -499,29 +690,32 @@ public void CancelRequest(int id_slot)
///
public void Destroy()
{
- try
+ CallWithLock(() =>
{
- if (llmlib != null)
+ try
{
- if (LLMObject != IntPtr.Zero)
+ if (llmlib != null)
{
- llmlib.LLM_Stop(LLMObject);
- if (remote) llmlib.LLM_StopServer(LLMObject);
- StopLogging();
- llmThread?.Join();
- llmlib.LLM_Delete(LLMObject);
- LLMObject = IntPtr.Zero;
+ if (LLMObject != IntPtr.Zero)
+ {
+ llmlib.LLM_Stop(LLMObject);
+ if (remote) llmlib.LLM_StopServer(LLMObject);
+ StopLogging();
+ llmThread?.Join();
+ llmlib.LLM_Delete(LLMObject);
+ LLMObject = IntPtr.Zero;
+ }
+ llmlib.Destroy();
+ llmlib = null;
}
- llmlib.Destroy();
+ started = false;
+ failed = false;
}
- started = false;
- failed = false;
- llmlib = null;
- }
- catch (Exception e)
- {
- LLMUnitySetup.LogError(e.Message);
- }
+ catch (Exception e)
+ {
+ LLMUnitySetup.LogError(e.Message);
+ }
+ }, false);
}
///
diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs
index a3cf0070..542520f9 100644
--- a/Runtime/LLMBuilder.cs
+++ b/Runtime/LLMBuilder.cs
@@ -104,15 +104,19 @@ static void AddActionAddMeta(string target)
public static void HideLibraryPlatforms(string platform)
{
- List platforms = new List(){ "windows", "macos", "linux", "android", "ios" };
+ List platforms = new List(){ "windows", "macos", "linux", "android", "ios", "setup" };
platforms.Remove(platform);
foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath))
{
+ string sourceName = Path.GetFileName(source);
foreach (string platformPrefix in platforms)
{
- if (Path.GetFileName(source).StartsWith(platformPrefix))
+ bool move = sourceName.StartsWith(platformPrefix);
+ move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib);
+ move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib);
+ if (move)
{
- string target = Path.Combine(BuildTempDir, Path.GetFileName(source));
+ string target = Path.Combine(BuildTempDir, sourceName);
MoveAction(source, target);
MoveAction(source + ".meta", target + ".meta");
}
diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs
index 9f37721a..c5ce99fd 100644
--- a/Runtime/LLMCharacter.cs
+++ b/Runtime/LLMCharacter.cs
@@ -27,6 +27,8 @@ public class LLMCharacter : MonoBehaviour
[Remote] public string host = "localhost";
/// port to use for the LLM server
[Remote] public int port = 13333;
+ /// number of retries to use for the LLM server requests (-1 = infinite)
+ [Remote] public int numRetries = -1;
/// file to save the chat history.
/// The file is saved only for Chat calls with addToHistory set to true.
/// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
@@ -118,7 +120,7 @@ public class LLMCharacter : MonoBehaviour
public List chat;
private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
private string chatTemplate;
- private ChatTemplate template;
+ private ChatTemplate template = null;
public string grammarString;
protected int id_slot = -1;
private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
@@ -270,11 +272,22 @@ public void SetPrompt(string newPrompt, bool clearChat = true)
InitPrompt(clearChat);
}
+ private bool CheckTemplate()
+ {
+ if (template == null)
+ {
+ LLMUnitySetup.LogError("Template not set!");
+ return false;
+ }
+ return true;
+ }
+
private async Task InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
{
- string systemPrompt = template.ComputePrompt(new List(){chat[0]}, "", false);
+ if (!CheckTemplate()) return;
+ string systemPrompt = template.ComputePrompt(new List(){chat[0]}, playerName, "", false);
await Tokenize(systemPrompt, SetNKeep);
}
}
@@ -311,7 +324,8 @@ public async Task LoadTemplate()
if (llmTemplate != chatTemplate)
{
chatTemplate = llmTemplate;
- template = ChatTemplate.GetTemplate(chatTemplate);
+ template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
+ nKeep = -1;
}
}
@@ -331,6 +345,7 @@ public async void SetGrammar(string path)
List GetStopwords()
{
+ if (!CheckTemplate()) return null;
List stopAll = new List(template.GetStop(playerName, AIName));
if (stop != null) stopAll.AddRange(stop);
return stopAll;
@@ -430,16 +445,22 @@ protected List TokenizeContent(TokenizeResult result)
return result.tokens;
}
- protected string SlotContent(SlotResult result)
+ protected string DetokenizeContent(TokenizeRequest result)
{
- // get the tokens from a tokenize result received from the endpoint
- return result.filename;
+ // get content from a chat result received from the endpoint
+ return result.content;
}
- protected string DetokenizeContent(TokenizeRequest result)
+ protected List EmbeddingsContent(EmbeddingsResult result)
{
// get content from a chat result received from the endpoint
- return result.content;
+ return result.embedding;
+ }
+
+ protected string SlotContent(SlotResult result)
+ {
+ // get the tokens from a tokenize result received from the endpoint
+ return result.filename;
}
///
@@ -459,6 +480,7 @@ public async Task Chat(string query, Callback callback = null, E
// call the callback function while the answer is received
// call the completionCallback function when the answer is fully received
await LoadTemplate();
+ if (!CheckTemplate()) return null;
await InitNKeep();
string json;
@@ -466,7 +488,7 @@ public async Task Chat(string query, Callback callback = null, E
try
{
AddPlayerMessage(query);
- string prompt = template.ComputePrompt(chat, AIName);
+ string prompt = template.ComputePrompt(chat, playerName, AIName);
json = JsonUtility.ToJson(GenerateRequest(prompt));
chat.RemoveAt(chat.Count - 1);
}
@@ -572,6 +594,21 @@ public async Task Detokenize(List tokens, Callback callback
return await PostRequest(json, "detokenize", DetokenizeContent, callback);
}
+ ///
+ /// Computes the embeddings of the provided input.
+ ///
+ /// input to compute the embeddings for
+ /// callback function called with the result string
+ /// the computed embeddings
+ public async Task> Embeddings(string query, Callback> callback = null)
+ {
+ // handle the tokenization of a message by the user
+ TokenizeRequest tokenizeRequest = new TokenizeRequest();
+ tokenizeRequest.content = query;
+ string json = JsonUtility.ToJson(tokenizeRequest);
+ return await PostRequest>(json, "embeddings", EmbeddingsContent, callback);
+ }
+
private async Task Slot(string filepath, string action)
{
SlotRequest slotRequest = new SlotRequest();
@@ -682,6 +719,9 @@ protected async Task PostRequestLocal(string json, string endpoin
case "detokenize":
callResult = await llm.Detokenize(json);
break;
+ case "embeddings":
+ callResult = await llm.Embeddings(json);
+ break;
case "slots":
callResult = await llm.Slot(json);
break;
@@ -726,38 +766,57 @@ protected async Task PostRequestRemote(string json, string endpoi
Ret result = default;
byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json);
- using (var request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
- {
- WIPRequests.Add(request);
+ UnityWebRequest request = null;
+ string error = null;
+ int tryNr = numRetries;
- request.method = "POST";
- if (requestHeaders != null)
+ while (tryNr != 0)
+ {
+ using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
{
- for (int i = 0; i < requestHeaders.Count; i++)
- request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
- }
+ WIPRequests.Add(request);
- // Start the request asynchronously
- var asyncOperation = request.SendWebRequest();
- float lastProgress = 0f;
- // Continue updating progress until the request is completed
- while (!asyncOperation.isDone)
- {
- float currentProgress = request.downloadProgress;
- // Check if progress has changed
- if (currentProgress != lastProgress && callback != null)
+ request.method = "POST";
+ if (requestHeaders != null)
+ {
+ for (int i = 0; i < requestHeaders.Count; i++)
+ request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
+ }
+
+ // Start the request asynchronously
+ var asyncOperation = request.SendWebRequest();
+ float lastProgress = 0f;
+ // Continue updating progress until the request is completed
+ while (!asyncOperation.isDone)
{
- callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
- lastProgress = currentProgress;
+ float currentProgress = request.downloadProgress;
+ // Check if progress has changed
+ if (currentProgress != lastProgress && callback != null)
+ {
+ callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
+ lastProgress = currentProgress;
+ }
+ // Wait for the next frame
+ await Task.Yield();
+ }
+ WIPRequests.Remove(request);
+ if (request.result == UnityWebRequest.Result.Success)
+ {
+ result = ConvertContent(request.downloadHandler.text, getContent);
+ error = null;
+ break;
+ }
+ else
+ {
+ result = default;
+ error = request.error;
}
- // Wait for the next frame
- await Task.Yield();
}
- WIPRequests.Remove(request);
- if (request.result != UnityWebRequest.Result.Success) LLMUnitySetup.LogError(request.error);
- else result = ConvertContent(request.downloadHandler.text, getContent);
- callback?.Invoke(result);
+ tryNr--;
}
+
+ if (error != null) LLMUnitySetup.LogError(error);
+ callback?.Invoke(result);
return result;
}
diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs
index f6aa22cc..0820078d 100644
--- a/Runtime/LLMChatTemplates.cs
+++ b/Runtime/LLMChatTemplates.cs
@@ -1,5 +1,6 @@
/// @file
/// @brief File implementing the chat templates.
+using System;
using System.Collections.Generic;
using System.IO;
using UnityEngine;
@@ -36,6 +37,7 @@ static ChatTemplate()
{
new ChatMLTemplate(),
new AlpacaTemplate(),
+ new GemmaTemplate(),
new MistralChatTemplate(),
new MistralInstructTemplate(),
new LLama3ChatTemplate(),
@@ -113,9 +115,12 @@ public static string FromTemplate(string template)
/// template name
public static string FromGGUF(string path)
{
- GGUFReader reader = new GGUFReader(path);
- string name;
+ return FromGGUF(new GGUFReader(path), path);
+ }
+ public static string FromGGUF(GGUFReader reader, string path)
+ {
+ string name;
name = FromTemplate(reader.GetStringField("tokenizer.chat_template"));
if (name != null) return name;
@@ -165,7 +170,7 @@ public static ChatTemplate GetTemplate(string template)
/// the AI name
/// whether to end the prompt with the AI prefix
/// prompt
- public virtual string ComputePrompt(List messages, string AIName, bool endWithPrefix = true)
+ public virtual string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true)
{
string chatPrompt = PromptPrefix();
int start = 0;
@@ -332,6 +337,47 @@ public override string[] GetStop(string playerName, string AIName)
}
}
+ /// @ingroup template
+ ///
+ /// Class implementing the Gemma template
+ ///
+ public class GemmaTemplate : ChatTemplate
+ {
+ public override string GetName() { return "gemma"; }
+ public override string GetDescription() { return "gemma"; }
+ public override string[] GetNameMatches() { return new string[] {"gemma"}; }
+
+ protected override string RequestSuffix() { return "\n"; }
+ protected override string PairSuffix() { return "\n"; }
+
+ protected override string PlayerPrefix(string playerName) { return "" + playerName + "\n"; }
+ protected override string AIPrefix(string AIName) { return "" + AIName + "\n"; }
+
+ public override string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true)
+ {
+ List messagesSystemPrompt = messages;
+ if (messages[0].role == "system")
+ {
+ string firstUserMessage = messages[0].content;
+ int start = 1;
+ if (messages.Count > 1)
+ {
+ if (firstUserMessage != "") firstUserMessage += "\n\n";
+ firstUserMessage += messages[1].content;
+ start = 2;
+ }
+ messagesSystemPrompt = new List(){new ChatMessage { role = playerName, content = firstUserMessage }};
+ messagesSystemPrompt.AddRange(messages.GetRange(start, messages.Count - start));
+ }
+ return base.ComputePrompt(messagesSystemPrompt, playerName, AIName, endWithPrefix);
+ }
+
+ public override string[] GetStop(string playerName, string AIName)
+ {
+ return AddStopNewlines(new string[] { "", "" });
+ }
+ }
+
/// @ingroup template
///
/// Class implementing the Alpaca template
@@ -417,7 +463,7 @@ public class Phi3Template : ChatTemplate
protected override string PairSuffix() { return "<|end|>\n"; }
- public override string ComputePrompt(List messages, string AIName, bool endWithPrefix = true)
+ public override string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true)
{
List messagesSystemPrompt = messages;
if (messages[0].role == "system")
@@ -433,7 +479,7 @@ public override string ComputePrompt(List messages, string AIName,
messagesSystemPrompt = new List(){new ChatMessage { role = "user", content = firstUserMessage }};
messagesSystemPrompt.AddRange(messages.GetRange(start, messages.Count - start));
}
- return base.ComputePrompt(messagesSystemPrompt, AIName, endWithPrefix);
+ return base.ComputePrompt(messagesSystemPrompt, playerName, AIName, endWithPrefix);
}
public override string[] GetStop(string playerName, string AIName)
diff --git a/Runtime/LLMGGUF.cs b/Runtime/LLMGGUF.cs
index 4db81061..2ebeaf80 100644
--- a/Runtime/LLMGGUF.cs
+++ b/Runtime/LLMGGUF.cs
@@ -125,6 +125,13 @@ public ReaderField GetField(string key)
return null;
}
+ public byte[] GetGenericField(string key)
+ {
+ ReaderField field = GetField(key);
+ if (field == null || field.parts.Count == 0) return null;
+ return (byte[])field.parts[field.parts.Count - 1];
+ }
+
///
/// Allows to retrieve a string GGUF field.
///
@@ -132,9 +139,21 @@ public ReaderField GetField(string key)
/// Retrieved GGUF value
public string GetStringField(string key)
{
- ReaderField field = GetField(key);
- if (field == null || field.parts.Count == 0) return null;
- return System.Text.Encoding.UTF8.GetString((byte[])field.parts[field.parts.Count - 1]);
+ byte[] value = GetGenericField(key);
+ if (value == null) return null;
+ return System.Text.Encoding.UTF8.GetString(value);
+ }
+
+ ///
+ /// Allows to retrieve an integer GGUF field.
+ ///
+ /// GGUF field to retrieve
+ /// Retrieved GGUF value
+ public int GetIntField(string key)
+ {
+ byte[] value = GetGenericField(key);
+ if (value == null) return -1;
+ return BitConverter.ToInt32(value, 0);
}
private byte[] ReadBytes(int offset, int count)
diff --git a/Runtime/LLMInterface.cs b/Runtime/LLMInterface.cs
index f5086595..a9090632 100644
--- a/Runtime/LLMInterface.cs
+++ b/Runtime/LLMInterface.cs
@@ -93,6 +93,39 @@ public struct TokenizeResult
public List tokens;
}
+ [Serializable]
+ public struct EmbeddingsResult
+ {
+ public List embedding;
+ }
+
+ [Serializable]
+ public struct LoraWeightRequest
+ {
+ public int id;
+ public float scale;
+ }
+
+ [Serializable]
+ public struct LoraWeightRequestList
+ {
+ public List loraWeights;
+ }
+
+ [Serializable]
+ public struct LoraWeightResult
+ {
+ public int id;
+ public string path;
+ public float scale;
+ }
+
+ [Serializable]
+ public struct LoraWeightResultList
+ {
+ public List loraWeights;
+ }
+
[Serializable]
public struct TemplateResult
{
diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs
index c8d49ccd..e3501769 100644
--- a/Runtime/LLMLib.cs
+++ b/Runtime/LLMLib.cs
@@ -281,7 +281,6 @@ static LLMLib()
public LLMLib(string arch)
{
- LLMUnitySetup.Log(GetArchitecturePath(arch));
libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch));
if (libraryHandle == IntPtr.Zero)
{
@@ -298,6 +297,9 @@ public LLMLib(string arch)
LLM_SetTemplate = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_SetTemplate");
LLM_Tokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Tokenize");
LLM_Detokenize = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Detokenize");
+ LLM_Embeddings = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Embeddings");
+ LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_Weight");
+ LLM_LoraList = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Lora_List");
LLM_Completion = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Completion");
LLM_Slot = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Slot");
LLM_Cancel = LibraryLoader.GetSymbolDelegate(libraryHandle, "LLM_Cancel");
@@ -324,8 +326,16 @@ public static List PossibleArchitectures(bool gpu = false)
{
if (gpu)
{
- architectures.Add("cuda-cu12.2.0");
- architectures.Add("cuda-cu11.7.1");
+ if (LLMUnitySetup.FullLlamaLib)
+ {
+ architectures.Add("cuda-cu12.2.0-full");
+ architectures.Add("cuda-cu11.7.1-full");
+ }
+ else
+ {
+ architectures.Add("cuda-cu12.2.0");
+ architectures.Add("cuda-cu11.7.1");
+ }
architectures.Add("hip");
architectures.Add("vulkan");
}
@@ -452,6 +462,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate);
public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
+ public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper);
public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot);
@@ -474,6 +487,9 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public LLM_TokenizeDelegate LLM_Tokenize;
public LLM_DetokenizeDelegate LLM_Detokenize;
public LLM_CompletionDelegate LLM_Completion;
+ public LLM_EmbeddingsDelegate LLM_Embeddings;
+ public LLM_LoraWeightDelegate LLM_Lora_Weight;
+ public LLM_LoraListDelegate LLM_LoraList;
public LLM_SlotDelegate LLM_Slot;
public LLM_CancelDelegate LLM_Cancel;
public LLM_StatusDelegate LLM_Status;
diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs
index ff1101d8..f08edc87 100644
--- a/Runtime/LLMManager.cs
+++ b/Runtime/LLMManager.cs
@@ -17,17 +17,36 @@ public class ModelEntry
public string chatTemplate;
public string url;
public bool includeInBuild;
+ public int contextLength;
+ public static string GetFilenameOrRelativeAssetPath(string path)
+ {
+ string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
+ string basePath = LLMUnitySetup.GetAssetPath();
+ if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath))
+ {
+ return LLMUnitySetup.RelativePath(assetPath, basePath);
+ }
+ return path;
+ }
public ModelEntry(string path, bool lora = false, string label = null, string url = null)
{
- filename = Path.GetFileName(path);
+ filename = GetFilenameOrRelativeAssetPath(path);
this.label = label == null ? filename : label;
this.lora = lora;
- this.path = Path.GetFullPath(path).Replace('\\', '/');
- chatTemplate = lora ? null : ChatTemplate.FromGGUF(this.path);
+ this.path = LLMUnitySetup.GetFullPath(path);
this.url = url;
includeInBuild = true;
+ chatTemplate = null;
+ contextLength = -1;
+ if (!lora)
+ {
+ GGUFReader reader = new GGUFReader(this.path);
+ chatTemplate = ChatTemplate.FromGGUF(reader, this.path);
+ string arch = reader.GetStringField("general.architecture");
+ if (arch != null) contextLength = reader.GetIntField($"{arch}.context_length");
+ }
}
public ModelEntry OnlyRequiredFields()
@@ -154,7 +173,7 @@ public static void SetTemplate(ModelEntry entry, string chatTemplate)
public static ModelEntry Get(string path)
{
string filename = Path.GetFileName(path);
- string fullPath = Path.GetFullPath(path).Replace('\\', '/');
+ string fullPath = LLMUnitySetup.GetFullPath(path);
foreach (ModelEntry entry in modelEntries)
{
if (entry.filename == filename || entry.path == fullPath) return entry;
@@ -379,7 +398,7 @@ public static void Remove(ModelEntry entry)
foreach (LLM llm in llms)
{
if (!entry.lora && llm.model == entry.filename) llm.model = "";
- else if (entry.lora && llm.lora == entry.filename) llm.lora = "";
+ else if (entry.lora) llm.RemoveLora(entry.filename);
}
}
diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs
index 8150b22f..064a031d 100644
--- a/Runtime/LLMUnitySetup.cs
+++ b/Runtime/LLMUnitySetup.cs
@@ -47,6 +47,7 @@ public class ModelAttribute : PropertyAttribute {}
public class ModelDownloadAttribute : ModelAttribute {}
public class ModelDownloadAdvancedAttribute : ModelAdvancedAttribute {}
public class ModelAdvancedAttribute : PropertyAttribute {}
+ public class ModelExtrasAttribute : PropertyAttribute {}
public class ChatAttribute : PropertyAttribute {}
public class ChatAdvancedAttribute : PropertyAttribute {}
public class LLMUnityAttribute : PropertyAttribute {}
@@ -84,11 +85,15 @@ public class LLMUnitySetup
{
// DON'T CHANGE! the version is autocompleted with a GitHub action
/// LLM for Unity version
- public static string Version = "v2.1.1";
+ public static string Version = "v2.2.0";
/// LlamaLib version
- public static string LlamaLibVersion = "v1.1.6";
+ public static string LlamaLibVersion = "v1.1.10";
+ /// LlamaLib release url
+ public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}";
/// LlamaLib url
- public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip";
+ public static string LlamaLibURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp.zip";
+ /// LlamaLib extension url
+ public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp-full.zip";
/// LlamaLib path
public static string libraryPath = GetAssetPath(Path.GetFileName(LlamaLibURL).Replace(".zip", ""));
/// LLMnity store path
@@ -101,32 +106,23 @@ public class LLMUnitySetup
/// Default models for download
[HideInInspector] public static readonly (string, string, string)[] modelOptions = new(string, string, string)[]
{
- ("Llama 3 7B (medium, best overall)", "https://huggingface.co/lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf?download=true", "https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE"),
+ ("Llama 3.1 8B (medium, best overall)", "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf?download=true", "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE"),
+ ("Gemma 2 9B it (medium, great overall)", "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf?download=true", "https://ai.google.dev/gemma/terms"),
("Mistral 7B Instruct v0.2 (medium, great overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true", null),
("OpenHermes 2.5 7B (medium, good for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf?download=true", null),
("Phi 3 (small, great small model)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf?download=true", null),
("Qwen 2 0.5B (tiny, useful for mobile)", "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true", null),
};
- /// Add callback function to call for error logs
- public static void AddErrorCallBack(Callback callback)
- {
- errorCallbacks.Add(callback);
- }
-
- /// Remove callback function added for error logs
- public static void RemoveErrorCallBack(Callback callback)
- {
- errorCallbacks.Remove(callback);
- }
-
- /// Remove all callback function added for error logs
- public static void ClearErrorCallBacks()
- {
- errorCallbacks.Clear();
- }
-
/// \cond HIDE
+ [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All;
+ static string DebugModeKey = "DebugMode";
+ public static bool FullLlamaLib = false;
+ static string FullLlamaLibKey = "FullLlamaLib";
+ static List> errorCallbacks = new List>();
+ static readonly object lockObject = new object();
+ static Dictionary androidExtractTasks = new Dictionary();
+
public enum DebugModeType
{
All,
@@ -134,10 +130,6 @@ public enum DebugModeType
Error,
None
}
- [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All;
- static List> errorCallbacks = new List>();
- static readonly object lockObject = new object();
- static Dictionary androidExtractTasks = new Dictionary();
public static void Log(string message)
{
@@ -158,10 +150,10 @@ public static void LogError(string message)
foreach (Callback errorCallback in errorCallbacks) errorCallback(message);
}
- static string DebugModeKey = "DebugMode";
- static void LoadDebugMode()
+ static void LoadPlayerPrefs()
{
DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All);
+ FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1;
}
public static void SetDebugMode(DebugModeType newDebugMode)
@@ -172,6 +164,18 @@ public static void SetDebugMode(DebugModeType newDebugMode)
PlayerPrefs.Save();
}
+#if UNITY_EDITOR
+ public static void SetFullLlamaLib(bool value)
+ {
+ if (FullLlamaLib == value) return;
+ FullLlamaLib = value;
+ PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0);
+ PlayerPrefs.Save();
+ _ = DownloadLibrary();
+ }
+
+#endif
+
public static string GetAssetPath(string relPath = "")
{
// Path to store llm server binaries and models
@@ -183,15 +187,15 @@ public static string GetAssetPath(string relPath = "")
[InitializeOnLoadMethod]
static async Task InitializeOnLoad()
{
+ LoadPlayerPrefs();
await DownloadLibrary();
- LoadDebugMode();
}
#else
[RuntimeInitializeOnLoadMethod(RuntimeInitializeLoadType.BeforeSceneLoad)]
void InitializeOnLoad()
{
- LoadDebugMode();
+ LoadPlayerPrefs();
}
#endif
@@ -295,51 +299,118 @@ public static async Task AndroidExtractAsset(string path, bool overwrite = false
await AndroidExtractFile(Path.GetFileName(path), overwrite);
}
+ public static string GetFullPath(string path)
+ {
+ return Path.GetFullPath(path).Replace('\\', '/');
+ }
+
public static bool IsSubPath(string childPath, string parentPath)
{
- string fullParentPath = Path.GetFullPath(parentPath).Replace('\\', '/');
- string fullChildPath = Path.GetFullPath(childPath).Replace('\\', '/');
- return fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase);
+ return GetFullPath(childPath).StartsWith(GetFullPath(parentPath), StringComparison.OrdinalIgnoreCase);
+ }
+
+ public static string RelativePath(string fullPath, string basePath)
+ {
+ // Get the full paths and replace backslashes with forward slashes (or vice versa)
+ string fullParentPath = GetFullPath(basePath).TrimEnd('/');
+ string fullChildPath = GetFullPath(fullPath);
+
+ string relativePath = fullChildPath;
+ if (fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase))
+ {
+ relativePath = fullChildPath.Substring(fullParentPath.Length);
+ while (relativePath.StartsWith("/")) relativePath = relativePath.Substring(1);
+ }
+ return relativePath;
}
#if UNITY_EDITOR
[HideInInspector] public static float libraryProgress = 1;
- private static async Task DownloadLibrary()
+ public static void CreateEmptyFile(string path)
+ {
+ File.Create(path).Dispose();
+ }
+
+ static void ExtractInsideDirectory(string zipPath, string extractPath, bool overwrite = true)
+ {
+ using (ZipArchive archive = ZipFile.OpenRead(zipPath))
+ {
+ foreach (ZipArchiveEntry entry in archive.Entries)
+ {
+ if (string.IsNullOrEmpty(entry.Name)) continue;
+ string destinationPath = Path.Combine(extractPath, entry.FullName);
+ Directory.CreateDirectory(Path.GetDirectoryName(destinationPath));
+ entry.ExtractToFile(destinationPath, overwrite);
+ }
+ }
+ }
+
+ static async Task DownloadAndExtractInsideDirectory(string url, string path, string setupDir)
+ {
+ string urlName = Path.GetFileName(url);
+ string setupFile = Path.Combine(setupDir, urlName + ".complete");
+ if (File.Exists(setupFile)) return;
+
+ string zipPath = Path.Combine(Application.temporaryCachePath, urlName);
+ await DownloadFile(url, zipPath, true, null, SetLibraryProgress);
+
+ AssetDatabase.StartAssetEditing();
+ ExtractInsideDirectory(zipPath, path);
+ CreateEmptyFile(setupFile);
+ AssetDatabase.StopAssetEditing();
+
+ File.Delete(zipPath);
+ }
+
+ static async Task DownloadLibrary()
{
- if (libraryProgress < 1) return;
- libraryProgress = 0;
- string libZip = Path.Combine(Application.temporaryCachePath, Path.GetFileName(LlamaLibURL));
- if (!Directory.Exists(libraryPath))
+ void DeleteFileAndMeta(string path)
{
- await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress);
+ if (File.Exists(path + ".meta")) File.Delete(path + ".meta");
+ if (File.Exists(path)) File.Delete(path);
+ }
+
+ try
+ {
+ string setupDir = Path.Combine(libraryPath, "setup");
+ Directory.CreateDirectory(setupDir);
+
+ string lockFile = Path.Combine(setupDir, "LLMUnitySetup.lock");
+ if (File.Exists(lockFile)) return;
+ CreateEmptyFile(lockFile);
+
+ libraryProgress = 0;
+ await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir);
+
AssetDatabase.StartAssetEditing();
- ZipFile.ExtractToDirectory(libZip, libraryPath);
string androidDir = Path.Combine(libraryPath, "android");
if (Directory.Exists(androidDir))
{
- string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android");
- Directory.CreateDirectory(androidPluginDir);
- Directory.Move(androidDir, Path.Combine(androidPluginDir, Path.GetFileName(libraryPath)));
- }
- foreach (string librarySubPath in Directory.GetDirectories(libraryPath))
- {
- if (Path.GetFileName(librarySubPath).StartsWith("android"))
- {
- string pluginPath = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(librarySubPath));
- Directory.Move(librarySubPath, pluginPath);
- }
+ string androidPluginsDir = Path.Combine(Application.dataPath, "Plugins", "Android");
+ Directory.CreateDirectory(androidPluginsDir);
+ string pluginDir = Path.Combine(androidPluginsDir, Path.GetFileName(libraryPath));
+ if (Directory.Exists(pluginDir)) Directory.Delete(pluginDir, true);
+ Directory.Move(androidDir, pluginDir);
+ if (File.Exists(androidDir + ".meta")) File.Delete(androidDir + ".meta");
}
AssetDatabase.StopAssetEditing();
- File.Delete(libZip);
+
+ if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir);
+
+ libraryProgress = 1;
+ DeleteFileAndMeta(lockFile);
+ }
+ catch (Exception e)
+ {
+ LogError(e.Message);
}
- libraryProgress = 1;
}
private static void SetLibraryProgress(float progress)
{
- libraryProgress = progress;
+ libraryProgress = Math.Min(0.99f, progress);
}
public static string AddAsset(string assetPath)
@@ -363,6 +434,25 @@ public static string AddAsset(string assetPath)
#endif
/// \endcond
+
+ /// Add callback function to call for error logs
+ public static void AddErrorCallBack(Callback callback)
+ {
+ errorCallbacks.Add(callback);
+ }
+
+ /// Remove callback function added for error logs
+ public static void RemoveErrorCallBack(Callback callback)
+ {
+ errorCallbacks.Remove(callback);
+ }
+
+ /// Remove all callback function added for error logs
+ public static void ClearErrorCallBacks()
+ {
+ errorCallbacks.Clear();
+ }
+
public static int GetMaxFreqKHz(int cpuId)
{
string[] paths = new string[]
diff --git a/Runtime/LLMUtils.cs b/Runtime/LLMUtils.cs
new file mode 100644
index 00000000..6be9c6ff
--- /dev/null
+++ b/Runtime/LLMUtils.cs
@@ -0,0 +1,137 @@
+/// @file
+/// @brief File implementing LLM helper code.
+using System;
+using System.Collections.Generic;
+
+namespace LLMUnity
+{
+ /// \cond HIDE
+ public class LLMException : Exception
+ {
+ public int ErrorCode { get; private set; }
+
+ public LLMException(string message, int errorCode) : base(message)
+ {
+ ErrorCode = errorCode;
+ }
+ }
+
+ public class DestroyException : Exception {}
+
+ public class LoraAsset
+ {
+ public string assetPath;
+ public string fullPath;
+ public float weight;
+
+ public LoraAsset(string path, float weight = 1)
+ {
+ assetPath = LLM.GetLLMManagerAsset(path);
+ fullPath = RuntimePath(path);
+ this.weight = weight;
+ }
+
+ public static string RuntimePath(string path)
+ {
+ return LLMUnitySetup.GetFullPath(LLM.GetLLMManagerAssetRuntime(path));
+ }
+ }
+
+ public class LoraManager
+ {
+ List loras = new List();
+
+ public void Clear()
+ {
+ loras.Clear();
+ }
+
+ public int IndexOf(string path)
+ {
+ string fullPath = LoraAsset.RuntimePath(path);
+ for (int i = 0; i < loras.Count; i++)
+ {
+ LoraAsset lora = loras[i];
+ if (lora.assetPath == path || lora.fullPath == fullPath) return i;
+ }
+ return -1;
+ }
+
+ public bool Contains(string path)
+ {
+ return IndexOf(path) != -1;
+ }
+
+ public void Add(string path, float weight = 1)
+ {
+ if (Contains(path)) return;
+ loras.Add(new LoraAsset(path, weight));
+ }
+
+ public void Remove(string path)
+ {
+ int index = IndexOf(path);
+ if (index != -1) loras.RemoveAt(index);
+ }
+
+ public void SetWeight(string path, float weight)
+ {
+ int index = IndexOf(path);
+ if (index == -1)
+ {
+ LLMUnitySetup.LogError($"LoRA {path} not loaded with the LLM");
+ return;
+ }
+ loras[index].weight = weight;
+ }
+
+ public void FromStrings(string loraString, string loraWeightsString)
+ {
+ if (string.IsNullOrEmpty(loraString) && string.IsNullOrEmpty(loraWeightsString))
+ {
+ Clear();
+ return;
+ }
+
+ try
+ {
+ List loraStringArr = new List(loraString.Split(" "));
+ List loraWeightsStringArr = new List(loraWeightsString.Split(" "));
+ if (loraStringArr.Count != loraWeightsStringArr.Count) throw new Exception($"LoRAs number ({loraString}) doesn't match the number of weights ({loraWeightsString})");
+
+ List lorasNew = new List();
+ for (int i = 0; i < loraStringArr.Count; i++) lorasNew.Add(new LoraAsset(loraStringArr[i], float.Parse(loraWeightsStringArr[i])));
+ loras = lorasNew;
+ }
+ catch (Exception e)
+ {
+ LLMUnitySetup.LogError($"Loras not set: {e.Message}");
+ }
+ }
+
+ public (string, string) ToStrings()
+ {
+ string loraString = "";
+ string loraWeightsString = "";
+ for (int i = 0; i < loras.Count; i++)
+ {
+ if (i > 0)
+ {
+ loraString += " ";
+ loraWeightsString += " ";
+ }
+ loraString += loras[i].assetPath;
+ loraWeightsString += loras[i].weight;
+ }
+ return (loraString, loraWeightsString);
+ }
+
+ public float[] GetWeights()
+ {
+ float[] weights = new float[loras.Count];
+ for (int i = 0; i < loras.Count; i++) weights[i] = loras[i].weight;
+ return weights;
+ }
+ }
+ /// \endcond
+}
diff --git a/Runtime/LLMUtils.cs.meta b/Runtime/LLMUtils.cs.meta
new file mode 100644
index 00000000..93974a4f
--- /dev/null
+++ b/Runtime/LLMUtils.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 2ae6a2ce57e8af0fc876d0c380ed8a2f
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs
index b5afc880..edf0907b 100644
--- a/Tests/Runtime/TestLLM.cs
+++ b/Tests/Runtime/TestLLM.cs
@@ -5,54 +5,185 @@
using System.Collections.Generic;
using System;
using System.Collections;
-using UnityEngine.TestTools;
using System.IO;
+using UnityEngine.TestTools;
namespace LLMUnityTests
{
+ public class TestLLMLoraAssignment
+ {
+ [Test]
+ public void TestLoras()
+ {
+ GameObject gameObject = new GameObject();
+ gameObject.SetActive(false);
+ LLM llm = gameObject.AddComponent();
+
+ string lora1 = "/tmp/lala";
+ string lora2Rel = "test/lala";
+ string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel);
+ LLMUnitySetup.CreateEmptyFile(lora1);
+ Directory.CreateDirectory(Path.GetDirectoryName(lora2));
+ LLMUnitySetup.CreateEmptyFile(lora2);
+
+ llm.AddLora(lora1);
+ llm.AddLora(lora2);
+ Assert.AreEqual(llm.lora, lora1 + " " + lora2);
+ Assert.AreEqual(llm.loraWeights, "1 1");
+
+ llm.RemoveLoras();
+ Assert.AreEqual(llm.lora, "");
+ Assert.AreEqual(llm.loraWeights, "");
+
+ llm.AddLora(lora1, 0.8f);
+ llm.AddLora(lora2Rel, 0.9f);
+ Assert.AreEqual(llm.lora, lora1 + " " + lora2);
+ Assert.AreEqual(llm.loraWeights, "0.8 0.9");
+
+ llm.SetLoraWeight(lora2Rel, 0.7f);
+ Assert.AreEqual(llm.lora, lora1 + " " + lora2);
+ Assert.AreEqual(llm.loraWeights, "0.8 0.7");
+
+ llm.RemoveLora(lora2Rel);
+ Assert.AreEqual(llm.lora, lora1);
+ Assert.AreEqual(llm.loraWeights, "0.8");
+
+ llm.AddLora(lora2Rel);
+ llm.SetLoraWeight(lora2Rel, 0.5f);
+ Assert.AreEqual(llm.lora, lora1 + " " + lora2);
+ Assert.AreEqual(llm.loraWeights, "0.8 0.5");
+
+ llm.SetLoraWeight(lora2, 0.1f);
+ Assert.AreEqual(llm.lora, lora1 + " " + lora2);
+ Assert.AreEqual(llm.loraWeights, "0.8 0.1");
+
+ Dictionary loraToWeight = new Dictionary();
+ loraToWeight[lora1] = 0;
+ loraToWeight[lora2] = 0.2f;
+ llm.SetLoraWeights(loraToWeight);
+ Assert.AreEqual(llm.lora, lora1 + " " + lora2);
+ Assert.AreEqual(llm.loraWeights, "0 0.2");
+
+ File.Delete(lora1);
+ File.Delete(lora2);
+ }
+ }
+
public class TestLLM
{
+ protected static string modelUrl = "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true";
+ protected string modelNameLLManager;
+
protected GameObject gameObject;
protected LLM llm;
protected LLMCharacter llmCharacter;
- protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true";
- protected static string filename = Path.GetFileName(modelUrl).Split("?")[0];
- Exception error = null;
- string prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.";
+ protected Exception error = null;
+ protected string prompt;
+ protected string query;
+ protected string reply1;
+ protected string reply2;
+ protected int tokens1;
+ protected int tokens2;
+
public TestLLM()
{
- LLMUnitySetup.SetDebugMode(LLMUnitySetup.DebugModeType.All);
Task task = Init();
task.Wait();
}
public virtual async Task Init()
{
+ SetParameters();
+ await DownloadModels();
gameObject = new GameObject();
gameObject.SetActive(false);
- await SetLLM();
- SetLLMCharacter();
+ llm = CreateLLM();
+ llmCharacter = CreateLLMCharacter();
gameObject.SetActive(true);
}
- public async Task EmptyTask()
+ public virtual void SetParameters()
{
- await Task.Delay(1);
+ prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.";
+ query = "How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.";
+ reply1 = "To increase your meme production/output, you can try using more modern tools and techniques. For instance,";
+ reply2 = "To increase your meme production/output, you can try the following strategies:\n\n1. Use a meme generator";
+ tokens1 = 32;
+ tokens2 = 9;
}
- public virtual async Task SetLLM()
+ public virtual async Task DownloadModels()
{
- llm = gameObject.AddComponent();
- string filename = await LLMManager.DownloadModel(modelUrl);
- llm.SetModel(filename);
+ modelNameLLManager = await LLMManager.DownloadModel(modelUrl);
+ }
+
+ [Test]
+ public void TestGetLLMManagerAssetRuntime()
+ {
+ string path = "";
+ string managerPath = LLM.GetLLMManagerAssetRuntime(path);
+ Assert.AreEqual(managerPath, path);
+
+ path = "/tmp/lala";
+ LLMUnitySetup.CreateEmptyFile(path);
+ managerPath = LLM.GetLLMManagerAssetRuntime(path);
+ Assert.AreEqual(managerPath, path);
+ File.Delete(path);
+
+ path = modelNameLLManager;
+ managerPath = LLM.GetLLMManagerAssetRuntime(path);
+ Assert.AreEqual(managerPath, LLMManager.GetAssetPath(path));
+
+ path = LLMUnitySetup.GetAssetPath("lala");
+ LLMUnitySetup.CreateEmptyFile(path);
+ managerPath = LLM.GetLLMManagerAssetRuntime(path);
+ Assert.AreEqual(managerPath, path);
+ File.Delete(path);
+ }
+
+ [Test]
+ public void TestGetLLMManagerAssetEditor()
+ {
+ string path = "";
+ string managerPath = LLM.GetLLMManagerAssetEditor(path);
+ Assert.AreEqual(managerPath, path);
+
+ path = modelNameLLManager;
+ managerPath = LLM.GetLLMManagerAssetEditor(path);
+ Assert.AreEqual(managerPath, modelNameLLManager);
+
+ path = LLMManager.Get(modelNameLLManager).path;
+ managerPath = LLM.GetLLMManagerAssetEditor(path);
+ Assert.AreEqual(managerPath, modelNameLLManager);
+
+ string filename = "lala";
+ path = LLMUnitySetup.GetAssetPath(filename);
+ LLMUnitySetup.CreateEmptyFile(path);
+ managerPath = LLM.GetLLMManagerAssetEditor(filename);
+ Assert.AreEqual(managerPath, filename);
+ managerPath = LLM.GetLLMManagerAssetEditor(path);
+ Assert.AreEqual(managerPath, filename);
+ File.Delete(path);
+
+ path = "/tmp/lala";
+ LLMUnitySetup.CreateEmptyFile(path);
+ managerPath = LLM.GetLLMManagerAssetEditor(path);
+ Assert.AreEqual(managerPath, path);
+ File.Delete(path);
+ }
+
+ public virtual LLM CreateLLM()
+ {
+ LLM llm = gameObject.AddComponent();
+ llm.SetModel(modelNameLLManager);
llm.parallelPrompts = 1;
- llm.SetTemplate("alpaca");
+ return llm;
}
- public virtual void SetLLMCharacter()
+ public virtual LLMCharacter CreateLLMCharacter()
{
- llmCharacter = gameObject.AddComponent();
+ LLMCharacter llmCharacter = gameObject.AddComponent();
llmCharacter.llm = llm;
llmCharacter.playerName = "Instruction";
llmCharacter.AIName = "Response";
@@ -61,31 +192,28 @@ public virtual void SetLLMCharacter()
llmCharacter.seed = 0;
llmCharacter.stream = false;
llmCharacter.numPredict = 20;
+ return llmCharacter;
}
- public virtual async Task RunTests()
+ [UnityTest]
+ public IEnumerator RunTests()
+ {
+ Task task = RunTestsTask();
+ while (!task.IsCompleted) yield return null;
+ if (error != null)
+ {
+ Debug.LogError(error.ToString());
+ throw (error);
+ }
+ OnDestroy();
+ }
+
+ public async Task RunTestsTask()
{
error = null;
try
{
- llm.Awake();
- llmCharacter.Awake();
- await llmCharacter.Tokenize("I", TestTokens);
- await llmCharacter.Warmup();
- TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 1);
- TestWarmup();
- await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat);
- TestPostChat(3);
- llmCharacter.SetPrompt(llmCharacter.prompt);
- llmCharacter.AIName = "False response";
- await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat2);
- TestPostChat(3);
- await llmCharacter.Chat("bye!");
- TestPostChat(5);
- prompt = "How are you?";
- llmCharacter.SetPrompt(prompt);
- await llmCharacter.Chat("hi");
- TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3);
+ await Tests();
llm.OnDestroy();
}
catch (Exception e)
@@ -94,17 +222,26 @@ public virtual async Task RunTests()
}
}
- [UnityTest]
- public IEnumerator RunTestsWait()
+ public virtual async Task Tests()
{
- Task task = RunTests();
- while (!task.IsCompleted) yield return null;
- if (error != null)
- {
- Debug.LogError(error.ToString());
- throw (error);
- }
- OnDestroy();
+ await llmCharacter.Tokenize("I", TestTokens);
+ await llmCharacter.Warmup();
+ TestInitParameters(tokens1, 1);
+ TestWarmup();
+ await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1));
+ TestPostChat(3);
+ llmCharacter.SetPrompt(llmCharacter.prompt);
+ llmCharacter.AIName = "False response";
+ await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2));
+ TestPostChat(3);
+ await llmCharacter.Chat("bye!");
+ TestPostChat(5);
+ prompt = "How are you?";
+ llmCharacter.SetPrompt(prompt);
+ await llmCharacter.Chat("hi");
+ TestInitParameters(tokens2, 3);
+ List embeddings = await llmCharacter.Embeddings("hi how are you?");
+ TestEmbeddings(embeddings);
}
public void TestInitParameters(int nkeep, int chats)
@@ -116,7 +253,7 @@ public void TestInitParameters(int nkeep, int chats)
public void TestTokens(List tokens)
{
- Assert.AreEqual(tokens, new List {306});
+ Assert.AreEqual(tokens, new List {40});
}
public void TestWarmup()
@@ -124,16 +261,9 @@ public void TestWarmup()
Assert.That(llmCharacter.chat.Count == 1);
}
- public void TestChat(string reply)
+ public void TestChat(string reply, string replyGT)
{
- string AIReply = "To increase your meme production/output, you can consider the following:\n1. Use";
- Assert.That(reply.Trim() == AIReply);
- }
-
- public void TestChat2(string reply)
- {
- string AIReply = "One possible solution is to use a more advanced natural language processing library like NLTK or sp";
- Assert.That(reply.Trim() == AIReply);
+ Assert.That(reply.Trim() == replyGT);
}
public void TestPostChat(int num)
@@ -141,57 +271,158 @@ public void TestPostChat(int num)
Assert.That(llmCharacter.chat.Count == num);
}
- public virtual void OnDestroy()
+ public void TestEmbeddings(List embeddings)
{
- LLMManager.Remove(filename);
+ Assert.That(embeddings.Count == 896);
}
+
+ public virtual void OnDestroy() {}
}
public class TestLLM_LLMManager_Load : TestLLM
{
- public override Task SetLLM()
+ public override LLM CreateLLM()
{
- llm = gameObject.AddComponent();
+ LLM llm = gameObject.AddComponent();
+ string filename = Path.GetFileName(modelUrl).Split("?")[0];
string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
filename = LLMManager.LoadModel(sourcePath);
llm.SetModel(filename);
llm.parallelPrompts = 1;
- llm.SetTemplate("alpaca");
- return Task.CompletedTask;
+ return llm;
}
}
public class TestLLM_StreamingAssets_Load : TestLLM
{
- public override Task SetLLM()
+ string loadPath;
+
+ public override LLM CreateLLM()
{
- llm = gameObject.AddComponent();
+ LLM llm = gameObject.AddComponent();
+ string filename = Path.GetFileName(modelUrl).Split("?")[0];
string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
- string targetPath = LLMUnitySetup.GetAssetPath(filename);
- if (!File.Exists(targetPath)) File.Copy(sourcePath, targetPath);
- llm.SetModel(filename);
+ loadPath = LLMUnitySetup.GetAssetPath(filename);
+ if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath);
+ llm.SetModel(loadPath);
llm.parallelPrompts = 1;
- llm.SetTemplate("alpaca");
- return Task.CompletedTask;
+ return llm;
}
public override void OnDestroy()
{
- string targetPath = LLMUnitySetup.GetAssetPath(filename);
- if (!File.Exists(targetPath)) File.Delete(targetPath);
+ if (!File.Exists(loadPath)) File.Delete(loadPath);
}
}
public class TestLLM_SetModel_Warning : TestLLM
{
- public override Task SetLLM()
+ public override LLM CreateLLM()
{
- llm = gameObject.AddComponent();
- string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
- llm.SetModel(sourcePath);
+ LLM llm = gameObject.AddComponent();
+ string filename = Path.GetFileName(modelUrl).Split("?")[0];
+ string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
+ llm.SetModel(loadPath);
llm.parallelPrompts = 1;
- llm.SetTemplate("alpaca");
- return Task.CompletedTask;
+ return llm;
+ }
+ }
+
+ public class TestLLM_Remote : TestLLM
+ {
+ public override LLM CreateLLM()
+ {
+ LLM llm = base.CreateLLM();
+ llm.remote = true;
+ return llm;
+ }
+
+ public override LLMCharacter CreateLLMCharacter()
+ {
+ LLMCharacter llmCharacter = base.CreateLLMCharacter();
+ llmCharacter.remote = true;
+ return llmCharacter;
+ }
+ }
+
+ public class TestLLM_Lora : TestLLM
+ {
+ protected string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true";
+ protected string loraNameLLManager;
+ protected float loraWeight;
+
+ public override async Task DownloadModels()
+ {
+ await base.DownloadModels();
+ loraNameLLManager = await LLMManager.DownloadLora(loraUrl);
+ }
+
+ public override LLM CreateLLM()
+ {
+ LLM llm = base.CreateLLM();
+ llm.AddLora(loraNameLLManager, loraWeight);
+ return llm;
+ }
+
+ public override void SetParameters()
+ {
+ prompt = "";
+ query = "кто ты?";
+ reply1 = "Я - искусственный интеллект, создан для общения и понимания людей.";
+ reply2 = "Идиот";
+ tokens1 = 5;
+ tokens2 = 9;
+ loraWeight = 0.9f;
+ }
+
+ public override async Task Tests()
+ {
+ await base.Tests();
+ TestModelPaths();
+ await TestLoraWeight();
+ }
+
+ public void TestModelPaths()
+ {
+ Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(modelUrl).Split("?")[0]));
+ Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0]));
+ }
+
+ public async Task TestLoraWeight()
+ {
+ List loras = await llm.ListLoras();
+ Assert.AreEqual(loras[0].scale, loraWeight);
+ }
+ }
+
+
+ public class TestLLM_Lora_ChangeWeight : TestLLM_Lora
+ {
+ public override async Task Tests()
+ {
+ await base.Tests();
+ loraWeight = 0.6f;
+ llm.SetLoraWeight(loraNameLLManager, loraWeight);
+ await TestLoraWeight();
+ }
+ }
+
+ public class TestLLM_Double : TestLLM
+ {
+ LLM llm1;
+ LLMCharacter lLMCharacter1;
+
+ public override async Task Init()
+ {
+ SetParameters();
+ await DownloadModels();
+ gameObject = new GameObject();
+ gameObject.SetActive(false);
+ llm = CreateLLM();
+ llmCharacter = CreateLLMCharacter();
+ llm1 = CreateLLM();
+ lLMCharacter1 = CreateLLMCharacter();
+ gameObject.SetActive(true);
}
}
}
diff --git a/Tests/Runtime/TestLLMChatTemplates.cs b/Tests/Runtime/TestLLMChatTemplates.cs
index b48a2bef..bfedf6c8 100644
--- a/Tests/Runtime/TestLLMChatTemplates.cs
+++ b/Tests/Runtime/TestLLMChatTemplates.cs
@@ -21,16 +21,25 @@ public class TestChatTemplate
public void TestChatML()
{
Assert.AreEqual(
- new ChatMLTemplate().ComputePrompt(messages, "assistant"),
+ new ChatMLTemplate().ComputePrompt(messages, "user", "assistant"),
"<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n"
);
}
+ [Test]
+ public void TestGemma()
+ {
+ Assert.AreEqual(
+ new GemmaTemplate().ComputePrompt(messages, "user", "assistant"),
+ "user\nyou are a bot\n\nHello, how are you?\nassistant\nI'm doing great. How can I help you today?\nuser\nI'd like to show off how chat templating works!\nassistant\nchat template is awesome\nuser\ndo you think so?\nassistant\n"
+ );
+ }
+
[Test]
public void TestMistralInstruct()
{
Assert.AreEqual(
- new MistralInstructTemplate().ComputePrompt(messages, "assistant"),
+ new MistralInstructTemplate().ComputePrompt(messages, "user", "assistant"),
"[INST] you are a bot\n\nHello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]chat template is awesome[INST] do you think so? [/INST]"
);
}
@@ -39,7 +48,7 @@ public void TestMistralInstruct()
public void TestMistralChat()
{
Assert.AreEqual(
- new MistralChatTemplate().ComputePrompt(messages, "assistant"),
+ new MistralChatTemplate().ComputePrompt(messages, "user", "assistant"),
"[INST] you are a bot\n\n### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today?[INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome[INST] ### user: do you think so? [/INST]### assistant:"
);
}
@@ -48,7 +57,7 @@ public void TestMistralChat()
public void TestLLama2()
{
Assert.AreEqual(
- new LLama2Template().ComputePrompt(messages, "assistant"),
+ new LLama2Template().ComputePrompt(messages, "user", "assistant"),
"[INST] <>\nyou are a bot\n<> Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]chat template is awesome [INST] do you think so? [/INST]"
);
}
@@ -57,7 +66,7 @@ public void TestLLama2()
public void TestLLama2Chat()
{
Assert.AreEqual(
- new LLama2ChatTemplate().ComputePrompt(messages, "assistant"),
+ new LLama2ChatTemplate().ComputePrompt(messages, "user", "assistant"),
"[INST] <>\nyou are a bot\n<> ### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today? [INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome [INST] ### user: do you think so? [/INST]### assistant:"
);
}
@@ -66,7 +75,7 @@ public void TestLLama2Chat()
public void TestLLama3Chat()
{
Assert.AreEqual(
- new LLama3ChatTemplate().ComputePrompt(messages, "assistant"),
+ new LLama3ChatTemplate().ComputePrompt(messages, "user", "assistant"),
"<|start_header_id|>system<|end_header_id|>\n\nyou are a bot<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nchat template is awesome<|eot_id|><|start_header_id|>user<|end_header_id|>\n\ndo you think so?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
}
@@ -75,7 +84,7 @@ public void TestLLama3Chat()
public void TestAlpaca()
{
Assert.AreEqual(
- new AlpacaTemplate().ComputePrompt(messages, "assistant"),
+ new AlpacaTemplate().ComputePrompt(messages, "user", "assistant"),
"you are a bot\n\n### user: Hello, how are you?\n### assistant: I'm doing great. How can I help you today?\n### user: I'd like to show off how chat templating works!\n### assistant: chat template is awesome\n### user: do you think so?\n### assistant:"
);
}
@@ -84,7 +93,7 @@ public void TestAlpaca()
public void TestVicuna()
{
Assert.AreEqual(
- new VicunaTemplate().ComputePrompt(messages, "assistant"),
+ new VicunaTemplate().ComputePrompt(messages, "user", "assistant"),
"you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:"
);
}
@@ -93,7 +102,7 @@ public void TestVicuna()
public void TestPhi2()
{
Assert.AreEqual(
- new Phi2Template().ComputePrompt(messages, "assistant"),
+ new Phi2Template().ComputePrompt(messages, "user", "assistant"),
"you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:"
);
}
@@ -102,7 +111,7 @@ public void TestPhi2()
public void TestPhi3()
{
Assert.AreEqual(
- new Phi3Template().ComputePrompt(messages, "assistant"),
+ new Phi3Template().ComputePrompt(messages, "user", "assistant"),
"<|user|>\nyou are a bot\n\nHello, how are you?<|end|>\n<|assistant|>\nI'm doing great. How can I help you today?<|end|>\n<|user|>\nI'd like to show off how chat templating works!<|end|>\n<|assistant|>\nchat template is awesome<|end|>\n<|user|>\ndo you think so?<|end|>\n<|assistant|>\n"
);
}
@@ -111,7 +120,7 @@ public void TestPhi3()
public void TestZephyr()
{
Assert.AreEqual(
- new ZephyrTemplate().ComputePrompt(messages, "assistant"),
+ new ZephyrTemplate().ComputePrompt(messages, "user", "assistant"),
"<|system|>\nyou are a bot\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n<|assistant|>\nchat template is awesome\n<|user|>\ndo you think so?\n<|assistant|>\n"
);
}
diff --git a/Third Party Notices.md b/Third Party Notices.md
index ca86d120..3daa05a2 100644
--- a/Third Party Notices.md
+++ b/Third Party Notices.md
@@ -26,19 +26,35 @@ License: [link](https://github.com/Mozilla-Ocho/llamafile/blob/main/LICENSE)
The following models can be downloaded with LLMUnity:
-### meta-llama/Meta-Llama-3-8B-Instruct
+### meta-llama/Meta-Llama-3.1-8B-Instruct
Developer: Meta
-Origin: [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
-License Type: "llama3"
-License: [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE)
+Origin: [link](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
+License Type: "llama3.1"
+License: [link](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE)
-##### modified by: lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF
+##### modified by: bartowski/Meta-Llama-3-8B-Instruct-GGUF
-Developer: LM Studio
-Origin: [link](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF)
-License Type: "llama3"
-License: [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE)
+Developer:bartowski
+Origin: [link](https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF)
+License Type: "llama3.1"
+License: [link](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE)
+
+
+
+### google/gemma-2-9b-it
+
+Developer: Google
+Origin: [link](https://huggingface.co/google/gemma-2-9b-it)
+License Type: "gemma"
+License: [link](https://ai.google.dev/gemma/terms)
+
+##### modified by: bartowski/gemma-2-9b-it-GGUF
+
+Developer:bartowski
+Origin: [link](https://huggingface.co/bartowski/gemma-2-9b-it-GGUF)
+License Type: "gemma"
+License: [link](https://ai.google.dev/gemma/terms)
diff --git a/VERSION b/VERSION
index 826e1424..a4b6ac3d 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-v2.1.1
+v2.2.0
diff --git a/package.json b/package.json
index 795878ac..40182afe 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "ai.undream.llm",
- "version": "2.1.1",
+ "version": "2.2.0",
"displayName": "LLM for Unity",
"description": "LLM for Unity allows to run and distribute Large Language Models (LLMs) in the Unity engine.",
"unity": "2022.3",