Skip to content

Commit d5bd618

Browse files
committed
download on start functionality
1 parent ee9d050 commit d5bd618

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

Runtime/LLM.cs

+39-7
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ public class LLM : MonoBehaviour
7171
/// <summary> the path of the model being used (relative to the Assets/StreamingAssets folder).
7272
/// Models with .gguf format are allowed.</summary>
7373
[Model] public string model = "";
74+
/// <summary> toggle to enable model download on build </summary>
75+
[Model] public bool downloadOnBuild = false;
76+
/// <summary> the URL of the model to use.
77+
/// Models with .gguf format are allowed.</summary>
78+
[ModelDownload] public string modelURL = "";
7479
/// <summary> the path of the LORA model being used (relative to the Assets/StreamingAssets folder).
7580
/// Models with .bin format are allowed.</summary>
7681
[ModelAdvanced] public string lora = "";
@@ -81,7 +86,8 @@ public class LLM : MonoBehaviour
8186
[ModelAdvanced] public int batchSize = 512;
8287
/// <summary> a base prompt to use as a base for all LLMCharacter objects </summary>
8388
[TextArea(5, 10), ChatAdvanced] public string basePrompt = "";
84-
89+
/// <summary> Boolean set to true if the server has started and is ready to receive requests, false otherwise. </summary>
90+
public bool modelDownloaded { get; protected set; } = false;
8591
/// <summary> Boolean set to true if the server has started and is ready to receive requests, false otherwise. </summary>
8692
public bool started { get; protected set; } = false;
8793
/// <summary> Boolean set to true if the server has failed to start. </summary>
@@ -101,10 +107,12 @@ public class LLM : MonoBehaviour
101107
StreamWrapper logStreamWrapper = null;
102108
Thread llmThread = null;
103109
List<StreamWrapper> streamWrappers = new List<StreamWrapper>();
110+
List<Callback<float>> progressCallbacks = new List<Callback<float>>();
104111

105112
public void SetModelProgress(float progress)
106113
{
107114
modelProgress = progress;
115+
foreach (Callback<float> progressCallback in progressCallbacks) progressCallback?.Invoke(progress);
108116
}
109117

110118
/// \endcond
@@ -122,22 +130,32 @@ async Task<string> CopyAsset(string path)
122130
return path;
123131
}
124132

133+
public void ResetSelectedModel()
134+
{
135+
SelectedModel = 0;
136+
modelURL = "";
137+
model = "";
138+
}
139+
125140
public async Task DownloadDefaultModel(int optionIndex)
126141
{
127142
// download default model and disable model editor properties until the model is set
143+
if (optionIndex == 0)
144+
{
145+
ResetSelectedModel();
146+
return;
147+
}
128148
SelectedModel = optionIndex;
129149
string modelUrl = LLMUnitySetup.modelOptions[optionIndex].Item2;
130-
if (modelUrl == null) return;
150+
modelURL = modelUrl;
131151
string modelName = Path.GetFileName(modelUrl).Split("?")[0];
132152
await DownloadModel(modelUrl, modelName);
133153
}
134154

135-
public async Task DownloadModel(string modelUrl, string modelName = null, Callback<float> progressCallback = null, bool overwrite = false)
155+
public async Task DownloadModel(string modelUrl, string modelName, Callback<float> progressCallback = null, bool overwrite = false)
136156
{
137157
modelProgress = 0;
138-
if (modelName == null) modelName = model;
139158
string modelPath = LLMUnitySetup.GetAssetPath(modelName);
140-
141159
Callback<float> callback = (floatArg) =>
142160
{
143161
progressCallback?.Invoke(floatArg);
@@ -146,9 +164,21 @@ public async Task DownloadModel(string modelUrl, string modelName = null, Callba
146164
await LLMUnitySetup.DownloadFile(modelUrl, modelPath, overwrite, SetModel, callback);
147165
}
148166

149-
public async Task DownloadModel(string modelUrl, Callback<float> progressCallback = null, bool overwrite = false)
167+
public async Task DownloadModel()
168+
{
169+
await DownloadModel(modelURL, model);
170+
}
171+
172+
public async Task WaitUntilModelDownloaded(Callback<float> progressCallback = null)
173+
{
174+
if (progressCallback != null) progressCallbacks.Add(progressCallback);
175+
while (!modelDownloaded) await Task.Yield();
176+
if (progressCallback != null) progressCallbacks.Remove(progressCallback);
177+
}
178+
179+
public async Task WaitUntilReady()
150180
{
151-
await DownloadModel(modelUrl, null, progressCallback, overwrite);
181+
while (!started) await Task.Yield();
152182
}
153183

154184
/// <summary>
@@ -244,6 +274,8 @@ protected virtual string GetLlamaccpArguments()
244274
public async void Awake()
245275
{
246276
if (!enabled) return;
277+
if (downloadOnBuild) await DownloadModel();
278+
modelDownloaded = true;
247279
string arguments = GetLlamaccpArguments();
248280
if (arguments == null) return;
249281
if (asynchronousStartup) await Task.Run(() => StartLLMServer(arguments));

Runtime/LLMUnitySetup.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class LocalRemoteAttribute : PropertyAttribute {}
4343
public class RemoteAttribute : PropertyAttribute {}
4444
public class LocalAttribute : PropertyAttribute {}
4545
public class ModelAttribute : PropertyAttribute {}
46+
public class ModelDownloadAttribute : PropertyAttribute {}
4647
public class ModelAdvancedAttribute : PropertyAttribute {}
4748
public class ChatAttribute : PropertyAttribute {}
4849
public class ChatAdvancedAttribute : PropertyAttribute {}
@@ -149,7 +150,8 @@ public static void SetDebugMode(DebugModeType newDebugMode)
149150
public static string GetAssetPath(string relPath = "")
150151
{
151152
// Path to store llm server binaries and models
152-
return Path.Combine(Application.streamingAssetsPath, relPath).Replace('\\', '/');
153+
string assetsDir = Application.platform == RuntimePlatform.Android ? Application.persistentDataPath : Application.streamingAssetsPath;
154+
return Path.Combine(assetsDir, relPath).Replace('\\', '/');
153155
}
154156

155157
#if UNITY_EDITOR

0 commit comments

Comments
 (0)