@@ -49,16 +49,14 @@ struct job_info {
4949 gid_t gid ;
5050 uint32_t jobid ;
5151 uint32_t stepid ;
52- uint32_t taskid ;
5352 uint32_t local_task_count ;
5453 char * * environ ;
5554 char cwd [PATH_MAX ];
5655};
5756
5857struct shared_memory {
5958 pthread_mutex_t mutex ;
60- pthread_cond_t cond ;
61- bool initialized ;
59+ uint32_t init_tasks ;
6260 uint32_t started_tasks ;
6361 pid_t pid ;
6462};
@@ -79,7 +77,7 @@ static struct plugin_context context = {
7977 .log_fd = -1 ,
8078 .config = { .runtime_path = { 0 } },
8179 .args = NULL ,
82- .job = { .uid = -1 , .gid = -1 , .jobid = 0 , .stepid = 0 , .taskid = -1 , . environ = NULL , .cwd = { 0 } },
80+ .job = { .uid = -1 , .gid = -1 , .jobid = 0 , .stepid = 0 , .environ = NULL , .cwd = { 0 } },
8381 .container = { .name = NULL , .save_path = NULL , .reuse_rootfs = false, .reuse_pid = false, .temporary = false, .userns_fd = -1 , .mntns_fd = -1 , .cgroupns_fd = -1 , .cwd_fd = -1 },
8482 .user_init_rv = 0 ,
8583};
@@ -767,7 +765,6 @@ static struct shared_memory *shm_init(void)
767765{
768766 struct shared_memory * shm = NULL ;
769767 pthread_mutexattr_t mutex_attr ;
770- pthread_condattr_t cond_attr ;
771768 int ret ;
772769
773770 shm = mmap (0 , sizeof (* shm ), PROT_READ | PROT_WRITE , MAP_SHARED | MAP_ANONYMOUS , -1 , 0 );
@@ -785,24 +782,10 @@ static struct shared_memory *shm_init(void)
785782 goto fail ;
786783
787784 ret = pthread_mutex_init (& shm -> mutex , & mutex_attr );
788- pthread_mutexattr_destroy (& mutex_attr );
789785 if (ret < 0 )
790786 goto fail ;
791787
792- ret = pthread_condattr_init (& cond_attr );
793- if (ret < 0 )
794- goto fail ;
795-
796- ret = pthread_condattr_setpshared (& cond_attr , PTHREAD_PROCESS_SHARED );
797- if (ret < 0 )
798- goto fail ;
799-
800- ret = pthread_cond_init (& shm -> cond , & cond_attr );
801- pthread_condattr_destroy (& cond_attr );
802- if (ret < 0 )
803- goto fail ;
804-
805- shm -> initialized = false;
788+ shm -> init_tasks = 0 ;
806789 shm -> started_tasks = 0 ;
807790 shm -> pid = -1 ;
808791
@@ -825,10 +808,6 @@ static int shm_destroy(struct shared_memory *shm)
825808 if (ret < 0 )
826809 return (-1 );
827810
828- ret = pthread_cond_destroy (& shm -> cond );
829- if (ret < 0 )
830- return (-1 );
831-
832811 ret = munmap (shm , sizeof (* shm ));
833812 if (ret < 0 )
834813 return (-1 );
@@ -975,40 +954,35 @@ static int pytorch_setup(spank_t sp)
975954
976955static int enroot_start_once (struct container * container , struct shared_memory * shm )
977956{
978- bool created = false;
957+ int ret ;
958+ int rv = -1 ;
979959
980- if (container -> reuse_pid )
981- goto done ;
960+ pthread_mutex_lock (& shm -> mutex );
982961
983- if (container -> reuse_rootfs )
984- created = true;
962+ shm -> init_tasks += 1 ;
985963
986- pthread_mutex_lock (& shm -> mutex );
964+ /* The first task will create and/or start the enroot container */
965+ if (shm -> init_tasks == 1 ) {
966+ if (!container -> reuse_pid ) {
967+ if (!container -> reuse_rootfs ) {
968+ ret = enroot_container_create ();
969+ if (ret < 0 )
970+ goto fail ;
971+ }
987972
988- /* Local rank 0 will create and/or start the enroot container */
989- if (context .job .taskid == 0 ) {
990- if (!created ) {
991- if (enroot_container_create () == 0 )
992- created = true;
973+ shm -> pid = enroot_container_start ();
993974 }
975+ }
994976
995- if (created )
996- shm -> pid = enroot_container_start () ;
977+ if (shm -> pid < 0 )
978+ goto fail ;
997979
998- shm -> initialized = true;
999- pthread_cond_broadcast (& shm -> cond );
1000- } else {
1001- while (!shm -> initialized )
1002- pthread_cond_wait (& shm -> cond , & shm -> mutex );
1003- }
980+ rv = 0 ;
1004981
982+ fail :
1005983 pthread_mutex_unlock (& shm -> mutex );
1006984
1007- done :
1008- if (shm -> pid < 0 )
1009- return (-1 );
1010-
1011- return (0 );
985+ return (rv );
1012986}
1013987
1014988static int enroot_stop_once (struct container * container , struct shared_memory * shm )
@@ -1041,7 +1015,6 @@ static int enroot_stop_once(struct container *container, struct shared_memory *s
10411015int slurm_spank_task_init (spank_t sp , int ac , char * * av )
10421016{
10431017 int ret ;
1044- spank_err_t spank_ret ;
10451018 int rv = -1 ;
10461019
10471020 if (!context .enabled )
@@ -1055,12 +1028,6 @@ int slurm_spank_task_init(spank_t sp, int ac, char **av)
10551028 if (ret < 0 )
10561029 goto fail ;
10571030
1058- spank_ret = spank_get_item (sp , S_TASK_ID , & context .job .taskid );
1059- if (spank_ret != ESPANK_SUCCESS ) {
1060- slurm_error ("pyxis: couldn't get job task ID: %s" , spank_strerror (spank_ret ));
1061- goto fail ;
1062- }
1063-
10641031 ret = enroot_start_once (& context .container , context .shm );
10651032 if (ret < 0 ) {
10661033 slurm_error ("pyxis: couldn't start container" );
0 commit comments