Skip to content

Commit 9cce1c2

Browse files
committed
add lora test
1 parent f51d09f commit 9cce1c2

File tree

1 file changed

+119
-38
lines changed

1 file changed

+119
-38
lines changed

Tests/Runtime/TestLLM.cs

+119-38
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public void TestLLMLorasAssign()
2323
string lora2Rel = "test/lala";
2424
string lora2 = LLMUnitySetup.GetAssetPath(lora2Rel);
2525
LLMUnitySetup.CreateEmptyFile(lora1);
26+
Directory.CreateDirectory(Path.GetDirectoryName(lora2));
2627
LLMUnitySetup.CreateEmptyFile(lora2);
2728

2829
llm.AddLora(lora1);
@@ -63,14 +64,19 @@ public void TestLLMLorasAssign()
6364

6465
public class TestLLM
6566
{
66-
protected static string modelUrl = "https://huggingface.co/afrideva/smol_llama-220M-openhermes-GGUF/resolve/main/smol_llama-220m-openhermes.q4_k_m.gguf?download=true";
67+
protected static string modelUrl = "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true";
6768
protected string modelNameLLManager;
6869

6970
protected GameObject gameObject;
7071
protected LLM llm;
7172
protected LLMCharacter llmCharacter;
7273
Exception error = null;
73-
string prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.";
74+
protected string prompt;
75+
protected string query;
76+
protected string reply1;
77+
protected string reply2;
78+
protected int tokens1;
79+
protected int tokens2;
7480

7581

7682
public TestLLM()
@@ -81,14 +87,30 @@ public TestLLM()
8187

8288
public virtual async Task Init()
8389
{
84-
modelNameLLManager = await LLMManager.DownloadModel(modelUrl);
90+
SetParameters();
91+
await DownloadModels();
8592
gameObject = new GameObject();
8693
gameObject.SetActive(false);
87-
SetLLM();
88-
SetLLMCharacter();
94+
llm = CreateLLM();
95+
llmCharacter = CreateLLMCharacter();
8996
gameObject.SetActive(true);
9097
}
9198

99+
public virtual void SetParameters()
100+
{
101+
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.";
102+
query = "How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.";
103+
reply1 = "To increase your meme production/output, you can try using more modern tools and techniques. For instance,";
104+
reply2 = "To increase your meme production/output, you can try the following strategies:\n\n1. Use a meme generator";
105+
tokens1 = 32;
106+
tokens2 = 9;
107+
}
108+
109+
public virtual async Task DownloadModels()
110+
{
111+
modelNameLLManager = await LLMManager.DownloadModel(modelUrl);
112+
}
113+
92114
[Test]
93115
public void TestGetLLMManagerAssetRuntime()
94116
{
@@ -144,17 +166,17 @@ public void TestGetLLMManagerAssetEditor()
144166
File.Delete(path);
145167
}
146168

147-
public virtual void SetLLM()
169+
public virtual LLM CreateLLM()
148170
{
149-
llm = gameObject.AddComponent<LLM>();
171+
LLM llm = gameObject.AddComponent<LLM>();
150172
llm.SetModel(modelNameLLManager);
151173
llm.parallelPrompts = 1;
152-
llm.SetTemplate("alpaca");
174+
return llm;
153175
}
154176

155-
public virtual void SetLLMCharacter()
177+
public virtual LLMCharacter CreateLLMCharacter()
156178
{
157-
llmCharacter = gameObject.AddComponent<LLMCharacter>();
179+
LLMCharacter llmCharacter = gameObject.AddComponent<LLMCharacter>();
158180
llmCharacter.llm = llm;
159181
llmCharacter.playerName = "Instruction";
160182
llmCharacter.AIName = "Response";
@@ -163,6 +185,7 @@ public virtual void SetLLMCharacter()
163185
llmCharacter.seed = 0;
164186
llmCharacter.stream = false;
165187
llmCharacter.numPredict = 20;
188+
return llmCharacter;
166189
}
167190

168191
[UnityTest]
@@ -183,26 +206,22 @@ public async Task RunTestsTask()
183206
error = null;
184207
try
185208
{
186-
// await llm.WaitUntilReady();
187-
188-
// llm.Awake();
189-
// llmCharacter.Awake();
190209
await llmCharacter.Tokenize("I", TestTokens);
191210
await llmCharacter.Warmup();
192-
TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 1);
211+
TestInitParameters(tokens1, 1);
193212
TestWarmup();
194-
await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat);
213+
await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1));
195214
TestPostChat(3);
196215
llmCharacter.SetPrompt(llmCharacter.prompt);
197216
llmCharacter.AIName = "False response";
198-
await llmCharacter.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat2);
217+
await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2));
199218
TestPostChat(3);
200219
await llmCharacter.Chat("bye!");
201220
TestPostChat(5);
202221
prompt = "How are you?";
203222
llmCharacter.SetPrompt(prompt);
204223
await llmCharacter.Chat("hi");
205-
TestInitParameters((await llmCharacter.Tokenize(prompt)).Count + 2, 3);
224+
TestInitParameters(tokens2, 3);
206225
List<float> embeddings = await llmCharacter.Embeddings("hi how are you?");
207226
TestEmbeddings(embeddings);
208227
llm.OnDestroy();
@@ -222,24 +241,17 @@ public void TestInitParameters(int nkeep, int chats)
222241

223242
public void TestTokens(List<int> tokens)
224243
{
225-
Assert.AreEqual(tokens, new List<int> {306});
244+
Assert.AreEqual(tokens, new List<int> {40});
226245
}
227246

