From 6a1346aca8026215f91c902cb73cbecd962f32e2 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Tue, 20 Aug 2024 17:56:02 +0300 Subject: [PATCH] fix crash when stopping scene before LLM creation --- Runtime/LLM.cs | 70 ++++++++++++++++++++++++++++++----------------- Runtime/LLMLib.cs | 1 - 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 3d7b599b..7fb460c4 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -21,6 +21,8 @@ public LLMException(string message, int errorCode) : base(message) ErrorCode = errorCode; } } + + public class DestroyException : Exception {} /// \endcond [DefaultExecutionOrder(-1)] @@ -83,6 +85,7 @@ public class LLM : MonoBehaviour List streamWrappers = new List(); public LLMManager llmManager = new LLMManager(); List loraWeights = new List(); + private readonly object startLock = new object(); /// \endcond @@ -114,6 +117,7 @@ public async void Awake() return; } await Task.Run(() => StartLLMServer(arguments)); + if (!started) return; if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject); if (basePrompt != "") await SetBasePrompt(basePrompt); } @@ -322,7 +326,7 @@ private void StartLLMServer(string arguments) try { InitLib(arch); - InitServer(arguments); + InitService(arguments); LLMUnitySetup.Log($"Using architecture: {arch}"); break; } @@ -331,6 +335,10 @@ private void StartLLMServer(string arguments) error = e.Message; Destroy(); } + catch (DestroyException) + { + break; + } catch (Exception e) { error = $"{e.GetType()}: {e.Message}"; @@ -343,7 +351,7 @@ private void StartLLMServer(string arguments) failed = true; return; } - StartService(); + CallIfNotDestroyed(() => StartService()); LLMUnitySetup.Log("LLM service created"); } @@ -353,13 +361,22 @@ private void InitLib(string arch) CheckLLMStatus(false); } - private void InitServer(string arguments) + void CallIfNotDestroyed(EmptyCallback fn) { - if (debug) SetupLogging(); - LLMObject = llmlib.LLM_Construct(arguments); - if (remote) llmlib.LLM_StartServer(LLMObject); - llmlib.LLM_SetTemplate(LLMObject, chatTemplate); - CheckLLMStatus(false); + lock (startLock) + { + if (llmlib == null) throw new DestroyException(); + fn(); + } + } + + private void InitService(string arguments) + { + if (debug) CallIfNotDestroyed(() => SetupLogging()); + CallIfNotDestroyed(() => {LLMObject = llmlib.LLM_Construct(arguments);}); + if (remote) CallIfNotDestroyed(() => llmlib.LLM_StartServer(LLMObject)); + CallIfNotDestroyed(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate)); + CallIfNotDestroyed(() => CheckLLMStatus(false)); } private void StartService() @@ -624,28 +641,31 @@ public void CancelRequest(int id_slot) /// public void Destroy() { - try + lock (startLock) { - if (llmlib != null) + try { - if (LLMObject != IntPtr.Zero) + if (llmlib != null) { - llmlib.LLM_Stop(LLMObject); - if (remote) llmlib.LLM_StopServer(LLMObject); - StopLogging(); - llmThread?.Join(); - llmlib.LLM_Delete(LLMObject); - LLMObject = IntPtr.Zero; + if (LLMObject != IntPtr.Zero) + { + llmlib.LLM_Stop(LLMObject); + if (remote) llmlib.LLM_StopServer(LLMObject); + StopLogging(); + llmThread?.Join(); + llmlib.LLM_Delete(LLMObject); + LLMObject = IntPtr.Zero; + } + llmlib.Destroy(); + llmlib = null; } - llmlib.Destroy(); + started = false; + failed = false; + } + catch (Exception e) + { + LLMUnitySetup.LogError(e.Message); } - started = false; - failed = false; - llmlib = null; - } - catch (Exception e) - { - LLMUnitySetup.LogError(e.Message); } } diff --git a/Runtime/LLMLib.cs b/Runtime/LLMLib.cs index fbaa050d..edb60601 100644 --- a/Runtime/LLMLib.cs +++ b/Runtime/LLMLib.cs @@ -281,7 +281,6 @@ static LLMLib() public LLMLib(string arch) { - LLMUnitySetup.Log(GetArchitecturePath(arch)); libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch)); if (libraryHandle == IntPtr.Zero) {