-
-
Notifications
You must be signed in to change notification settings - Fork 181
Expand file tree
/
Copy pathLLMAgent.cs
More file actions
422 lines (374 loc) · 14.7 KB
/
LLMAgent.cs
File metadata and controls
422 lines (374 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
/// @file
/// @brief File implementing the LLM chat agent functionality for Unity.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using UndreamAI.LlamaLib;
using UnityEngine;
namespace LLMUnity
{
[DefaultExecutionOrder(-1)]
/// @ingroup llm
/// <summary>
/// Unity MonoBehaviour that implements a conversational AI agent with persistent chat history.
/// Extends LLMClient to provide chat-specific functionality including role management,
/// conversation history persistence, and specialized chat completion methods.
/// </summary>
public class LLMAgent : LLMClient
{
#region Inspector Fields
/// <summary>Filename for saving chat history (saved in persistentDataPath)</summary>
[Tooltip("Filename for saving chat history (saved in Application.persistentDataPath)")]
[LLM] public string save = "";
/// <summary>Debug LLM prompts</summary>
[Tooltip("Debug LLM prompts")]
[LLM] public bool debugPrompt = false;
/// <summary>Server slot to use for processing (affects caching behavior)</summary>
[Tooltip("Server slot to use for processing (affects caching behavior)")]
[ModelAdvanced, SerializeField] protected int _slot = -1;
/// <summary>Role name for user messages in conversation</summary>
[Tooltip("Role name for user messages in conversation")]
[Chat, SerializeField] protected string _userRole = "user";
/// <summary>Role name for AI assistant messages in conversation</summary>
[Tooltip("Role name for AI assistant messages in conversation")]
[Chat, SerializeField] protected string _assistantRole = "assistant";
/// <summary>System prompt that defines the AI's personality and behavior</summary>
[Tooltip("System prompt that defines the AI's personality and behavior")]
[TextArea(5, 10), Chat, SerializeField]
protected string _systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
#endregion
#region Public Properties
/// <summary>Server slot ID for this agent's requests</summary>
public int slot
{
get => _slot;
set
{
if (_slot != value)
{
_slot = value;
if (llmAgent != null) llmAgent.SlotId = _slot;
}
}
}
/// <summary>Role identifier for user messages</summary>
public string userRole
{
get => _userRole;
set
{
if (_userRole != value)
{
_userRole = value;
if (llmAgent != null) llmAgent.UserRole = _userRole;
}
}
}
/// <summary>Role identifier for assistant messages</summary>
public string assistantRole
{
get => _assistantRole;
set
{
if (_assistantRole != value)
{
_assistantRole = value;
if (llmAgent != null) llmAgent.AssistantRole = _assistantRole;
}
}
}
/// <summary>System prompt defining the agent's behavior and personality</summary>
public string systemPrompt
{
get => _systemPrompt;
set
{
if (_systemPrompt != value)
{
_systemPrompt = value;
if (llmAgent != null) llmAgent.SystemPrompt = _systemPrompt;
}
}
}
/// <summary>The underlying LLMAgent instance from LlamaLib</summary>
public UndreamAI.LlamaLib.LLMAgent llmAgent { get; protected set; }
/// <summary>Current conversation history as a list of chat messages</summary>
public List<ChatMessage> chat
{
get
{
if (llmAgent == null) return new List<ChatMessage>();
// convert each UndreamAI.LlamaLib.ChatMessage to LLMUnity.ChatMessage
return llmAgent.GetHistory()
.Select(m => new ChatMessage(m))
.ToList();
}
set
{
if (llmAgent != null)
{
// convert LLMUnity.ChatMessage back to UndreamAI.LlamaLib.ChatMessage
var history = value?.Select(m => (UndreamAI.LlamaLib.ChatMessage)m).ToList()
?? new List<UndreamAI.LlamaLib.ChatMessage>();
llmAgent.SetHistory(history);
}
}
}
#endregion
#region Unity Lifecycle and Initialization
public override void Awake()
{
if (!remote) llm?.Register(this);
base.Awake();
}
protected override async Task SetupCallerObject()
{
await base.SetupCallerObject();
string exceptionMessage = "";
try
{
llmAgent = new UndreamAI.LlamaLib.LLMAgent(llmClient, systemPrompt, userRole, assistantRole);
}
catch (Exception ex)
{
exceptionMessage = ex.Message;
}
if (llmAgent == null || exceptionMessage != "")
{
string error = "LLMAgent not initialized";
if (exceptionMessage != "") error += ", error: " + exceptionMessage;
LLMUnitySetup.LogError(error, true);
}
}
/// <summary>
/// Initialisation after setting up the LLM client (local or remote).
/// </summary>
protected override async Task PostSetupCallerObject()
{
await base.PostSetupCallerObject();
if (slot != -1) llmAgent.SlotId = slot;
await InitHistory();
}
protected override void OnValidate()
{
base.OnValidate();
// Validate slot configuration
if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts))
{
LLMUnitySetup.LogError($"Slot must be between 0 and {llm.parallelPrompts - 1}, or -1 for auto-assignment");
}
}
protected override LLMLocal GetCaller()
{
return llmAgent;
}
/// <summary>
/// Initializes conversation history by clearing current state and loading from file if available.
/// </summary>
protected virtual async Task InitHistory()
{
await ClearHistory();
if (!string.IsNullOrEmpty(save) && File.Exists(GetSavePath()))
{
await LoadHistory();
}
}
#endregion
#region File Path Management
/// <summary>
/// Gets the full path for a file in the persistent data directory.
/// </summary>
/// <returns>Full file path in persistent data directory</returns>
public virtual string GetSavePath()
{
if (string.IsNullOrEmpty(save))
{
LLMUnitySetup.LogError("No save path specified");
return null;
}
return Path.Combine(Application.persistentDataPath, save).Replace('\\', '/');
}
#endregion
#region Chat Management
/// <summary>
/// Clears the entire conversation history.
/// </summary>
public virtual async Task ClearHistory()
{
await CheckCaller(checkConnection: false);
llmAgent.ClearHistory();
}
/// <summary>
/// Adds a message with a specific role to the conversation history.
/// </summary>
/// <param name="role">Message role (e.g., userRole, assistantRole, or custom role)</param>
/// <param name="content">Message content</param>
public virtual async Task AddMessage(string role, string content)
{
await CheckCaller();
llmAgent.AddMessage(role, content);
}
/// <summary>
/// Adds a user message to the conversation history.
/// </summary>
/// <param name="content">User message content</param>
public virtual async Task AddUserMessage(string content)
{
await CheckCaller();
llmAgent.AddUserMessage(content);
}
/// <summary>
/// Adds an AI assistant message to the conversation history.
/// </summary>
/// <param name="content">Assistant message content</param>
public virtual async Task AddAssistantMessage(string content)
{
await CheckCaller();
llmAgent.AddAssistantMessage(content);
}
#endregion
#region Chat Functionality
/// \cond HIDE
[Serializable]
public class CompletionResponseJson
{
public string prompt;
public string content;
}
/// \endcond
/// <summary>
/// Processes a user query asynchronously and generates an AI response using conversation context.
/// The query and response are automatically added to chat history if specified.
/// </summary>
/// <param name="query">User's message or question</param>
/// <param name="callback">Optional streaming callback for partial responses</param>
/// <param name="completionCallback">Optional callback when response is complete</param>
/// <param name="addToHistory">Whether to add the exchange to conversation history</param>
/// <returns>Task that returns the AI assistant's response</returns>
public virtual async Task<string> Chat(string query, LlamaLib.CharArrayCallback callback = null,
EmptyCallback completionCallback = null, bool addToHistory = true)
{
await CheckCaller();
// Wrap callback to ensure it runs on the main thread
LlamaLib.CharArrayCallback wrappedCallback = Utils.WrapCallbackForAsync(callback, this);
SetCompletionParameters();
string result = await llmAgent.ChatAsync(query, addToHistory, wrappedCallback, returnResponseJson: debugPrompt);
if (debugPrompt)
{
CompletionResponseJson responseJson = JsonUtility.FromJson<CompletionResponseJson>(result);
LLMUnitySetup.Log(responseJson.prompt);
result = responseJson.content;
}
if (addToHistory && result != null && save != "") _ = SaveHistory();
completionCallback?.Invoke();
return result;
}
/// <summary>
/// Warms up the model by processing the system prompt without generating output.
/// This caches the system prompt processing for faster subsequent responses.
/// </summary>
/// <param name="completionCallback">Optional callback when warmup completes</param>
/// <returns>Task that completes when warmup finishes</returns>
public virtual async Task Warmup(EmptyCallback completionCallback = null)
{
await Warmup(null, completionCallback);
}
/// <summary>
/// Warms up the model with a specific prompt without adding it to history.
/// This pre-processes prompts for faster response times in subsequent interactions.
/// </summary>
/// <param name="query">Warmup prompt (not added to history)</param>
/// <param name="completionCallback">Optional callback when warmup completes</param>
/// <returns>Task that completes when warmup finishes</returns>
public virtual async Task Warmup(string query, EmptyCallback completionCallback = null)
{
int originalNumPredict = numPredict;
try
{
// Set to generate no tokens for warmup
numPredict = 0;
await Chat(query, null, completionCallback, false);
}
finally
{
// Restore original setting
numPredict = originalNumPredict;
SetCompletionParameters();
}
}
#endregion
#region Persistence
/// <summary>
/// Saves the conversation history and optionally the LLM cache to disk.
/// </summary>
public virtual async Task SaveHistory()
{
if (string.IsNullOrEmpty(save))
{
LLMUnitySetup.LogError("No save path specified");
return;
}
await CheckCaller();
// Save chat history
string jsonPath = GetSavePath();
string directory = Path.GetDirectoryName(jsonPath);
if (!Directory.Exists(directory))
{
Directory.CreateDirectory(directory);
}
try
{
llmAgent.SaveHistory(jsonPath);
LLMUnitySetup.Log($"Saved chat history to: {jsonPath}");
}
catch (Exception ex)
{
LLMUnitySetup.LogError($"Failed to save chat history to '{jsonPath}': {ex.Message}", true);
}
}
/// <summary>
/// Loads conversation history and optionally the LLM cache from disk.
/// </summary>
public virtual async Task LoadHistory()
{
if (string.IsNullOrEmpty(save))
{
LLMUnitySetup.LogError("No save path specified");
return;
}
await CheckCaller();
// Load chat history
string jsonPath = GetSavePath();
if (!File.Exists(jsonPath))
{
LLMUnitySetup.LogError($"Chat history file not found: {jsonPath}");
}
try
{
llmAgent.LoadHistory(jsonPath);
LLMUnitySetup.Log($"Loaded chat history from: {jsonPath}");
}
catch (Exception ex)
{
LLMUnitySetup.LogError($"Failed to load chat history from '{jsonPath}': {ex.Message}", true);
}
}
#endregion
#region Request Management
/// <summary>
/// Cancels any active requests for this agent.
/// </summary>
public void CancelRequests()
{
llmAgent?.Cancel();
}
#endregion
}
public class ChatMessage : UndreamAI.LlamaLib.ChatMessage
{
public ChatMessage(string role, string content) : base(role, content) {}
public ChatMessage(UndreamAI.LlamaLib.ChatMessage other) : base(other.role, other.content) {}
}
}