Skip to content

Commit 1cd9067

Browse files
committed
Allow JSON schema grammars
1 parent 04daedb commit 1cd9067

File tree

4 files changed

+93
-14
lines changed

4 files changed

+93
-14
lines changed

Editor/LLMCallerEditor.cs

+11
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ public override void AddModelSettings(SerializedObject llmScriptSO)
3535
}
3636
};
3737
}
38+
if (GUILayout.Button("Load JSON grammar", GUILayout.Width(buttonWidth)))
39+
{
40+
EditorApplication.delayCall += () =>
41+
{
42+
string path = EditorUtility.OpenFilePanelWithFilters("Select a json schema grammar file", "", new string[] { "Grammar Files", "json" });
43+
if (!string.IsNullOrEmpty(path))
44+
{
45+
((LLMCharacter)target).SetJSONGrammar(path);
46+
}
47+
};
48+
}
3849
EditorGUILayout.EndHorizontal();
3950

4051
ShowPropertiesOfClass("", llmScriptSO, new List<Type> { typeof(ModelAdvancedAttribute) }, false);

README.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,18 @@ The [MobileDemo](Samples~/MobileDemo) is an example application for Android / iO
176176
<details>
177177
<summary>Restrict the output of the LLM / Function calling</summary>
178178

179-
To restrict the output of the LLM you can use a GBNF grammar, read more [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars).<br>
179+
To restrict the output of the LLM you can use a grammar, read more [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars).<br>
180180
The grammar can be saved in a .gbnf file and loaded at the LLMCharacter with the `Load Grammar` button (Advanced options).<br>
181181
For instance to receive replies in json format you can use the [json.gbnf](https://github.com/ggerganov/llama.cpp/blob/b4218/grammars/json.gbnf) grammar.<br>
182+
Graamars in JSON schema format are also supported and can be loaded with the `Load JSON Grammar` button (Advanced options).<br>
182183

183184
Alternatively you can set the grammar directly with code:
184185
``` c#
185-
llmCharacter.grammarString = "your grammar here";
186+
// GBNF grammar
187+
llmCharacter.grammarString = "your GBNF grammar here";
188+
189+
// or JSON schema grammar
190+
llmCharacter.grammarJSONString = "your JSON schema grammar here";
186191
```
187192

188193
For function calling you can define similarly a grammar that allows only the function names as output, and then call the respective function.<br>

Runtime/LLMCharacter.cs

+66-11
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ public class LLMCharacter : LLMCaller
3636
/// <summary> grammar file used for the LLMCharacter (.gbnf format) </summary>
3737
[Tooltip("grammar file used for the LLMCharacter (.gbnf format)")]
3838
[ModelAdvanced] public string grammar = null;
39+
/// <summary> grammar file used for the LLMCharacter (.json format) </summary>
40+
[Tooltip("grammar file used for the LLMCharacter (.json format)")]
41+
[ModelAdvanced] public string grammarJSON = null;
3942
/// <summary> cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!) </summary>
4043
[Tooltip("cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)")]
4144
[ModelAdvanced] public bool cachePrompt = true;
@@ -124,8 +127,11 @@ public class LLMCharacter : LLMCaller
124127
[Tooltip("the chat history as list of chat messages")]
125128
public List<ChatMessage> chat = new List<ChatMessage>();
126129
/// <summary> the grammar to use </summary>
127-
[Tooltip("the grammar to use")]
130+
[Tooltip("the grammar to use (GBNF)")]
128131
public string grammarString;
132+
/// <summary> the grammar to use </summary>
133+
[Tooltip("the grammar to use (JSON schema)")]
134+
public string grammarJSONString;
129135

130136
/// \cond HIDE
131137
protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
@@ -269,9 +275,17 @@ protected virtual async Task<bool> InitNKeep()
269275

270276
protected virtual void InitGrammar()
271277
{
272-
if (grammar != null && grammar != "")
278+
grammarString = "";
279+
grammarJSONString = "";
280+
if (!String.IsNullOrEmpty(grammar))
273281
{
274282
grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
283+
if (!String.IsNullOrEmpty(grammarJSON))
284+
LLMUnitySetup.LogWarning("Both GBNF and JSON grammars are set, only the GBNF will be used");
285+
}
286+
else if (!String.IsNullOrEmpty(grammarJSON))
287+
{
288+
grammarJSONString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammarJSON));
275289
}
276290
}
277291

@@ -308,16 +322,35 @@ public virtual async Task LoadTemplate()
308322
/// Sets the grammar file of the LLMCharacter
309323
/// </summary>
310324
/// <param name="path">path to the grammar file</param>
311-
public virtual async void SetGrammar(string path)
325+
public virtual async Task SetGrammarFile(string path, bool gnbf)
312326
{
313327
#if UNITY_EDITOR
314328
if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
315329
#endif
316330
await LLMUnitySetup.AndroidExtractAsset(path, true);
317-
grammar = path;
331+
if (gnbf) grammar = path;
332+
else grammarJSON = path;
318333
InitGrammar();
319334
}
320335

