Skip to content

Allow JSON schema grammars #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Editor/LLMCallerEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ public override void AddModelSettings(SerializedObject llmScriptSO)
}
};
}
if (GUILayout.Button("Load JSON grammar", GUILayout.Width(buttonWidth)))
{
EditorApplication.delayCall += () =>
{
string path = EditorUtility.OpenFilePanelWithFilters("Select a json schema grammar file", "", new string[] { "Grammar Files", "json" });
if (!string.IsNullOrEmpty(path))
{
((LLMCharacter)target).SetJSONGrammar(path);
}
};
}
EditorGUILayout.EndHorizontal();

ShowPropertiesOfClass("", llmScriptSO, new List<Type> { typeof(ModelAdvancedAttribute) }, false);
Expand Down
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,18 @@ The [MobileDemo](Samples~/MobileDemo) is an example application for Android / iO
<details>
<summary>Restrict the output of the LLM / Function calling</summary>

To restrict the output of the LLM you can use a GBNF grammar, read more [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars).<br>
To restrict the output of the LLM you can use a grammar, read more [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars).<br>
The grammar can be saved in a .gbnf file and loaded at the LLMCharacter with the `Load Grammar` button (Advanced options).<br>
For instance to receive replies in json format you can use the [json.gbnf](https://github.com/ggerganov/llama.cpp/blob/b4218/grammars/json.gbnf) grammar.<br>
Graamars in JSON schema format are also supported and can be loaded with the `Load JSON Grammar` button (Advanced options).<br>

Alternatively you can set the grammar directly with code:
``` c#
llmCharacter.grammarString = "your grammar here";
// GBNF grammar
llmCharacter.grammarString = "your GBNF grammar here";

// or JSON schema grammar
llmCharacter.grammarJSONString = "your JSON schema grammar here";
```

For function calling you can define similarly a grammar that allows only the function names as output, and then call the respective function.<br>
Expand Down
77 changes: 66 additions & 11 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class LLMCharacter : LLMCaller
/// <summary> grammar file used for the LLMCharacter (.gbnf format) </summary>
[Tooltip("grammar file used for the LLMCharacter (.gbnf format)")]
[ModelAdvanced] public string grammar = null;
/// <summary> grammar file used for the LLMCharacter (.json format) </summary>
[Tooltip("grammar file used for the LLMCharacter (.json format)")]
[ModelAdvanced] public string grammarJSON = null;
/// <summary> cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!) </summary>
[Tooltip("cache the processed prompt to avoid reprocessing the entire prompt every time (default: true, recommended!)")]
[ModelAdvanced] public bool cachePrompt = true;
Expand Down Expand Up @@ -124,8 +127,11 @@ public class LLMCharacter : LLMCaller
[Tooltip("the chat history as list of chat messages")]
public List<ChatMessage> chat = new List<ChatMessage>();
/// <summary> the grammar to use </summary>
[Tooltip("the grammar to use")]
[Tooltip("the grammar to use (GBNF)")]
public string grammarString;
/// <summary> the grammar to use </summary>
[Tooltip("the grammar to use (JSON schema)")]
public string grammarJSONString;

/// \cond HIDE
protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
Expand Down Expand Up @@ -269,9 +275,17 @@ protected virtual async Task<bool> InitNKeep()

protected virtual void InitGrammar()
{
if (grammar != null && grammar != "")
grammarString = "";
grammarJSONString = "";
if (!String.IsNullOrEmpty(grammar))
{
grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
if (!String.IsNullOrEmpty(grammarJSON))
LLMUnitySetup.LogWarning("Both GBNF and JSON grammars are set, only the GBNF will be used");
}
else if (!String.IsNullOrEmpty(grammarJSON))
{
grammarJSONString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammarJSON));
}
}

Expand Down Expand Up @@ -308,16 +322,35 @@ public virtual async Task LoadTemplate()
/// Sets the grammar file of the LLMCharacter
/// </summary>
/// <param name="path">path to the grammar file</param>
public virtual async void SetGrammar(string path)
public virtual async Task SetGrammarFile(string path, bool gnbf)
{
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
#endif
await LLMUnitySetup.AndroidExtractAsset(path, true);
grammar = path;
if (gnbf) grammar = path;
else grammarJSON = path;
InitGrammar();
}