228247
public void TestWarmup()
229248
{
230249
Assert.That(llmCharacter.chat.Count == 1);
231250
}
232251

233-
public void TestChat(string reply)
234-
{
235-
string AIReply = "One way to increase your meme production/output is by creating a more complex and customized";
236-
Assert.That(reply.Trim() == AIReply);
237-
}
238-
239-
public void TestChat2(string reply)
252+
public void TestChat(string reply, string replyGT)
240253
{
241-
string AIReply = "One possible solution is to use a more advanced natural language processing library like NLTK or sp";
242-
Assert.That(reply.Trim() == AIReply);
254+
Assert.That(reply.Trim() == replyGT);
243255
}
244256

245257
public void TestPostChat(int num)
@@ -249,40 +261,40 @@ public void TestPostChat(int num)
249261

250262
public void TestEmbeddings(List<float> embeddings)
251263
{
252-
Assert.That(embeddings.Count == 1024);
264+
Assert.That(embeddings.Count == 896);
253265
}
254266

255267
public virtual void OnDestroy() {}
256268
}
257269

258270
public class TestLLM_LLMManager_Load : TestLLM
259271
{
260-
public override void SetLLM()
272+
public override LLM CreateLLM()
261273
{
262-
llm = gameObject.AddComponent<LLM>();
274+
LLM llm = gameObject.AddComponent<LLM>();
263275
string filename = Path.GetFileName(modelUrl).Split("?")[0];
264276
string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
265277
filename = LLMManager.LoadModel(sourcePath);
266278
llm.SetModel(filename);
267279
llm.parallelPrompts = 1;
268-
llm.SetTemplate("alpaca");
280+
return llm;
269281
}
270282
}
271283

272284
public class TestLLM_StreamingAssets_Load : TestLLM
273285
{
274286
string loadPath;
275287

276-
public override void SetLLM()
288+
public override LLM CreateLLM()
277289
{
278-
llm = gameObject.AddComponent<LLM>();
290+
LLM llm = gameObject.AddComponent<LLM>();
279291
string filename = Path.GetFileName(modelUrl).Split("?")[0];
280292
string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
281293
loadPath = LLMUnitySetup.GetAssetPath(filename);
282294
if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath);
283295
llm.SetModel(loadPath);
284296
llm.parallelPrompts = 1;
285-
llm.SetTemplate("alpaca");
297+
return llm;
286298
}
287299

288300
public override void OnDestroy()
@@ -293,14 +305,83 @@ public override void OnDestroy()
293305

294306
public class TestLLM_SetModel_Warning : TestLLM
295307
{
296-
public override void SetLLM()
308+
public override LLM CreateLLM()
297309
{
298-
llm = gameObject.AddComponent<LLM>();
310+
LLM llm = gameObject.AddComponent<LLM>();
299311
string filename = Path.GetFileName(modelUrl).Split("?")[0];
300312
string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename);
301313
llm.SetModel(loadPath);
302314
llm.parallelPrompts = 1;
303-
llm.SetTemplate("alpaca");
315+
return llm;
316+
}
317+
}
318+
319+
public class TestLLM_NoLora : TestLLM
320+
{
321+
public override void SetParameters()
322+
{
323+
prompt = "";
324+
query = "кто ты?";
325+
reply1 = "Я - искусственный интеллект, который помогаю вам с информацией и задачами";
326+
reply2 = "I'm sorry, but I didn't understand your request. Could you please provide more information or clarify";
327+
tokens1 = 5;
328+
tokens2 = 9;
329+
}
330+
}
331+
332+
public class TestLLM_Lora : TestLLM
333+
{
334+
string loraUrl = "https://huggingface.co/undreamer/Qwen2-0.5B-Instruct-ru-lora/resolve/main/Qwen2-0.5B-Instruct-ru-lora.gguf?download=true";
335+
string loraNameLLManager;
336+
337+
public override async Task DownloadModels()
338+
{
339+
await base.DownloadModels();
340+
loraNameLLManager = await LLMManager.DownloadLora(loraUrl);
341+
}
342+
343+
public override LLM CreateLLM()
344+
{
345+
LLM llm = base.CreateLLM();
346+
llm.AddLora(loraNameLLManager);
347+
return llm;
348+
}
349+
350+
public override void SetParameters()
351+
{
352+
prompt = "";
353+
query = "кто ты?";
354+
reply1 = "Я - искусственный интеллект, созданный для помощи и общения с людьми";
355+
reply2 = "Идиот";
356+
tokens1 = 5;
357+
tokens2 = 9;
358+
}
359+
360+
[Test]
361+
public void TestModelPaths()
362+
{
363+
Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(modelUrl).Split("?")[0]));
364+
Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0]));
365+
}
366+
}
367+
368+
369+
public class TestLLM_Double : TestLLM
370+
{
371+
LLM llm1;
372+
LLMCharacter lLMCharacter1;
373+
374+
public override async Task Init()
375+
{
376+
SetParameters();
377+
await DownloadModels();
378+
gameObject = new GameObject();
379+
gameObject.SetActive(false);
380+
llm = CreateLLM();
381+
llmCharacter = CreateLLMCharacter();
382+
llm1 = CreateLLM();
383+
lLMCharacter1 = CreateLLMCharacter();
384+
gameObject.SetActive(true);
304385
}
305386
}
306387
}

0 commit comments

Comments
 (0)