diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs
index ed50046d..9601579f 100644
--- a/Editor/LLMEditor.cs
+++ b/Editor/LLMEditor.cs
@@ -74,6 +74,7 @@ public void AddModelSettings(SerializedObject llmScriptSO)
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ModelAdvancedAttribute));
+ if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute));
}
ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false);
Space();
diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs
index 87a40938..d3510ac8 100644
--- a/Editor/PropertyEditor.cs
+++ b/Editor/PropertyEditor.cs
@@ -16,17 +16,11 @@ public void AddScript(SerializedObject llmScriptSO)
EditorGUILayout.PropertyField(scriptProp);
}
- public void AddOptionsToggle(SerializedObject llmScriptSO, string propertyName, string name)
+ public bool ToggleButton(string text, bool activated)
{
- SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty(propertyName);
- string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " " + name;
GUIStyle style = new GUIStyle("Button");
- if (advancedOptionsProp.boolValue)
- style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
- if (GUILayout.Button(toggleText, style, GUILayout.Width(buttonWidth)))
- {
- advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
- }
+ if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
+ return GUILayout.Button(text, style, GUILayout.Width(buttonWidth));
}
public void AddSetupSettings(SerializedObject llmScriptSO)
@@ -54,8 +48,12 @@ public void AddChatSettings(SerializedObject llmScriptSO)
public void AddOptionsToggles(SerializedObject llmScriptSO)
{
LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode));
+
EditorGUILayout.BeginHorizontal();
- AddOptionsToggle(llmScriptSO, "advancedOptions", "Advanced Options");
+ SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions");
+ string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options";
+ if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
+ if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
EditorGUILayout.EndHorizontal();
Space();
}
diff --git a/README.md b/README.md
index d8e79bba..bdf6350b 100644
--- a/README.md
+++ b/README.md
@@ -345,6 +345,7 @@ If you have loaded a model locally you need to set its URL through the expanded
- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
+- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)
#### 💻 Setup Settings
@@ -381,6 +382,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Model` the path of the model being used (relative to the Assets/StreamingAssets folder)
- `Chat Template` the chat template being used for the LLM
- `Lora` the path of the LoRA being used (relative to the Assets/StreamingAssets folder)
+ - `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled)
@@ -395,6 +397,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
+- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)
#### 💻 Setup Settings
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs
index 54395e8c..9e50c513 100644
--- a/Runtime/LLM.cs
+++ b/Runtime/LLM.cs
@@ -22,7 +22,7 @@ public LLMException(string message, int errorCode) : base(message)
}
}
- public class DestroyException : Exception { }
+ public class DestroyException : Exception {}
/// \endcond
[DefaultExecutionOrder(-1)]
@@ -74,6 +74,8 @@ public class LLM : MonoBehaviour
/// the paths of the LORA models being used (relative to the Assets/StreamingAssets folder).
/// Models with .gguf format are allowed.
[ModelAdvanced] public string lora = "";
+ /// enable use of flash attention
+ [ModelExtras] public bool flashAttention = false;
/// \cond HIDE
@@ -297,6 +299,7 @@ protected virtual string GetLlamaccpArguments()
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
+ if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn";
return arguments;
}
@@ -383,7 +386,7 @@ private void StartService()
{
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
- while (!llmlib.LLM_Started(LLMObject)) { }
+ while (!llmlib.LLM_Started(LLMObject)) {}
loraWeights = new List();
for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f);
started = true;
@@ -607,7 +610,7 @@ public async Task Slot(string json)
public async Task Completion(string json, Callback streamCallback = null)
{
AssertStarted();
- if (streamCallback == null) streamCallback = (string s) => { };
+ if (streamCallback == null) streamCallback = (string s) => {};
StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
if (!started) return null;
diff --git a/Runtime/LLMBuilder.cs b/Runtime/LLMBuilder.cs
index a3cf0070..542520f9 100644
--- a/Runtime/LLMBuilder.cs
+++ b/Runtime/LLMBuilder.cs
@@ -104,15 +104,19 @@ static void AddActionAddMeta(string target)
public static void HideLibraryPlatforms(string platform)
{
- List platforms = new List(){ "windows", "macos", "linux", "android", "ios" };
+ List platforms = new List(){ "windows", "macos", "linux", "android", "ios", "setup" };
platforms.Remove(platform);
foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath))
{
+ string sourceName = Path.GetFileName(source);
foreach (string platformPrefix in platforms)
{
- if (Path.GetFileName(source).StartsWith(platformPrefix))
+ bool move = sourceName.StartsWith(platformPrefix);
+ move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib);
+ move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib);
+ if (move)
{
- string target = Path.Combine(BuildTempDir, Path.GetFileName(source));
+ string target = Path.Combine(BuildTempDir, sourceName);
MoveAction(source, target);
MoveAction(source + ".meta", target + ".meta");
}
diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs
index edb60601..e3501769 100644
--- a/Runtime/LLMLib.cs
+++ b/Runtime/LLMLib.cs
@@ -326,8 +326,16 @@ public static List PossibleArchitectures(bool gpu = false)
{
if (gpu)
{
- architectures.Add("cuda-cu12.2.0");
- architectures.Add("cuda-cu11.7.1");
+ if (LLMUnitySetup.FullLlamaLib)
+ {
+ architectures.Add("cuda-cu12.2.0-full");
+ architectures.Add("cuda-cu11.7.1-full");
+ }
+ else
+ {
+ architectures.Add("cuda-cu12.2.0");
+ architectures.Add("cuda-cu11.7.1");
+ }
architectures.Add("hip");
architectures.Add("vulkan");
}
diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs
index 285eee78..428c2da5 100644
--- a/Runtime/LLMUnitySetup.cs
+++ b/Runtime/LLMUnitySetup.cs
@@ -47,6 +47,7 @@ public class ModelAttribute : PropertyAttribute {}
public class ModelDownloadAttribute : ModelAttribute {}
public class ModelDownloadAdvancedAttribute : ModelAdvancedAttribute {}
public class ModelAdvancedAttribute : PropertyAttribute {}
+ public class ModelExtrasAttribute : PropertyAttribute {}
public class ChatAttribute : PropertyAttribute {}
public class ChatAdvancedAttribute : PropertyAttribute {}
public class LLMUnityAttribute : PropertyAttribute {}
@@ -87,8 +88,12 @@ public class LLMUnitySetup
public static string Version = "v2.1.2";
/// LlamaLib version
public static string LlamaLibVersion = "v1.1.8";
+ /// LlamaLib release url
+ public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}";
/// LlamaLib url
- public static string LlamaLibURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}/undreamai-{LlamaLibVersion}-llamacpp.zip";
+ public static string LlamaLibURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp.zip";
+ /// LlamaLib extension url
+ public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/undreamai-{LlamaLibVersion}-llamacpp-full.zip";
/// LlamaLib path
public static string libraryPath = GetAssetPath(Path.GetFileName(LlamaLibURL).Replace(".zip", ""));
/// LLMnity store path
@@ -109,25 +114,15 @@ public class LLMUnitySetup
("Qwen 2 0.5B (tiny, useful for mobile)", "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true", null),
};
- /// Add callback function to call for error logs
- public static void AddErrorCallBack(Callback callback)
- {
- errorCallbacks.Add(callback);
- }
-
- /// Remove callback function added for error logs
- public static void RemoveErrorCallBack(Callback callback)
- {
- errorCallbacks.Remove(callback);
- }
-
- /// Remove all callback function added for error logs
- public static void ClearErrorCallBacks()
- {
- errorCallbacks.Clear();
- }
-
/// \cond HIDE
+ [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All;
+ static string DebugModeKey = "DebugMode";
+ public static bool FullLlamaLib = false;
+ static string FullLlamaLibKey = "FullLlamaLib";
+ static List> errorCallbacks = new List>();
+ static readonly object lockObject = new object();
+ static Dictionary androidExtractTasks = new Dictionary();
+
public enum DebugModeType
{
All,
@@ -135,10 +130,6 @@ public enum DebugModeType
Error,
None
}
- [LLMUnity] public static DebugModeType DebugMode = DebugModeType.All;
- static List> errorCallbacks = new List>();
- static readonly object lockObject = new object();
- static Dictionary androidExtractTasks = new Dictionary();
public static void Log(string message)
{
@@ -159,10 +150,10 @@ public static void LogError(string message)
foreach (Callback errorCallback in errorCallbacks) errorCallback(message);
}
- static string DebugModeKey = "DebugMode";
- static void LoadDebugMode()
+ static void LoadPlayerPrefs()
{
DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All);
+ FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1;
}
public static void SetDebugMode(DebugModeType newDebugMode)
@@ -173,6 +164,18 @@ public static void SetDebugMode(DebugModeType newDebugMode)
PlayerPrefs.Save();
}
+#if UNITY_EDITOR
+ public static void SetFullLlamaLib(bool value)
+ {
+ if (FullLlamaLib == value) return;
+ FullLlamaLib = value;
+ PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0);
+ PlayerPrefs.Save();
+ _ = DownloadLibrary();
+ }
+
+#endif
+
public static string GetAssetPath(string relPath = "")
{
// Path to store llm server binaries and models
@@ -184,15 +187,15 @@ public static string GetAssetPath(string relPath = "")
[InitializeOnLoadMethod]
static async Task InitializeOnLoad()
{
+ LoadPlayerPrefs();
await DownloadLibrary();
- LoadDebugMode();
}
#else
[RuntimeInitializeOnLoadMethod(RuntimeInitializeLoadType.BeforeSceneLoad)]
void InitializeOnLoad()
{
- LoadDebugMode();
+ LoadPlayerPrefs();
}
#endif
@@ -307,40 +310,89 @@ public static bool IsSubPath(string childPath, string parentPath)
[HideInInspector] public static float libraryProgress = 1;
- private static async Task DownloadLibrary()
+ static void CreateEmptyFile(string path)
+ {
+ File.Create(path).Dispose();
+ }
+
+ static void ExtractInsideDirectory(string zipPath, string extractPath, bool overwrite = true)
{
- if (libraryProgress < 1) return;
- libraryProgress = 0;
- string libZip = Path.Combine(Application.temporaryCachePath, Path.GetFileName(LlamaLibURL));
- if (!Directory.Exists(libraryPath))
+ using (ZipArchive archive = ZipFile.OpenRead(zipPath))
{
- await DownloadFile(LlamaLibURL, libZip, true, null, SetLibraryProgress);
+ foreach (ZipArchiveEntry entry in archive.Entries)
+ {
+ if (string.IsNullOrEmpty(entry.Name)) continue;
+ string destinationPath = Path.Combine(extractPath, entry.FullName);
+ Directory.CreateDirectory(Path.GetDirectoryName(destinationPath));
+ entry.ExtractToFile(destinationPath, overwrite);
+ }
+ }
+ }
+
+ static async Task DownloadAndExtractInsideDirectory(string url, string path, string setupDir)
+ {
+ string urlName = Path.GetFileName(url);
+ string setupFile = Path.Combine(setupDir, urlName + ".complete");
+ if (File.Exists(setupFile)) return;
+
+ string zipPath = Path.Combine(Application.temporaryCachePath, urlName);
+ await DownloadFile(url, zipPath, true, null, SetLibraryProgress);
+
+ AssetDatabase.StartAssetEditing();
+ ExtractInsideDirectory(zipPath, path);
+ CreateEmptyFile(setupFile);
+ AssetDatabase.StopAssetEditing();
+
+ File.Delete(zipPath);
+ }
+
+ static async Task DownloadLibrary()
+ {
+ void DeleteFileAndMeta(string path)
+ {
+ if (File.Exists(path + ".meta")) File.Delete(path + ".meta");
+ if (File.Exists(path)) File.Delete(path);
+ }
+
+ try
+ {
+ string setupDir = Path.Combine(libraryPath, "setup");
+ Directory.CreateDirectory(setupDir);
+
+ string lockFile = Path.Combine(setupDir, "LLMUnitySetup.lock");
+ if (File.Exists(lockFile)) return;
+ CreateEmptyFile(lockFile);
+
+ libraryProgress = 0;
+ await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir);
+
AssetDatabase.StartAssetEditing();
- ZipFile.ExtractToDirectory(libZip, libraryPath);
string androidDir = Path.Combine(libraryPath, "android");
if (Directory.Exists(androidDir))
{
- string androidPluginDir = Path.Combine(Application.dataPath, "Plugins", "Android");
- Directory.CreateDirectory(androidPluginDir);
- Directory.Move(androidDir, Path.Combine(androidPluginDir, Path.GetFileName(libraryPath)));
- }
- foreach (string librarySubPath in Directory.GetDirectories(libraryPath))
- {
- if (Path.GetFileName(librarySubPath).StartsWith("android"))
- {
- string pluginPath = Path.Combine(Application.dataPath, "Plugins", "Android", Path.GetFileName(librarySubPath));
- Directory.Move(librarySubPath, pluginPath);
- }
+ string androidPluginsDir = Path.Combine(Application.dataPath, "Plugins", "Android");
+ Directory.CreateDirectory(androidPluginsDir);
+ string pluginDir = Path.Combine(androidPluginsDir, Path.GetFileName(libraryPath));
+ if (Directory.Exists(pluginDir)) Directory.Delete(pluginDir, true);
+ Directory.Move(androidDir, pluginDir);
+ if (File.Exists(androidDir + ".meta")) File.Delete(androidDir + ".meta");
}
AssetDatabase.StopAssetEditing();
- File.Delete(libZip);
+
+ if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir);
+
+ libraryProgress = 1;
+ DeleteFileAndMeta(lockFile);
+ }
+ catch (Exception e)
+ {
+ LogError(e.Message);
}
- libraryProgress = 1;
}
private static void SetLibraryProgress(float progress)
{
- libraryProgress = progress;
+ libraryProgress = Math.Min(0.99f, progress);
}
public static string AddAsset(string assetPath)
@@ -364,6 +416,25 @@ public static string AddAsset(string assetPath)
#endif
/// \endcond
+
+ /// Add callback function to call for error logs
+ public static void AddErrorCallBack(Callback callback)
+ {
+ errorCallbacks.Add(callback);
+ }
+
+ /// Remove callback function added for error logs
+ public static void RemoveErrorCallBack(Callback callback)
+ {
+ errorCallbacks.Remove(callback);
+ }
+
+ /// Remove all callback function added for error logs
+ public static void ClearErrorCallBacks()
+ {
+ errorCallbacks.Clear();
+ }
+
public static int GetMaxFreqKHz(int cpuId)
{
string[] paths = new string[]