Skip to content

Commit f4c3a7f

Browse files
committed
download or extract model and lora
1 parent 061388c commit f4c3a7f

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

Runtime/LLM.cs

+43-20
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,20 @@ public class LLM : MonoBehaviour
6868
[LLMAdvanced] public bool asynchronousStartup = true;
6969
/// <summary> select to not destroy the LLM GameObject when loading a new Scene. </summary>
7070
[LLMAdvanced] public bool dontDestroyOnLoad = true;
71+
/// <summary> toggle to enable model download on build </summary>
72+
[Model] public bool downloadOnBuild = false;
7173
/// <summary> the path of the model being used (relative to the Assets/StreamingAssets folder).
7274
/// Models with .gguf format are allowed.</summary>
7375
[Model] public string model = "";
74-
/// <summary> toggle to enable model download on build </summary>
75-
[Model] public bool downloadOnBuild = false;
7676
/// <summary> the URL of the model to use.
7777
/// Models with .gguf format are allowed.</summary>
7878
[ModelDownload] public string modelURL = "";
7979
/// <summary> the path of the LORA model being used (relative to the Assets/StreamingAssets folder).
8080
/// Models with .bin format are allowed.</summary>
8181
[ModelAdvanced] public string lora = "";
82+
/// <summary> the URL of the LORA to use.
83+
/// Models with .bin format are allowed.</summary>
84+
[ModelDownloadAdvanced] public string loraURL = "";
8285
/// <summary> Size of the prompt context (0 = context size of the model).
8386
/// This is the number of tokens the model can take as input when generating responses. </summary>
8487
[ModelAdvanced] public int contextSize = 0;
@@ -87,7 +90,7 @@ public class LLM : MonoBehaviour
8790
/// <summary> a base prompt to use as a base for all LLMCharacter objects </summary>
8891
[TextArea(5, 10), ChatAdvanced] public string basePrompt = "";
8992
/// <summary> Boolean set to true if the server has started and is ready to receive requests, false otherwise. </summary>
90-
public bool modelDownloaded { get; protected set; } = false;
93+
public bool modelsDownloaded { get; protected set; } = false;
9194
/// <summary> Boolean set to true if the server has started and is ready to receive requests, false otherwise. </summary>
9295
public bool started { get; protected set; } = false;
9396
/// <summary> Boolean set to true if the server has failed to start. </summary>
@@ -96,6 +99,7 @@ public class LLM : MonoBehaviour
9699
/// \cond HIDE
97100
public int SelectedModel = 0;
98101
[HideInInspector] public float modelProgress = 1;
102+
[HideInInspector] public float loraProgress = 1;
99103
[HideInInspector] public float modelCopyProgress = 1;
100104
[HideInInspector] public bool modelHide = true;
101105

@@ -107,12 +111,19 @@ public class LLM : MonoBehaviour
107111
StreamWrapper logStreamWrapper = null;
108112
Thread llmThread = null;
109113
List<StreamWrapper> streamWrappers = new List<StreamWrapper>();
110-
List<Callback<float>> progressCallbacks = new List<Callback<float>>();
114+
List<Callback<float>> modelProgressCallbacks = new List<Callback<float>>();
115+
List<Callback<float>> loraProgressCallbacks = new List<Callback<float>>();
111116

112117
public void SetModelProgress(float progress)
113118
{
114119
modelProgress = progress;
115-
foreach (Callback<float> progressCallback in progressCallbacks) progressCallback?.Invoke(progress);
120+
foreach (Callback<float> modelProgressCallback in modelProgressCallbacks) modelProgressCallback?.Invoke(progress);
121+
}
122+
123+
public void SetLoraProgress(float progress)
124+
{
125+
loraProgress = progress;
126+
foreach (Callback<float> loraProgressCallback in loraProgressCallbacks) loraProgressCallback?.Invoke(progress);
116127
}
117128

118129
/// \endcond
@@ -152,28 +163,39 @@ public async Task DownloadDefaultModel(int optionIndex)
152163
await DownloadModel(modelUrl, modelName);
153164
}
154165

155-
public async Task DownloadModel(string modelUrl, string modelName, Callback<float> progressCallback = null, bool overwrite = false)
166+
public async Task DownloadModel(string modelUrl, string modelName, bool overwrite = false)
156167
{
157168
modelProgress = 0;
158169
string modelPath = LLMUnitySetup.GetAssetPath(modelName);
159-
Callback<float> callback = (floatArg) =>
160-
{
161-
progressCallback?.Invoke(floatArg);
162-
SetModelProgress(floatArg);
163-
};
164-
await LLMUnitySetup.DownloadFile(modelUrl, modelPath, overwrite, SetModel, callback);
170+
await LLMUnitySetup.DownloadFile(modelUrl, modelPath, overwrite, SetModel, SetModelProgress);
171+
}
172+
173+
public async Task DownloadLora(string loraUrl, string loraName, bool overwrite = false)
174+
{
175+
loraProgress = 0;
176+
string loraPath = LLMUnitySetup.GetAssetPath(loraName);
177+
await LLMUnitySetup.DownloadFile(loraUrl, loraPath, overwrite, SetLora, SetLoraProgress);
178+
}
179+
180+
public async Task DownloadModels()
181+
{
182+
if (modelURL != "") await DownloadModel(modelURL, model);
183+
if (loraURL != "") await DownloadLora(loraURL, lora);
165184
}
166185

167-
public async Task DownloadModel()
186+
public async Task AndroidExtractModels()
168187
{
169-
await DownloadModel(modelURL, model);
188+
if (!downloadOnBuild || modelURL == "") await LLMUnitySetup.AndroidExtractFile(model);
189+
if (!downloadOnBuild || loraURL == "") await LLMUnitySetup.AndroidExtractFile(lora);
170190
}
171191

172-
public async Task WaitUntilModelDownloaded(Callback<float> progressCallback = null)
192+
public async Task WaitUntilModelsDownloaded(Callback<float> modelProgressCallback = null, Callback<float> loraProgressCallback = null)
173193
{
174-
if (progressCallback != null) progressCallbacks.Add(progressCallback);
175-
while (!modelDownloaded) await Task.Yield();
176-
if (progressCallback != null) progressCallbacks.Remove(progressCallback);
194+
if (modelProgressCallback != null) modelProgressCallbacks.Add(modelProgressCallback);
195+
if (loraProgressCallback != null) loraProgressCallbacks.Add(loraProgressCallback);
196+
while (!modelsDownloaded) await Task.Yield();
197+
if (modelProgressCallback != null) modelProgressCallbacks.Remove(modelProgressCallback);
198+
if (loraProgressCallback != null) loraProgressCallbacks.Remove(loraProgressCallback);
177199
}
178200

179201
public async Task WaitUntilReady()
@@ -274,8 +296,9 @@ protected virtual string GetLlamaccpArguments()
274296
public async void Awake()
275297
{
276298
if (!enabled) return;
277-
if (downloadOnBuild) await DownloadModel();
278-
modelDownloaded = true;
299+
if (downloadOnBuild) await DownloadModels();
300+
modelsDownloaded = true;
301+
if (Application.platform == RuntimePlatform.Android) await AndroidExtractModels();
279302
string arguments = GetLlamaccpArguments();
280303
if (arguments == null) return;
281304
if (asynchronousStartup) await Task.Run(() => StartLLMServer(arguments));

0 commit comments

Comments
 (0)