@@ -36,6 +36,9 @@ public class LLMCharacter : LLMCaller
36
36
/// <summary> grammar file used for the LLMCharacter (.gbnf format) </summary>
37
37
[ Tooltip ( "grammar file used for the LLMCharacter (.gbnf format)" ) ]
38
38
[ ModelAdvanced ] public string grammar = null ;
39
+ /// <summary> grammar file used for the LLMCharacter (.json format) </summary>
40
+ [ Tooltip ( "grammar file used for the LLMCharacter (.json format)" ) ]
41
+ [ ModelAdvanced ] public string grammarJSON = null ;
39
42
/// <summary> cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!) </summary>
40
43
[ Tooltip ( "cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)" ) ]
41
44
[ ModelAdvanced ] public bool cachePrompt = true ;
@@ -124,8 +127,11 @@ public class LLMCharacter : LLMCaller
124
127
[ Tooltip ( "the chat history as list of chat messages" ) ]
125
128
public List < ChatMessage > chat = new List < ChatMessage > ( ) ;
126
129
/// <summary> the grammar to use </summary>
127
- [ Tooltip ( "the grammar to use" ) ]
130
+ [ Tooltip ( "the grammar to use (GBNF) " ) ]
128
131
public string grammarString ;
132
+ /// <summary> the grammar to use </summary>
133
+ [ Tooltip ( "the grammar to use (JSON schema)" ) ]
134
+ public string grammarJSONString ;
129
135
130
136
/// \cond HIDE
131
137
protected SemaphoreSlim chatLock = new SemaphoreSlim ( 1 , 1 ) ;
@@ -269,9 +275,17 @@ protected virtual async Task<bool> InitNKeep()
269
275
270
276
protected virtual void InitGrammar ( )
271
277
{
272
- if ( grammar != null && grammar != "" )
278
+ grammarString = "" ;
279
+ grammarJSONString = "" ;
280
+ if ( ! String . IsNullOrEmpty ( grammar ) )
273
281
{
274
282
grammarString = File . ReadAllText ( LLMUnitySetup . GetAssetPath ( grammar ) ) ;
283
+ if ( ! String . IsNullOrEmpty ( grammarJSON ) )
284
+ LLMUnitySetup . LogWarning ( "Both GBNF and JSON grammars are set, only the GBNF will be used" ) ;
285
+ }
286
+ else if ( ! String . IsNullOrEmpty ( grammarJSON ) )
287
+ {
288
+ grammarJSONString = File . ReadAllText ( LLMUnitySetup . GetAssetPath ( grammarJSON ) ) ;
275
289
}
276
290
}
277
291
@@ -308,16 +322,35 @@ public virtual async Task LoadTemplate()
308
322
/// Sets the grammar file of the LLMCharacter
309
323
/// </summary>
310
324
/// <param name="path">path to the grammar file</param>
311
- public virtual async void SetGrammar ( string path )
325
+ public virtual async Task SetGrammarFile ( string path , bool gnbf )
312
326
{
313
327
#if UNITY_EDITOR
314
328
if ( ! EditorApplication . isPlaying ) path = LLMUnitySetup . AddAsset ( path ) ;
315
329
#endif
316
330
await LLMUnitySetup . AndroidExtractAsset ( path , true ) ;
317
- grammar = path ;
331
+ if ( gnbf ) grammar = path ;
332
+ else grammarJSON = path ;
318
333
InitGrammar ( ) ;
319
334
}
320
335
336
+ /// <summary>
337
+ /// Sets the grammar file of the LLMCharacter (GBNF)
338
+ /// </summary>
339
+ /// <param name="path">path to the grammar file</param>
340
+ public virtual async Task SetGrammar ( string path )
341
+ {
342
+ await SetGrammarFile ( path , true ) ;
343
+ }
344
+
345
+ /// <summary>
346
+ /// Sets the grammar file of the LLMCharacter (JSON schema)
347
+ /// </summary>
348
+ /// <param name="path">path to the grammar file</param>
349
+ public virtual async Task SetJSONGrammar ( string path )
350
+ {
351
+ await SetGrammarFile ( path , false ) ;
352
+ }
353
+
321
354
protected virtual List < string > GetStopwords ( )
322
355
{
323
356
if ( ! CheckTemplate ( ) ) return null ;
@@ -352,6 +385,7 @@ protected virtual ChatRequest GenerateRequest(string prompt)
352
385
chatRequest . mirostat_tau = mirostatTau ;
353
386
chatRequest . mirostat_eta = mirostatEta ;
354
387
chatRequest . grammar = grammarString ;
388
+ chatRequest . json_schema = grammarJSONString ;
355
389
chatRequest . seed = seed ;
356
390
chatRequest . ignore_eos = ignoreEos ;
357
391
chatRequest . logit_bias = logitBias ;
@@ -418,8 +452,30 @@ protected virtual string TemplateContent(TemplateResult result)
418
452
return result . template ;
419
453
}
420
454
421
- protected virtual async Task < string > CompletionRequest ( string json , Callback < string > callback = null )
455
+ protected virtual string ChatRequestToJson ( ChatRequest request )
456
+ {
457
+ string json = JsonUtility . ToJson ( request ) ;
458
+ int grammarIndex = json . LastIndexOf ( '}' ) ;
459
+ if ( ! String . IsNullOrEmpty ( request . grammar ) )
460
+ {
461
+ GrammarWrapper grammarWrapper = new GrammarWrapper { grammar = request . grammar } ;
462
+ string grammarToJSON = JsonUtility . ToJson ( grammarWrapper ) ;
463
+ int start = grammarToJSON . IndexOf ( ":\" " ) + 2 ;
464
+ int end = grammarToJSON . LastIndexOf ( "\" " ) ;
465
+ string grammarSerialised = grammarToJSON . Substring ( start , end - start ) ;
466
+ json = json . Insert ( grammarIndex , $ ",\" grammar\" : \" { grammarSerialised } \" ") ;
467
+ }
468
+ else if ( ! String . IsNullOrEmpty ( request . json_schema ) )
469
+ {
470
+ json = json . Insert ( grammarIndex , $ ",\" json_schema\" :{ request . json_schema } ") ;
471
+ }
472
+ Debug . Log ( json ) ;
473
+ return json ;
474
+ }
475
+
476
+ protected virtual async Task < string > CompletionRequest ( ChatRequest request , Callback < string > callback = null )
422
477
{
478
+ string json = ChatRequestToJson ( request ) ;
423
479
string result = "" ;
424
480
if ( stream )
425
481
{
@@ -470,8 +526,8 @@ public virtual async Task<string> Chat(string query, Callback<string> callback =
470
526
if ( ! CheckTemplate ( ) ) return null ;
471
527
if ( ! await InitNKeep ( ) ) return null ;
472
528
473
- string json = JsonUtility . ToJson ( await PromptWithQuery ( query ) ) ;
474
- string result = await CompletionRequest ( json , callback ) ;
529
+ ChatRequest request = await PromptWithQuery ( query ) ;
530
+ string result = await CompletionRequest ( request , callback ) ;
475
531
476
532
if ( addToHistory && result != null )
477
533
{
@@ -508,8 +564,8 @@ public virtual async Task<string> Complete(string prompt, Callback<string> callb
508
564
// call the completionCallback function when the answer is fully received
509
565
await LoadTemplate ( ) ;
510
566
511
- string json = JsonUtility . ToJson ( GenerateRequest ( prompt ) ) ;
512
- string result = await CompletionRequest ( json , callback ) ;
567
+ ChatRequest request = GenerateRequest ( prompt ) ;
568
+ string result = await CompletionRequest ( request , callback ) ;
513
569
completionCallback ? . Invoke ( ) ;
514
570
return result ;
515
571
}
@@ -553,8 +609,7 @@ public virtual async Task Warmup(string query, EmptyCallback completionCallback
553
609
}
554
610
555
611
request . n_predict = 0 ;
556
- string json = JsonUtility . ToJson ( request ) ;
557
- await CompletionRequest ( json ) ;
612
+ await CompletionRequest ( request ) ;
558
613
completionCallback ? . Invoke ( ) ;
559
614
}
560
615
0 commit comments