Skip to content

Fix support for extras (flash attention, iQ quants) #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public override void AddModelSettings(SerializedObject llmScriptSO)
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ModelAdvancedAttribute));
if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute));
}
ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false);
Space();
Expand Down Expand Up @@ -444,12 +445,18 @@ private void CopyToClipboard(string text)
te.Copy();
}

public void AddExtrasToggle()
{
if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
}

public override void AddOptionsToggles(SerializedObject llmScriptSO)
{
AddDebugModeToggle();

EditorGUILayout.BeginHorizontal();
AddAdvancedOptionsToggle(llmScriptSO);
AddExtrasToggle();
EditorGUILayout.EndHorizontal();
Space();
}
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ Save the scene, run and enjoy!
### LLM Settings

- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages arequants)
- `Log Level` select how verbose the log messages are
- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)

#### 💻 Setup Settings

Expand Down Expand Up @@ -550,13 +551,15 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Chat Template` the chat template being used for the LLM
- `Lora` the path of the LoRAs being used (relative to the Assets/StreamingAssets folder)
- `Lora Weights` the weights of the LoRAs being used
- `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled)

</details>

### LLMCharacter Settings

- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)

#### 💻 Setup Settings
<div>
Expand Down
3 changes: 3 additions & 0 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public class LLM : MonoBehaviour
[ModelAdvanced] public string lora = "";
/// <summary> the weights of the LORA models being used.</summary>
[ModelAdvanced] public string loraWeights = "";
/// <summary> enable use of flash attention </summary>
[ModelExtras] public bool flashAttention = false;

/// <summary> API key to use for the server (optional) </summary>
public string APIKey;
Expand Down Expand Up @@ -430,6 +432,7 @@ protected virtual string GetLlamaccpArguments()
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn";
if (remote)
{
arguments += $" --port {port} --host 0.0.0.0";
Expand Down
2 changes: 2 additions & 0 deletions Runtime/LLMBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ public static void BuildLibraryPlatforms(string platform)
foreach (string platformPrefix in platforms)
{
bool move = sourceName.StartsWith(platformPrefix);
move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib);
move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib);
if (move)
{
string target = Path.Combine(BuildTempDir, sourceName);
Expand Down
12 changes: 10 additions & 2 deletions Runtime/LLMLib.cs
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,16 @@ public static List<string> PossibleArchitectures(bool gpu = false)
{
if (gpu)
{
architectures.Add("cuda-cu12.2.0");
architectures.Add("cuda-cu11.7.1");
if (LLMUnitySetup.FullLlamaLib)
{
architectures.Add("cuda-cu12.2.0-full");
architectures.Add("cuda-cu11.7.1-full");
}
else
{
architectures.Add("cuda-cu12.2.0");
architectures.Add("cuda-cu11.7.1");
}
architectures.Add("hip");
architectures.Add("vulkan");
}
Expand Down
23 changes: 22 additions & 1 deletion Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public class LocalRemoteAttribute : PropertyAttribute {}
public class RemoteAttribute : PropertyAttribute {}
public class LocalAttribute : PropertyAttribute {}
public class ModelAttribute : PropertyAttribute {}
public class ModelExtrasAttribute : PropertyAttribute {}
public class ChatAttribute : PropertyAttribute {}
public class LLMUnityAttribute : PropertyAttribute {}

Expand Down Expand Up @@ -102,7 +103,7 @@ public class LLMUnitySetup
/// <summary> LLM for Unity version </summary>
public static string Version = "v2.4.1";
/// <summary> LlamaLib version </summary>
public static string LlamaLibVersion = "v1.2.0-dev";
public static string LlamaLibVersion = "v1.2.1";
/// <summary> LlamaLib release url </summary>
public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}";
/// <summary> LlamaLib name </summary>
Expand All @@ -111,6 +112,8 @@ public class LLMUnitySetup
public static string libraryPath = GetAssetPath(libraryName);
/// <summary> LlamaLib url </summary>
public static string LlamaLibURL = $"{LlamaLibReleaseURL}/{libraryName}.zip";
/// <summary> LlamaLib extension url </summary>
public static string LlamaLibExtensionURL = $"{LlamaLibReleaseURL}/{libraryName}-full.zip";
/// <summary> LLMnity store path </summary>
public static string LLMUnityStore = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "LLMUnity");
/// <summary> Model download path </summary>
Expand Down Expand Up @@ -150,6 +153,8 @@ public class LLMUnitySetup
/// \cond HIDE
[LLMUnity] public static DebugModeType DebugMode = DebugModeType.All;
static string DebugModeKey = "DebugMode";
public static bool FullLlamaLib = false;
static string FullLlamaLibKey = "FullLlamaLib";
static List<Callback<string>> errorCallbacks = new List<Callback<string>>();
static readonly object lockObject = new object();
static Dictionary<string, Task> androidExtractTasks = new Dictionary<string, Task>();
Expand Down Expand Up @@ -184,6 +189,7 @@ public static void LogError(string message)
static void LoadPlayerPrefs()
{
DebugMode = (DebugModeType)PlayerPrefs.GetInt(DebugModeKey, (int)DebugModeType.All);
FullLlamaLib = PlayerPrefs.GetInt(FullLlamaLibKey, 0) == 1;
}

public static void SetDebugMode(DebugModeType newDebugMode)
Expand All @@ -194,6 +200,18 @@ public static void SetDebugMode(DebugModeType newDebugMode)
PlayerPrefs.Save();
}

#if UNITY_EDITOR
public static void SetFullLlamaLib(bool value)
{
if (FullLlamaLib == value) return;
FullLlamaLib = value;
PlayerPrefs.SetInt(FullLlamaLibKey, value ? 1 : 0);
PlayerPrefs.Save();
_ = DownloadLibrary();
}

#endif

public static string GetLibraryName(string version)
{
return $"undreamai-{version}-llamacpp";
Expand Down Expand Up @@ -436,6 +454,9 @@ static async Task DownloadLibrary()

// setup LlamaLib in StreamingAssets
await DownloadAndExtractInsideDirectory(LlamaLibURL, libraryPath, setupDir);

// setup LlamaLib extras in StreamingAssets
if (FullLlamaLib) await DownloadAndExtractInsideDirectory(LlamaLibExtensionURL, libraryPath, setupDir);
}
catch (Exception e)
{
Expand Down
Loading