/// <summary>
/// Sets the grammar file of the LLMCharacter (GBNF)
/// </summary>
/// <param name="path">path to the grammar file</param>
public virtual async Task SetGrammar(string path)
{
await SetGrammarFile(path, true);
}

/// <summary>
/// Sets the grammar file of the LLMCharacter (JSON schema)
/// </summary>
/// <param name="path">path to the grammar file</param>
public virtual async Task SetJSONGrammar(string path)
{
await SetGrammarFile(path, false);
}

protected virtual List<string> GetStopwords()
{
if (!CheckTemplate()) return null;
Expand Down Expand Up @@ -352,6 +385,7 @@ protected virtual ChatRequest GenerateRequest(string prompt)
chatRequest.mirostat_tau = mirostatTau;
chatRequest.mirostat_eta = mirostatEta;
chatRequest.grammar = grammarString;
chatRequest.json_schema = grammarJSONString;
chatRequest.seed = seed;
chatRequest.ignore_eos = ignoreEos;
chatRequest.logit_bias = logitBias;
Expand Down Expand Up @@ -418,8 +452,30 @@ protected virtual string TemplateContent(TemplateResult result)
return result.template;
}

protected virtual async Task<string> CompletionRequest(string json, Callback<string> callback = null)
protected virtual string ChatRequestToJson(ChatRequest request)
{
string json = JsonUtility.ToJson(request);
int grammarIndex = json.LastIndexOf('}');
if (!String.IsNullOrEmpty(request.grammar))
{
GrammarWrapper grammarWrapper = new GrammarWrapper { grammar = request.grammar };
string grammarToJSON = JsonUtility.ToJson(grammarWrapper);
int start = grammarToJSON.IndexOf(":\"") + 2;
int end = grammarToJSON.LastIndexOf("\"");
string grammarSerialised = grammarToJSON.Substring(start, end - start);
json = json.Insert(grammarIndex, $",\"grammar\": \"{grammarSerialised}\"");
}
else if (!String.IsNullOrEmpty(request.json_schema))
{
json = json.Insert(grammarIndex, $",\"json_schema\":{request.json_schema}");
}
Debug.Log(json);
return json;
}

protected virtual async Task<string> CompletionRequest(ChatRequest request, Callback<string> callback = null)
{
string json = ChatRequestToJson(request);
string result = "";
if (stream)
{
Expand Down Expand Up @@ -470,8 +526,8 @@ public virtual async Task<string> Chat(string query, Callback<string> callback =
if (!CheckTemplate()) return null;
if (!await InitNKeep()) return null;

string json = JsonUtility.ToJson(await PromptWithQuery(query));
string result = await CompletionRequest(json, callback);
ChatRequest request = await PromptWithQuery(query);
string result = await CompletionRequest(request, callback);

if (addToHistory && result != null)
{
Expand Down Expand Up @@ -508,8 +564,8 @@ public virtual async Task<string> Complete(string prompt, Callback<string> callb
// call the completionCallback function when the answer is fully received
await LoadTemplate();

string json = JsonUtility.ToJson(GenerateRequest(prompt));
string result = await CompletionRequest(json, callback);
ChatRequest request = GenerateRequest(prompt);
string result = await CompletionRequest(request, callback);
completionCallback?.Invoke();
return result;
}
Expand Down Expand Up @@ -553,8 +609,7 @@ public virtual async Task Warmup(string query, EmptyCallback completionCallback
}

request.n_predict = 0;
string json = JsonUtility.ToJson(request);
await CompletionRequest(json);
await CompletionRequest(request);
completionCallback?.Invoke();
}

Expand Down
10 changes: 9 additions & 1 deletion Runtime/LLMInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public struct ChatRequest
public int mirostat;
public float mirostat_tau;
public float mirostat_eta;
public string grammar;
// EXCLUDE grammars from JsonUtility serialization, serialise them manually
[NonSerialized] public string grammar;
[NonSerialized] public string json_schema;
public int seed;
public bool ignore_eos;
public Dictionary<int, string> logit_bias;
Expand All @@ -39,6 +41,12 @@ public struct ChatRequest
public List<ChatMessage> messages;
}

[Serializable]
public struct GrammarWrapper
{
public string grammar;
}

[Serializable]
public struct SystemPromptRequest
{
Expand Down
Loading