@@ -17,6 +17,7 @@ limitations under the License.
1717package core
1818
1919import (
20+ "fmt"
2021 "testing"
2122
2223 . "github.com/onsi/gomega"
@@ -31,17 +32,17 @@ import (
3132)
3233
3334func TestPytorchjobWithSFTtrainerFinetuning (t * testing.T ) {
34- runPytorchjobWithSFTtrainer (t , "config.json" )
35+ runPytorchjobWithSFTtrainer (t , "config.json" , 0 )
3536}
3637
3738func TestPytorchjobWithSFTtrainerLoRa (t * testing.T ) {
38- runPytorchjobWithSFTtrainer (t , "config_lora.json" )
39+ runPytorchjobWithSFTtrainer (t , "config_lora.json" , 0 )
3940}
4041func TestPytorchjobWithSFTtrainerQLoRa (t * testing.T ) {
41- runPytorchjobWithSFTtrainer (t , "config_qlora.json" )
42+ runPytorchjobWithSFTtrainer (t , "config_qlora.json" , 1 )
4243}
4344
44- func runPytorchjobWithSFTtrainer (t * testing.T , modelConfigFile string ) {
45+ func runPytorchjobWithSFTtrainer (t * testing.T , modelConfigFile string , numGpus int ) {
4546 test := With (t )
4647
4748 // Create a namespace
@@ -61,7 +62,7 @@ func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string) {
6162 NamespaceSelector : & metav1.LabelSelector {},
6263 ResourceGroups : []kueuev1beta1.ResourceGroup {
6364 {
64- CoveredResources : []corev1.ResourceName {corev1 .ResourceName ("cpu" ), corev1 .ResourceName ("memory" )},
65+ CoveredResources : []corev1.ResourceName {corev1 .ResourceName ("cpu" ), corev1 .ResourceName ("memory" ), corev1 . ResourceName ( "nvidia.com/gpu" ) },
6566 Flavors : []kueuev1beta1.FlavorQuotas {
6667 {
6768 Name : kueuev1beta1 .ResourceFlavorReference (resourceFlavor .Name ),
@@ -74,6 +75,10 @@ func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string) {
7475 Name : corev1 .ResourceMemory ,
7576 NominalQuota : resource .MustParse ("12Gi" ),
7677 },
78+ {
79+ Name : corev1 .ResourceName ("nvidia.com/gpu" ),
80+ NominalQuota : resource .MustParse (fmt .Sprint (numGpus )),
81+ },
7782 },
7883 },
7984 },
@@ -85,7 +90,7 @@ func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string) {
8590 localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
8691
8792 // Create training PyTorch job
88- tuningJob := createPyTorchJob (test , namespace .Name , localQueue .Name , * config )
93+ tuningJob := createPyTorchJob (test , namespace .Name , localQueue .Name , * config , numGpus )
8994
9095 // Make sure the Kueue Workload is admitted
9196 test .Eventually (KueueWorkloads (test , namespace .Name ), TestTimeoutLong ).
@@ -149,14 +154,14 @@ func TestPytorchjobUsingKueueQuota(t *testing.T) {
149154 localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name , AsDefaultQueue )
150155
151156 // Create first training PyTorch job
152- tuningJob := createPyTorchJob (test , namespace .Name , localQueue .Name , * config )
157+ tuningJob := createPyTorchJob (test , namespace .Name , localQueue .Name , * config , 0 )
153158
154159 // Make sure the PyTorch job is running
155160 test .Eventually (PytorchJob (test , namespace .Name , tuningJob .Name ), TestTimeoutLong ).
156161 Should (WithTransform (PytorchJobConditionRunning , Equal (corev1 .ConditionTrue )))
157162
158163 // Create second training PyTorch job
159- secondTuningJob := createPyTorchJob (test , namespace .Name , localQueue .Name , * config )
164+ secondTuningJob := createPyTorchJob (test , namespace .Name , localQueue .Name , * config , 0 )
160165
161166 // Make sure the second PyTorch job is suspended, waiting for first job to finish
162167 test .Eventually (PytorchJob (test , namespace .Name , secondTuningJob .Name ), TestTimeoutShort ).
@@ -175,7 +180,7 @@ func TestPytorchjobUsingKueueQuota(t *testing.T) {
175180 test .T ().Logf ("PytorchJob %s/%s ran successfully" , secondTuningJob .Namespace , secondTuningJob .Name )
176181}
177182
178- func createPyTorchJob (test Test , namespace , localQueueName string , config corev1.ConfigMap ) * kftov1.PyTorchJob {
183+ func createPyTorchJob (test Test , namespace , localQueueName string , config corev1.ConfigMap , numGpus int ) * kftov1.PyTorchJob {
179184 tuningJob := & kftov1.PyTorchJob {
180185 TypeMeta : metav1.TypeMeta {
181186 APIVersion : corev1 .SchemeGroupVersion .String (),
@@ -194,6 +199,12 @@ func createPyTorchJob(test Test, namespace, localQueueName string, config corev1
194199 RestartPolicy : "OnFailure" ,
195200 Template : corev1.PodTemplateSpec {
196201 Spec : corev1.PodSpec {
202+ Tolerations : []corev1.Toleration {
203+ {
204+ Key : "nvidia.com/gpu" ,
205+ Operator : corev1 .TolerationOpExists ,
206+ },
207+ },
197208 InitContainers : []corev1.Container {
198209 {
199210 Name : "copy-model" ,
@@ -238,10 +249,12 @@ func createPyTorchJob(test Test, namespace, localQueueName string, config corev1
238249 Requests : corev1.ResourceList {
239250 corev1 .ResourceCPU : resource .MustParse ("2" ),
240251 corev1 .ResourceMemory : resource .MustParse ("7Gi" ),
252+ "nvidia.com/gpu" : resource .MustParse (fmt .Sprint (numGpus )),
241253 },
242254 Limits : corev1.ResourceList {
243255 corev1 .ResourceCPU : resource .MustParse ("2" ),
244256 corev1 .ResourceMemory : resource .MustParse ("7Gi" ),
257+ "nvidia.com/gpu" : resource .MustParse (fmt .Sprint (numGpus )),
245258 },
246259 },
247260 SecurityContext : & corev1.SecurityContext {
0 commit comments