Skip to content

Setup allowing to use extra features: flash attention and IQ quants #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 21, 2024
1 change: 1 addition & 0 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public void AddModelSettings(SerializedObject llmScriptSO)
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ModelAdvancedAttribute));
if (LLMUnitySetup.FullLlamaLib) attributeClasses.Add(typeof(ModelExtrasAttribute));
}
ShowPropertiesOfClass("", llmScriptSO, attributeClasses, false);
Space();
Expand Down
18 changes: 8 additions & 10 deletions Editor/PropertyEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@ public void AddScript(SerializedObject llmScriptSO)
EditorGUILayout.PropertyField(scriptProp);
}

public void AddOptionsToggle(SerializedObject llmScriptSO, string propertyName, string name)
public bool ToggleButton(string text, bool activated)
{
SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty(propertyName);
string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " " + name;
GUIStyle style = new GUIStyle("Button");
if (advancedOptionsProp.boolValue)
style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
if (GUILayout.Button(toggleText, style, GUILayout.Width(buttonWidth)))
{
advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
}
if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
return GUILayout.Button(text, style, GUILayout.Width(buttonWidth));
}

public void AddSetupSettings(SerializedObject llmScriptSO)
Expand Down Expand Up @@ -54,8 +48,12 @@ public void AddChatSettings(SerializedObject llmScriptSO)
public void AddOptionsToggles(SerializedObject llmScriptSO)
{
LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode));

EditorGUILayout.BeginHorizontal();
AddOptionsToggle(llmScriptSO, "advancedOptions", "Advanced Options");
SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions");
string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options";
if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
EditorGUILayout.EndHorizontal();
Space();
}
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ If you have loaded a model locally you need to set its URL through the expanded

- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)

#### 💻 Setup Settings

Expand Down Expand Up @@ -381,6 +382,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Model` the path of the model being used (relative to the Assets/StreamingAssets folder)
- `Chat Template` the chat template being used for the LLM
- `Lora` the path of the LoRA being used (relative to the Assets/StreamingAssets folder)
- `Flash Attention` click to use flash attention in the model (if `Use extras` is enabled)

</details>

Expand All @@ -395,6 +397,7 @@ If the user's GPU is not supported, the LLM will fall back to the CPU

- `Show/Hide Advanced Options` Toggle to show/hide advanced options from below
- `Log Level` select how verbose the log messages are
- `Use extras` select to install and allow the use of extra features (flash attention and IQ quants)

#### 💻 Setup Settings
<div>
Expand Down
9 changes: 6 additions & 3 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public LLMException(string message, int errorCode) : base(message)
}
}

public class DestroyException : Exception { }
public class DestroyException : Exception {}
/// \endcond

[DefaultExecutionOrder(-1)]
Expand Down Expand Up @@ -74,6 +74,8 @@ public class LLM : MonoBehaviour
/// <summary> the paths of the LORA models being used (relative to the Assets/StreamingAssets folder).
/// Models with .gguf format are allowed.</summary>
[ModelAdvanced] public string lora = "";
/// <summary> enable use of flash attention </summary>
[ModelExtras] public bool flashAttention = false;

/// \cond HIDE

Expand Down Expand Up @@ -297,6 +299,7 @@ protected virtual string GetLlamaccpArguments()
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn";
return arguments;
}

Expand Down Expand Up @@ -383,7 +386,7 @@ private void StartService()
{
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
while (!llmlib.LLM_Started(LLMObject)) { }
while (!llmlib.LLM_Started(LLMObject)) {}
loraWeights = new List<float>();
for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f);
started = true;
Expand Down Expand Up @@ -607,7 +610,7 @@ public async Task<string> Slot(string json)
public async Task<string> Completion(string json, Callback<string> streamCallback = null)
{
AssertStarted();
if (streamCallback == null) streamCallback = (string s) => { };
if (streamCallback == null) streamCallback = (string s) => {};
StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
if (!started) return null;
Expand Down
10 changes: 7 additions & 3 deletions Runtime/LLMBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ static void AddActionAddMeta(string target)

public static void HideLibraryPlatforms(string platform)
{
List<string> platforms = new List<string>(){ "windows", "macos", "linux", "android", "ios" };
List<string> platforms = new List<string>(){ "windows", "macos", "linux", "android", "ios", "setup" };
platforms.Remove(platform);
foreach (string source in Directory.GetDirectories(LLMUnitySetup.libraryPath))
{
string sourceName = Path.GetFileName(source);
foreach (string platformPrefix in platforms)
{
if (Path.GetFileName(source).StartsWith(platformPrefix))
bool move = sourceName.StartsWith(platformPrefix);
move = move || (sourceName.Contains("cuda") && !sourceName.Contains("full") && LLMUnitySetup.FullLlamaLib);
move = move || (sourceName.Contains("cuda") && sourceName.Contains("full") && !LLMUnitySetup.FullLlamaLib);
if (move)
{
string target = Path.Combine(BuildTempDir, Path.GetFileName(source));
string target = Path.Combine(BuildTempDir, sourceName);
MoveAction(source, target);
MoveAction(source + ".meta", target + ".meta");
}
Expand Down
12 changes: 10 additions & 2 deletions Runtime/LLMLib.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,16 @@ public static List<string> PossibleArchitectures(bool gpu = false)
{
if (gpu)
{
architectures.Add("cuda-cu12.2.0");
architectures.Add("cuda-cu11.7.1");
if (LLMUnitySetup.FullLlamaLib)
{
architectures.Add("cuda-cu12.2.0-full");
architectures.Add("cuda-cu11.7.1-full");
}
else
{
architectures.Add("cuda-cu12.2.0");
architectures.Add("cuda-cu11.7.1");
}
architectures.Add("hip");
architectures.Add("vulkan");
}
Expand Down
Loading
Loading