@@ -71,6 +71,11 @@ public class LLM : MonoBehaviour
71
71
/// <summary> the path of the model being used (relative to the Assets/StreamingAssets folder).
72
72
/// Models with .gguf format are allowed.</summary>
73
73
[ Model ] public string model = "" ;
74
+ /// <summary> toggle to enable model download on build </summary>
75
+ [ Model ] public bool downloadOnBuild = false ;
76
+ /// <summary> the URL of the model to use.
77
+ /// Models with .gguf format are allowed.</summary>
78
+ [ ModelDownload ] public string modelURL = "" ;
74
79
/// <summary> the path of the LORA model being used (relative to the Assets/StreamingAssets folder).
75
80
/// Models with .bin format are allowed.</summary>
76
81
[ ModelAdvanced ] public string lora = "" ;
@@ -81,7 +86,8 @@ public class LLM : MonoBehaviour
81
86
[ ModelAdvanced ] public int batchSize = 512 ;
82
87
/// <summary> a base prompt to use as a base for all LLMCharacter objects </summary>
83
88
[ TextArea ( 5 , 10 ) , ChatAdvanced ] public string basePrompt = "" ;
84
-
89
+ /// <summary> Boolean set to true if the server has started and is ready to receive requests, false otherwise. </summary>
90
+ public bool modelDownloaded { get ; protected set ; } = false ;
85
91
/// <summary> Boolean set to true if the server has started and is ready to receive requests, false otherwise. </summary>
86
92
public bool started { get ; protected set ; } = false ;
87
93
/// <summary> Boolean set to true if the server has failed to start. </summary>
@@ -101,10 +107,12 @@ public class LLM : MonoBehaviour
101
107
StreamWrapper logStreamWrapper = null ;
102
108
Thread llmThread = null ;
103
109
List < StreamWrapper > streamWrappers = new List < StreamWrapper > ( ) ;
110
+ List < Callback < float > > progressCallbacks = new List < Callback < float > > ( ) ;
104
111
105
112
public void SetModelProgress ( float progress )
106
113
{
107
114
modelProgress = progress ;
115
+ foreach ( Callback < float > progressCallback in progressCallbacks ) progressCallback ? . Invoke ( progress ) ;
108
116
}
109
117
110
118
/// \endcond
@@ -122,22 +130,32 @@ async Task<string> CopyAsset(string path)
122
130
return path ;
123
131
}
124
132
133
+ public void ResetSelectedModel ( )
134
+ {
135
+ SelectedModel = 0 ;
136
+ modelURL = "" ;
137
+ model = "" ;
138
+ }
139
+
125
140
public async Task DownloadDefaultModel ( int optionIndex )
126
141
{
127
142
// download default model and disable model editor properties until the model is set
143
+ if ( optionIndex == 0 )
144
+ {
145
+ ResetSelectedModel ( ) ;
146
+ return ;
147
+ }
128
148
SelectedModel = optionIndex ;
129
149
string modelUrl = LLMUnitySetup . modelOptions [ optionIndex ] . Item2 ;
130
- if ( modelUrl == null ) return ;
150
+ modelURL = modelUrl ;
131
151
string modelName = Path . GetFileName ( modelUrl ) . Split ( "?" ) [ 0 ] ;
132
152
await DownloadModel ( modelUrl , modelName ) ;
133
153
}
134
154
135
- public async Task DownloadModel ( string modelUrl , string modelName = null , Callback < float > progressCallback = null , bool overwrite = false )
155
+ public async Task DownloadModel ( string modelUrl , string modelName , Callback < float > progressCallback = null , bool overwrite = false )
136
156
{
137
157
modelProgress = 0 ;
138
- if ( modelName == null ) modelName = model ;
139
158
string modelPath = LLMUnitySetup . GetAssetPath ( modelName ) ;
140
-
141
159
Callback < float > callback = ( floatArg ) =>
142
160
{
143
161
progressCallback ? . Invoke ( floatArg ) ;
@@ -146,9 +164,21 @@ public async Task DownloadModel(string modelUrl, string modelName = null, Callba
146
164
await LLMUnitySetup . DownloadFile ( modelUrl , modelPath , overwrite , SetModel , callback ) ;
147
165
}
148
166
149
- public async Task DownloadModel ( string modelUrl , Callback < float > progressCallback = null , bool overwrite = false )
167
+ public async Task DownloadModel ( )
168
+ {
169
+ await DownloadModel ( modelURL , model ) ;
170
+ }
171
+
172
+ public async Task WaitUntilModelDownloaded ( Callback < float > progressCallback = null )
173
+ {
174
+ if ( progressCallback != null ) progressCallbacks . Add ( progressCallback ) ;
175
+ while ( ! modelDownloaded ) await Task . Yield ( ) ;
176
+ if ( progressCallback != null ) progressCallbacks . Remove ( progressCallback ) ;
177
+ }
178
+
179
+ public async Task WaitUntilReady ( )
150
180
{
151
- await DownloadModel ( modelUrl , null , progressCallback , overwrite ) ;
181
+ while ( ! started ) await Task . Yield ( ) ;
152
182
}
153
183
154
184
/// <summary>
@@ -244,6 +274,8 @@ protected virtual string GetLlamaccpArguments()
244
274
public async void Awake ( )
245
275
{
246
276
if ( ! enabled ) return ;
277
+ if ( downloadOnBuild ) await DownloadModel ( ) ;
278
+ modelDownloaded = true ;
247
279
string arguments = GetLlamaccpArguments ( ) ;
248
280
if ( arguments == null ) return ;
249
281
if ( asynchronousStartup ) await Task . Run ( ( ) => StartLLMServer ( arguments ) ) ;
0 commit comments