@@ -36,6 +36,9 @@ public class LLMCharacter : LLMCaller
3636 /// <summary> grammar file used for the LLMCharacter (.gbnf format) </summary>
3737 [ Tooltip ( "grammar file used for the LLMCharacter (.gbnf format)" ) ]
3838 [ ModelAdvanced ] public string grammar = null ;
39+ /// <summary> grammar file used for the LLMCharacter (.json format) </summary>
40+ [ Tooltip ( "grammar file used for the LLMCharacter (.json format)" ) ]
41+ [ ModelAdvanced ] public string grammarJSON = null ;
3942 /// <summary> cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!) </summary>
4043 [ Tooltip ( "cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)" ) ]
4144 [ ModelAdvanced ] public bool cachePrompt = true ;
@@ -124,8 +127,11 @@ public class LLMCharacter : LLMCaller
124127 [ Tooltip ( "the chat history as list of chat messages" ) ]
125128 public List < ChatMessage > chat = new List < ChatMessage > ( ) ;
126129 /// <summary> the grammar to use </summary>
127- [ Tooltip ( "the grammar to use" ) ]
130+ [ Tooltip ( "the grammar to use (GBNF) " ) ]
128131 public string grammarString ;
132+ /// <summary> the grammar to use </summary>
133+ [ Tooltip ( "the grammar to use (JSON schema)" ) ]
134+ public string grammarJSONString ;
129135
130136 /// \cond HIDE
131137 protected SemaphoreSlim chatLock = new SemaphoreSlim ( 1 , 1 ) ;
@@ -269,9 +275,17 @@ protected virtual async Task<bool> InitNKeep()
269275
270276 protected virtual void InitGrammar ( )
271277 {
272- if ( grammar != null && grammar != "" )
278+ grammarString = "" ;
279+ grammarJSONString = "" ;
280+ if ( ! String . IsNullOrEmpty ( grammar ) )
273281 {
274282 grammarString = File . ReadAllText ( LLMUnitySetup . GetAssetPath ( grammar ) ) ;
283+ if ( ! String . IsNullOrEmpty ( grammarJSON ) )
284+ LLMUnitySetup . LogWarning ( "Both GBNF and JSON grammars are set, only the GBNF will be used" ) ;
285+ }
286+ else if ( ! String . IsNullOrEmpty ( grammarJSON ) )
287+ {
288+ grammarJSONString = File . ReadAllText ( LLMUnitySetup . GetAssetPath ( grammarJSON ) ) ;
275289 }
276290 }
277291
@@ -308,16 +322,35 @@ public virtual async Task LoadTemplate()
308322 /// Sets the grammar file of the LLMCharacter
309323 /// </summary>
310324 /// <param name="path">path to the grammar file</param>
311- public virtual async void SetGrammar ( string path )
325+ public virtual async Task SetGrammarFile ( string path , bool gnbf )
312326 {
313327#if UNITY_EDITOR
314328 if ( ! EditorApplication . isPlaying ) path = LLMUnitySetup . AddAsset ( path ) ;
315329#endif
316330 await LLMUnitySetup . AndroidExtractAsset ( path , true ) ;
317- grammar = path ;
331+ if ( gnbf ) grammar = path ;
332+ else grammarJSON = path ;
318333 InitGrammar ( ) ;
319334 }
320335
336+ /// <summary>
337+ /// Sets the grammar file of the LLMCharacter (GBNF)
338+ /// </summary>
339+ /// <param name="path">path to the grammar file</param>
340+ public virtual async Task SetGrammar ( string path )
341+ {
342+ await SetGrammarFile ( path , true ) ;
343+ }
344+
345+ /// <summary>
346+ /// Sets the grammar file of the LLMCharacter (JSON schema)
347+ /// </summary>
348+ /// <param name="path">path to the grammar file</param>
349+ public virtual async Task SetJSONGrammar ( string path )
350+ {
351+ await SetGrammarFile ( path , false ) ;
352+ }
353+
321354 protected virtual List < string > GetStopwords ( )
322355 {
323356 if ( ! CheckTemplate ( ) ) return null ;
@@ -352,6 +385,7 @@ protected virtual ChatRequest GenerateRequest(string prompt)
352385 chatRequest . mirostat_tau = mirostatTau ;
353386 chatRequest . mirostat_eta = mirostatEta ;
354387 chatRequest . grammar = grammarString ;
388+ chatRequest . json_schema = grammarJSONString ;
355389 chatRequest . seed = seed ;
356390 chatRequest . ignore_eos = ignoreEos ;
357391 chatRequest . logit_bias = logitBias ;
@@ -418,8 +452,30 @@ protected virtual string TemplateContent(TemplateResult result)
418452 return result . template ;
419453 }
420454
421- protected virtual async Task < string > CompletionRequest ( string json , Callback < string > callback = null )
455+ protected virtual string ChatRequestToJson ( ChatRequest request )
456+ {
457+ string json = JsonUtility . ToJson ( request ) ;
458+ int grammarIndex = json . LastIndexOf ( '}' ) ;
459+ if ( ! String . IsNullOrEmpty ( request . grammar ) )
460+ {
461+ GrammarWrapper grammarWrapper = new GrammarWrapper { grammar = request . grammar } ;
462+ string grammarToJSON = JsonUtility . ToJson ( grammarWrapper ) ;
463+ int start = grammarToJSON . IndexOf ( ":\" " ) + 2 ;
464+ int end = grammarToJSON . LastIndexOf ( "\" " ) ;
465+ string grammarSerialised = grammarToJSON . Substring ( start , end - start ) ;
466+ json = json . Insert ( grammarIndex , $ ",\" grammar\" : \" { grammarSerialised } \" ") ;
467+ }
468+ else if ( ! String . IsNullOrEmpty ( request . json_schema ) )
469+ {
470+ json = json . Insert ( grammarIndex , $ ",\" json_schema\" :{ request . json_schema } ") ;
471+ }
472+ Debug . Log ( json ) ;
473+ return json ;
474+ }
475+
476+ protected virtual async Task < string > CompletionRequest ( ChatRequest request , Callback < string > callback = null )
422477 {
478+ string json = ChatRequestToJson ( request ) ;
423479 string result = "" ;
424480 if ( stream )
425481 {
@@ -470,8 +526,8 @@ public virtual async Task<string> Chat(string query, Callback<string> callback =
470526 if ( ! CheckTemplate ( ) ) return null ;
471527 if ( ! await InitNKeep ( ) ) return null ;
472528
473- string json = JsonUtility . ToJson ( await PromptWithQuery ( query ) ) ;
474- string result = await CompletionRequest ( json , callback ) ;
529+ ChatRequest request = await PromptWithQuery ( query ) ;
530+ string result = await CompletionRequest ( request , callback ) ;
475531
476532 if ( addToHistory && result != null )
477533 {
@@ -508,8 +564,8 @@ public virtual async Task<string> Complete(string prompt, Callback<string> callb
508564 // call the completionCallback function when the answer is fully received
509565 await LoadTemplate ( ) ;
510566
511- string json = JsonUtility . ToJson ( GenerateRequest ( prompt ) ) ;
512- string result = await CompletionRequest ( json , callback ) ;
567+ ChatRequest request = GenerateRequest ( prompt ) ;
568+ string result = await CompletionRequest ( request , callback ) ;
513569 completionCallback ? . Invoke ( ) ;
514570 return result ;
515571 }
@@ -553,8 +609,7 @@ public virtual async Task Warmup(string query, EmptyCallback completionCallback
553609 }
554610
555611 request . n_predict = 0 ;
556- string json = JsonUtility . ToJson ( request ) ;
557- await CompletionRequest ( json ) ;
612+ await CompletionRequest ( request ) ;
558613 completionCallback ? . Invoke ( ) ;
559614 }
560615
0 commit comments