Skip to content

Commit 252bdca

Browse files
committed
Revert "Always use local rank 0 to import the container image"
This change introduced a potential deadlock when the job is cancelled during the image import. Local rank 0 would then get terminated without signalling the other local ranks. This reverts commit 6170baf. Signed-off-by: Felix Abecassis <fabecassis@nvidia.com>
1 parent 8ffbea8 commit 252bdca

File tree

1 file changed

+22
-55
lines changed

1 file changed

+22
-55
lines changed

pyxis_slurmstepd.c

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,14 @@ struct job_info {
4949
gid_t gid;
5050
uint32_t jobid;
5151
uint32_t stepid;
52-
uint32_t taskid;
5352
uint32_t local_task_count;
5453
char **environ;
5554
char cwd[PATH_MAX];
5655
};
5756

5857
struct shared_memory {
5958
pthread_mutex_t mutex;
60-
pthread_cond_t cond;
61-
bool initialized;
59+
uint32_t init_tasks;
6260
uint32_t started_tasks;
6361
pid_t pid;
6462
};
@@ -79,7 +77,7 @@ static struct plugin_context context = {
7977
.log_fd = -1,
8078
.config = { .runtime_path = { 0 } },
8179
.args = NULL,
82-
.job = { .uid = -1, .gid = -1, .jobid = 0, .stepid = 0, .taskid = -1, .environ = NULL, .cwd = { 0 } },
80+
.job = { .uid = -1, .gid = -1, .jobid = 0, .stepid = 0, .environ = NULL, .cwd = { 0 } },
8381
.container = { .name = NULL, .save_path = NULL, .reuse_rootfs = false, .reuse_pid = false, .temporary = false, .userns_fd = -1, .mntns_fd = -1, .cgroupns_fd = -1, .cwd_fd = -1 },
8482
.user_init_rv = 0,
8583
};
@@ -767,7 +765,6 @@ static struct shared_memory *shm_init(void)
767765
{
768766
struct shared_memory *shm = NULL;
769767
pthread_mutexattr_t mutex_attr;
770-
pthread_condattr_t cond_attr;
771768
int ret;
772769

773770
shm = mmap(0, sizeof(*shm), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);
@@ -785,24 +782,10 @@ static struct shared_memory *shm_init(void)
785782
goto fail;
786783

787784
ret = pthread_mutex_init(&shm->mutex, &mutex_attr);
788-
pthread_mutexattr_destroy(&mutex_attr);
789785
if (ret < 0)
790786
goto fail;
791787

792-
ret = pthread_condattr_init(&cond_attr);
793-
if (ret < 0)
794-
goto fail;
795-
796-
ret = pthread_condattr_setpshared(&cond_attr, PTHREAD_PROCESS_SHARED);
797-
if (ret < 0)
798-
goto fail;
799-
800-
ret = pthread_cond_init(&shm->cond, &cond_attr);
801-
pthread_condattr_destroy(&cond_attr);
802-
if (ret < 0)
803-
goto fail;
804-
805-
shm->initialized = false;
788+
shm->init_tasks = 0;
806789
shm->started_tasks = 0;
807790
shm->pid = -1;
808791

@@ -825,10 +808,6 @@ static int shm_destroy(struct shared_memory *shm)
825808
if (ret < 0)
826809
return (-1);
827810

828-
ret = pthread_cond_destroy(&shm->cond);
829-
if (ret < 0)
830-
return (-1);
831-
832811
ret = munmap(shm, sizeof(*shm));
833812
if (ret < 0)
834813
return (-1);
@@ -975,40 +954,35 @@ static int pytorch_setup(spank_t sp)
975954

976955
static int enroot_start_once(struct container *container, struct shared_memory *shm)
977956
{
978-
bool created = false;
957+
int ret;
958+
int rv = -1;
979959

980-
if (container->reuse_pid)
981-
goto done;
960+
pthread_mutex_lock(&shm->mutex);
982961

983-
if (container->reuse_rootfs)
984-
created = true;
962+
shm->init_tasks += 1;
985963

986-
pthread_mutex_lock(&shm->mutex);
964+
/* The first task will create and/or start the enroot container */
965+
if (shm->init_tasks == 1) {
966+
if (!container->reuse_pid) {
967+
if (!container->reuse_rootfs) {
968+
ret = enroot_container_create();
969+
if (ret < 0)
970+
goto fail;
971+
}
987972

988-
/* Local rank 0 will create and/or start the enroot container */
989-
if (context.job.taskid == 0) {
990-
if (!created) {
991-
if (enroot_container_create() == 0)
992-
created = true;
973+
shm->pid = enroot_container_start();
993974
}
975+
}
994976

995-
if (created)
996-
shm->pid = enroot_container_start();
977+
if (shm->pid < 0)
978+
goto fail;
997979

998-
shm->initialized = true;
999-
pthread_cond_broadcast(&shm->cond);
1000-
} else {
1001-
while (!shm->initialized)
1002-
pthread_cond_wait(&shm->cond, &shm->mutex);
1003-
}
980+
rv = 0;
1004981

982+
fail:
1005983
pthread_mutex_unlock(&shm->mutex);
1006984

1007-
done:
1008-
if (shm->pid < 0)
1009-
return (-1);
1010-
1011-
return (0);
985+
return (rv);
1012986
}
1013987

1014988
static int enroot_stop_once(struct container *container, struct shared_memory *shm)
@@ -1041,7 +1015,6 @@ static int enroot_stop_once(struct container *container, struct shared_memory *s
10411015
int slurm_spank_task_init(spank_t sp, int ac, char **av)
10421016
{
10431017
int ret;
1044-
spank_err_t spank_ret;
10451018
int rv = -1;
10461019

10471020
if (!context.enabled)
@@ -1055,12 +1028,6 @@ int slurm_spank_task_init(spank_t sp, int ac, char **av)
10551028
if (ret < 0)
10561029
goto fail;
10571030

1058-
spank_ret = spank_get_item(sp, S_TASK_ID, &context.job.taskid);
1059-
if (spank_ret != ESPANK_SUCCESS) {
1060-
slurm_error("pyxis: couldn't get job task ID: %s", spank_strerror(spank_ret));
1061-
goto fail;
1062-
}
1063-
10641031
ret = enroot_start_once(&context.container, context.shm);
10651032
if (ret < 0) {
10661033
slurm_error("pyxis: couldn't start container");

0 commit comments

Comments
 (0)