|
7 | 7 | using System.Collections; |
8 | 8 | using System.IO; |
9 | 9 | using UnityEngine.TestTools; |
| 10 | +using UnityEditor; |
| 11 | +using UnityEditor.TestTools.TestRunner.Api; |
10 | 12 |
|
11 | 13 | namespace LLMUnityTests |
12 | 14 | { |
| 15 | + [InitializeOnLoad] |
| 16 | + public static class TestRunListener |
| 17 | + { |
| 18 | + static TestRunListener() |
| 19 | + { |
| 20 | + var api = ScriptableObject.CreateInstance<TestRunnerApi>(); |
| 21 | + api.RegisterCallbacks(new TestRunCallbacks()); |
| 22 | + } |
| 23 | + } |
| 24 | + |
| 25 | + public class TestRunCallbacks : ICallbacks |
| 26 | + { |
| 27 | + public void RunStarted(ITestAdaptor testsToRun){} |
| 28 | + |
| 29 | + public void RunFinished(ITestResultAdaptor result) |
| 30 | + { |
| 31 | + LLMUnitySetup.FullLlamaLib = false; |
| 32 | + } |
| 33 | + |
| 34 | + public void TestStarted(ITestAdaptor test) { |
| 35 | + LLMUnitySetup.FullLlamaLib = test.FullName.Contains("CUDA_full"); |
| 36 | + } |
| 37 | + |
| 38 | + public void TestFinished(ITestResultAdaptor result) |
| 39 | + { |
| 40 | + LLMUnitySetup.FullLlamaLib = false; |
| 41 | + } |
| 42 | + } |
| 43 | + |
13 | 44 | public class TestLLMLoraAssignment |
14 | 45 | { |
15 | 46 | [Test] |
@@ -459,6 +490,10 @@ public override LLMCharacter CreateLLMCharacter() |
459 | 490 | LLMCharacter llmCharacter = base.CreateLLMCharacter(); |
460 | 491 | llmCharacter.save = saveName; |
461 | 492 | llmCharacter.saveCache = true; |
| 493 | + foreach (string filename in new string[]{ |
| 494 | + llmCharacter.GetJsonSavePath(saveName), |
| 495 | + llmCharacter.GetCacheSavePath(saveName) |
| 496 | + }) if (File.Exists(filename)) File.Delete(filename); |
462 | 497 | return llmCharacter; |
463 | 498 | } |
464 | 499 |
|
@@ -492,4 +527,41 @@ public void TestSave() |
492 | 527 | } |
493 | 528 | } |
494 | 529 | } |
| 530 | + |
| 531 | + public class TestLLM_CUDA : TestLLM |
| 532 | + { |
| 533 | + public override LLM CreateLLM() |
| 534 | + { |
| 535 | + LLM llm = base.CreateLLM(); |
| 536 | + llm.numGPULayers = 10; |
| 537 | + return llm; |
| 538 | + } |
| 539 | + } |
| 540 | + |
| 541 | + public class TestLLM_CUDA_full : TestLLM_CUDA |
| 542 | + { |
| 543 | + public override void SetParameters() |
| 544 | + { |
| 545 | + base.SetParameters(); |
| 546 | + reply1 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster"; |
| 547 | + reply2 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster"; |
| 548 | + } |
| 549 | + } |
| 550 | + |
| 551 | + public class TestLLM_CUDA_full_attention : TestLLM_CUDA |
| 552 | + { |
| 553 | + public override LLM CreateLLM() |
| 554 | + { |
| 555 | + LLM llm = base.CreateLLM(); |
| 556 | + llm.flashAttention = true; |
| 557 | + return llm; |
| 558 | + } |
| 559 | + |
| 560 | + public override void SetParameters() |
| 561 | + { |
| 562 | + base.SetParameters(); |
| 563 | + reply1 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster"; |
| 564 | + reply2 = "To increase your meme production output, you can try using various tools and techniques to generate more memes."; |
| 565 | + } |
| 566 | + } |
495 | 567 | } |
0 commit comments