336+
/// <summary>
337+
/// Sets the grammar file of the LLMCharacter (GBNF)
338+
/// </summary>
339+
/// <param name="path">path to the grammar file</param>
340+
public virtual async Task SetGrammar(string path)
341+
{
342+
await SetGrammarFile(path, true);
343+
}
344+
345+
/// <summary>
346+
/// Sets the grammar file of the LLMCharacter (JSON schema)
347+
/// </summary>
348+
/// <param name="path">path to the grammar file</param>
349+
public virtual async Task SetJSONGrammar(string path)
350+
{
351+
await SetGrammarFile(path, false);
352+
}
353+
321354
protected virtual List<string> GetStopwords()
322355
{
323356
if (!CheckTemplate()) return null;
@@ -352,6 +385,7 @@ protected virtual ChatRequest GenerateRequest(string prompt)
352385
chatRequest.mirostat_tau = mirostatTau;
353386
chatRequest.mirostat_eta = mirostatEta;
354387
chatRequest.grammar = grammarString;
388+
chatRequest.json_schema = grammarJSONString;
355389
chatRequest.seed = seed;
356390
chatRequest.ignore_eos = ignoreEos;
357391
chatRequest.logit_bias = logitBias;
@@ -418,8 +452,30 @@ protected virtual string TemplateContent(TemplateResult result)
418452
return result.template;
419453
}
420454

421-
protected virtual async Task<string> CompletionRequest(string json, Callback<string> callback = null)
455+
protected virtual string ChatRequestToJson(ChatRequest request)
456+
{
457+
string json = JsonUtility.ToJson(request);
458+
int grammarIndex = json.LastIndexOf('}');
459+
if (!String.IsNullOrEmpty(request.grammar))
460+
{
461+
GrammarWrapper grammarWrapper = new GrammarWrapper { grammar = request.grammar };
462+
string grammarToJSON = JsonUtility.ToJson(grammarWrapper);
463+
int start = grammarToJSON.IndexOf(":\"") + 2;
464+
int end = grammarToJSON.LastIndexOf("\"");
465+
string grammarSerialised = grammarToJSON.Substring(start, end - start);
466+
json = json.Insert(grammarIndex, $",\"grammar\": \"{grammarSerialised}\"");
467+
}
468+
else if (!String.IsNullOrEmpty(request.json_schema))
469+
{
470+
json = json.Insert(grammarIndex, $",\"json_schema\":{request.json_schema}");
471+
}
472+
Debug.Log(json);
473+
return json;
474+
}
475+
476+
protected virtual async Task<string> CompletionRequest(ChatRequest request, Callback<string> callback = null)
422477
{
478+
string json = ChatRequestToJson(request);
423479
string result = "";
424480
if (stream)
425481
{
@@ -470,8 +526,8 @@ public virtual async Task<string> Chat(string query, Callback<string> callback =
470526
if (!CheckTemplate()) return null;
471527
if (!await InitNKeep()) return null;
472528

473-
string json = JsonUtility.ToJson(await PromptWithQuery(query));
474-
string result = await CompletionRequest(json, callback);
529+
ChatRequest request = await PromptWithQuery(query);
530+
string result = await CompletionRequest(request, callback);
475531

476532
if (addToHistory && result != null)
477533
{
@@ -508,8 +564,8 @@ public virtual async Task<string> Complete(string prompt, Callback<string> callb
508564
// call the completionCallback function when the answer is fully received
509565
await LoadTemplate();
510566

511-
string json = JsonUtility.ToJson(GenerateRequest(prompt));
512-
string result = await CompletionRequest(json, callback);
567+
ChatRequest request = GenerateRequest(prompt);
568+
string result = await CompletionRequest(request, callback);
513569
completionCallback?.Invoke();
514570
return result;
515571
}
@@ -553,8 +609,7 @@ public virtual async Task Warmup(string query, EmptyCallback completionCallback
553609
}
554610

555611
request.n_predict = 0;
556-
string json = JsonUtility.ToJson(request);
557-
await CompletionRequest(json);
612+
await CompletionRequest(request);
558613
completionCallback?.Invoke();
559614
}
560615

Runtime/LLMInterface.cs

+9-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ public struct ChatRequest
3030
public int mirostat;
3131
public float mirostat_tau;
3232
public float mirostat_eta;
33-
public string grammar;
33+
// EXCLUDE grammars from JsonUtility serialization, serialise them manually
34+
[NonSerialized] public string grammar;
35+
[NonSerialized] public string json_schema;
3436
public int seed;
3537
public bool ignore_eos;
3638
public Dictionary<int, string> logit_bias;
@@ -39,6 +41,12 @@ public struct ChatRequest
3941
public List<ChatMessage> messages;
4042
}
4143

44+
[Serializable]
45+
public struct GrammarWrapper
46+
{
47+
public string grammar;
48+
}
49+
4250
[Serializable]
4351
public struct SystemPromptRequest
4452
{

0 commit comments

Comments
 (0)