Skip to content

Commit b14f614

Browse files
committed
CUDA / full / fa tests
1 parent 418d70f commit b14f614

1 file changed

Lines changed: 72 additions & 0 deletions

File tree

Tests/Runtime/TestLLM.cs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,40 @@
77
using System.Collections;
88
using System.IO;
99
using UnityEngine.TestTools;
10+
using UnityEditor;
11+
using UnityEditor.TestTools.TestRunner.Api;
1012

1113
namespace LLMUnityTests
1214
{
15+
[InitializeOnLoad]
16+
public static class TestRunListener
17+
{
18+
static TestRunListener()
19+
{
20+
var api = ScriptableObject.CreateInstance<TestRunnerApi>();
21+
api.RegisterCallbacks(new TestRunCallbacks());
22+
}
23+
}
24+
25+
public class TestRunCallbacks : ICallbacks
26+
{
27+
public void RunStarted(ITestAdaptor testsToRun){}
28+
29+
public void RunFinished(ITestResultAdaptor result)
30+
{
31+
LLMUnitySetup.FullLlamaLib = false;
32+
}
33+
34+
public void TestStarted(ITestAdaptor test) {
35+
LLMUnitySetup.FullLlamaLib = test.FullName.Contains("CUDA_full");
36+
}
37+
38+
public void TestFinished(ITestResultAdaptor result)
39+
{
40+
LLMUnitySetup.FullLlamaLib = false;
41+
}
42+
}
43+
1344
public class TestLLMLoraAssignment
1445
{
1546
[Test]
@@ -459,6 +490,10 @@ public override LLMCharacter CreateLLMCharacter()
459490
LLMCharacter llmCharacter = base.CreateLLMCharacter();
460491
llmCharacter.save = saveName;
461492
llmCharacter.saveCache = true;
493+
foreach (string filename in new string[]{
494+
llmCharacter.GetJsonSavePath(saveName),
495+
llmCharacter.GetCacheSavePath(saveName)
496+
}) if (File.Exists(filename)) File.Delete(filename);
462497
return llmCharacter;
463498
}
464499

@@ -492,4 +527,41 @@ public void TestSave()
492527
}
493528
}
494529
}
530+
531+
public class TestLLM_CUDA : TestLLM
532+
{
533+
public override LLM CreateLLM()
534+
{
535+
LLM llm = base.CreateLLM();
536+
llm.numGPULayers = 10;
537+
return llm;
538+
}
539+
}
540+
541+
public class TestLLM_CUDA_full : TestLLM_CUDA
542+
{
543+
public override void SetParameters()
544+
{
545+
base.SetParameters();
546+
reply1 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster";
547+
reply2 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster";
548+
}
549+
}
550+
551+
public class TestLLM_CUDA_full_attention : TestLLM_CUDA
552+
{
553+
public override LLM CreateLLM()
554+
{
555+
LLM llm = base.CreateLLM();
556+
llm.flashAttention = true;
557+
return llm;
558+
}
559+
560+
public override void SetParameters()
561+
{
562+
base.SetParameters();
563+
reply1 = "To increase your meme production output, you might consider using more advanced tools and techniques to generate memes faster";
564+
reply2 = "To increase your meme production output, you can try using various tools and techniques to generate more memes.";
565+
}
566+
}
495567
}

0 commit comments

Comments
 (0)