diff --git a/README.md b/README.md
index bdf6350b..8dee1a57 100644
--- a/README.md
+++ b/README.md
@@ -406,8 +406,9 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Remote` whether the LLM used is remote or local
- `LLM` the LLM GameObject (if `Remote` is not set)
-- `Hort` ip of the LLM (if `Remote` is set)
-- `Port` port of the LLM (if `Remote` is set)
+- `Hort` ip of the LLM server (if `Remote` is set)
+- `Port` port of the LLM server (if `Remote` is set)
+- `Num Retries` number of HTTP request retries from the LLM server (if `Remote` is set)
- Save
save filename or relative path
If set, the chat history and LLM state (if save cache is enabled) is automatically saved to file specified.
The chat history is saved with a json suffix and the LLM state with a cache suffix.
Both files are saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
- `Save Cache` select to save the LLM state along with the chat history. The LLM state is typically around 100MB+.
- `Debug Prompt` select to log the constructed prompts in the Unity Editor
diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs
index 9e50c513..77b71398 100644
--- a/Runtime/LLM.cs
+++ b/Runtime/LLM.cs
@@ -354,7 +354,7 @@ private void StartLLMServer(string arguments)
failed = true;
return;
}
- CallIfNotDestroyed(() => StartService());
+ CallWithLock(StartService);
LLMUnitySetup.Log("LLM service created");
}
@@ -364,22 +364,22 @@ private void InitLib(string arch)
CheckLLMStatus(false);
}
- void CallIfNotDestroyed(EmptyCallback fn)
+ void CallWithLock(EmptyCallback fn, bool checkNull = true)
{
lock (startLock)
{
- if (llmlib == null) throw new DestroyException();
+ if (checkNull && llmlib == null) throw new DestroyException();
fn();
}
}
private void InitService(string arguments)
{
- if (debug) CallIfNotDestroyed(() => SetupLogging());
- CallIfNotDestroyed(() => { LLMObject = llmlib.LLM_Construct(arguments); });
- if (remote) CallIfNotDestroyed(() => llmlib.LLM_StartServer(LLMObject));
- CallIfNotDestroyed(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate));
- CallIfNotDestroyed(() => CheckLLMStatus(false));
+ if (debug) CallWithLock(SetupLogging);
+ CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
+ CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate));
+ if (remote) CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
+ CallWithLock(() => CheckLLMStatus(false));
}
private void StartService()
@@ -644,7 +644,7 @@ public void CancelRequest(int id_slot)
///
public void Destroy()
{
- lock (startLock)
+ CallWithLock(() =>
{
try
{
@@ -669,7 +669,7 @@ public void Destroy()
{
LLMUnitySetup.LogError(e.Message);
}
- }
+ }, false);
}
///
diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs
index 3b002564..c5ce99fd 100644
--- a/Runtime/LLMCharacter.cs
+++ b/Runtime/LLMCharacter.cs
@@ -27,6 +27,8 @@ public class LLMCharacter : MonoBehaviour
[Remote] public string host = "localhost";
/// port to use for the LLM server
[Remote] public int port = 13333;
+ /// number of retries to use for the LLM server requests (-1 = infinite)
+ [Remote] public int numRetries = -1;
/// file to save the chat history.
/// The file is saved only for Chat calls with addToHistory set to true.
/// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
@@ -118,7 +120,7 @@ public class LLMCharacter : MonoBehaviour
public List chat;
private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
private string chatTemplate;
- private ChatTemplate template;
+ private ChatTemplate template = null;
public string grammarString;
protected int id_slot = -1;
private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
@@ -270,10 +272,21 @@ public void SetPrompt(string newPrompt, bool clearChat = true)
InitPrompt(clearChat);
}
+ private bool CheckTemplate()
+ {
+ if (template == null)
+ {
+ LLMUnitySetup.LogError("Template not set!");
+ return false;
+ }
+ return true;
+ }
+
private async Task InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
{
+ if (!CheckTemplate()) return;
string systemPrompt = template.ComputePrompt(new List(){chat[0]}, playerName, "", false);
await Tokenize(systemPrompt, SetNKeep);
}
@@ -311,7 +324,8 @@ public async Task LoadTemplate()
if (llmTemplate != chatTemplate)
{
chatTemplate = llmTemplate;
- template = ChatTemplate.GetTemplate(chatTemplate);
+ template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
+ nKeep = -1;
}
}
@@ -331,6 +345,7 @@ public async void SetGrammar(string path)
List GetStopwords()
{
+ if (!CheckTemplate()) return null;
List stopAll = new List(template.GetStop(playerName, AIName));
if (stop != null) stopAll.AddRange(stop);
return stopAll;
@@ -465,6 +480,7 @@ public async Task Chat(string query, Callback callback = null, E
// call the callback function while the answer is received
// call the completionCallback function when the answer is fully received
await LoadTemplate();
+ if (!CheckTemplate()) return null;
await InitNKeep();
string json;
@@ -750,38 +766,57 @@ protected async Task PostRequestRemote(string json, string endpoi
Ret result = default;
byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json);
- using (var request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
- {
- WIPRequests.Add(request);
+ UnityWebRequest request = null;
+ string error = null;
+ int tryNr = numRetries;
- request.method = "POST";
- if (requestHeaders != null)
+ while (tryNr != 0)
+ {
+ using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
{
- for (int i = 0; i < requestHeaders.Count; i++)
- request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
- }
+ WIPRequests.Add(request);
- // Start the request asynchronously
- var asyncOperation = request.SendWebRequest();
- float lastProgress = 0f;
- // Continue updating progress until the request is completed
- while (!asyncOperation.isDone)
- {
- float currentProgress = request.downloadProgress;
- // Check if progress has changed
- if (currentProgress != lastProgress && callback != null)
+ request.method = "POST";
+ if (requestHeaders != null)
{
- callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
- lastProgress = currentProgress;
+ for (int i = 0; i < requestHeaders.Count; i++)
+ request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
+ }
+
+ // Start the request asynchronously
+ var asyncOperation = request.SendWebRequest();
+ float lastProgress = 0f;
+ // Continue updating progress until the request is completed
+ while (!asyncOperation.isDone)
+ {
+ float currentProgress = request.downloadProgress;
+ // Check if progress has changed
+ if (currentProgress != lastProgress && callback != null)
+ {
+ callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
+ lastProgress = currentProgress;
+ }
+ // Wait for the next frame
+ await Task.Yield();
+ }
+ WIPRequests.Remove(request);
+ if (request.result == UnityWebRequest.Result.Success)
+ {
+ result = ConvertContent(request.downloadHandler.text, getContent);
+ error = null;
+ break;
+ }
+ else
+ {
+ result = default;
+ error = request.error;
}
- // Wait for the next frame
- await Task.Yield();
}
- WIPRequests.Remove(request);
- if (request.result != UnityWebRequest.Result.Success) LLMUnitySetup.LogError(request.error);
- else result = ConvertContent(request.downloadHandler.text, getContent);
- callback?.Invoke(result);
+ tryNr--;
}
+
+ if (error != null) LLMUnitySetup.LogError(error);
+ callback?.Invoke(result);
return result;
}