From 1e8bead5b78a51e01a62e01ece2a2508ef5fdcc5 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 23 Jul 2024 15:58:21 +0300 Subject: [PATCH 01/26] button width static --- Editor/PropertyEditor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs index e5175a87..87a40938 100644 --- a/Editor/PropertyEditor.cs +++ b/Editor/PropertyEditor.cs @@ -8,7 +8,7 @@ namespace LLMUnity { public class PropertyEditor : Editor { - protected int buttonWidth = 150; + public static int buttonWidth = 150; public void AddScript(SerializedObject llmScriptSO) { From 85674381d5952be32c05fc844bd05aa42b7921e7 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 23 Jul 2024 15:59:01 +0300 Subject: [PATCH 02/26] LLM manager --- Editor/LLMManagerEditor.cs | 186 ++++++++++++++++++++++++++++++++ Editor/LLMManagerEditor.cs.meta | 11 ++ Runtime/LLMManager.cs | 64 +++++++++++ Runtime/LLMManager.cs.meta | 11 ++ Runtime/LLMUnitySetup.cs | 5 + 5 files changed, 277 insertions(+) create mode 100644 Editor/LLMManagerEditor.cs create mode 100644 Editor/LLMManagerEditor.cs.meta create mode 100644 Runtime/LLMManager.cs create mode 100644 Runtime/LLMManager.cs.meta diff --git a/Editor/LLMManagerEditor.cs b/Editor/LLMManagerEditor.cs new file mode 100644 index 00000000..d4adaa18 --- /dev/null +++ b/Editor/LLMManagerEditor.cs @@ -0,0 +1,186 @@ +using UnityEditor; +using UnityEngine; +using UnityEditorInternal; +using System; +using System.Collections.Generic; + +namespace LLMUnity +{ + [CustomEditor(typeof(LLMManager))] + public class LLMManagerEditor : Editor + { + private ReorderableList modelList; + static float nameColumnWidth = 250f; + static float textColumnWidth = 150f; + static float includeInBuildColumnWidth = 50f; + static float actionColumnWidth = 30f; + static int elementPadding = 10; + static GUIContent trashIcon; + static List modelOptions; + static List modelURLs; + + static void ResetModelOptions() + { + List existingOptions = new List(); + foreach (ModelEntry entry in LLMManager.modelEntries) existingOptions.Add(entry.url); + modelOptions = new List(); + modelURLs = new List(); + for (int i = 0; i < LLMUnitySetup.modelOptions.Length; i++) + { + string url = LLMUnitySetup.modelOptions[i].Item2; + if (existingOptions.Contains(url)) continue; + modelOptions.Add(LLMUnitySetup.modelOptions[i].Item1); + modelURLs.Add(url); + } + } + + List getColumnPositions(float offsetX) + { + List offsets = new List(); + float[] widths = new float[] {actionColumnWidth, nameColumnWidth, textColumnWidth, textColumnWidth, includeInBuildColumnWidth}; + float offset = offsetX; + foreach (float width in widths) + { + offsets.Add(offset); + offset += width + elementPadding; + } + return new List(){offsets.ToArray(), widths}; + } + + void UpdateModels(bool resetOptions = false) + { + LLMManager.Save(); + if (resetOptions) ResetModelOptions(); + Repaint(); + } + + void OnEnable() + { + ResetModelOptions(); + trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); + + modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), true, true, true, true) + { + drawElementCallback = async(rect, index, isActive, isFocused) => + { + if (index >= LLMManager.modelEntries.Count) return; + + List positions = getColumnPositions(rect.x); + float[] offsets = positions[0]; + float[] widths = positions[1]; + var actionRect = new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight); + var nameRect = new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight); + var urlRect = new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight); + var pathRect = new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight); + var includeInBuildRect = new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight); + var entry = LLMManager.modelEntries[index]; + + bool hasPath = entry.localPath != null && entry.localPath != ""; + bool hasURL = entry.url != null && entry.url != ""; + + + if (GUI.Button(actionRect, trashIcon)) + { + LLMManager.modelEntries.Remove(entry); + UpdateModels(true); + } + + DrawCopyableLabel(nameRect, entry.name); + + if (hasURL) + { + DrawCopyableLabel(urlRect, entry.url); + } + else if (hasPath) + { + string newURL = EditorGUI.TextField(urlRect, entry.url); + if (newURL != entry.url) + { + entry.url = newURL; + UpdateModels(); + } + } + else + { + urlRect.width = PropertyEditor.buttonWidth; + int newIndex = EditorGUI.Popup(urlRect, 0, modelOptions.ToArray()); + if (newIndex != 0) + { + await LLMManager.DownloadModel(entry, modelURLs[newIndex], modelOptions[newIndex]); + UpdateModels(true); + } + } + + if (hasPath) + { + DrawCopyableLabel(pathRect, entry.localPath); + } + else + { + pathRect.width = PropertyEditor.buttonWidth; + if (GUI.Button(pathRect, "Load model")) + { + EditorApplication.delayCall += () => + { + string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); + if (!string.IsNullOrEmpty(path)) + { + entry.localPath = path; + entry.name = LLMManager.ModelPathToName(path); + UpdateModels(); + } + }; + } + } + + bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); + if (includeInBuild != entry.includeInBuild) + { + entry.includeInBuild = includeInBuild; + UpdateModels(); + } + }, + drawHeaderCallback = (rect) => + { + List positions = getColumnPositions(rect.x + ReorderableList.Defaults.dragHandleWidth - ReorderableList.Defaults.padding + 1); + float[] offsets = positions[0]; + float[] widths = positions[1]; + EditorGUI.LabelField(new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight), ""); + EditorGUI.LabelField(new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight), "Model"); + EditorGUI.LabelField(new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight), "URL"); + EditorGUI.LabelField(new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight), "Local Path"); + EditorGUI.LabelField(new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight), "Build"); + } + }; + } + + private void DrawCopyableLabel(Rect rect, string text) + { + EditorGUI.LabelField(rect, text); + if (Event.current.type == EventType.ContextClick && rect.Contains(Event.current.mousePosition)) + { + GenericMenu menu = new GenericMenu(); + menu.AddItem(new GUIContent("Copy"), false, () => CopyToClipboard(text)); + menu.ShowAsContext(); + Event.current.Use(); + } + } + + private void CopyToClipboard(string text) + { + TextEditor te = new TextEditor + { + text = text + }; + te.SelectAll(); + te.Copy(); + } + + public override void OnInspectorGUI() + { + serializedObject.Update(); + modelList.DoLayoutList(); + serializedObject.ApplyModifiedProperties(); + } + } +} diff --git a/Editor/LLMManagerEditor.cs.meta b/Editor/LLMManagerEditor.cs.meta new file mode 100644 index 00000000..8e49e889 --- /dev/null +++ b/Editor/LLMManagerEditor.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 4209594efd29689d490881e0f61b9270 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs new file mode 100644 index 00000000..0d8d3539 --- /dev/null +++ b/Runtime/LLMManager.cs @@ -0,0 +1,64 @@ +#if UNITY_EDITOR +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using UnityEditor; +using UnityEngine; + +namespace LLMUnity +{ + [Serializable] + public class ModelEntry + { + public string name; + public string url; + public string localPath; + public bool includeInBuild; + } + + [Serializable] + public class ModelEntryList + { + public List modelEntries; + } + + public class LLMManager : MonoBehaviour + { + public static List modelEntries = new List(); + + [InitializeOnLoadMethod] + static void InitializeOnLoad() + { + Load(); + } + + public static string ModelPathToName(string path) + { + return Path.GetFileNameWithoutExtension(path.Split("?")[0]); + } + + public static async Task DownloadModel(ModelEntry entry, string url, string name = null) + { + string modelName = Path.GetFileName(url).Split("?")[0]; + string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName); + await LLMUnitySetup.DownloadFile(url, modelPath); + entry.name = name == null ? ModelPathToName(url) : name; + entry.url = url; + entry.localPath = modelPath; + } + + public static void Save() + { + Directory.CreateDirectory(Path.GetDirectoryName(LLMUnitySetup.modelListPath)); + File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new ModelEntryList { modelEntries = modelEntries })); + } + + public static void Load() + { + if (!File.Exists(LLMUnitySetup.modelListPath)) return; + modelEntries = JsonUtility.FromJson(File.ReadAllText(LLMUnitySetup.modelListPath)).modelEntries; + } + } +} +#endif diff --git a/Runtime/LLMManager.cs.meta b/Runtime/LLMManager.cs.meta new file mode 100644 index 00000000..ea565c07 --- /dev/null +++ b/Runtime/LLMManager.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 936a5c66e859e31489f7ab1b78acb987 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index bc9d9fb1..cb5d2b12 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -78,6 +78,10 @@ public class LLMUnitySetup public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; /// LlamaLib path public static string libraryPath = GetAssetPath(Path.GetFileName(LlamaLibURL).Replace(".zip", "")); + /// Model download path + public static string modelDownloadPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity"); + /// Model list for project + public static string modelListPath = Path.Combine(Application.temporaryCachePath, "modelCache.json"); /// Default models for download [HideInInspector] public static readonly (string, string)[] modelOptions = new(string, string)[] @@ -86,6 +90,7 @@ public class LLMUnitySetup ("Mistral 7B Instruct v0.2 (medium, best overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true"), ("OpenHermes 2.5 7B (medium, best for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf?download=true"), ("Phi 3 (small, great)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf?download=true"), + ("Test", "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true"), }; /// Add callback function to call for error logs From c7db57b07f461862dcd543c3e8b221b1372e28b9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 23 Jul 2024 17:12:38 +0300 Subject: [PATCH 03/26] move LLMManager Editor inside LLM --- Editor/LLMEditor.cs | 198 +++++++++++++++++++++++++++----- Editor/LLMManagerEditor.cs | 186 ------------------------------ Editor/LLMManagerEditor.cs.meta | 11 -- Runtime/LLM.cs | 1 + Runtime/LLMManager.cs | 2 +- 5 files changed, 171 insertions(+), 227 deletions(-) delete mode 100644 Editor/LLMManagerEditor.cs delete mode 100644 Editor/LLMManagerEditor.cs.meta diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index cb778f9c..7c5c0920 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using UnityEditor; +using UnityEditorInternal; using UnityEngine; namespace LLMUnity @@ -9,6 +10,16 @@ namespace LLMUnity [CustomEditor(typeof(LLM))] public class LLMEditor : PropertyEditor { + private ReorderableList modelList; + static float nameColumnWidth = 250f; + static float textColumnWidth = 150f; + static float includeInBuildColumnWidth = 50f; + static float actionColumnWidth = 30f; + static int elementPadding = 10; + static GUIContent trashIcon; + static List modelOptions; + static List modelURLs; + protected override Type[] GetPropertyTypes() { return new Type[] { typeof(LLM) }; @@ -24,37 +35,10 @@ public void AddModelLoadersSettings(SerializedObject llmScriptSO, LLM llmScript) public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) { - EditorGUILayout.BeginHorizontal(); - - string[] options = new string[LLMUnitySetup.modelOptions.Length]; - for (int i = 0; i < LLMUnitySetup.modelOptions.Length; i++) - { - options[i] = LLMUnitySetup.modelOptions[i].Item1; - } - - int newIndex = EditorGUILayout.Popup("Model", llmScript.SelectedModel, options); - if (newIndex != llmScript.SelectedModel) - { - llmScript.DownloadDefaultModel(newIndex); - } - - if (GUILayout.Button("Load model", GUILayout.Width(buttonWidth))) - { - EditorApplication.delayCall += () => - { - string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); - if (!string.IsNullOrEmpty(path)) - { - llmScript.ResetSelectedModel(); - llmScript.SetModel(path); - } - }; - } - EditorGUILayout.EndHorizontal(); - + modelList.DoLayoutList(); string[] templateOptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); int index = Array.IndexOf(ChatTemplate.templatesDescription.Values.ToList().ToArray(), llmScript.chatTemplate); - newIndex = EditorGUILayout.Popup("Chat Template", index, templateOptions); + int newIndex = EditorGUILayout.Popup("Chat Template", index, templateOptions); if (newIndex != index) { llmScript.SetTemplate(ChatTemplate.templatesDescription[templateOptions[newIndex]]); @@ -102,6 +86,162 @@ void ShowProgress(float progress, string progressText) if (progress != 1) EditorGUI.ProgressBar(EditorGUILayout.GetControlRect(), progress, progressText); } + static void ResetModelOptions() + { + List existingOptions = new List(); + foreach (ModelEntry entry in LLMManager.modelEntries) existingOptions.Add(entry.url); + modelOptions = new List(); + modelURLs = new List(); + for (int i = 0; i < LLMUnitySetup.modelOptions.Length; i++) + { + string url = LLMUnitySetup.modelOptions[i].Item2; + if (existingOptions.Contains(url)) continue; + modelOptions.Add(LLMUnitySetup.modelOptions[i].Item1); + modelURLs.Add(url); + } + } + + List getColumnPositions(float offsetX) + { + List offsets = new List(); + float[] widths = new float[] {actionColumnWidth, nameColumnWidth, textColumnWidth, textColumnWidth, includeInBuildColumnWidth}; + float offset = offsetX; + foreach (float width in widths) + { + offsets.Add(offset); + offset += width + elementPadding; + } + return new List(){offsets.ToArray(), widths}; + } + + void UpdateModels(bool resetOptions = false) + { + LLMManager.Save(); + if (resetOptions) ResetModelOptions(); + Repaint(); + } + + void OnEnable() + { + ResetModelOptions(); + trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); + + modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), true, true, true, true) + { + drawElementCallback = async(rect, index, isActive, isFocused) => + { + if (index >= LLMManager.modelEntries.Count) return; + + List positions = getColumnPositions(rect.x); + float[] offsets = positions[0]; + float[] widths = positions[1]; + var actionRect = new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight); + var nameRect = new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight); + var urlRect = new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight); + var pathRect = new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight); + var includeInBuildRect = new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight); + var entry = LLMManager.modelEntries[index]; + + bool hasPath = entry.localPath != null && entry.localPath != ""; + bool hasURL = entry.url != null && entry.url != ""; + + if (GUI.Button(actionRect, trashIcon)) + { + LLMManager.modelEntries.Remove(entry); + UpdateModels(true); + } + + DrawCopyableLabel(nameRect, entry.name); + + if (hasURL) + { + DrawCopyableLabel(urlRect, entry.url); + } + else if (hasPath) + { + string newURL = EditorGUI.TextField(urlRect, entry.url); + if (newURL != entry.url) + { + entry.url = newURL; + UpdateModels(); + } + } + else + { + urlRect.width = buttonWidth; + int newIndex = EditorGUI.Popup(urlRect, 0, modelOptions.ToArray()); + if (newIndex != 0) + { + await LLMManager.DownloadModel(entry, modelURLs[newIndex], modelOptions[newIndex]); + UpdateModels(true); + } + } + + if (hasPath) + { + DrawCopyableLabel(pathRect, entry.localPath); + } + else + { + pathRect.width = buttonWidth; + if (GUI.Button(pathRect, "Load model")) + { + EditorApplication.delayCall += () => + { + string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); + if (!string.IsNullOrEmpty(path)) + { + entry.localPath = path; + entry.name = LLMManager.ModelPathToName(path); + UpdateModels(); + } + }; + } + } + + bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); + if (includeInBuild != entry.includeInBuild) + { + entry.includeInBuild = includeInBuild; + UpdateModels(); + } + }, + drawHeaderCallback = (rect) => + { + List positions = getColumnPositions(rect.x + ReorderableList.Defaults.dragHandleWidth - ReorderableList.Defaults.padding + 1); + float[] offsets = positions[0]; + float[] widths = positions[1]; + EditorGUI.LabelField(new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight), ""); + EditorGUI.LabelField(new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight), "Model"); + EditorGUI.LabelField(new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight), "URL"); + EditorGUI.LabelField(new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight), "Local Path"); + EditorGUI.LabelField(new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight), "Build"); + } + }; + } + + private void DrawCopyableLabel(Rect rect, string text) + { + EditorGUI.LabelField(rect, text); + if (Event.current.type == EventType.ContextClick && rect.Contains(Event.current.mousePosition)) + { + GenericMenu menu = new GenericMenu(); + menu.AddItem(new GUIContent("Copy"), false, () => CopyToClipboard(text)); + menu.ShowAsContext(); + Event.current.Use(); + } + } + + private void CopyToClipboard(string text) + { + TextEditor te = new TextEditor + { + text = text + }; + te.SelectAll(); + te.Copy(); + } + public override void OnInspectorGUI() { LLM llmScript = (LLM)target; diff --git a/Editor/LLMManagerEditor.cs b/Editor/LLMManagerEditor.cs deleted file mode 100644 index d4adaa18..00000000 --- a/Editor/LLMManagerEditor.cs +++ /dev/null @@ -1,186 +0,0 @@ -using UnityEditor; -using UnityEngine; -using UnityEditorInternal; -using System; -using System.Collections.Generic; - -namespace LLMUnity -{ - [CustomEditor(typeof(LLMManager))] - public class LLMManagerEditor : Editor - { - private ReorderableList modelList; - static float nameColumnWidth = 250f; - static float textColumnWidth = 150f; - static float includeInBuildColumnWidth = 50f; - static float actionColumnWidth = 30f; - static int elementPadding = 10; - static GUIContent trashIcon; - static List modelOptions; - static List modelURLs; - - static void ResetModelOptions() - { - List existingOptions = new List(); - foreach (ModelEntry entry in LLMManager.modelEntries) existingOptions.Add(entry.url); - modelOptions = new List(); - modelURLs = new List(); - for (int i = 0; i < LLMUnitySetup.modelOptions.Length; i++) - { - string url = LLMUnitySetup.modelOptions[i].Item2; - if (existingOptions.Contains(url)) continue; - modelOptions.Add(LLMUnitySetup.modelOptions[i].Item1); - modelURLs.Add(url); - } - } - - List getColumnPositions(float offsetX) - { - List offsets = new List(); - float[] widths = new float[] {actionColumnWidth, nameColumnWidth, textColumnWidth, textColumnWidth, includeInBuildColumnWidth}; - float offset = offsetX; - foreach (float width in widths) - { - offsets.Add(offset); - offset += width + elementPadding; - } - return new List(){offsets.ToArray(), widths}; - } - - void UpdateModels(bool resetOptions = false) - { - LLMManager.Save(); - if (resetOptions) ResetModelOptions(); - Repaint(); - } - - void OnEnable() - { - ResetModelOptions(); - trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); - - modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), true, true, true, true) - { - drawElementCallback = async(rect, index, isActive, isFocused) => - { - if (index >= LLMManager.modelEntries.Count) return; - - List positions = getColumnPositions(rect.x); - float[] offsets = positions[0]; - float[] widths = positions[1]; - var actionRect = new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight); - var nameRect = new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight); - var urlRect = new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight); - var pathRect = new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight); - var includeInBuildRect = new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight); - var entry = LLMManager.modelEntries[index]; - - bool hasPath = entry.localPath != null && entry.localPath != ""; - bool hasURL = entry.url != null && entry.url != ""; - - - if (GUI.Button(actionRect, trashIcon)) - { - LLMManager.modelEntries.Remove(entry); - UpdateModels(true); - } - - DrawCopyableLabel(nameRect, entry.name); - - if (hasURL) - { - DrawCopyableLabel(urlRect, entry.url); - } - else if (hasPath) - { - string newURL = EditorGUI.TextField(urlRect, entry.url); - if (newURL != entry.url) - { - entry.url = newURL; - UpdateModels(); - } - } - else - { - urlRect.width = PropertyEditor.buttonWidth; - int newIndex = EditorGUI.Popup(urlRect, 0, modelOptions.ToArray()); - if (newIndex != 0) - { - await LLMManager.DownloadModel(entry, modelURLs[newIndex], modelOptions[newIndex]); - UpdateModels(true); - } - } - - if (hasPath) - { - DrawCopyableLabel(pathRect, entry.localPath); - } - else - { - pathRect.width = PropertyEditor.buttonWidth; - if (GUI.Button(pathRect, "Load model")) - { - EditorApplication.delayCall += () => - { - string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); - if (!string.IsNullOrEmpty(path)) - { - entry.localPath = path; - entry.name = LLMManager.ModelPathToName(path); - UpdateModels(); - } - }; - } - } - - bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); - if (includeInBuild != entry.includeInBuild) - { - entry.includeInBuild = includeInBuild; - UpdateModels(); - } - }, - drawHeaderCallback = (rect) => - { - List positions = getColumnPositions(rect.x + ReorderableList.Defaults.dragHandleWidth - ReorderableList.Defaults.padding + 1); - float[] offsets = positions[0]; - float[] widths = positions[1]; - EditorGUI.LabelField(new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight), ""); - EditorGUI.LabelField(new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight), "Model"); - EditorGUI.LabelField(new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight), "URL"); - EditorGUI.LabelField(new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight), "Local Path"); - EditorGUI.LabelField(new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight), "Build"); - } - }; - } - - private void DrawCopyableLabel(Rect rect, string text) - { - EditorGUI.LabelField(rect, text); - if (Event.current.type == EventType.ContextClick && rect.Contains(Event.current.mousePosition)) - { - GenericMenu menu = new GenericMenu(); - menu.AddItem(new GUIContent("Copy"), false, () => CopyToClipboard(text)); - menu.ShowAsContext(); - Event.current.Use(); - } - } - - private void CopyToClipboard(string text) - { - TextEditor te = new TextEditor - { - text = text - }; - te.SelectAll(); - te.Copy(); - } - - public override void OnInspectorGUI() - { - serializedObject.Update(); - modelList.DoLayoutList(); - serializedObject.ApplyModifiedProperties(); - } - } -} diff --git a/Editor/LLMManagerEditor.cs.meta b/Editor/LLMManagerEditor.cs.meta deleted file mode 100644 index 8e49e889..00000000 --- a/Editor/LLMManagerEditor.cs.meta +++ /dev/null @@ -1,11 +0,0 @@ -fileFormatVersion: 2 -guid: 4209594efd29689d490881e0f61b9270 -MonoImporter: - externalObjects: {} - serializedVersion: 2 - defaultReferences: [] - executionOrder: 0 - icon: {instanceID: 0} - userData: - assetBundleName: - assetBundleVariant: diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 76e31ab4..1811b942 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -97,6 +97,7 @@ public class LLM : MonoBehaviour public bool failed { get; protected set; } = false; /// \cond HIDE + public LLMManager llmManager = new LLMManager(); public int SelectedModel = 0; [HideInInspector] public float modelProgress = 1; [HideInInspector] public float loraProgress = 1; diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 0d8d3539..2496a8b6 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -23,7 +23,7 @@ public class ModelEntryList public List modelEntries; } - public class LLMManager : MonoBehaviour + public class LLMManager { public static List modelEntries = new List(); From 9083761f6c76511251dbf577038493eaa4278efc Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 23 Jul 2024 20:05:17 +0300 Subject: [PATCH 04/26] add trash icon --- Resources.meta | 8 ++ Resources/llmunity_trash_icon.png | Bin 0 -> 3334 bytes Resources/llmunity_trash_icon.png.meta | 140 +++++++++++++++++++++++++ 3 files changed, 148 insertions(+) create mode 100644 Resources.meta create mode 100644 Resources/llmunity_trash_icon.png create mode 100644 Resources/llmunity_trash_icon.png.meta diff --git a/Resources.meta b/Resources.meta new file mode 100644 index 00000000..cd69ccb5 --- /dev/null +++ b/Resources.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 688bae55bf18bd75dbc7fee333923c15 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Resources/llmunity_trash_icon.png b/Resources/llmunity_trash_icon.png new file mode 100644 index 0000000000000000000000000000000000000000..7457cc94936e3f63e821d5fdea61645c7af4e978 GIT binary patch literal 3334 zcmeH}cQD-T7RP@q!U}=}Q6hMwCM=1z5_PpG5hX>H$jZq>74F_sR8m$^g{f<3YH8os(bdy`@X*l6*yNF^xrL>bwT-PE+}_dY zAI>hW9-dy_2p=TM*B>1a7!>?8G%O-A>KP^`HZDFP>G_M~l+;(R)3dU%Ik|cH1%>a6 zic3n%$}1|XK2+msYU}Fp4UJ9BEgwI%wRd!W?)uW*)7$rTU~p)7WOQtNVseT=oSvDT zn_pO5`nJ46T3y@N-1@$~v%9x{aCmh5lff|g#T0frny{5mZsz3mr?t?A zsTlhcwq{PPQ%pF3o&E2G`gr$HBWz}}x}1o7y}F_p#1^|rm{&S3}!#Lf|Qw+#rG{@-gv>> z5@)647lutMnb6$8Er>EoG9wZe(!+;)2y6;$+()knJ;UMMGzpo$`G2^m`VM(-J~j)v z)subcMozlC{rIzynMwMg;DGj0zEMM7Z85xkmlgOlKN+XFid1Z@8h6!=m#m%e$kbX| zvoLaPw_+s5jE!cE=J3vRC`PX^;gIr}keP)8>d~oPd^Udpa(IAoGp~|2(r-$X~iZ6zu<3C)nnVNt4!zFS$f0ic^S!ZyZf#09 zB49T0(XR{)3oKmRZ6&L6I_2bjUsf_$tpFeA*I=6?*I%jHf`@F3FG%^7^k{fg@1Zzn zSqi&cg_m3h9g_qDaBof+=NY|eBt=a9SI|9oxrKU4$cdQ+h4-84+3e1xw1f3o zEfKlScaK9Xuqs9rV1rP~%UFO~Sh+JC5&Of`SJs?5D#6L{xMhs~Q#E$bmi1AA_p#V+ ztYSO7Gdfp>8$0yKel=^?;A?5^QYb^9Iw3^X=*t-=_u#|ZN>qH2_99(gpJ7jY?5{^`QZc$}>4EVWZt};&Rvx1AQ-L z9|H|9<=cZx?w{t}0=b=!s6oiLG~Ed^VcVNBJyhu)CO|XjdzP`6RKf5O zrLexU!tVD!Jl$dAYo`>wNZ+jfpCN}@dx?UsHqf$(R;}-K*ZE2&gNj=ooM~sPua_1I zLSq5xBw0;1p5($i_|Kmlfih8ka;u zQ2A;-!%}t}(UB?jHr4W|&(rmpwGXJ%QOssiNoGUyaqIc{2)A_CMPpZ&&AK*$-epHF z+sEhQmUl96rOIj(43i_v1Uq)2m?uGHAxB%NoOorQ?TcsCz2dbTS-xnSqbv&QeX>>m$)NaAh_>8}=bXr01}l|L+}IKH}7+P5?=!s z`OB1c!aKKKN*e5wXrke451$ihn%{v~3!!gL(d_SUREvDta}(zSk`DxVOS?RH{9Jk0 z_FvQVCsfEY_CICcleZY*nB!iNQceFo`9q^3oGrkWPCih?o0m~gCeuvsEV#VD7~$)} zD$*D%4H_KkD0pj>QawhB#yH#o8Q@;;=)}|>22E|ynKmqGc2M3)^&;+W6IsWU>vM!8 zDb&@OGT#Mi{Bm1rzM%ix=Eyylf^+w4=)L0a9*7z|p@70OE!&fyBMgzJv^{aOa*Qe; zMs3R;D~J*`z2lW%*!%xVpS5n>Eb31{P9eml>Kcw&D&KY|**bw>9nJ;cOi_06-lx|_ zsK(ldO;~jU36h`dt3^X%9li9;Wki-opY7f@y{V%eP0s^s#}C#=zH|_COq9KdP||^o z)m9tkhjRtV6dK%*GRB;v{u*BZ(WnUM0n#pk2e1 zJUUi&J`2=gsDigQ#@A$qD5X)XT>icdAE9v37*wVZ3Mq#Y%9;&Jl+!AWAX*iGej>Q= z!L9LZW?GGLK^>Wg13rtl@uUrOhN&pCc}ZnN&=>9%D*U8Ay3H1`Bhg>FqsA%|k4uSK+f|8sBGfxVi($5(DN%|GI@mY+sry=;d& zesHvEo^`y7in;(rb2$>%KfefEWAkMWF*ZHlshL~C6t62;)D%niGOR}L>-U+-4Rqou zQlCzCa^g1BiWL0L;j(5Cv;KX}T|eHGs>~txYlIrxUhT*B&B*Q#d%Vi5&G$HlBTsoY z<#^^?XTC>j84Xl|%Jg0|%*}u;5vf;5`EIE~JUUwhHkr&@Xy$BG4kmbyoxZao2@!CQC%jJSvaNa`L2Rlk(8p<}!d;)yC7he|@kAP^`7f)Vxe{4ayMr@gZy`mYT~egh+91E8(04=Y!- G3I8V^H`opU literal 0 HcmV?d00001 diff --git a/Resources/llmunity_trash_icon.png.meta b/Resources/llmunity_trash_icon.png.meta new file mode 100644 index 00000000..9b334b18 --- /dev/null +++ b/Resources/llmunity_trash_icon.png.meta @@ -0,0 +1,140 @@ +fileFormatVersion: 2 +guid: 0e04eced3ed2d120e84e7c10c8b32ddc +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 12 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + flipGreenChannel: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + vTOnly: 0 + ignoreMipmapLimit: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: 1 + aniso: 1 + mipBias: 0 + wrapU: 0 + wrapV: 0 + wrapW: 0 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + flipbookRows: 1 + flipbookColumns: 1 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + ignorePngGamma: 0 + applyGammaDecoding: 0 + swizzle: 50462976 + cookieLightType: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Standalone + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Android + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Server + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + nameFileIdTable: {} + mipmapLimitGroupName: + pSDRemoveMatte: 0 + userData: + assetBundleName: + assetBundleVariant: From abdb8aab58517de221735b4f43ef9a54c5a0af60 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 23 Jul 2024 20:05:58 +0300 Subject: [PATCH 05/26] migrate LLMManager Editor to LLM Editor --- Editor/LLMBuildProcessor.cs | 4 +- Editor/LLMEditor.cs | 169 ++++++++++++++++++++---------------- Runtime/LLM.cs | 118 +++---------------------- Runtime/LLMManager.cs | 74 ++++++++++++++-- Runtime/LLMUnitySetup.cs | 7 ++ 5 files changed, 181 insertions(+), 191 deletions(-) diff --git a/Editor/LLMBuildProcessor.cs b/Editor/LLMBuildProcessor.cs index 2e17bf84..05cc6238 100644 --- a/Editor/LLMBuildProcessor.cs +++ b/Editor/LLMBuildProcessor.cs @@ -117,8 +117,8 @@ static void HideModels() { foreach (LLM llm in FindObjectsOfType()) { - if (!llm.downloadOnBuild) continue; - if (llm.modelURL != "") MoveAssetAndMeta(LLMUnitySetup.GetAssetPath(llm.model), Path.Combine(tempDir, Path.GetFileName(llm.model))); + // if (!llm.downloadOnBuild) continue; + // if (llm.modelURL != "") MoveAssetAndMeta(LLMUnitySetup.GetAssetPath(llm.model), Path.Combine(tempDir, Path.GetFileName(llm.model))); if (llm.loraURL != "") MoveAssetAndMeta(LLMUnitySetup.GetAssetPath(llm.lora), Path.Combine(tempDir, Path.GetFileName(llm.lora))); } } diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 7c5c0920..63f95455 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -11,14 +11,16 @@ namespace LLMUnity public class LLMEditor : PropertyEditor { private ReorderableList modelList; - static float nameColumnWidth = 250f; + static float nameColumnWidth = 150f; + static float templateColumnWidth = 100f; static float textColumnWidth = 150f; static float includeInBuildColumnWidth = 50f; - static float actionColumnWidth = 30f; + static float actionColumnWidth = 20f; static int elementPadding = 10; static GUIContent trashIcon; static List modelOptions; static List modelURLs; + string[] templateOptions; protected override Type[] GetPropertyTypes() { @@ -35,14 +37,12 @@ public void AddModelLoadersSettings(SerializedObject llmScriptSO, LLM llmScript) public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) { + float[] widths = GetColumnWidths(); + float listWidth = ReorderableList.Defaults.dragHandleWidth; + foreach (float width in widths) listWidth += width + (listWidth == 0 ? 0 : elementPadding); + EditorGUILayout.BeginVertical(GUILayout.Width(listWidth)); modelList.DoLayoutList(); - string[] templateOptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); - int index = Array.IndexOf(ChatTemplate.templatesDescription.Values.ToList().ToArray(), llmScript.chatTemplate); - int newIndex = EditorGUILayout.Popup("Chat Template", index, templateOptions); - if (newIndex != index) - { - llmScript.SetTemplate(ChatTemplate.templatesDescription[templateOptions[newIndex]]); - } + EditorGUILayout.EndVertical(); } public void AddModelAddonLoaders(SerializedObject llmScriptSO, LLM llmScript, bool layout = true) @@ -70,14 +70,11 @@ public void AddModelAddonLoaders(SerializedObject llmScriptSO, LLM llmScript, bo public void AddModelSettings(SerializedObject llmScriptSO) { List attributeClasses = new List { typeof(ModelAttribute) }; - List excludeAttributeClasses = new List { typeof(ModelDownloadAttribute), typeof(ModelDownloadAdvancedAttribute) }; - if (llmScriptSO.FindProperty("downloadOnBuild").boolValue) excludeAttributeClasses.Remove(typeof(ModelDownloadAttribute)); if (llmScriptSO.FindProperty("advancedOptions").boolValue) { attributeClasses.Add(typeof(ModelAdvancedAttribute)); - if (llmScriptSO.FindProperty("downloadOnBuild").boolValue) excludeAttributeClasses.Remove(typeof(ModelDownloadAdvancedAttribute)); } - ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false, excludeAttributeClasses); + ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false); Space(); } @@ -95,23 +92,29 @@ static void ResetModelOptions() for (int i = 0; i < LLMUnitySetup.modelOptions.Length; i++) { string url = LLMUnitySetup.modelOptions[i].Item2; - if (existingOptions.Contains(url)) continue; + if (i > 0 && existingOptions.Contains(url)) continue; modelOptions.Add(LLMUnitySetup.modelOptions[i].Item1); modelURLs.Add(url); } } - List getColumnPositions(float offsetX) + float[] GetColumnWidths() + { + float[] widths = new float[] {actionColumnWidth, nameColumnWidth, templateColumnWidth, textColumnWidth, textColumnWidth, includeInBuildColumnWidth, actionColumnWidth}; + return widths; + } + + List CreateColumnRects(float x, float y) { - List offsets = new List(); - float[] widths = new float[] {actionColumnWidth, nameColumnWidth, textColumnWidth, textColumnWidth, includeInBuildColumnWidth}; - float offset = offsetX; + float[] widths = GetColumnWidths(); + float offset = x; + List rects = new List(); foreach (float width in widths) { - offsets.Add(offset); + rects.Add(new Rect(offset, y, width, EditorGUIUtility.singleLineHeight)); offset += width + elementPadding; } - return new List(){offsets.ToArray(), widths}; + return rects; } void UpdateModels(bool resetOptions = false) @@ -123,41 +126,54 @@ void UpdateModels(bool resetOptions = false) void OnEnable() { + var llmScript = (LLM)target; ResetModelOptions(); + templateOptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), true, true, true, true) { - drawElementCallback = async(rect, index, isActive, isFocused) => + drawElementCallback = (rect, index, isActive, isFocused) => { if (index >= LLMManager.modelEntries.Count) return; - List positions = getColumnPositions(rect.x); - float[] offsets = positions[0]; - float[] widths = positions[1]; - var actionRect = new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight); - var nameRect = new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight); - var urlRect = new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight); - var pathRect = new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight); - var includeInBuildRect = new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight); + List rects = CreateColumnRects(rect.x, rect.y); + var selectRect = rects[0]; + var nameRect = rects[1]; + var templateRect = rects[2]; + var urlRect = rects[3]; + var pathRect = rects[4]; + var includeInBuildRect = rects[5]; + var actionRect = rects[6]; var entry = LLMManager.modelEntries[index]; bool hasPath = entry.localPath != null && entry.localPath != ""; bool hasURL = entry.url != null && entry.url != ""; - if (GUI.Button(actionRect, trashIcon)) + bool isSelected = llmScript.model == entry.localPath; + bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); + if (newSelected && !isSelected) { - LLMManager.modelEntries.Remove(entry); - UpdateModels(true); + llmScript.model = entry.localPath; + llmScript.SetTemplate(entry.chatTemplate); } DrawCopyableLabel(nameRect, entry.name); + int templateIndex = Array.IndexOf(ChatTemplate.templatesDescription.Values.ToList().ToArray(), entry.chatTemplate); + int newTemplateIndex = EditorGUI.Popup(templateRect, templateIndex, templateOptions); + if (newTemplateIndex != templateIndex) + { + entry.chatTemplate = ChatTemplate.templatesDescription[templateOptions[newTemplateIndex]]; + if (isSelected) llmScript.SetTemplate(entry.chatTemplate); + UpdateModels(); + } + if (hasURL) { DrawCopyableLabel(urlRect, entry.url); } - else if (hasPath) + else { string newURL = EditorGUI.TextField(urlRect, entry.url); if (newURL != entry.url) @@ -166,56 +182,63 @@ void OnEnable() UpdateModels(); } } - else + DrawCopyableLabel(pathRect, entry.localPath); + + bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); + if (includeInBuild != entry.includeInBuild) { - urlRect.width = buttonWidth; - int newIndex = EditorGUI.Popup(urlRect, 0, modelOptions.ToArray()); - if (newIndex != 0) - { - await LLMManager.DownloadModel(entry, modelURLs[newIndex], modelOptions[newIndex]); - UpdateModels(true); - } + entry.includeInBuild = includeInBuild; + UpdateModels(); } - if (hasPath) + if (GUI.Button(actionRect, trashIcon)) { - DrawCopyableLabel(pathRect, entry.localPath); + LLMManager.modelEntries.Remove(entry); + UpdateModels(true); } - else + }, + drawHeaderCallback = (rect) => + { + List rects = CreateColumnRects(rect.x + ReorderableList.Defaults.dragHandleWidth - ReorderableList.Defaults.padding + 1, rect.y); + EditorGUI.LabelField(rects[0], ""); + EditorGUI.LabelField(rects[1], "Model"); + EditorGUI.LabelField(rects[2], "Chat template"); + EditorGUI.LabelField(rects[3], "URL"); + EditorGUI.LabelField(rects[4], "Path"); + EditorGUI.LabelField(rects[5], "Build"); + EditorGUI.LabelField(rects[6], ""); + }, + drawFooterCallback = async(rect) => + { + Rect downloadRect = new Rect(rect.x, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect loadRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + + int newIndex = EditorGUI.Popup(downloadRect, 0, modelOptions.ToArray()); + if (newIndex != 0) { - pathRect.width = buttonWidth; - if (GUI.Button(pathRect, "Load model")) + await LLMManager.DownloadModel(modelURLs[newIndex], modelOptions[newIndex]); + UpdateModels(true); + } + + if (GUI.Button(loadRect, "Load model")) + { + EditorApplication.delayCall += () => { - EditorApplication.delayCall += () => + string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); + if (!string.IsNullOrEmpty(path)) { - string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); - if (!string.IsNullOrEmpty(path)) - { - entry.localPath = path; - entry.name = LLMManager.ModelPathToName(path); - UpdateModels(); - } - }; - } + LLMManager.LoadModel(path); + UpdateModels(); + } + }; } - bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); - if (includeInBuild != entry.includeInBuild) + bool downloadOnBuild = EditorGUILayout.Toggle("Download on Build", LLMManager.downloadOnBuild); + if (downloadOnBuild != LLMManager.downloadOnBuild) { - entry.includeInBuild = includeInBuild; + LLMManager.downloadOnBuild = downloadOnBuild; UpdateModels(); } - }, - drawHeaderCallback = (rect) => - { - List positions = getColumnPositions(rect.x + ReorderableList.Defaults.dragHandleWidth - ReorderableList.Defaults.padding + 1); - float[] offsets = positions[0]; - float[] widths = positions[1]; - EditorGUI.LabelField(new Rect(offsets[0], rect.y, widths[0], EditorGUIUtility.singleLineHeight), ""); - EditorGUI.LabelField(new Rect(offsets[1], rect.y, widths[1], EditorGUIUtility.singleLineHeight), "Model"); - EditorGUI.LabelField(new Rect(offsets[2], rect.y, widths[2], EditorGUIUtility.singleLineHeight), "URL"); - EditorGUI.LabelField(new Rect(offsets[3], rect.y, widths[3], EditorGUIUtility.singleLineHeight), "Local Path"); - EditorGUI.LabelField(new Rect(offsets[4], rect.y, widths[4], EditorGUIUtility.singleLineHeight), "Build"); } }; } @@ -250,14 +273,12 @@ public override void OnInspectorGUI() OnInspectorGUIStart(llmScriptSO); ShowProgress(LLMUnitySetup.libraryProgress, "Setup Library"); - ShowProgress(llmScript.modelProgress, "Model Downloading"); - ShowProgress(llmScript.modelCopyProgress, "Model Copying"); + ShowProgress(LLMManager.modelProgress, "Model Downloading"); + GUI.enabled = LLMUnitySetup.libraryProgress == 1 && LLMManager.modelProgress == 1; - GUI.enabled = LLMUnitySetup.libraryProgress == 1 && llmScript.modelProgress == 1 && llmScript.modelCopyProgress == 1; AddOptionsToggles(llmScriptSO); AddSetupSettings(llmScriptSO); AddModelLoadersSettings(llmScriptSO, llmScript); - GUI.enabled = true; AddChatSettings(llmScriptSO); OnInspectorGUIEnd(llmScriptSO); diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 1811b942..6af7ea3b 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -68,14 +68,6 @@ public class LLM : MonoBehaviour [LLMAdvanced] public bool asynchronousStartup = true; /// select to not destroy the LLM GameObject when loading a new Scene. [LLMAdvanced] public bool dontDestroyOnLoad = true; - /// toggle to enable model download on build - [Model] public bool downloadOnBuild = false; - /// the path of the model being used (relative to the Assets/StreamingAssets folder). - /// Models with .gguf format are allowed. - [Model] public string model = ""; - /// the URL of the model to use. - /// Models with .gguf format are allowed. - [ModelDownload] public string modelURL = ""; /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder). /// Models with .bin format are allowed. [ModelAdvanced] public string lora = ""; @@ -90,20 +82,13 @@ public class LLM : MonoBehaviour /// a base prompt to use as a base for all LLMCharacter objects [TextArea(5, 10), ChatAdvanced] public string basePrompt = ""; /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. - public bool modelsDownloaded { get; protected set; } = false; - /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. public bool started { get; protected set; } = false; /// Boolean set to true if the server has failed to start. public bool failed { get; protected set; } = false; /// \cond HIDE public LLMManager llmManager = new LLMManager(); - public int SelectedModel = 0; - [HideInInspector] public float modelProgress = 1; - [HideInInspector] public float loraProgress = 1; - [HideInInspector] public float modelCopyProgress = 1; - [HideInInspector] public bool modelHide = true; - + public string model = ""; public string chatTemplate = ChatTemplate.DefaultTemplate; IntPtr LLMObject = IntPtr.Zero; @@ -112,93 +97,9 @@ public class LLM : MonoBehaviour StreamWrapper logStreamWrapper = null; Thread llmThread = null; List streamWrappers = new List(); - List> modelProgressCallbacks = new List>(); - List> loraProgressCallbacks = new List>(); - - public void SetModelProgress(float progress) - { - modelProgress = progress; - foreach (Callback modelProgressCallback in modelProgressCallbacks) modelProgressCallback?.Invoke(progress); - } - - public void SetLoraProgress(float progress) - { - loraProgress = progress; - foreach (Callback loraProgressCallback in loraProgressCallbacks) loraProgressCallback?.Invoke(progress); - } /// \endcond - string CopyAsset(string path) - { -#if UNITY_EDITOR - if (!EditorApplication.isPlaying) - { - modelCopyProgress = 0; - path = LLMUnitySetup.AddAsset(path, LLMUnitySetup.GetAssetPath()); - modelCopyProgress = 1; - } -#endif - return path; - } - - public void ResetSelectedModel() - { - SelectedModel = 0; - modelURL = ""; - model = ""; - } - - public async Task DownloadDefaultModel(int optionIndex) - { - // download default model and disable model editor properties until the model is set - if (optionIndex == 0) - { - ResetSelectedModel(); - return; - } - SelectedModel = optionIndex; - string modelUrl = LLMUnitySetup.modelOptions[optionIndex].Item2; - modelURL = modelUrl; - string modelName = Path.GetFileName(modelUrl).Split("?")[0]; - await DownloadModel(modelUrl, modelName); - } - - public async Task DownloadModel(string modelUrl, string modelName, bool overwrite = false, bool setTemplate = true) - { - modelProgress = 0; - string modelPath = LLMUnitySetup.GetAssetPath(modelName); - await LLMUnitySetup.DownloadFile(modelUrl, modelPath, overwrite, (string path) => SetModel(path, setTemplate), SetModelProgress); - } - - public async Task DownloadLora(string loraUrl, string loraName, bool overwrite = false) - { - loraProgress = 0; - string loraPath = LLMUnitySetup.GetAssetPath(loraName); - await LLMUnitySetup.DownloadFile(loraUrl, loraPath, overwrite, SetLora, SetLoraProgress); - } - - public async Task DownloadModels(bool overwrite = false) - { - if (modelURL != "") await DownloadModel(modelURL, model, overwrite, false); - if (loraURL != "") await DownloadLora(loraURL, lora, overwrite); - } - - public async Task AndroidExtractModels() - { - if (!downloadOnBuild || modelURL == "") await LLMUnitySetup.AndroidExtractFile(model); - if (!downloadOnBuild || loraURL == "") await LLMUnitySetup.AndroidExtractFile(lora); - } - - public async Task WaitUntilModelsDownloaded(Callback modelProgressCallback = null, Callback loraProgressCallback = null) - { - if (modelProgressCallback != null) modelProgressCallbacks.Add(modelProgressCallback); - if (loraProgressCallback != null) loraProgressCallbacks.Add(loraProgressCallback); - while (!modelsDownloaded) await Task.Yield(); - if (modelProgressCallback != null) modelProgressCallbacks.Remove(modelProgressCallback); - if (loraProgressCallback != null) loraProgressCallbacks.Remove(loraProgressCallback); - } - public async Task WaitUntilReady() { while (!started) await Task.Yield(); @@ -210,13 +111,16 @@ public async Task WaitUntilReady() /// Models supported are in .gguf format. /// /// path to model to use (.gguf format) - public void SetModel(string path, bool setTemplate = true) + public void SetModel(string path) { // set the model and enable the model editor properties - model = CopyAsset(path); - if (setTemplate) SetTemplate(ChatTemplate.FromGGUF(LLMUnitySetup.GetAssetPath(model))); #if UNITY_EDITOR + ModelEntry entry = LLMManager.LoadModel(path); + model = entry.localPath; + SetTemplate(entry.chatTemplate); if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); +#else + model = path; #endif } @@ -228,7 +132,7 @@ public void SetModel(string path, bool setTemplate = true) /// path to LORA model to use (.bin format) public void SetLora(string path) { - lora = CopyAsset(path); + lora = path; #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); #endif @@ -297,9 +201,9 @@ protected virtual string GetLlamaccpArguments() public async void Awake() { if (!enabled) return; - if (downloadOnBuild) await DownloadModels(); - modelsDownloaded = true; - if (Application.platform == RuntimePlatform.Android) await AndroidExtractModels(); + // if (downloadOnBuild) await DownloadModels(); + // modelsDownloaded = true; + // if (Application.platform == RuntimePlatform.Android) await AndroidExtractModels(); string arguments = GetLlamaccpArguments(); if (arguments == null) return; if (asynchronousStartup) await Task.Run(() => StartLLMServer(arguments)); diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 2496a8b6..d18cebfe 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -12,21 +12,32 @@ namespace LLMUnity public class ModelEntry { public string name; + public string chatTemplate; public string url; public string localPath; public bool includeInBuild; } [Serializable] - public class ModelEntryList + public class LLMManagerStore { + public bool downloadOnBuild; public List modelEntries; } public class LLMManager { + public static bool downloadOnBuild = false; public static List modelEntries = new List(); + /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. + public static bool modelsDownloaded { get; protected set; } = false; + static List> modelProgressCallbacks = new List>(); + static List> loraProgressCallbacks = new List>(); + + [HideInInspector] public static float modelProgress = 1; + // [HideInInspector] public static float loraProgress = 1; + [InitializeOnLoadMethod] static void InitializeOnLoad() { @@ -38,26 +49,73 @@ public static string ModelPathToName(string path) return Path.GetFileNameWithoutExtension(path.Split("?")[0]); } - public static async Task DownloadModel(ModelEntry entry, string url, string name = null) + public static ModelEntry CreateEntry(string path, string url = null, string name = null) { - string modelName = Path.GetFileName(url).Split("?")[0]; - string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName); - await LLMUnitySetup.DownloadFile(url, modelPath); + ModelEntry entry = new ModelEntry(); entry.name = name == null ? ModelPathToName(url) : name; + entry.chatTemplate = ChatTemplate.FromGGUF(path); entry.url = url; - entry.localPath = modelPath; + entry.localPath = Path.GetFullPath(path).Replace('\\', '/'); + return entry; + } + + public static ModelEntry AddEntry(string path, string url = null, string name = null) + { + ModelEntry entry = CreateEntry(path, url, name); + modelEntries.Add(entry); + return entry; + } + + public static async Task WaitUntilModelsDownloaded(Callback modelProgressCallback = null, Callback loraProgressCallback = null) + { + if (modelProgressCallback != null) modelProgressCallbacks.Add(modelProgressCallback); + if (loraProgressCallback != null) loraProgressCallbacks.Add(loraProgressCallback); + while (!modelsDownloaded) await Task.Yield(); + if (modelProgressCallback != null) modelProgressCallbacks.Remove(modelProgressCallback); + if (loraProgressCallback != null) loraProgressCallbacks.Remove(loraProgressCallback); + } + + public static async Task DownloadModel(string url, string name = null) + { + foreach (ModelEntry modelEntry in modelEntries) + { + if (modelEntry.url == url) return modelEntry; + } + string modelName = Path.GetFileName(url).Split("?")[0]; + string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName); + modelProgress = 0; + await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress); + return AddEntry(modelPath, url, name); + } + + public static ModelEntry LoadModel(string path) + { + string fullPath = Path.GetFullPath(path).Replace('\\', '/'); + foreach (ModelEntry modelEntry in modelEntries) + { + if (modelEntry.localPath == fullPath) return modelEntry; + } + return AddEntry(path); + } + + public static void SetModelProgress(float progress) + { + modelProgress = progress; + foreach (Callback modelProgressCallback in modelProgressCallbacks) modelProgressCallback?.Invoke(progress); } public static void Save() { Directory.CreateDirectory(Path.GetDirectoryName(LLMUnitySetup.modelListPath)); - File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new ModelEntryList { modelEntries = modelEntries })); + File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnBuild = downloadOnBuild })); } public static void Load() { if (!File.Exists(LLMUnitySetup.modelListPath)) return; - modelEntries = JsonUtility.FromJson(File.ReadAllText(LLMUnitySetup.modelListPath)).modelEntries; + LLMManagerStore store = JsonUtility.FromJson(File.ReadAllText(LLMUnitySetup.modelListPath)); + modelEntries = store.modelEntries; + downloadOnBuild = store.downloadOnBuild; } } } diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index cb5d2b12..f74b12fe 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -258,6 +258,13 @@ public static async Task AndroidExtractFile(string assetName, bool overwrite = f } } + public static bool IsSubPath(string childPath, string parentPath) + { + string fullParentPath = Path.GetFullPath(parentPath).Replace('\\', '/'); + string fullChildPath = Path.GetFullPath(childPath).Replace('\\', '/'); + return fullChildPath.StartsWith(fullParentPath, StringComparison.OrdinalIgnoreCase); + } + #if UNITY_EDITOR [HideInInspector] public static float libraryProgress = 1; From fa6a2afa0dc6e9ef80c9f4bdb2c6198fc5f240f9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 24 Jul 2024 15:54:29 +0300 Subject: [PATCH 06/26] add custom url option --- Runtime/LLMUnitySetup.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index f74b12fe..a24cf3a2 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -87,6 +87,7 @@ public class LLMUnitySetup [HideInInspector] public static readonly (string, string)[] modelOptions = new(string, string)[] { ("Download model", null), + ("Custom URL", null), ("Mistral 7B Instruct v0.2 (medium, best overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true"), ("OpenHermes 2.5 7B (medium, best for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf?download=true"), ("Phi 3 (small, great)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf?download=true"), From e65f14a09ef2dcd7cbbf46e940428dfd652fb4c9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 24 Jul 2024 15:55:13 +0300 Subject: [PATCH 07/26] implement loras to model selection --- Editor/LLMEditor.cs | 223 ++++++++++++++++++++++++++++-------------- Runtime/LLM.cs | 2 +- Runtime/LLMManager.cs | 138 +++++++++++++++++++++----- 3 files changed, 263 insertions(+), 100 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 63f95455..7f4196c5 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using UnityEditor; using UnityEditorInternal; using UnityEngine; @@ -21,6 +22,11 @@ public class LLMEditor : PropertyEditor static List modelOptions; static List modelURLs; string[] templateOptions; + string elementFocus = ""; + bool showCustomURL = false; + string customURL = ""; + bool customURLLora = false; + bool customURLFocus = false; protected override Type[] GetPropertyTypes() { @@ -31,7 +37,6 @@ public void AddModelLoadersSettings(SerializedObject llmScriptSO, LLM llmScript) { EditorGUILayout.LabelField("Model Settings", EditorStyles.boldLabel); AddModelLoaders(llmScriptSO, llmScript); - AddModelAddonLoaders(llmScriptSO, llmScript); AddModelSettings(llmScriptSO); } @@ -42,29 +47,13 @@ public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) foreach (float width in widths) listWidth += width + (listWidth == 0 ? 0 : elementPadding); EditorGUILayout.BeginVertical(GUILayout.Width(listWidth)); modelList.DoLayoutList(); - EditorGUILayout.EndVertical(); - } - - public void AddModelAddonLoaders(SerializedObject llmScriptSO, LLM llmScript, bool layout = true) - { - if (llmScriptSO.FindProperty("advancedOptions").boolValue) + bool downloadOnBuild = EditorGUILayout.Toggle("Download on Build", LLMManager.downloadOnBuild); + if (downloadOnBuild != LLMManager.downloadOnBuild) { - EditorGUILayout.BeginHorizontal(); - GUILayout.Label("Lora", GUILayout.Width(EditorGUIUtility.labelWidth)); - - if (GUILayout.Button("Load lora", GUILayout.Width(buttonWidth))) - { - EditorApplication.delayCall += () => - { - string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" }); - if (!string.IsNullOrEmpty(path)) - { - llmScript.SetLora(path); - } - }; - } - EditorGUILayout.EndHorizontal(); + LLMManager.downloadOnBuild = downloadOnBuild; + LLMManager.Save(); } + EditorGUILayout.EndVertical(); } public void AddModelSettings(SerializedObject llmScriptSO) @@ -89,11 +78,10 @@ static void ResetModelOptions() foreach (ModelEntry entry in LLMManager.modelEntries) existingOptions.Add(entry.url); modelOptions = new List(); modelURLs = new List(); - for (int i = 0; i < LLMUnitySetup.modelOptions.Length; i++) + foreach ((string name, string url) in LLMUnitySetup.modelOptions) { - string url = LLMUnitySetup.modelOptions[i].Item2; - if (i > 0 && existingOptions.Contains(url)) continue; - modelOptions.Add(LLMUnitySetup.modelOptions[i].Item1); + if (url != null && existingOptions.Contains(url)) continue; + modelOptions.Add(name); modelURLs.Add(url); } } @@ -104,14 +92,14 @@ float[] GetColumnWidths() return widths; } - List CreateColumnRects(float x, float y) + List CreateColumnRects(Rect rect) { float[] widths = GetColumnWidths(); - float offset = x; + float offset = rect.x; List rects = new List(); foreach (float width in widths) { - rects.Add(new Rect(offset, y, width, EditorGUIUtility.singleLineHeight)); + rects.Add(new Rect(offset, rect.y, width, EditorGUIUtility.singleLineHeight)); offset += width + elementPadding; } return rects; @@ -124,20 +112,112 @@ void UpdateModels(bool resetOptions = false) Repaint(); } + void showCustomURLField(bool lora) + { + customURL = ""; + customURLLora = lora; + showCustomURL = true; + customURLFocus = true; + Repaint(); + } + + async Task createCustomURLField(Rect rect) + { + bool submit; + Event e = Event.current; + if (e.type == EventType.KeyDown && (e.keyCode == KeyCode.Return || e.keyCode == KeyCode.KeypadEnter)) + { + submit = true; + e.Use(); + } + else + { + Rect labelRect = new Rect(rect.x, rect.y, 100, EditorGUIUtility.singleLineHeight); + Rect textRect = new Rect(rect.x + labelRect.width + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect submitRect = new Rect(rect.x + labelRect.width + buttonWidth + elementPadding * 2 , rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + + EditorGUI.LabelField(labelRect, "Enter URL:"); + GUI.SetNextControlName("customURLFocus"); + customURL = EditorGUI.TextField(textRect, customURL); + submit = GUI.Button(submitRect, "Submit"); + + if (customURLFocus) + { + customURLFocus = false; + elementFocus = "customURLFocus"; + } + } + + if (submit) + { + showCustomURL = false; + elementFocus = "dummy"; + Repaint(); + await LLMManager.Download(customURL, customURLLora); + UpdateModels(true); + } + } + + async Task createButtons(Rect rect, LLM llmScript) + { + Rect downloadModelRect = new Rect(rect.x, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect loadModelRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect downloadLoraRect = new Rect(rect.x + (buttonWidth + elementPadding) * 2, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect loadLoraRect = new Rect(rect.x + (buttonWidth + elementPadding) * 3, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); int modelIndex = EditorGUI.Popup(downloadModelRect, 0, modelOptions.ToArray()); + + if (modelIndex == 1) + { + showCustomURLField(false); + } + else if (modelIndex > 1) + { + await LLMManager.DownloadModel(modelURLs[modelIndex], modelOptions[modelIndex]); + UpdateModels(true); + } + + if (GUI.Button(loadModelRect, "Load model")) + { + EditorApplication.delayCall += () => + { + string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); + if (!string.IsNullOrEmpty(path)) + { + LLMManager.LoadModel(path); + UpdateModels(); + } + }; + } + + if (GUI.Button(downloadLoraRect, "Download LoRA")) + { + showCustomURLField(true); + } + if (GUI.Button(loadLoraRect, "Load LoRA")) + { + EditorApplication.delayCall += () => + { + string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" }); + if (!string.IsNullOrEmpty(path)) + { + llmScript.SetLora(path); + } + }; + } + } + void OnEnable() { - var llmScript = (LLM)target; + LLM llmScript = (LLM)target; ResetModelOptions(); templateOptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); - - modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), true, true, true, true) + modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), false, true, false, false) { drawElementCallback = (rect, index, isActive, isFocused) => { if (index >= LLMManager.modelEntries.Count) return; - List rects = CreateColumnRects(rect.x, rect.y); + List rects = CreateColumnRects(rect); var selectRect = rects[0]; var nameRect = rects[1]; var templateRect = rects[2]; @@ -150,23 +230,34 @@ void OnEnable() bool hasPath = entry.localPath != null && entry.localPath != ""; bool hasURL = entry.url != null && entry.url != ""; - bool isSelected = llmScript.model == entry.localPath; - bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); - if (newSelected && !isSelected) + bool isSelected = false; + if (!entry.lora) { - llmScript.model = entry.localPath; - llmScript.SetTemplate(entry.chatTemplate); + isSelected = llmScript.model == entry.localPath; + bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); + if (newSelected && !isSelected) llmScript.SetModel(entry.localPath); } + else + { + isSelected = llmScript.lora == entry.localPath; + bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); + if (newSelected && !isSelected) llmScript.SetLora(entry.localPath); + else if (!newSelected && isSelected) llmScript.SetLora(""); + } + DrawCopyableLabel(nameRect, entry.name); - int templateIndex = Array.IndexOf(ChatTemplate.templatesDescription.Values.ToList().ToArray(), entry.chatTemplate); - int newTemplateIndex = EditorGUI.Popup(templateRect, templateIndex, templateOptions); - if (newTemplateIndex != templateIndex) + if (!entry.lora) { - entry.chatTemplate = ChatTemplate.templatesDescription[templateOptions[newTemplateIndex]]; - if (isSelected) llmScript.SetTemplate(entry.chatTemplate); - UpdateModels(); + int templateIndex = Array.IndexOf(ChatTemplate.templatesDescription.Values.ToList().ToArray(), entry.chatTemplate); + int newTemplateIndex = EditorGUI.Popup(templateRect, templateIndex, templateOptions); + if (newTemplateIndex != templateIndex) + { + entry.chatTemplate = ChatTemplate.templatesDescription[templateOptions[newTemplateIndex]]; + if (isSelected) llmScript.SetTemplate(entry.chatTemplate); + UpdateModels(); + } } if (hasURL) @@ -193,13 +284,13 @@ void OnEnable() if (GUI.Button(actionRect, trashIcon)) { - LLMManager.modelEntries.Remove(entry); + LLMManager.Remove(entry); UpdateModels(true); } }, drawHeaderCallback = (rect) => { - List rects = CreateColumnRects(rect.x + ReorderableList.Defaults.dragHandleWidth - ReorderableList.Defaults.padding + 1, rect.y); + List rects = CreateColumnRects(rect); EditorGUI.LabelField(rects[0], ""); EditorGUI.LabelField(rects[1], "Model"); EditorGUI.LabelField(rects[2], "Chat template"); @@ -210,35 +301,8 @@ void OnEnable() }, drawFooterCallback = async(rect) => { - Rect downloadRect = new Rect(rect.x, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect loadRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - - int newIndex = EditorGUI.Popup(downloadRect, 0, modelOptions.ToArray()); - if (newIndex != 0) - { - await LLMManager.DownloadModel(modelURLs[newIndex], modelOptions[newIndex]); - UpdateModels(true); - } - - if (GUI.Button(loadRect, "Load model")) - { - EditorApplication.delayCall += () => - { - string path = EditorUtility.OpenFilePanelWithFilters("Select a gguf model file", "", new string[] { "Model Files", "gguf" }); - if (!string.IsNullOrEmpty(path)) - { - LLMManager.LoadModel(path); - UpdateModels(); - } - }; - } - - bool downloadOnBuild = EditorGUILayout.Toggle("Download on Build", LLMManager.downloadOnBuild); - if (downloadOnBuild != LLMManager.downloadOnBuild) - { - LLMManager.downloadOnBuild = downloadOnBuild; - UpdateModels(); - } + if (showCustomURL) await createCustomURLField(rect); + else await createButtons(rect, llmScript); } }; } @@ -267,6 +331,12 @@ private void CopyToClipboard(string text) public override void OnInspectorGUI() { + if (elementFocus != "") + { + EditorGUI.FocusTextInControl(elementFocus); + elementFocus = ""; + } + LLM llmScript = (LLM)target; SerializedObject llmScriptSO = new SerializedObject(llmScript); @@ -274,7 +344,8 @@ public override void OnInspectorGUI() ShowProgress(LLMUnitySetup.libraryProgress, "Setup Library"); ShowProgress(LLMManager.modelProgress, "Model Downloading"); - GUI.enabled = LLMUnitySetup.libraryProgress == 1 && LLMManager.modelProgress == 1; + ShowProgress(LLMManager.loraProgress, "LoRA Downloading"); + GUI.enabled = LLMUnitySetup.libraryProgress == 1 && LLMManager.modelProgress == 1 && LLMManager.loraProgress == 1; AddOptionsToggles(llmScriptSO); AddSetupSettings(llmScriptSO); diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 6af7ea3b..a68f338d 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -115,7 +115,7 @@ public void SetModel(string path) { // set the model and enable the model editor properties #if UNITY_EDITOR - ModelEntry entry = LLMManager.LoadModel(path); + ModelEntry entry = LLMManager.Get(LLMManager.LoadModel(path)); model = entry.localPath; SetTemplate(entry.chatTemplate); if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index d18cebfe..1e928ded 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading.Tasks; using UnityEditor; using UnityEngine; @@ -12,6 +13,7 @@ namespace LLMUnity public class ModelEntry { public string name; + public bool lora; public string chatTemplate; public string url; public string localPath; @@ -36,7 +38,7 @@ public class LLMManager static List> loraProgressCallbacks = new List>(); [HideInInspector] public static float modelProgress = 1; - // [HideInInspector] public static float loraProgress = 1; + [HideInInspector] public static float loraProgress = 1; [InitializeOnLoadMethod] static void InitializeOnLoad() @@ -49,21 +51,29 @@ public static string ModelPathToName(string path) return Path.GetFileNameWithoutExtension(path.Split("?")[0]); } - public static ModelEntry CreateEntry(string path, string url = null, string name = null) + public static string AddEntry(string path, bool lora = false, string name = null, string url = null) { + string key = name == null ? ModelPathToName(url) : name; ModelEntry entry = new ModelEntry(); - entry.name = name == null ? ModelPathToName(url) : name; - entry.chatTemplate = ChatTemplate.FromGGUF(path); + entry.name = key; + entry.lora = lora; + entry.chatTemplate = lora ? null : ChatTemplate.FromGGUF(path); entry.url = url; entry.localPath = Path.GetFullPath(path).Replace('\\', '/'); - return entry; - } - - public static ModelEntry AddEntry(string path, string url = null, string name = null) - { - ModelEntry entry = CreateEntry(path, url, name); - modelEntries.Add(entry); - return entry; + int indexToInsert = modelEntries.Count; + if (!lora) + { + for (int i = modelEntries.Count - 1; i >= 0; i--) + { + if (!modelEntries[i].lora) + { + indexToInsert = i + 1; + break; + } + } + } + modelEntries.Insert(indexToInsert, entry); + return key; } public static async Task WaitUntilModelsDownloaded(Callback modelProgressCallback = null, Callback loraProgressCallback = null) @@ -75,27 +85,103 @@ public static async Task WaitUntilModelsDownloaded(Callback modelProgress if (loraProgressCallback != null) loraProgressCallbacks.Remove(loraProgressCallback); } - public static async Task DownloadModel(string url, string name = null) + public static async Task Download(string url, bool lora = false, string name = null) { - foreach (ModelEntry modelEntry in modelEntries) + foreach (ModelEntry entry in modelEntries) { - if (modelEntry.url == url) return modelEntry; + if (entry.url == url) return entry.name; } string modelName = Path.GetFileName(url).Split("?")[0]; string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName); - modelProgress = 0; - await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress); - return AddEntry(modelPath, url, name); + if (!lora) + { + modelProgress = 0; + try + { + await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress); + } + catch (Exception ex) + { + modelProgress = 1; + throw ex; + } + } + else + { + loraProgress = 0; + try + { + await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress); + } + catch (Exception ex) + { + loraProgress = 1; + throw ex; + } + } + return AddEntry(modelPath, lora, name, url); } - public static ModelEntry LoadModel(string path) + public static string Load(string path, bool lora = false, string name = null) { string fullPath = Path.GetFullPath(path).Replace('\\', '/'); - foreach (ModelEntry modelEntry in modelEntries) + foreach (ModelEntry entry in modelEntries) + { + if (entry.localPath == fullPath) return entry.name; + } + return AddEntry(path, lora, name); + } + + public static async Task DownloadModel(string url, string name = null) + { + return await Download(url, false, name); + } + + public static async Task DownloadLora(string url, string name = null) + { + return await Download(url, true, name); + } + + public static string LoadModel(string url, string name = null) + { + return Load(url, false, name); + } + + public static string LoadLora(string url, string name = null) + { + return Load(url, true, name); + } + + public static void SetModelTemplate(string name, string chatTemplate) + { + foreach (ModelEntry entry in modelEntries) { - if (modelEntry.localPath == fullPath) return modelEntry; + if (entry.name == name) + { + entry.chatTemplate = chatTemplate; + break; + } } - return AddEntry(path); + } + + public static ModelEntry Get(string name) + { + foreach (ModelEntry entry in modelEntries) + { + if (entry.name == name) return entry; + } + return null; + } + + public static void Remove(string name) + { + Remove(Get(name)); + } + + public static void Remove(ModelEntry entry) + { + if (entry == null) return; + modelEntries.Remove(entry); } public static void SetModelProgress(float progress) @@ -104,6 +190,12 @@ public static void SetModelProgress(float progress) foreach (Callback modelProgressCallback in modelProgressCallbacks) modelProgressCallback?.Invoke(progress); } + public static void SetLoraProgress(float progress) + { + loraProgress = progress; + foreach (Callback loraProgressCallback in loraProgressCallbacks) loraProgressCallback?.Invoke(progress); + } + public static void Save() { Directory.CreateDirectory(Path.GetDirectoryName(LLMUnitySetup.modelListPath)); @@ -114,8 +206,8 @@ public static void Load() { if (!File.Exists(LLMUnitySetup.modelListPath)) return; LLMManagerStore store = JsonUtility.FromJson(File.ReadAllText(LLMUnitySetup.modelListPath)); - modelEntries = store.modelEntries; downloadOnBuild = store.downloadOnBuild; + modelEntries = store.modelEntries; } } } From 5f9fe9816508be706cf2fca4a1eb5ae3f09369c9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 24 Jul 2024 16:56:31 +0300 Subject: [PATCH 08/26] json and button beautification --- Editor/LLMEditor.cs | 5 +++-- Runtime/LLMManager.cs | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 7f4196c5..f1febad4 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -162,9 +162,10 @@ async Task createButtons(Rect rect, LLM llmScript) { Rect downloadModelRect = new Rect(rect.x, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); Rect loadModelRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect downloadLoraRect = new Rect(rect.x + (buttonWidth + elementPadding) * 2, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect loadLoraRect = new Rect(rect.x + (buttonWidth + elementPadding) * 3, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); int modelIndex = EditorGUI.Popup(downloadModelRect, 0, modelOptions.ToArray()); + Rect downloadLoraRect = new Rect(rect.width - 2 * buttonWidth - elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect loadLoraRect = new Rect(rect.width - buttonWidth, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + int modelIndex = EditorGUI.Popup(downloadModelRect, 0, modelOptions.ToArray()); if (modelIndex == 1) { showCustomURLField(false); diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 1e928ded..073f03f1 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -199,7 +199,7 @@ public static void SetLoraProgress(float progress) public static void Save() { Directory.CreateDirectory(Path.GetDirectoryName(LLMUnitySetup.modelListPath)); - File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnBuild = downloadOnBuild })); + File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnBuild = downloadOnBuild }, true)); } public static void Load() From 8f0162828a0a854be514f8ee355f6be52888da8c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 24 Jul 2024 20:26:08 +0300 Subject: [PATCH 09/26] lora as argument --- Editor/LLMEditor.cs | 22 ++++++++++++++++------ Runtime/LLM.cs | 7 +------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index f1febad4..9873968e 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -95,12 +95,13 @@ float[] GetColumnWidths() List CreateColumnRects(Rect rect) { float[] widths = GetColumnWidths(); - float offset = rect.x; + float offsetX = rect.x; + float offsetY = rect.y + (rect.height - EditorGUIUtility.singleLineHeight) / 2; List rects = new List(); foreach (float width in widths) { - rects.Add(new Rect(offset, rect.y, width, EditorGUIUtility.singleLineHeight)); - offset += width + elementPadding; + rects.Add(new Rect(offsetX, offsetY, width, EditorGUIUtility.singleLineHeight)); + offsetX += width + elementPadding; } return rects; } @@ -162,8 +163,8 @@ async Task createButtons(Rect rect, LLM llmScript) { Rect downloadModelRect = new Rect(rect.x, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); Rect loadModelRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect downloadLoraRect = new Rect(rect.width - 2 * buttonWidth - elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect loadLoraRect = new Rect(rect.width - buttonWidth, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect downloadLoraRect = new Rect(rect.xMax - 2 * buttonWidth - elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect loadLoraRect = new Rect(rect.xMax - buttonWidth, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); int modelIndex = EditorGUI.Popup(downloadModelRect, 0, modelOptions.ToArray()); if (modelIndex == 1) @@ -212,11 +213,16 @@ void OnEnable() ResetModelOptions(); templateOptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); + Texture2D loraLineTexture = new Texture2D(1, 1); + loraLineTexture.SetPixel(0, 0, Color.black); + loraLineTexture.Apply(); + modelList = new ReorderableList(LLMManager.modelEntries, typeof(ModelEntry), false, true, false, false) { drawElementCallback = (rect, index, isActive, isFocused) => { if (index >= LLMManager.modelEntries.Count) return; + var entry = LLMManager.modelEntries[index]; List rects = CreateColumnRects(rect); var selectRect = rects[0]; @@ -226,7 +232,6 @@ void OnEnable() var pathRect = rects[4]; var includeInBuildRect = rects[5]; var actionRect = rects[6]; - var entry = LLMManager.modelEntries[index]; bool hasPath = entry.localPath != null && entry.localPath != ""; bool hasURL = entry.url != null && entry.url != ""; @@ -288,6 +293,11 @@ void OnEnable() LLMManager.Remove(entry); UpdateModels(true); } + + if (!entry.lora && index < LLMManager.modelEntries.Count - 1 && LLMManager.modelEntries[index + 1].lora) + { + GUI.DrawTexture(new Rect(rect.x - ReorderableList.Defaults.padding, rect.yMax, rect.width + ReorderableList.Defaults.padding * 2, 1), loraLineTexture); + } }, drawHeaderCallback = (rect) => { diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index a68f338d..87d48fa2 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -68,12 +68,6 @@ public class LLM : MonoBehaviour [LLMAdvanced] public bool asynchronousStartup = true; /// select to not destroy the LLM GameObject when loading a new Scene. [LLMAdvanced] public bool dontDestroyOnLoad = true; - /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder). - /// Models with .bin format are allowed. - [ModelAdvanced] public string lora = ""; - /// the URL of the LORA to use. - /// Models with .bin format are allowed. - [ModelDownloadAdvanced] public string loraURL = ""; /// Size of the prompt context (0 = context size of the model). /// This is the number of tokens the model can take as input when generating responses. [ModelAdvanced] public int contextSize = 0; @@ -88,6 +82,7 @@ public class LLM : MonoBehaviour /// \cond HIDE public LLMManager llmManager = new LLMManager(); + public string lora = ""; public string model = ""; public string chatTemplate = ChatTemplate.DefaultTemplate; From b333ef5040f045dfb7fa676ed34931103a48d41c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 24 Jul 2024 20:26:49 +0300 Subject: [PATCH 10/26] UI improvements --- Editor/LLMEditor.cs | 184 ++++++++++++++++++++++++++---------------- Runtime/LLMManager.cs | 42 +++++----- 2 files changed, 132 insertions(+), 94 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 9873968e..3a160d70 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -13,9 +13,9 @@ public class LLMEditor : PropertyEditor { private ReorderableList modelList; static float nameColumnWidth = 150f; - static float templateColumnWidth = 100f; + static float templateColumnWidth = 150f; static float textColumnWidth = 150f; - static float includeInBuildColumnWidth = 50f; + static float includeInBuildColumnWidth = 30f; static float actionColumnWidth = 20f; static int elementPadding = 10; static GUIContent trashIcon; @@ -42,18 +42,25 @@ public void AddModelLoadersSettings(SerializedObject llmScriptSO, LLM llmScript) public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) { - float[] widths = GetColumnWidths(); - float listWidth = ReorderableList.Defaults.dragHandleWidth; - foreach (float width in widths) listWidth += width + (listWidth == 0 ? 0 : elementPadding); - EditorGUILayout.BeginVertical(GUILayout.Width(listWidth)); - modelList.DoLayoutList(); - bool downloadOnBuild = EditorGUILayout.Toggle("Download on Build", LLMManager.downloadOnBuild); - if (downloadOnBuild != LLMManager.downloadOnBuild) + if (LLMManager.modelEntries.Count == 0) { - LLMManager.downloadOnBuild = downloadOnBuild; + DrawFooter(EditorGUILayout.GetControlRect()); + } + else + { + float[] widths = GetColumnWidths(llmScript.advancedOptions); + float listWidth = 2 * ReorderableList.Defaults.padding * 2; + foreach (float width in widths) listWidth += width + (listWidth == 0 ? 0 : elementPadding); + EditorGUILayout.BeginVertical(GUILayout.Width(listWidth)); + modelList.DoLayoutList(); + EditorGUILayout.EndVertical(); + } + bool downloadOnStart = EditorGUILayout.Toggle("Download on Start", LLMManager.downloadOnStart); + if (downloadOnStart != LLMManager.downloadOnStart) + { + LLMManager.downloadOnStart = downloadOnStart; LLMManager.Save(); } - EditorGUILayout.EndVertical(); } public void AddModelSettings(SerializedObject llmScriptSO) @@ -86,15 +93,17 @@ static void ResetModelOptions() } } - float[] GetColumnWidths() + float[] GetColumnWidths(bool expandedView) { - float[] widths = new float[] {actionColumnWidth, nameColumnWidth, templateColumnWidth, textColumnWidth, textColumnWidth, includeInBuildColumnWidth, actionColumnWidth}; - return widths; + List widths = new List(){actionColumnWidth, nameColumnWidth, templateColumnWidth}; + if (expandedView) widths.AddRange(new List(){textColumnWidth, textColumnWidth}); + widths.AddRange(new List(){includeInBuildColumnWidth, actionColumnWidth}); + return widths.ToArray(); } - List CreateColumnRects(Rect rect) + List CreateColumnRects(Rect rect, bool expandedView) { - float[] widths = GetColumnWidths(); + float[] widths = GetColumnWidths(expandedView); float offsetX = rect.x; float offsetY = rect.y + (rect.height - EditorGUIUtility.singleLineHeight) / 2; List rects = new List(); @@ -122,25 +131,39 @@ void showCustomURLField(bool lora) Repaint(); } + void SetModelIfNone() + { + LLM llmScript = (LLM)target; + if (llmScript.model == "" && LLMManager.modelEntries.Count == 1) llmScript.SetModel(LLMManager.modelEntries[0].localPath); + } + async Task createCustomURLField(Rect rect) { - bool submit; + bool submit = false; + bool exit = false; Event e = Event.current; if (e.type == EventType.KeyDown && (e.keyCode == KeyCode.Return || e.keyCode == KeyCode.KeypadEnter)) { submit = true; e.Use(); } + else if (e.type == EventType.KeyDown && (e.keyCode == KeyCode.Escape)) + { + exit = true; + e.Use(); + } else { Rect labelRect = new Rect(rect.x, rect.y, 100, EditorGUIUtility.singleLineHeight); Rect textRect = new Rect(rect.x + labelRect.width + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect submitRect = new Rect(rect.x + labelRect.width + buttonWidth + elementPadding * 2 , rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); + Rect submitRect = new Rect(rect.x + labelRect.width + buttonWidth + elementPadding * 2, rect.y, buttonWidth / 2f, EditorGUIUtility.singleLineHeight); + Rect backRect = new Rect(rect.x + labelRect.width + buttonWidth * 1.5f + elementPadding * 3, rect.y, buttonWidth / 2f, EditorGUIUtility.singleLineHeight); EditorGUI.LabelField(labelRect, "Enter URL:"); GUI.SetNextControlName("customURLFocus"); customURL = EditorGUI.TextField(textRect, customURL); submit = GUI.Button(submitRect, "Submit"); + exit = GUI.Button(backRect, "Back"); if (customURLFocus) { @@ -149,13 +172,17 @@ async Task createCustomURLField(Rect rect) } } - if (submit) + if (exit || submit) { showCustomURL = false; elementFocus = "dummy"; Repaint(); - await LLMManager.Download(customURL, customURLLora); - UpdateModels(true); + if (submit && customURL != "") + { + await LLMManager.Download(customURL, customURLLora); + SetModelIfNone(); + UpdateModels(true); + } } } @@ -165,7 +192,6 @@ async Task createButtons(Rect rect, LLM llmScript) Rect loadModelRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); Rect downloadLoraRect = new Rect(rect.xMax - 2 * buttonWidth - elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); Rect loadLoraRect = new Rect(rect.xMax - buttonWidth, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - int modelIndex = EditorGUI.Popup(downloadModelRect, 0, modelOptions.ToArray()); if (modelIndex == 1) { @@ -174,6 +200,7 @@ async Task createButtons(Rect rect, LLM llmScript) else if (modelIndex > 1) { await LLMManager.DownloadModel(modelURLs[modelIndex], modelOptions[modelIndex]); + SetModelIfNone(); UpdateModels(true); } @@ -190,23 +217,33 @@ async Task createButtons(Rect rect, LLM llmScript) }; } - if (GUI.Button(downloadLoraRect, "Download LoRA")) - { - showCustomURLField(true); - } - if (GUI.Button(loadLoraRect, "Load LoRA")) + if (llmScript.advancedOptions) { - EditorApplication.delayCall += () => + if (GUI.Button(downloadLoraRect, "Download LoRA")) { - string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" }); - if (!string.IsNullOrEmpty(path)) + showCustomURLField(true); + } + if (GUI.Button(loadLoraRect, "Load LoRA")) + { + EditorApplication.delayCall += () => { - llmScript.SetLora(path); - } - }; + string path = EditorUtility.OpenFilePanelWithFilters("Select a bin lora file", "", new string[] { "Model Files", "bin" }); + if (!string.IsNullOrEmpty(path)) + { + llmScript.SetLora(path); + } + }; + } } } + async void DrawFooter(Rect rect) + { + LLM llmScript = (LLM)target; + if (showCustomURL) await createCustomURLField(rect); + else await createButtons(rect, llmScript); + } + void OnEnable() { LLM llmScript = (LLM)target; @@ -222,16 +259,22 @@ void OnEnable() drawElementCallback = (rect, index, isActive, isFocused) => { if (index >= LLMManager.modelEntries.Count) return; - var entry = LLMManager.modelEntries[index]; - - List rects = CreateColumnRects(rect); - var selectRect = rects[0]; - var nameRect = rects[1]; - var templateRect = rects[2]; - var urlRect = rects[3]; - var pathRect = rects[4]; - var includeInBuildRect = rects[5]; - var actionRect = rects[6]; + ModelEntry entry = LLMManager.modelEntries[index]; + + List rects = CreateColumnRects(rect, llmScript.advancedOptions); + int col = 0; + Rect selectRect = rects[col++]; + Rect nameRect = rects[col++]; + Rect templateRect = rects[col++]; + Rect urlRect = new Rect(); + Rect pathRect = new Rect(); + if (llmScript.advancedOptions) + { + urlRect = rects[col++]; + pathRect = rects[col++]; + } + Rect includeInBuildRect = rects[col++]; + Rect actionRect = rects[col++]; bool hasPath = entry.localPath != null && entry.localPath != ""; bool hasURL = entry.url != null && entry.url != ""; @@ -251,7 +294,6 @@ void OnEnable() else if (!newSelected && isSelected) llmScript.SetLora(""); } - DrawCopyableLabel(nameRect, entry.name); if (!entry.lora) @@ -266,20 +308,23 @@ void OnEnable() } } - if (hasURL) + if (llmScript.advancedOptions) { - DrawCopyableLabel(urlRect, entry.url); - } - else - { - string newURL = EditorGUI.TextField(urlRect, entry.url); - if (newURL != entry.url) + if (hasURL) { - entry.url = newURL; - UpdateModels(); + DrawCopyableLabel(urlRect, entry.url); } + else + { + string newURL = EditorGUI.TextField(urlRect, entry.url); + if (newURL != entry.url) + { + entry.url = newURL; + UpdateModels(); + } + } + DrawCopyableLabel(pathRect, entry.localPath); } - DrawCopyableLabel(pathRect, entry.localPath); bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); if (includeInBuild != entry.includeInBuild) @@ -301,20 +346,20 @@ void OnEnable() }, drawHeaderCallback = (rect) => { - List rects = CreateColumnRects(rect); - EditorGUI.LabelField(rects[0], ""); - EditorGUI.LabelField(rects[1], "Model"); - EditorGUI.LabelField(rects[2], "Chat template"); - EditorGUI.LabelField(rects[3], "URL"); - EditorGUI.LabelField(rects[4], "Path"); - EditorGUI.LabelField(rects[5], "Build"); - EditorGUI.LabelField(rects[6], ""); + List rects = CreateColumnRects(rect, llmScript.advancedOptions); + int col = 0; + EditorGUI.LabelField(rects[col++], ""); + EditorGUI.LabelField(rects[col++], "Model"); + EditorGUI.LabelField(rects[col++], "Chat template"); + if (llmScript.advancedOptions) + { + EditorGUI.LabelField(rects[col++], "URL"); + EditorGUI.LabelField(rects[col++], "Path"); + } + EditorGUI.LabelField(rects[col++], "Build"); + EditorGUI.LabelField(rects[col++], ""); }, - drawFooterCallback = async(rect) => - { - if (showCustomURL) await createCustomURLField(rect); - else await createButtons(rect, llmScript); - } + drawFooterCallback = DrawFooter, }; } @@ -332,10 +377,7 @@ private void DrawCopyableLabel(Rect rect, string text) private void CopyToClipboard(string text) { - TextEditor te = new TextEditor - { - text = text - }; + TextEditor te = new TextEditor {text = text}; te.SelectAll(); te.Copy(); } diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 073f03f1..d1469ab2 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading.Tasks; using UnityEditor; using UnityEngine; @@ -23,13 +22,13 @@ public class ModelEntry [Serializable] public class LLMManagerStore { - public bool downloadOnBuild; + public bool downloadOnStart; public List modelEntries; } public class LLMManager { - public static bool downloadOnBuild = false; + public static bool downloadOnStart = false; public static List modelEntries = new List(); /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. @@ -53,13 +52,14 @@ public static string ModelPathToName(string path) public static string AddEntry(string path, bool lora = false, string name = null, string url = null) { - string key = name == null ? ModelPathToName(url) : name; + string key = name == null ? ModelPathToName(path) : name; ModelEntry entry = new ModelEntry(); entry.name = key; entry.lora = lora; entry.chatTemplate = lora ? null : ChatTemplate.FromGGUF(path); entry.url = url; entry.localPath = Path.GetFullPath(path).Replace('\\', '/'); + entry.includeInBuild = true; int indexToInsert = modelEntries.Count; if (!lora) { @@ -93,31 +93,27 @@ public static async Task Download(string url, bool lora = false, string } string modelName = Path.GetFileName(url).Split("?")[0]; string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName); - if (!lora) + float preModelProgress = modelProgress; + float preLoraProgress = loraProgress; + try { - modelProgress = 0; - try + if (!lora) { + modelProgress = 0; await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress); } - catch (Exception ex) + else { - modelProgress = 1; - throw ex; + loraProgress = 0; + await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress); } } - else + catch (Exception ex) { - loraProgress = 0; - try - { - await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress); - } - catch (Exception ex) - { - loraProgress = 1; - throw ex; - } + modelProgress = preModelProgress; + loraProgress = preLoraProgress; + LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message); + return null; } return AddEntry(modelPath, lora, name, url); } @@ -199,14 +195,14 @@ public static void SetLoraProgress(float progress) public static void Save() { Directory.CreateDirectory(Path.GetDirectoryName(LLMUnitySetup.modelListPath)); - File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnBuild = downloadOnBuild }, true)); + File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart }, true)); } public static void Load() { if (!File.Exists(LLMUnitySetup.modelListPath)) return; LLMManagerStore store = JsonUtility.FromJson(File.ReadAllText(LLMUnitySetup.modelListPath)); - downloadOnBuild = store.downloadOnBuild; + downloadOnStart = store.downloadOnStart; modelEntries = store.modelEntries; } } From 5661ac7919d20d21c51d0bdedbaeb719994a0e5f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 25 Jul 2024 14:39:49 +0300 Subject: [PATCH 11/26] add label field, register LLM to LLMManager and remove model if not there anymore --- Runtime/LLM.cs | 30 ++++++++++---- Runtime/LLMManager.cs | 96 ++++++++++++++++++++++++++++--------------- 2 files changed, 86 insertions(+), 40 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 87d48fa2..e2ce0b0f 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -80,11 +80,17 @@ public class LLM : MonoBehaviour /// Boolean set to true if the server has failed to start. public bool failed { get; protected set; } = false; - /// \cond HIDE - public LLMManager llmManager = new LLMManager(); - public string lora = ""; + /// the LLM model to use. + /// Models with .gguf format are allowed. public string model = ""; + /// Chat template used for the model public string chatTemplate = ChatTemplate.DefaultTemplate; + /// the path of the LORA model being used (relative to the Assets/StreamingAssets folder). + /// Models with .bin format are allowed. + public string lora = ""; + + /// \cond HIDE + public LLMManager llmManager = new LLMManager(); IntPtr LLMObject = IntPtr.Zero; List clients = new List(); @@ -95,6 +101,14 @@ public class LLM : MonoBehaviour /// \endcond +#if UNITY_EDITOR + public LLM() + { + LLMManager.Register(this); + } + +#endif + public async Task WaitUntilReady() { while (!started) await Task.Yield(); @@ -109,13 +123,10 @@ public async Task WaitUntilReady() public void SetModel(string path) { // set the model and enable the model editor properties + model = path; #if UNITY_EDITOR - ModelEntry entry = LLMManager.Get(LLMManager.LoadModel(path)); - model = entry.localPath; - SetTemplate(entry.chatTemplate); + SetTemplate(LLMManager.Get(path).chatTemplate); if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); -#else - model = path; #endif } @@ -486,6 +497,9 @@ public void Destroy() public void OnDestroy() { Destroy(); +#if UNITY_EDITOR + LLMManager.Unregister(this); +#endif } } } diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index d1469ab2..04e08d8b 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -11,11 +11,12 @@ namespace LLMUnity [Serializable] public class ModelEntry { - public string name; + public string label; + public string filename; + public string path; public bool lora; public string chatTemplate; public string url; - public string localPath; public bool includeInBuild; } @@ -38,6 +39,7 @@ public class LLMManager [HideInInspector] public static float modelProgress = 1; [HideInInspector] public static float loraProgress = 1; + static List llms = new List(); [InitializeOnLoadMethod] static void InitializeOnLoad() @@ -45,20 +47,15 @@ static void InitializeOnLoad() Load(); } - public static string ModelPathToName(string path) + public static string AddEntry(string path, bool lora = false, string label = null, string url = null) { - return Path.GetFileNameWithoutExtension(path.Split("?")[0]); - } - - public static string AddEntry(string path, bool lora = false, string name = null, string url = null) - { - string key = name == null ? ModelPathToName(path) : name; ModelEntry entry = new ModelEntry(); - entry.name = key; + entry.filename = Path.GetFileName(path.Split("?")[0]); + entry.label = label == null ? entry.filename : label; entry.lora = lora; entry.chatTemplate = lora ? null : ChatTemplate.FromGGUF(path); entry.url = url; - entry.localPath = Path.GetFullPath(path).Replace('\\', '/'); + entry.path = Path.GetFullPath(path).Replace('\\', '/'); entry.includeInBuild = true; int indexToInsert = modelEntries.Count; if (!lora) @@ -73,7 +70,7 @@ public static string AddEntry(string path, bool lora = false, string name = null } } modelEntries.Insert(indexToInsert, entry); - return key; + return entry.filename; } public static async Task WaitUntilModelsDownloaded(Callback modelProgressCallback = null, Callback loraProgressCallback = null) @@ -85,11 +82,11 @@ public static async Task WaitUntilModelsDownloaded(Callback modelProgress if (loraProgressCallback != null) loraProgressCallbacks.Remove(loraProgressCallback); } - public static async Task Download(string url, bool lora = false, string name = null) + public static async Task Download(string url, bool lora = false, string label = null) { foreach (ModelEntry entry in modelEntries) { - if (entry.url == url) return entry.name; + if (entry.url == url) return entry.filename; } string modelName = Path.GetFileName(url).Split("?")[0]; string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName); @@ -115,44 +112,44 @@ public static async Task Download(string url, bool lora = false, string LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message); return null; } - return AddEntry(modelPath, lora, name, url); + return AddEntry(modelPath, lora, label, url); } - public static string Load(string path, bool lora = false, string name = null) + public static string Load(string path, bool lora = false, string label = null) { string fullPath = Path.GetFullPath(path).Replace('\\', '/'); foreach (ModelEntry entry in modelEntries) { - if (entry.localPath == fullPath) return entry.name; + if (entry.path == fullPath) return entry.filename; } - return AddEntry(path, lora, name); + return AddEntry(fullPath, lora, label); } - public static async Task DownloadModel(string url, string name = null) + public static async Task DownloadModel(string url, string label = null) { - return await Download(url, false, name); + return await Download(url, false, label); } - public static async Task DownloadLora(string url, string name = null) + public static async Task DownloadLora(string url, string label = null) { - return await Download(url, true, name); + return await Download(url, true, label); } - public static string LoadModel(string url, string name = null) + public static string LoadModel(string url, string label = null) { - return Load(url, false, name); + return Load(url, false, label); } - public static string LoadLora(string url, string name = null) + public static string LoadLora(string url, string label = null) { - return Load(url, true, name); + return Load(url, true, label); } - public static void SetModelTemplate(string name, string chatTemplate) + public static void SetModelTemplate(string filename, string chatTemplate) { foreach (ModelEntry entry in modelEntries) { - if (entry.name == name) + if (entry.filename == filename) { entry.chatTemplate = chatTemplate; break; @@ -160,24 +157,59 @@ public static void SetModelTemplate(string name, string chatTemplate) } } - public static ModelEntry Get(string name) + public static ModelEntry Get(string filename) { foreach (ModelEntry entry in modelEntries) { - if (entry.name == name) return entry; + if (entry.filename == filename) return entry; } return null; } - public static void Remove(string name) + public static void Remove(string filename) { - Remove(Get(name)); + Remove(Get(filename)); } public static void Remove(ModelEntry entry) { if (entry == null) return; modelEntries.Remove(entry); + foreach (LLM llm in llms) + { + if (!entry.lora && llm.model == entry.filename) llm.model = ""; + else if (entry.lora && llm.lora == entry.filename) llm.lora = ""; + } + } + + public static int Num(bool lora) + { + int num = 0; + foreach (ModelEntry entry in modelEntries) + { + if (entry.lora == lora) num++; + } + return num; + } + + public static int NumModels() + { + return Num(false); + } + + public static int NumLoras() + { + return Num(true); + } + + public static void Register(LLM llm) + { + llms.Add(llm); + } + + public static void Unregister(LLM llm) + { + llms.Remove(llm); } public static void SetModelProgress(float progress) From 8d4b2eb36543f1b6ac7a8d85fe88b56b7e14b2fb Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 25 Jul 2024 14:40:56 +0300 Subject: [PATCH 12/26] expand button, button improvements, set model/lora if not in LLM --- Editor/LLMEditor.cs | 112 ++++++++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 50 deletions(-) diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 3a160d70..c78cef3f 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -27,6 +27,7 @@ public class LLMEditor : PropertyEditor string customURL = ""; bool customURLLora = false; bool customURLFocus = false; + bool expandedView = false; protected override Type[] GetPropertyTypes() { @@ -42,19 +43,27 @@ public void AddModelLoadersSettings(SerializedObject llmScriptSO, LLM llmScript) public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) { - if (LLMManager.modelEntries.Count == 0) + if (LLMManager.modelEntries.Count > 0) { - DrawFooter(EditorGUILayout.GetControlRect()); - } - else - { - float[] widths = GetColumnWidths(llmScript.advancedOptions); - float listWidth = 2 * ReorderableList.Defaults.padding * 2; + float[] widths = GetColumnWidths(expandedView); + float listWidth = 2 * ReorderableList.Defaults.padding; foreach (float width in widths) listWidth += width + (listWidth == 0 ? 0 : elementPadding); + EditorGUILayout.BeginHorizontal(GUILayout.Width(listWidth + actionColumnWidth)); + EditorGUILayout.BeginVertical(GUILayout.Width(listWidth)); modelList.DoLayoutList(); EditorGUILayout.EndVertical(); + + Rect expandedRect = GUILayoutUtility.GetRect(actionColumnWidth, modelList.elementHeight + ReorderableList.Defaults.padding); + expandedRect.y += modelList.GetHeight() - modelList.elementHeight - ReorderableList.Defaults.padding; + if (GUI.Button(expandedRect, expandedView ? "«" : "»")) + { + expandedView = !expandedView; + Repaint(); + } + EditorGUILayout.EndHorizontal(); } + AddLoadButtons(); bool downloadOnStart = EditorGUILayout.Toggle("Download on Start", LLMManager.downloadOnStart); if (downloadOnStart != LLMManager.downloadOnStart) { @@ -101,7 +110,7 @@ float[] GetColumnWidths(bool expandedView) return widths.ToArray(); } - List CreateColumnRects(Rect rect, bool expandedView) + List CreateColumnRects(Rect rect) { float[] widths = GetColumnWidths(expandedView); float offsetX = rect.x; @@ -131,13 +140,15 @@ void showCustomURLField(bool lora) Repaint(); } - void SetModelIfNone() + void SetModelIfNone(string filename, bool lora) { LLM llmScript = (LLM)target; - if (llmScript.model == "" && LLMManager.modelEntries.Count == 1) llmScript.SetModel(LLMManager.modelEntries[0].localPath); + int num = LLMManager.Num(lora); + if (!lora && llmScript.model == "" && num == 1) llmScript.SetModel(filename); + if (lora && llmScript.lora == "" && num == 1) llmScript.SetLora(filename); } - async Task createCustomURLField(Rect rect) + async Task createCustomURLField() { bool submit = false; bool exit = false; @@ -154,16 +165,13 @@ async Task createCustomURLField(Rect rect) } else { - Rect labelRect = new Rect(rect.x, rect.y, 100, EditorGUIUtility.singleLineHeight); - Rect textRect = new Rect(rect.x + labelRect.width + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect submitRect = new Rect(rect.x + labelRect.width + buttonWidth + elementPadding * 2, rect.y, buttonWidth / 2f, EditorGUIUtility.singleLineHeight); - Rect backRect = new Rect(rect.x + labelRect.width + buttonWidth * 1.5f + elementPadding * 3, rect.y, buttonWidth / 2f, EditorGUIUtility.singleLineHeight); - - EditorGUI.LabelField(labelRect, "Enter URL:"); + EditorGUILayout.BeginHorizontal(); + EditorGUILayout.LabelField("Enter URL", GUILayout.Width(100)); GUI.SetNextControlName("customURLFocus"); - customURL = EditorGUI.TextField(textRect, customURL); - submit = GUI.Button(submitRect, "Submit"); - exit = GUI.Button(backRect, "Back"); + customURL = EditorGUILayout.TextField(customURL, GUILayout.Width(buttonWidth)); + submit = GUILayout.Button("Submit", GUILayout.Width(buttonWidth)); + exit = GUILayout.Button("Back", GUILayout.Width(buttonWidth)); + EditorGUILayout.EndHorizontal(); if (customURLFocus) { @@ -179,32 +187,33 @@ async Task createCustomURLField(Rect rect) Repaint(); if (submit && customURL != "") { - await LLMManager.Download(customURL, customURLLora); - SetModelIfNone(); + string filename = await LLMManager.Download(customURL, customURLLora); + SetModelIfNone(filename, customURLLora); UpdateModels(true); } } } - async Task createButtons(Rect rect, LLM llmScript) + async Task createButtons() { - Rect downloadModelRect = new Rect(rect.x, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect loadModelRect = new Rect(rect.x + buttonWidth + elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect downloadLoraRect = new Rect(rect.xMax - 2 * buttonWidth - elementPadding, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - Rect loadLoraRect = new Rect(rect.xMax - buttonWidth, rect.y, buttonWidth, EditorGUIUtility.singleLineHeight); - int modelIndex = EditorGUI.Popup(downloadModelRect, 0, modelOptions.ToArray()); + LLM llmScript = (LLM)target; + EditorGUILayout.BeginHorizontal(); + + GUIStyle centeredPopupStyle = new GUIStyle(EditorStyles.popup); + centeredPopupStyle.alignment = TextAnchor.MiddleCenter; + int modelIndex = EditorGUILayout.Popup(0, modelOptions.ToArray(), centeredPopupStyle, GUILayout.Width(buttonWidth)); if (modelIndex == 1) { showCustomURLField(false); } else if (modelIndex > 1) { - await LLMManager.DownloadModel(modelURLs[modelIndex], modelOptions[modelIndex]); - SetModelIfNone(); + string filename = await LLMManager.DownloadModel(modelURLs[modelIndex], modelOptions[modelIndex]); + SetModelIfNone(filename, false); UpdateModels(true); } - if (GUI.Button(loadModelRect, "Load model")) + if (GUILayout.Button("Load model", GUILayout.Width(buttonWidth))) { EditorApplication.delayCall += () => { @@ -216,14 +225,16 @@ async Task createButtons(Rect rect, LLM llmScript) } }; } + EditorGUILayout.EndHorizontal(); if (llmScript.advancedOptions) { - if (GUI.Button(downloadLoraRect, "Download LoRA")) + EditorGUILayout.BeginHorizontal(); + if (GUILayout.Button("Download LoRA", GUILayout.Width(buttonWidth))) { showCustomURLField(true); } - if (GUI.Button(loadLoraRect, "Load LoRA")) + if (GUILayout.Button("Load LoRA", GUILayout.Width(buttonWidth))) { EditorApplication.delayCall += () => { @@ -234,14 +245,14 @@ async Task createButtons(Rect rect, LLM llmScript) } }; } + EditorGUILayout.EndHorizontal(); } } - async void DrawFooter(Rect rect) + async Task AddLoadButtons() { - LLM llmScript = (LLM)target; - if (showCustomURL) await createCustomURLField(rect); - else await createButtons(rect, llmScript); + if (showCustomURL) await createCustomURLField(); + else await createButtons(); } void OnEnable() @@ -261,14 +272,14 @@ void OnEnable() if (index >= LLMManager.modelEntries.Count) return; ModelEntry entry = LLMManager.modelEntries[index]; - List rects = CreateColumnRects(rect, llmScript.advancedOptions); + List rects = CreateColumnRects(rect); int col = 0; Rect selectRect = rects[col++]; Rect nameRect = rects[col++]; Rect templateRect = rects[col++]; Rect urlRect = new Rect(); Rect pathRect = new Rect(); - if (llmScript.advancedOptions) + if (expandedView) { urlRect = rects[col++]; pathRect = rects[col++]; @@ -276,25 +287,25 @@ void OnEnable() Rect includeInBuildRect = rects[col++]; Rect actionRect = rects[col++]; - bool hasPath = entry.localPath != null && entry.localPath != ""; + bool hasPath = entry.path != null && entry.path != ""; bool hasURL = entry.url != null && entry.url != ""; bool isSelected = false; if (!entry.lora) { - isSelected = llmScript.model == entry.localPath; + isSelected = llmScript.model == entry.filename; bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); - if (newSelected && !isSelected) llmScript.SetModel(entry.localPath); + if (newSelected && !isSelected) llmScript.SetModel(entry.filename); } else { - isSelected = llmScript.lora == entry.localPath; + isSelected = llmScript.lora == entry.filename; bool newSelected = EditorGUI.Toggle(selectRect, isSelected, EditorStyles.radioButton); - if (newSelected && !isSelected) llmScript.SetLora(entry.localPath); + if (newSelected && !isSelected) llmScript.SetLora(entry.filename); else if (!newSelected && isSelected) llmScript.SetLora(""); } - DrawCopyableLabel(nameRect, entry.name); + DrawCopyableLabel(nameRect, entry.label); if (!entry.lora) { @@ -308,7 +319,7 @@ void OnEnable() } } - if (llmScript.advancedOptions) + if (expandedView) { if (hasURL) { @@ -323,7 +334,7 @@ void OnEnable() UpdateModels(); } } - DrawCopyableLabel(pathRect, entry.localPath); + DrawCopyableLabel(pathRect, entry.path); } bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); @@ -346,12 +357,12 @@ void OnEnable() }, drawHeaderCallback = (rect) => { - List rects = CreateColumnRects(rect, llmScript.advancedOptions); + List rects = CreateColumnRects(rect); int col = 0; EditorGUI.LabelField(rects[col++], ""); EditorGUI.LabelField(rects[col++], "Model"); EditorGUI.LabelField(rects[col++], "Chat template"); - if (llmScript.advancedOptions) + if (expandedView) { EditorGUI.LabelField(rects[col++], "URL"); EditorGUI.LabelField(rects[col++], "Path"); @@ -359,7 +370,8 @@ void OnEnable() EditorGUI.LabelField(rects[col++], "Build"); EditorGUI.LabelField(rects[col++], ""); }, - drawFooterCallback = DrawFooter, + drawFooterCallback = {}, + footerHeight = 0, }; } From 467f07833b66310e6ee498ef2277313d57ae5e6e Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 25 Jul 2024 14:49:10 +0300 Subject: [PATCH 13/26] simplify AddAsset --- Runtime/LLMCharacter.cs | 2 +- Runtime/LLMUnitySetup.cs | 25 +++++++++---------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 7fb45489..cd3a192c 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -322,7 +322,7 @@ public async Task LoadTemplate() public void SetGrammar(string path) { #if UNITY_EDITOR - if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path, LLMUnitySetup.GetAssetPath()); + if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path); #endif grammar = path; InitGrammar(); diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index a24cf3a2..abe96d72 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -288,30 +288,23 @@ private static void SetLibraryProgress(float progress) libraryProgress = progress; } - public static string AddAsset(string assetPath, string basePath) + public static string AddAsset(string assetPath) { if (!File.Exists(assetPath)) { LogError($"{assetPath} does not exist!"); return null; } - // add an asset to the basePath directory if it is not already there and return the relative path - string basePathSlash = basePath.Replace('\\', '/'); - string fullPath = Path.GetFullPath(assetPath).Replace('\\', '/'); - Directory.CreateDirectory(basePathSlash); - if (!fullPath.StartsWith(basePathSlash)) + string filename = Path.GetFileName(assetPath); + string fullPath = GetAssetPath(filename); + AssetDatabase.StartAssetEditing(); + foreach (string path in new string[] {fullPath, fullPath + ".meta"}) { - // if the asset is not in the assets dir copy it over - fullPath = Path.Combine(basePathSlash, Path.GetFileName(assetPath)); - AssetDatabase.StartAssetEditing(); - foreach (string filename in new string[] {fullPath, fullPath + ".meta"}) - { - if (File.Exists(filename)) File.Delete(filename); - } - CreateSymlink(assetPath, fullPath); - AssetDatabase.StopAssetEditing(); + if (File.Exists(path)) File.Delete(path); } - return fullPath.Substring(basePathSlash.Length + 1); + File.Copy(assetPath, fullPath); + AssetDatabase.StopAssetEditing(); + return filename; } public static void CreateSymlink(string sourcePath, string targetPath) From ee7c6ce2d3250d0cdb29e88d1adcf6370340f38d Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Thu, 25 Jul 2024 15:00:51 +0300 Subject: [PATCH 14/26] use LLM manager on Editor mode otherwise GetAssetPath --- Runtime/LLM.cs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index e2ce0b0f..6db36a56 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -90,7 +90,6 @@ public class LLM : MonoBehaviour public string lora = ""; /// \cond HIDE - public LLMManager llmManager = new LLMManager(); IntPtr LLMObject = IntPtr.Zero; List clients = new List(); @@ -102,6 +101,9 @@ public class LLM : MonoBehaviour /// \endcond #if UNITY_EDITOR + + public LLMManager llmManager = new LLMManager(); + public LLM() { LLMManager.Register(this); @@ -171,7 +173,11 @@ protected virtual string GetLlamaccpArguments() LLMUnitySetup.LogError("No model file provided!"); return null; } +#if UNITY_EDITOR + string modelPath = LLMManager.Get(model).path; +#else string modelPath = LLMUnitySetup.GetAssetPath(model); +#endif if (!File.Exists(modelPath)) { LLMUnitySetup.LogError($"File {modelPath} not found!"); @@ -180,7 +186,11 @@ protected virtual string GetLlamaccpArguments() string loraPath = ""; if (lora != "") { +#if UNITY_EDITOR + loraPath = LLMManager.Get(lora).path; +#else loraPath = LLMUnitySetup.GetAssetPath(lora); +#endif if (!File.Exists(loraPath)) { LLMUnitySetup.LogError($"File {loraPath} not found!"); From a1a0b7165f9b54abcded4a8d136b0a84eac93e6c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 13:15:09 +0300 Subject: [PATCH 15/26] improve build process --- Editor/LLMBuildProcessor.cs | 142 +++++++----------------------------- Editor/LLMEditor.cs | 6 +- Runtime/LLMBuilder.cs | 120 ++++++++++++++++++++++++++++++ Runtime/LLMBuilder.cs.meta | 11 +++ Runtime/LLMUnitySetup.cs | 78 +++++++++++++------- 5 files changed, 209 insertions(+), 148 deletions(-) create mode 100644 Runtime/LLMBuilder.cs create mode 100644 Runtime/LLMBuilder.cs.meta diff --git a/Editor/LLMBuildProcessor.cs b/Editor/LLMBuildProcessor.cs index 05cc6238..b98caf0f 100644 --- a/Editor/LLMBuildProcessor.cs +++ b/Editor/LLMBuildProcessor.cs @@ -2,150 +2,58 @@ using UnityEditor.Build; using UnityEditor.Build.Reporting; using UnityEngine; -using System.IO; -using System.Collections.Generic; -using System; namespace LLMUnity { - public class LLMBuildProcessor : MonoBehaviour, IPreprocessBuildWithReport, IPostprocessBuildWithReport + public class LLMBuildProcessor : IPreprocessBuildWithReport, IPostprocessBuildWithReport { public int callbackOrder => 0; - static string tempDir = Path.Combine(Application.temporaryCachePath, "LLMBuildProcessor", Path.GetFileName(LLMUnitySetup.libraryPath)); - static List movedPairs = new List(); - static string movedCache = Path.Combine(tempDir, "moved.json"); - [InitializeOnLoadMethod] - private static void InitializeOnLoad() - { - if (!Directory.Exists(tempDir)) Directory.CreateDirectory(tempDir); - else ResetMoves(); - } - - // CALLED BEFORE THE BUILD + // called before the build public void OnPreprocessBuild(BuildReport report) { - // Start listening for errors when build starts Application.logMessageReceived += OnBuildError; - HideLibraryPlatforms(report.summary.platform); - HideModels(); - if (movedPairs.Count > 0) AssetDatabase.Refresh(); - } - - // CALLED DURING BUILD TO CHECK FOR ERRORS - private void OnBuildError(string condition, string stacktrace, LogType type) - { - if (type == LogType.Error) - { - // FAILED TO BUILD, STOP LISTENING FOR ERRORS - BuildCompleted(); - } - } - - // CALLED AFTER THE BUILD - public void OnPostprocessBuild(BuildReport report) - { - BuildCompleted(); - } - - public void BuildCompleted() - { - Application.logMessageReceived -= OnBuildError; - ResetMoves(); - } - - static bool MovePath(string source, string target) - { - bool moved = false; - if (File.Exists(source)) - { - File.Move(source, target); - moved = true; - } - else if (Directory.Exists(source)) - { - Directory.Move(source, target); - moved = true; - } - if (moved) - { - movedPairs.Add(new MovedPair {source = source, target = target}); - File.WriteAllText(movedCache, JsonUtility.ToJson(new FoldersMovedWrapper { movedPairs = movedPairs })); - } - return moved; - } - - static void MoveAssetAndMeta(string source, string target) - { - MovePath(source + ".meta", target + ".meta"); - MovePath(source, target); - } - - static void HideLibraryPlatforms(BuildTarget buildPlatform) - { - List platforms = new List(){ "windows", "macos", "linux", "android" }; - switch (buildPlatform) + string platform = null; + switch (report.summary.platform) { case BuildTarget.StandaloneWindows: case BuildTarget.StandaloneWindows64: - platforms.Remove("windows"); + platform = "windows"; break; case BuildTarget.StandaloneLinux64: - platforms.Remove("linux"); + platform = "linux"; break; case BuildTarget.StandaloneOSX: - platforms.Remove("macos"); + platform = "macos"; break; case BuildTarget.Android: - platforms.Remove("android"); + platform = "android"; + break; + case BuildTarget.iOS: + platform = "ios"; break; } - - foreach (string dirname in Directory.GetDirectories(LLMUnitySetup.libraryPath)) - { - foreach (string platform in platforms) - { - if (Path.GetFileName(dirname).StartsWith(platform)) - { - MoveAssetAndMeta(dirname, Path.Combine(tempDir, Path.GetFileName(dirname))); - } - } - } + LLMBuilder.HideLibraryPlatforms(platform); + LLMBuilder.CopyModels(); + AssetDatabase.Refresh(); } - static void HideModels() + // called during build to check for errors + private void OnBuildError(string condition, string stacktrace, LogType type) { - foreach (LLM llm in FindObjectsOfType()) - { - // if (!llm.downloadOnBuild) continue; - // if (llm.modelURL != "") MoveAssetAndMeta(LLMUnitySetup.GetAssetPath(llm.model), Path.Combine(tempDir, Path.GetFileName(llm.model))); - if (llm.loraURL != "") MoveAssetAndMeta(LLMUnitySetup.GetAssetPath(llm.lora), Path.Combine(tempDir, Path.GetFileName(llm.lora))); - } + if (type == LogType.Error) BuildCompleted(); } - static void ResetMoves() + // called after the build + public void OnPostprocessBuild(BuildReport report) { - if (!File.Exists(movedCache)) return; - List movedPairs = JsonUtility.FromJson(File.ReadAllText(movedCache)).movedPairs; - if (movedPairs == null) return; - - bool refresh = false; - foreach (var pair in movedPairs) refresh |= MovePath(pair.target, pair.source); - if (refresh) AssetDatabase.Refresh(); - File.Delete(movedCache); + BuildCompleted(); } - } - [Serializable] - public struct MovedPair - { - public string source; - public string target; - } - - [Serializable] - public class FoldersMovedWrapper - { - public List movedPairs; + public void BuildCompleted() + { + Application.logMessageReceived -= OnBuildError; + LLMBuilder.Reset(); + } } } diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index c78cef3f..cb482931 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -63,7 +63,7 @@ public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) } EditorGUILayout.EndHorizontal(); } - AddLoadButtons(); + _ = AddLoadButtons(); bool downloadOnStart = EditorGUILayout.Toggle("Download on Start", LLMManager.downloadOnStart); if (downloadOnStart != LLMManager.downloadOnStart) { @@ -92,8 +92,8 @@ static void ResetModelOptions() { List existingOptions = new List(); foreach (ModelEntry entry in LLMManager.modelEntries) existingOptions.Add(entry.url); - modelOptions = new List(); - modelURLs = new List(); + modelOptions = new List(){"Download model", "Custom URL"}; + modelURLs = new List(){null, null}; foreach ((string name, string url) in LLMUnitySetup.modelOptions) { if (url != null && existingOptions.Contains(url)) continue; diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs new file mode 100644 index 00000000..ff433961 --- /dev/null +++ b/Runtime/LLMBuilder.cs @@ -0,0 +1,120 @@ +using UnityEditor; +using UnityEngine; +using System.IO; +using System.Collections.Generic; +using System; + +#if UNITY_EDITOR +namespace LLMUnity +{ + public class LLMBuilder + { + static List movedPairs = new List(); + static string movedCache = Path.Combine(LLMUnitySetup.buildTempDir, "moved.json"); + + [InitializeOnLoadMethod] + private static void InitializeOnLoad() + { + Directory.CreateDirectory(LLMUnitySetup.buildTempDir); + Reset(); + } + + public delegate void ActionCallback(string source, string target); + + static void AddMovedPair(string source, string target) + { + movedPairs.Add(new MovedPair {source = source, target = target}); + File.WriteAllText(movedCache, JsonUtility.ToJson(new FoldersMovedWrapper { movedPairs = movedPairs }, true)); + } + + static bool MoveAction(string source, string target, bool addEntry = true) + { + ActionCallback moveCallback; + if (File.Exists(source)) moveCallback = File.Move; + else if (Directory.Exists(source)) moveCallback = LLMUnitySetup.MovePath; + else return false; + + if (addEntry) AddMovedPair(source, target); + moveCallback(source, target); + return true; + } + + static bool CopyAction(string source, string target, bool addEntry = true) + { + ActionCallback copyCallback; + if (File.Exists(source)) copyCallback = File.Copy; + else if (Directory.Exists(source)) copyCallback = LLMUnitySetup.CopyPath; + else return false; + + if (addEntry) AddMovedPair("", target); + copyCallback(source, target); + return true; + } + + static bool DeleteAction(string source) + { + return LLMUnitySetup.DeletePath(source); + } + + public static void HideLibraryPlatforms(string platform) + { + List platforms = new List(){ "windows", "macos", "linux", "android", "ios" }; + platforms.Remove(platform); + foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath)) + { + foreach (string platformPrefix in platforms) + { + if (Path.GetFileName(source).StartsWith(platformPrefix)) + { + string target = Path.Combine(LLMUnitySetup.buildTempDir, Path.GetFileName(source)); + MoveAction(source, target); + MoveAction(source + ".meta", target + ".meta"); + } + } + } + } + + public static void CopyModels() + { + if (LLMManager.downloadOnStart) return; + foreach (ModelEntry modelEntry in LLMManager.modelEntries) + { + string source = modelEntry.path; + string target = LLMUnitySetup.GetAssetPath(modelEntry.filename); + if (!modelEntry.includeInBuild || File.Exists(target)) continue; + CopyAction(source, target); + AddMovedPair("", target + ".meta"); + } + } + + public static void Reset() + { + if (!File.Exists(movedCache)) return; + List movedPairs = JsonUtility.FromJson(File.ReadAllText(movedCache)).movedPairs; + if (movedPairs == null) return; + + bool refresh = false; + foreach (var pair in movedPairs) + { + if (pair.source == "") refresh |= DeleteAction(pair.target); + else refresh |= MoveAction(pair.target, pair.source, false); + } + if (refresh) AssetDatabase.Refresh(); + LLMUnitySetup.DeletePath(movedCache); + } + } + + [Serializable] + public struct MovedPair + { + public string source; + public string target; + } + + [Serializable] + public class FoldersMovedWrapper + { + public List movedPairs; + } +} +#endif diff --git a/Runtime/LLMBuilder.cs.meta b/Runtime/LLMBuilder.cs.meta new file mode 100644 index 00000000..14615c3e --- /dev/null +++ b/Runtime/LLMBuilder.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: e52304e7914527ae0801d752670d7bec +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index abe96d72..63ceab3c 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -78,16 +78,18 @@ public class LLMUnitySetup public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; /// LlamaLib path public static string libraryPath = GetAssetPath(Path.GetFileName(LlamaLibURL).Replace(".zip", "")); + /// LLMnity store path + public static string LLMUnityStore = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity"); /// Model download path - public static string modelDownloadPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity"); + public static string modelDownloadPath = Path.Combine(LLMUnityStore, "models"); /// Model list for project public static string modelListPath = Path.Combine(Application.temporaryCachePath, "modelCache.json"); + /// Temporary dir for build + public static string buildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild"); /// Default models for download [HideInInspector] public static readonly (string, string)[] modelOptions = new(string, string)[] { - ("Download model", null), - ("Custom URL", null), ("Mistral 7B Instruct v0.2 (medium, best overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true"), ("OpenHermes 2.5 7B (medium, best for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf?download=true"), ("Phi 3 (small, great)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf?download=true"), @@ -228,11 +230,11 @@ public static async Task AndroidExtractFile(string assetName, bool overwrite = f string target = GetAssetPath(assetName); if (!overwrite && File.Exists(target)) { - Debug.Log($"File {target} already exists"); + Log($"File {target} already exists"); return; } - Debug.Log($"Extracting {source} to {target}"); + Log($"Extracting {source} to {target}"); // UnityWebRequest to read the file from StreamingAssets UnityWebRequest www = UnityWebRequest.Get(source); @@ -242,7 +244,7 @@ public static async Task AndroidExtractFile(string assetName, bool overwrite = f while (!operation.isDone) await Task.Delay(1); if (www.result != UnityWebRequest.Result.Success) { - Debug.LogError("Failed to load file from StreamingAssets: " + www.error); + LogError("Failed to load file from StreamingAssets: " + www.error); } else { @@ -267,6 +269,44 @@ public static bool IsSubPath(string childPath, string parentPath) } #if UNITY_EDITOR + + public static void CopyPath(string source, string target) + { + if (File.Exists(source)) + { + File.Copy(source, target); + } + else if (Directory.Exists(source)) + { + Directory.CreateDirectory(target); + List filesAndDirs = new List(); + filesAndDirs.AddRange(Directory.GetFiles(source)); + filesAndDirs.AddRange(Directory.GetDirectories(source)); + foreach (string path in filesAndDirs) + { + CopyPath(path, Path.Combine(target, Path.GetFileName(path))); + } + } + } + + public static void MovePath(string source, string target) + { + CopyPath(source, target); + DeletePath(source); + } + + public static bool DeletePath(string path) + { + if (!IsSubPath(path, GetAssetPath()) && !IsSubPath(path, buildTempDir)) + { + LogError($"Safeguard: {path} will not be deleted because it may not be safe"); + return false; + } + if (File.Exists(path)) File.Delete(path); + else if (Directory.Exists(path)) Directory.Delete(path, true); + return true; + } + [HideInInspector] public static float libraryProgress = 1; private static async Task DownloadLibrary() @@ -277,7 +317,9 @@ private static async Task DownloadLibrary() if (!Directory.Exists(libraryPath)) { await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress); + AssetDatabase.StartAssetEditing(); ZipFile.ExtractToDirectory(libZip, libraryPath); + AssetDatabase.StopAssetEditing(); File.Delete(libZip); } libraryProgress = 1; @@ -307,26 +349,6 @@ public static string AddAsset(string assetPath) return filename; } - public static void CreateSymlink(string sourcePath, string targetPath) - { - bool isDirectory = Directory.Exists(sourcePath); - if (!isDirectory && !File.Exists(sourcePath)) throw new FileNotFoundException($"Source path does not exist: {sourcePath}"); - - bool success; -#if UNITY_STANDALONE_WIN - success = CreateSymbolicLink(targetPath, sourcePath, (int)isDirectory); -#else - success = symlink(sourcePath, targetPath) == 0; -#endif - if (!success) throw new IOException($"Failed to create symbolic link: {targetPath}"); - } - - [DllImport("kernel32.dll", CharSet = CharSet.Unicode)] - private static extern bool CreateSymbolicLink(string lpSymlinkFileName, string lpTargetFileName, int dwFlags); - - [DllImport("libc", SetLastError = true)] - private static extern int symlink(string oldpath, string newpath); - #endif /// \endcond public static int GetMaxFreqKHz(int cpuId) @@ -422,7 +444,7 @@ public static int AndroidGetNumBigCores() } catch (Exception e) { - Debug.LogError(e.Message); + LogError(e.Message); } int numBigCores = 0; @@ -474,7 +496,7 @@ public static int AndroidGetNumBigCoresCapacity() } catch (Exception e) { - Debug.LogError(e.Message); + LogError(e.Message); } int numBigCores = 0; From 9f29ecab2f67b75bad3ed720dbd4b03e435ca72f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 13:43:14 +0300 Subject: [PATCH 16/26] store models in a player pref --- Runtime/LLMManager.cs | 13 +++++++++---- Runtime/LLMUnitySetup.cs | 3 --- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 04e08d8b..b5d05fac 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -29,6 +29,7 @@ public class LLMManagerStore public class LLMManager { + static string LLMManagerPref = "LLMManager"; public static bool downloadOnStart = false; public static List modelEntries = new List(); @@ -70,6 +71,7 @@ public static string AddEntry(string path, bool lora = false, string label = nul } } modelEntries.Insert(indexToInsert, entry); + Save(); return entry.filename; } @@ -175,6 +177,7 @@ public static void Remove(ModelEntry entry) { if (entry == null) return; modelEntries.Remove(entry); + Save(); foreach (LLM llm in llms) { if (!entry.lora && llm.model == entry.filename) llm.model = ""; @@ -226,14 +229,16 @@ public static void SetLoraProgress(float progress) public static void Save() { - Directory.CreateDirectory(Path.GetDirectoryName(LLMUnitySetup.modelListPath)); - File.WriteAllText(LLMUnitySetup.modelListPath, JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart }, true)); + string pref = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart }, true); + PlayerPrefs.SetString(LLMManagerPref, pref); + PlayerPrefs.Save(); } public static void Load() { - if (!File.Exists(LLMUnitySetup.modelListPath)) return; - LLMManagerStore store = JsonUtility.FromJson(File.ReadAllText(LLMUnitySetup.modelListPath)); + string pref = PlayerPrefs.GetString(LLMManagerPref); + if (pref == null || pref == "") return; + LLMManagerStore store = JsonUtility.FromJson(pref); downloadOnStart = store.downloadOnStart; modelEntries = store.modelEntries; } diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 63ceab3c..dc70d709 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -7,7 +7,6 @@ using System; using System.IO.Compression; using System.Collections.Generic; -using System.Runtime.InteropServices; using UnityEngine.Networking; /// @defgroup llm LLM @@ -82,8 +81,6 @@ public class LLMUnitySetup public static string LLMUnityStore = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity"); /// Model download path public static string modelDownloadPath = Path.Combine(LLMUnityStore, "models"); - /// Model list for project - public static string modelListPath = Path.Combine(Application.temporaryCachePath, "modelCache.json"); /// Temporary dir for build public static string buildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild"); From a86ed07551e8757a4824c1a1e0bb98c32de6d631 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 17:13:15 +0300 Subject: [PATCH 17/26] download models with LLMManager --- Editor/LLMBuildProcessor.cs | 2 +- Editor/LLMEditor.cs | 16 ++- Runtime/LLM.cs | 46 +++++---- Runtime/LLMBuilder.cs | 57 +++++------ Runtime/LLMManager.cs | 150 +++++++++++++++++++++++----- Runtime/LLMUnitySetup.cs | 20 +++- Runtime/ResumingWebClient.cs | 9 +- Samples~/AndroidDemo/AndroidDemo.cs | 4 +- 8 files changed, 210 insertions(+), 94 deletions(-) diff --git a/Editor/LLMBuildProcessor.cs b/Editor/LLMBuildProcessor.cs index b98caf0f..23165cbf 100644 --- a/Editor/LLMBuildProcessor.cs +++ b/Editor/LLMBuildProcessor.cs @@ -34,7 +34,7 @@ public void OnPreprocessBuild(BuildReport report) break; } LLMBuilder.HideLibraryPlatforms(platform); - LLMBuilder.CopyModels(); + LLMBuilder.BuildModels(); AssetDatabase.Refresh(); } diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index cb482931..b9354cfa 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -21,7 +21,6 @@ public class LLMEditor : PropertyEditor static GUIContent trashIcon; static List modelOptions; static List modelURLs; - string[] templateOptions; string elementFocus = ""; bool showCustomURL = false; string customURL = ""; @@ -126,7 +125,6 @@ List CreateColumnRects(Rect rect) void UpdateModels(bool resetOptions = false) { - LLMManager.Save(); if (resetOptions) ResetModelOptions(); Repaint(); } @@ -259,7 +257,6 @@ void OnEnable() { LLM llmScript = (LLM)target; ResetModelOptions(); - templateOptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); Texture2D loraLineTexture = new Texture2D(1, 1); loraLineTexture.SetPixel(0, 0, Color.black); @@ -309,12 +306,13 @@ void OnEnable() if (!entry.lora) { - int templateIndex = Array.IndexOf(ChatTemplate.templatesDescription.Values.ToList().ToArray(), entry.chatTemplate); - int newTemplateIndex = EditorGUI.Popup(templateRect, templateIndex, templateOptions); + string[] templateDescriptions = ChatTemplate.templatesDescription.Keys.ToList().ToArray(); + string[] templates = ChatTemplate.templatesDescription.Values.ToList().ToArray(); + int templateIndex = Array.IndexOf(templates, entry.chatTemplate); + int newTemplateIndex = EditorGUI.Popup(templateRect, templateIndex, templateDescriptions); if (newTemplateIndex != templateIndex) { - entry.chatTemplate = ChatTemplate.templatesDescription[templateOptions[newTemplateIndex]]; - if (isSelected) llmScript.SetTemplate(entry.chatTemplate); + LLMManager.SetTemplate(entry.filename, templates[newTemplateIndex]); UpdateModels(); } } @@ -330,7 +328,7 @@ void OnEnable() string newURL = EditorGUI.TextField(urlRect, entry.url); if (newURL != entry.url) { - entry.url = newURL; + LLMManager.SetURL(entry, newURL); UpdateModels(); } } @@ -340,7 +338,7 @@ void OnEnable() bool includeInBuild = EditorGUI.ToggleLeft(includeInBuildRect, "", entry.includeInBuild); if (includeInBuild != entry.includeInBuild) { - entry.includeInBuild = includeInBuild; + LLMManager.SetIncludeInBuild(entry, includeInBuild); UpdateModels(); } diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 6db36a56..4e111541 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -79,6 +79,8 @@ public class LLM : MonoBehaviour public bool started { get; protected set; } = false; /// Boolean set to true if the server has failed to start. public bool failed { get; protected set; } = false; + /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. + public static bool modelsDownloaded { get; protected set; } = false; /// the LLM model to use. /// Models with .gguf format are allowed. @@ -100,6 +102,26 @@ public class LLM : MonoBehaviour /// \endcond + /// + /// The Unity Awake function that starts the LLM server. + /// The server can be started asynchronously if the asynchronousStartup option is set. + /// + public async void Awake() + { + if (!enabled) return; +#if !UNITY_EDITOR + await LLMManager.DownloadModels(); +#endif + modelsDownloaded = true; + // if (Application.platform == RuntimePlatform.Android) await AndroidExtractModels(); + string arguments = GetLlamaccpArguments(); + if (arguments == null) return; + if (asynchronousStartup) await Task.Run(() => StartLLMServer(arguments)); + else StartLLMServer(arguments); + if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); + if (basePrompt != "") await SetBasePrompt(basePrompt); + } + #if UNITY_EDITOR public LLMManager llmManager = new LLMManager(); @@ -116,6 +138,12 @@ public async Task WaitUntilReady() while (!started) await Task.Yield(); } + public static async Task WaitUntilModelsDownloaded(Callback downloadProgressCallback = null) + { + if (downloadProgressCallback != null) LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback); + while (!modelsDownloaded) await Task.Yield(); + } + /// /// Allows to set the model used by the LLM. /// The model provided is copied to the Assets/StreamingAssets folder that allows it to also work in the build. @@ -210,24 +238,6 @@ protected virtual string GetLlamaccpArguments() return arguments; } - /// - /// The Unity Awake function that starts the LLM server. - /// The server can be started asynchronously if the asynchronousStartup option is set. - /// - public async void Awake() - { - if (!enabled) return; - // if (downloadOnBuild) await DownloadModels(); - // modelsDownloaded = true; - // if (Application.platform == RuntimePlatform.Android) await AndroidExtractModels(); - string arguments = GetLlamaccpArguments(); - if (arguments == null) return; - if (asynchronousStartup) await Task.Run(() => StartLLMServer(arguments)); - else StartLLMServer(arguments); - if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); - if (basePrompt != "") await SetBasePrompt(basePrompt); - } - private void SetupLogging() { logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning, true); diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs index ff433961..db7e80c3 100644 --- a/Runtime/LLMBuilder.cs +++ b/Runtime/LLMBuilder.cs @@ -2,29 +2,31 @@ using UnityEngine; using System.IO; using System.Collections.Generic; -using System; #if UNITY_EDITOR namespace LLMUnity { public class LLMBuilder { - static List movedPairs = new List(); - static string movedCache = Path.Combine(LLMUnitySetup.buildTempDir, "moved.json"); + static List movedPairs = new List(); + static string movedCache = Path.Combine(LLMUnitySetup.BuildTempDir, "moved.json"); [InitializeOnLoadMethod] private static void InitializeOnLoad() { - Directory.CreateDirectory(LLMUnitySetup.buildTempDir); + Directory.CreateDirectory(LLMUnitySetup.BuildTempDir); Reset(); } - public delegate void ActionCallback(string source, string target); - static void AddMovedPair(string source, string target) { - movedPairs.Add(new MovedPair {source = source, target = target}); - File.WriteAllText(movedCache, JsonUtility.ToJson(new FoldersMovedWrapper { movedPairs = movedPairs }, true)); + movedPairs.Add(new StringPair {source = source, target = target}); + File.WriteAllText(movedCache, JsonUtility.ToJson(new ListStringPair { pairs = movedPairs }, true)); + } + + static void AddTargetPair(string target) + { + AddMovedPair("", target); } static bool MoveAction(string source, string target, bool addEntry = true) @@ -46,11 +48,17 @@ static bool CopyAction(string source, string target, bool addEntry = true) else if (Directory.Exists(source)) copyCallback = LLMUnitySetup.CopyPath; else return false; - if (addEntry) AddMovedPair("", target); + if (addEntry) AddTargetPair(target); copyCallback(source, target); return true; } + static void CopyActionAddMeta(string source, string target) + { + CopyAction(source, target); + AddTargetPair(target + ".meta"); + } + static bool DeleteAction(string source) { return LLMUnitySetup.DeletePath(source); @@ -66,7 +74,7 @@ public static void HideLibraryPlatforms(string platform) { if (Path.GetFileName(source).StartsWith(platformPrefix)) { - string target = Path.Combine(LLMUnitySetup.buildTempDir, Path.GetFileName(source)); + string target = Path.Combine(LLMUnitySetup.BuildTempDir, Path.GetFileName(source)); MoveAction(source, target); MoveAction(source + ".meta", target + ".meta"); } @@ -74,23 +82,17 @@ public static void HideLibraryPlatforms(string platform) } } - public static void CopyModels() + public static void BuildModels() { - if (LLMManager.downloadOnStart) return; - foreach (ModelEntry modelEntry in LLMManager.modelEntries) - { - string source = modelEntry.path; - string target = LLMUnitySetup.GetAssetPath(modelEntry.filename); - if (!modelEntry.includeInBuild || File.Exists(target)) continue; - CopyAction(source, target); - AddMovedPair("", target + ".meta"); - } + LLMUnitySetup.DeletePath(LLMUnitySetup.BuildFile); + LLMManager.Build(CopyActionAddMeta); + if (File.Exists(LLMUnitySetup.BuildFile)) AddTargetPair(LLMUnitySetup.BuildFile); } public static void Reset() { if (!File.Exists(movedCache)) return; - List movedPairs = JsonUtility.FromJson(File.ReadAllText(movedCache)).movedPairs; + List movedPairs = JsonUtility.FromJson(File.ReadAllText(movedCache)).pairs; if (movedPairs == null) return; bool refresh = false; @@ -103,18 +105,5 @@ public static void Reset() LLMUnitySetup.DeletePath(movedCache); } } - - [Serializable] - public struct MovedPair - { - public string source; - public string target; - } - - [Serializable] - public class FoldersMovedWrapper - { - public List movedPairs; - } } #endif diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index b5d05fac..e644d72b 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -1,4 +1,3 @@ -#if UNITY_EDITOR using System; using System.Collections.Generic; using System.IO; @@ -8,6 +7,7 @@ namespace LLMUnity { +#if UNITY_EDITOR [Serializable] public class ModelEntry { @@ -26,18 +26,78 @@ public class LLMManagerStore public bool downloadOnStart; public List modelEntries; } +#endif public class LLMManager { + public static float downloadProgress = 1; + public static List> downloadProgressCallbacks = new List>(); + static long totalSize; + static long currFileSize; + static long completedSize; + + public static void SetDownloadProgress(float progress) + { + downloadProgress = (completedSize + progress * currFileSize) / totalSize; + foreach (Callback downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress); + } + + public static async Task DownloadModels() + { + if (!File.Exists(LLMUnitySetup.BuildFile)) return; + + List downloads = new List(); + using (FileStream fs = new FileStream(LLMUnitySetup.BuildFile, FileMode.Open, FileAccess.Read)) + { + using (BinaryReader reader = new BinaryReader(fs)) + { + List downloadsToDo = JsonUtility.FromJson(reader.ReadString()).pairs; + foreach (StringPair pair in downloadsToDo) + { + string target = LLMUnitySetup.GetAssetPath(pair.target); + if (!File.Exists(target)) downloads.Add(new StringPair {source = pair.source, target = target}); + } + } + } + if (downloads.Count == 0) return; + + try + { + downloadProgress = 0; + totalSize = 0; + completedSize = 0; + + ResumingWebClient client = new ResumingWebClient(); + Dictionary fileSizes = new Dictionary(); + foreach (StringPair pair in downloads) + { + long size = client.GetURLFileSize(pair.source); + fileSizes[pair.source] = size; + totalSize += size; + } + + foreach (StringPair pair in downloads) + { + currFileSize = fileSizes[pair.source]; + await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress); + totalSize += currFileSize; + } + + completedSize = totalSize; + SetDownloadProgress(0); + } + catch (Exception ex) + { + LLMUnitySetup.LogError($"Error downloading the models"); + throw ex; + } + } + +#if UNITY_EDITOR static string LLMManagerPref = "LLMManager"; public static bool downloadOnStart = false; public static List modelEntries = new List(); - /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. - public static bool modelsDownloaded { get; protected set; } = false; - static List> modelProgressCallbacks = new List>(); - static List> loraProgressCallbacks = new List>(); - [HideInInspector] public static float modelProgress = 1; [HideInInspector] public static float loraProgress = 1; static List llms = new List(); @@ -75,15 +135,6 @@ public static string AddEntry(string path, bool lora = false, string label = nul return entry.filename; } - public static async Task WaitUntilModelsDownloaded(Callback modelProgressCallback = null, Callback loraProgressCallback = null) - { - if (modelProgressCallback != null) modelProgressCallbacks.Add(modelProgressCallback); - if (loraProgressCallback != null) loraProgressCallbacks.Add(loraProgressCallback); - while (!modelsDownloaded) await Task.Yield(); - if (modelProgressCallback != null) modelProgressCallbacks.Remove(modelProgressCallback); - if (loraProgressCallback != null) loraProgressCallbacks.Remove(loraProgressCallback); - } - public static async Task Download(string url, bool lora = false, string label = null) { foreach (ModelEntry entry in modelEntries) @@ -147,16 +198,44 @@ public static string LoadLora(string url, string label = null) return Load(url, true, label); } - public static void SetModelTemplate(string filename, string chatTemplate) + public static void SetTemplate(string filename, string chatTemplate) { - foreach (ModelEntry entry in modelEntries) + SetTemplate(Get(filename), chatTemplate); + } + + public static void SetTemplate(ModelEntry entry, string chatTemplate) + { + if (entry == null) return; + entry.chatTemplate = chatTemplate; + foreach (LLM llm in llms) { - if (entry.filename == filename) - { - entry.chatTemplate = chatTemplate; - break; - } + if (llm.model == entry.filename) llm.SetTemplate(chatTemplate); } + Save(); + } + + public static void SetURL(string filename, string url) + { + SetURL(Get(filename), url); + } + + public static void SetURL(ModelEntry entry, string url) + { + if (entry == null) return; + entry.url = url; + Save(); + } + + public static void SetIncludeInBuild(string filename, bool includeInBuild) + { + SetIncludeInBuild(Get(filename), includeInBuild); + } + + public static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild) + { + if (entry == null) return; + entry.includeInBuild = includeInBuild; + Save(); } public static ModelEntry Get(string filename) @@ -218,13 +297,11 @@ public static void Unregister(LLM llm) public static void SetModelProgress(float progress) { modelProgress = progress; - foreach (Callback modelProgressCallback in modelProgressCallbacks) modelProgressCallback?.Invoke(progress); } public static void SetLoraProgress(float progress) { loraProgress = progress; - foreach (Callback loraProgressCallback in loraProgressCallbacks) loraProgressCallback?.Invoke(progress); } public static void Save() @@ -242,6 +319,29 @@ public static void Load() downloadOnStart = store.downloadOnStart; modelEntries = store.modelEntries; } + + public static void Build(ActionCallback copyCallback) + { + List downloads = new List(); + foreach (ModelEntry modelEntry in modelEntries) + { + if (!modelEntry.includeInBuild) continue; + string target = LLMUnitySetup.GetAssetPath(modelEntry.filename); + if (File.Exists(target)) continue; + if (!downloadOnStart) copyCallback(modelEntry.path, target); + else downloads.Add(new StringPair { source = modelEntry.url, target = modelEntry.filename }); + } + + if (downloads.Count > 0) + { + string downloadJSON = JsonUtility.ToJson(new ListStringPair { pairs = downloads }, true); + using (FileStream fs = new FileStream(LLMUnitySetup.BuildFile, FileMode.Create, FileAccess.Write)) + { + using (BinaryWriter writer = new BinaryWriter(fs)) writer.Write(downloadJSON); + } + } + } + +#endif } } -#endif diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index dc70d709..e3638ccd 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -60,6 +60,20 @@ public NotImplementedException() : base("The method needs to be implemented by s public delegate void Callback(T message); public delegate Task TaskCallback(T message); public delegate T2 ContentCallback(T message); + public delegate void ActionCallback(string source, string target); + + [Serializable] + public struct StringPair + { + public string source; + public string target; + } + + [Serializable] + public class ListStringPair + { + public List pairs; + } /// \endcond /// @ingroup utils @@ -82,7 +96,9 @@ public class LLMUnitySetup /// Model download path public static string modelDownloadPath = Path.Combine(LLMUnityStore, "models"); /// Temporary dir for build - public static string buildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild"); + public static string BuildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild"); + /// Temporary dir for build + public static string BuildFile = GetAssetPath("LLMUnityBuild.bin"); /// Default models for download [HideInInspector] public static readonly (string, string)[] modelOptions = new(string, string)[] @@ -294,7 +310,7 @@ public static void MovePath(string source, string target) public static bool DeletePath(string path) { - if (!IsSubPath(path, GetAssetPath()) && !IsSubPath(path, buildTempDir)) + if (!IsSubPath(path, GetAssetPath()) && !IsSubPath(path, BuildTempDir)) { LogError($"Safeguard: {path} will not be deleted because it may not be safe"); return false; diff --git a/Runtime/ResumingWebClient.cs b/Runtime/ResumingWebClient.cs index b6e8d4d5..8a678d5c 100644 --- a/Runtime/ResumingWebClient.cs +++ b/Runtime/ResumingWebClient.cs @@ -19,7 +19,12 @@ public ResumingWebClient() _context = SynchronizationContext.Current ?? new SynchronizationContext(); } - private long GetRemoteFileSizeAsync(Uri address) + public long GetURLFileSize(string address) + { + return GetURLFileSize(new Uri(address)); + } + + public long GetURLFileSize(Uri address) { WebRequest request = GetWebRequest(address); request.Method = "HEAD"; @@ -45,7 +50,7 @@ public Task DownloadFileTaskAsyncResume(Uri address, string fileName, bool resum WebRequest request = GetWebRequest(address); if (request is HttpWebRequest webRequest && bytesToSkip > 0) { - long remoteFileSize = GetRemoteFileSizeAsync(address); + long remoteFileSize = GetURLFileSize(address); if (bytesToSkip >= remoteFileSize) { LLMUnitySetup.Log($"File is already fully downloaded: {fileName}"); diff --git a/Samples~/AndroidDemo/AndroidDemo.cs b/Samples~/AndroidDemo/AndroidDemo.cs index 681f6dfb..c697fb46 100644 --- a/Samples~/AndroidDemo/AndroidDemo.cs +++ b/Samples~/AndroidDemo/AndroidDemo.cs @@ -8,7 +8,6 @@ namespace LLMUnitySamples { public class AndroidDemo : MonoBehaviour { - public LLM llm; public LLMCharacter llmCharacter; public GameObject ChatPanel; @@ -32,14 +31,13 @@ async Task ShowDownloadScreen() { ChatPanel.SetActive(false); DownloadPanel.SetActive(true); - await llm.WaitUntilModelDownloaded(SetProgress); + await LLM.WaitUntilModelsDownloaded(SetProgress); DownloadPanel.SetActive(false); ChatPanel.SetActive(true); } async Task WarmUp() { - llm.SetTemplate("alpaca"); cores = LLMUnitySetup.AndroidGetNumBigCores(); AIText.text += $"Warming up the model...\nWill use {cores} cores"; await llmCharacter.Warmup(); From a785c2b8395a08e85e2cd8e7950867d11ce4056c Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 18:22:59 +0300 Subject: [PATCH 18/26] extract files on android --- Runtime/LLM.cs | 11 ++++++++++- Runtime/LLMCharacter.cs | 3 ++- Runtime/LLMManager.cs | 3 ++- Runtime/LLMUnitySetup.cs | 10 ++++++---- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 4e111541..0d6f1d2b 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -113,7 +113,7 @@ public async void Awake() await LLMManager.DownloadModels(); #endif modelsDownloaded = true; - // if (Application.platform == RuntimePlatform.Android) await AndroidExtractModels(); + await AndroidSetup(); string arguments = GetLlamaccpArguments(); if (arguments == null) return; if (asynchronousStartup) await Task.Run(() => StartLLMServer(arguments)); @@ -122,6 +122,15 @@ public async void Awake() if (basePrompt != "") await SetBasePrompt(basePrompt); } + public async Task AndroidSetup() + { + if (Application.platform != RuntimePlatform.Android) return; + foreach (string path in new string[] {model, lora}) + { + if (path != "" && !File.Exists(LLMUnitySetup.GetAssetPath(path))) await LLMUnitySetup.AndroidExtractFile(path); + } + } + #if UNITY_EDITOR public LLMManager llmManager = new LLMManager(); diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index cd3a192c..4a80fd9c 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -319,11 +319,12 @@ public async Task LoadTemplate() /// Set the grammar file of the LLMCharacter /// /// path to the grammar file - public void SetGrammar(string path) + public async void SetGrammar(string path) { #if UNITY_EDITOR if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path); #endif + if (Application.platform == RuntimePlatform.Android) await LLMUnitySetup.AndroidExtractFile(path); grammar = path; InitGrammar(); } diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index e644d72b..6d2fd0d2 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -44,6 +44,7 @@ public static void SetDownloadProgress(float progress) public static async Task DownloadModels() { + if (Application.platform == RuntimePlatform.Android) await LLMUnitySetup.AndroidExtractFile(LLMUnitySetup.BuildFilename); if (!File.Exists(LLMUnitySetup.BuildFile)) return; List downloads = new List(); @@ -80,7 +81,7 @@ public static async Task DownloadModels() { currFileSize = fileSizes[pair.source]; await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress); - totalSize += currFileSize; + completedSize += currFileSize; } completedSize = totalSize; diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index e3638ccd..07157409 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -97,8 +97,10 @@ public class LLMUnitySetup public static string modelDownloadPath = Path.Combine(LLMUnityStore, "models"); /// Temporary dir for build public static string BuildTempDir = Path.Combine(Application.temporaryCachePath, "LLMUnityBuild"); - /// Temporary dir for build - public static string BuildFile = GetAssetPath("LLMUnityBuild.bin"); + /// Name of file with build information for runtime + public static string BuildFilename = "LLMUnityBuild.bin"; + /// Path of file with build information for runtime + public static string BuildFile = GetAssetPath(BuildFilename); /// Default models for download [HideInInspector] public static readonly (string, string)[] modelOptions = new(string, string)[] @@ -237,13 +239,13 @@ public static async Task DownloadFile( callback?.Invoke(savePath); } - public static async Task AndroidExtractFile(string assetName, bool overwrite = false, int chunkSize = 1024*1024) + public static async Task AndroidExtractFile(string assetName, bool overwrite = false, bool log = true, int chunkSize = 1024*1024) { string source = "jar:file://" + Application.dataPath + "!/assets/" + assetName; string target = GetAssetPath(assetName); if (!overwrite && File.Exists(target)) { - Log($"File {target} already exists"); + if (log) Log($"File {target} already exists"); return; } From 6556cd65f5bdaf5d2b28d38b7e21931a560538d9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 18:29:21 +0300 Subject: [PATCH 19/26] move android libraries to Plugins --- Runtime/LLMUnitySetup.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 07157409..599d70dd 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -334,6 +334,14 @@ private static async Task DownloadLibrary() await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress); AssetDatabase.StartAssetEditing(); ZipFile.ExtractToDirectory(libZip, libraryPath); + foreach (string librarySubPath in Directory.GetDirectories(libraryPath)) + { + if (Path.GetFileName(librarySubPath).StartsWith("android")) + { + string pluginPath = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(librarySubPath)); + Directory.Move(librarySubPath, pluginPath); + } + } AssetDatabase.StopAssetEditing(); File.Delete(libZip); } From 5d40198c1b6b17aedf20d149ce6da45a20a8c980 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 18:36:23 +0300 Subject: [PATCH 20/26] move android library directly --- Runtime/LLMUnitySetup.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 599d70dd..5786f77b 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -334,6 +334,12 @@ private static async Task DownloadLibrary() await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress); AssetDatabase.StartAssetEditing(); ZipFile.ExtractToDirectory(libZip, libraryPath); + string androidDir = Path.Combine(libraryPath, "android"); + if (Directory.Exists(androidDir)) + { + string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(libraryPath)); + Directory.Move(androidDir, androidPluginDir); + } foreach (string librarySubPath in Directory.GetDirectories(libraryPath)) { if (Path.GetFileName(librarySubPath).StartsWith("android")) From 94a6f4b8ac321f9f395f35061247ccec4f6d9d68 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 18:36:48 +0300 Subject: [PATCH 21/26] bump LlamaLib to v1.1.6 --- Runtime/LLMUnitySetup.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 5786f77b..499ee0d6 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -86,7 +86,7 @@ public class LLMUnitySetup /// LLM for Unity version public static string Version = "v2.0.3"; /// LlamaLib version - public static string LlamaLibVersion = "v1.1.5"; + public static string LlamaLibVersion = "v1.1.6"; /// LlamaLib url public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip"; /// LlamaLib path From af3f14bd9fa0ba6ade29419b1bbff8c2d08d859e Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 18:49:18 +0300 Subject: [PATCH 22/26] create plugin dir first --- Runtime/LLMUnitySetup.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index 499ee0d6..17b1d2b9 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -337,8 +337,9 @@ private static async Task DownloadLibrary() string androidDir = Path.Combine(libraryPath, "android"); if (Directory.Exists(androidDir)) { - string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(libraryPath)); - Directory.Move(androidDir, androidPluginDir); + string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android"); + Directory.CreateDirectory(androidPluginDir); + Directory.Move(androidDir, Path.Combine(androidPluginDir, Path.GetFileName(libraryPath))); } foreach (string librarySubPath in Directory.GetDirectories(libraryPath)) { From e495b62d04a2ca2d3660260214d5ec4c79b99ecf Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 18:49:33 +0300 Subject: [PATCH 23/26] remove null warning --- Runtime/ResumingWebClient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/ResumingWebClient.cs b/Runtime/ResumingWebClient.cs index 8a678d5c..b282caa2 100644 --- a/Runtime/ResumingWebClient.cs +++ b/Runtime/ResumingWebClient.cs @@ -35,7 +35,7 @@ public long GetURLFileSize(Uri address) public Task DownloadFileTaskAsyncResume(Uri address, string fileName, bool resume = false, Callback progressCallback = null) { var tcs = new TaskCompletionSource(address); - FileStream? fs = null; + FileStream fs = null; long bytesToSkip = 0; try From 849f4d312277b7d77ef1d75a3765c096065c0122 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 19:35:48 +0300 Subject: [PATCH 24/26] allow only one download access --- Runtime/LLMManager.cs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index 6d2fd0d2..c01eec8d 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -32,6 +32,8 @@ public class LLMManager { public static float downloadProgress = 1; public static List> downloadProgressCallbacks = new List>(); + static Task downloadModelsTask; + static readonly object lockObject = new object(); static long totalSize; static long currFileSize; static long completedSize; @@ -42,7 +44,16 @@ public static void SetDownloadProgress(float progress) foreach (Callback downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress); } - public static async Task DownloadModels() + public static Task DownloadModels() + { + lock (lockObject) + { + if (downloadModelsTask == null) downloadModelsTask = DownloadModelsOnce(); + } + return downloadModelsTask; + } + + public static async Task DownloadModelsOnce() { if (Application.platform == RuntimePlatform.Android) await LLMUnitySetup.AndroidExtractFile(LLMUnitySetup.BuildFilename); if (!File.Exists(LLMUnitySetup.BuildFile)) return; From abe062825000c148ad28606ceaa5138c936e8a0f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 19:53:09 +0300 Subject: [PATCH 25/26] capture and expose download errors --- Runtime/LLM.cs | 14 +++++++++----- Runtime/LLMManager.cs | 15 ++++++++------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 0d6f1d2b..544d9727 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -79,8 +79,10 @@ public class LLM : MonoBehaviour public bool started { get; protected set; } = false; /// Boolean set to true if the server has failed to start. public bool failed { get; protected set; } = false; + /// Boolean set to true if the models were not downloaded successfully. + public static bool downloadFailed { get; protected set; } = false; /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. - public static bool modelsDownloaded { get; protected set; } = false; + public static bool downloadComplete { get; protected set; } = false; /// the LLM model to use. /// Models with .gguf format are allowed. @@ -110,9 +112,10 @@ public async void Awake() { if (!enabled) return; #if !UNITY_EDITOR - await LLMManager.DownloadModels(); + downloadFailed = !await LLMManager.DownloadModels(); #endif - modelsDownloaded = true; + downloadComplete = true; + if (downloadFailed) return; await AndroidSetup(); string arguments = GetLlamaccpArguments(); if (arguments == null) return; @@ -147,10 +150,11 @@ public async Task WaitUntilReady() while (!started) await Task.Yield(); } - public static async Task WaitUntilModelsDownloaded(Callback downloadProgressCallback = null) + public static async Task WaitUntilModelsDownloaded(Callback downloadProgressCallback = null) { if (downloadProgressCallback != null) LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback); - while (!modelsDownloaded) await Task.Yield(); + while (!downloadComplete) await Task.Yield(); + return !downloadFailed; } /// diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index c01eec8d..efbb33d7 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -32,7 +32,7 @@ public class LLMManager { public static float downloadProgress = 1; public static List> downloadProgressCallbacks = new List>(); - static Task downloadModelsTask; + static Task downloadModelsTask; static readonly object lockObject = new object(); static long totalSize; static long currFileSize; @@ -44,7 +44,7 @@ public static void SetDownloadProgress(float progress) foreach (Callback downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress); } - public static Task DownloadModels() + public static Task DownloadModels() { lock (lockObject) { @@ -53,10 +53,10 @@ public static Task DownloadModels() return downloadModelsTask; } - public static async Task DownloadModelsOnce() + public static async Task DownloadModelsOnce() { if (Application.platform == RuntimePlatform.Android) await LLMUnitySetup.AndroidExtractFile(LLMUnitySetup.BuildFilename); - if (!File.Exists(LLMUnitySetup.BuildFile)) return; + if (!File.Exists(LLMUnitySetup.BuildFile)) return true; List downloads = new List(); using (FileStream fs = new FileStream(LLMUnitySetup.BuildFile, FileMode.Open, FileAccess.Read)) @@ -71,7 +71,7 @@ public static async Task DownloadModelsOnce() } } } - if (downloads.Count == 0) return; + if (downloads.Count == 0) return true; try { @@ -100,9 +100,10 @@ public static async Task DownloadModelsOnce() } catch (Exception ex) { - LLMUnitySetup.LogError($"Error downloading the models"); - throw ex; + LLMUnitySetup.LogError($"Error downloading the models: {ex.Message}"); + return false; } + return true; } #if UNITY_EDITOR From 3c656ef2d6d1975da2c62868db01354cb138564f Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Fri, 26 Jul 2024 19:53:36 +0300 Subject: [PATCH 26/26] fix android dll name --- Runtime/LLMLib.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index 95cd5576..c8d49ccd 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -404,7 +404,7 @@ public static string GetArchitecturePath(string arch) } else if (Application.platform == RuntimePlatform.Android) { - return "libundreamai_android_plugin.so"; + return "libundreamai_android.so"; } else {