Skip to content

Commit b499369

Browse files
committed
review: use generateRange function for gpus
Also update volume to be added to container. This still is pending testing on AWS! Signed-off-by: vsoch <vsoch@users.noreply.github.com>
1 parent b9cee29 commit b499369

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

pkg/constants/constants.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ const (
116116

117117
// Ensure MPI has full memory of the host
118118
FluxMemoryVolumeName = "shared-memory"
119+
FluxMemoryVolumePath = "/dev/shm"
119120

120121
// emptyDir volume using for complete spack view software
121122
FluxSpackViewVolumeName = "spack-install"

pkg/runtime/framework/plugins/flux/flux.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ func (f *Flux) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) e
201201
*corev1ac.VolumeMount().WithName(constants.FluxSpackViewVolumeName).WithMountPath(constants.FluxSpackViewVolumePath),
202202
*corev1ac.VolumeMount().WithName(configMapName).WithMountPath(constants.FluxConfigVolumeName).WithReadOnly(true),
203203
*corev1ac.VolumeMount().WithName(constants.FluxCurveVolumeName).WithMountPath(constants.FluxCurveVolumePath).WithReadOnly(true),
204+
*corev1ac.VolumeMount().WithName(constants.FluxMemoryVolumeName).WithMountPath(constants.FluxMemoryVolumePath).WithReadOnly(true),
204205
)
205206
}
206207
}
@@ -429,12 +430,8 @@ func (f *Flux) generateFluxEntrypoint(trainJob *trainer.TrainJob, info *runtime.
429430
Rspec := fmt.Sprintf("--cores=0-%d", tasks-1)
430431
if gpus > 0 {
431432
flags = fmt.Sprintf("%s -g %d", flags, gpus)
432-
gpus = gpus - 1
433-
if gpus == 0 {
434-
Rspec = fmt.Sprintf("%s --gpu=0", Rspec)
435-
} else {
436-
Rspec = fmt.Sprintf("%s --gpu=0-%d", Rspec, gpus)
437-
}
433+
gpuSpec := generateRange(int32(gpus), 0)
434+
Rspec = fmt.Sprintf("%s --gpu=%s", Rspec, gpuSpec)
438435
}
439436
return fmt.Sprintf(entrypointTemplate, Rspec, mainHost, flags)
440437
}

0 commit comments

Comments
 (0)