1616
1717from typing import Any , Dict , List , Union , cast
1818
19+ import yaml
20+
1921from cloudai .core import JsonGenStrategy
22+ from cloudai .systems .kubernetes import KubernetesSystem
2023
2124from .nccl import NCCLTestDefinition
2225
2326
2427class NcclTestKubernetesJsonGenStrategy (JsonGenStrategy ):
25- """JSON generation strategy for NCCL tests on Kubernetes systems."""
28+ """
29+ JSON generation strategy for NCCL tests on Kubernetes systems.
2630
27- SSH_PORT : int = 2222
31+ This strategy generates an MPIJob configuration for running NCCL tests.
32+ """
2833
2934 def gen_json (self ) -> dict [Any , Any ]:
30- return {
35+ k8s_system = cast (KubernetesSystem , self .system )
36+ job_name = self .sanitize_k8s_job_name (self .test_run .name )
37+
38+ deployment = {
3139 "apiVersion" : "kubeflow.org/v2beta1" ,
3240 "kind" : "MPIJob" ,
3341 "metadata" : {
34- "name" : self .sanitize_k8s_job_name ("nccl-test" ),
42+ "name" : job_name ,
43+ "namespace" : k8s_system .default_namespace ,
3544 },
3645 "spec" : {
37- "slotsPerWorker" : 1 ,
46+ "slotsPerWorker" : k8s_system . gpus_per_node ,
3847 "runPolicy" : {"cleanPodPolicy" : "Running" },
3948 "mpiReplicaSpecs" : {
4049 "Launcher" : self ._create_launcher_spec (),
@@ -43,47 +52,53 @@ def gen_json(self) -> dict[Any, Any]:
4352 },
4453 }
4554
46- def _create_launcher_spec (self ) -> dict [str , Any ]:
55+ with open (self .test_run .output_path / "deployment.yaml" , "w" ) as f :
56+ yaml .dump (deployment , f )
57+
58+ return deployment
59+
60+ @property
61+ def container_url (self ) -> str :
4762 tdef : NCCLTestDefinition = cast (NCCLTestDefinition , self .test_run .test )
63+ return tdef .cmd_args .docker_image_url .replace ("#" , "/" )
64+
65+ def _create_launcher_spec (self ) -> dict [str , Any ]:
4866 env_vars = self ._get_merged_env_vars ()
4967 return {
5068 "replicas" : 1 ,
5169 "template" : {
5270 "spec" : {
53- "hostNetwork" : True ,
5471 "containers" : [
5572 {
56- "image" : tdef . cmd_args . docker_image_url ,
73+ "image" : self . container_url ,
5774 "name" : "nccl-test-launcher" ,
5875 "imagePullPolicy" : "IfNotPresent" ,
5976 "securityContext" : {"privileged" : True },
6077 "env" : self ._generate_env_list (env_vars ),
6178 "command" : ["/bin/bash" , "-c" ],
62- "args" : [self ._generate_launcher_command (env_vars )],
79+ "args" : [self ._generate_launcher_command ()],
80+ "resources" : self ._prepare_launcher_resources (),
6381 }
6482 ],
6583 },
6684 },
6785 }
6886
6987 def _create_worker_spec (self ) -> dict [str , Any ]:
70- tdef : NCCLTestDefinition = cast (NCCLTestDefinition , self .test_run .test )
7188 env_vars = self ._get_merged_env_vars ()
7289 return {
73- "replicas" : self .test_run .num_nodes ,
90+ "replicas" : self .test_run .nnodes ,
7491 "template" : {
7592 "spec" : {
76- "hostNetwork" : True ,
7793 "containers" : [
7894 {
79- "image" : tdef . cmd_args . docker_image_url ,
95+ "image" : self . container_url ,
8096 "name" : "nccl-test-worker" ,
8197 "imagePullPolicy" : "IfNotPresent" ,
8298 "securityContext" : {"privileged" : True },
83- "ports" : [{"containerPort" : self .SSH_PORT , "name" : "ssh" }],
8499 "env" : self ._generate_env_list (env_vars ),
85- "command" : ["/bin/bash" ],
86- "args" : ["-c" , f"/usr/sbin/sshd -p { self .SSH_PORT } ; sleep infinity" ],
100+ "command" : ["/bin/bash" , "-c" ],
101+ "args" : [self ._generate_worker_command () ],
87102 "resources" : self ._prepare_worker_resources (),
88103 "volumeMounts" : [
89104 {"mountPath" : "/dev/shm" , "name" : "dev-shm" },
@@ -95,31 +110,54 @@ def _create_worker_spec(self) -> dict[str, Any]:
95110 },
96111 }
97112
113+ def _generate_worker_command (self ) -> str :
114+ """
115+ Generate command for worker pods that starts the SSH daemon.
116+
117+ If the SSH daemon is not installed, it will be installed and the SSH keys will be generated.
118+ """
119+ return """
120+ set -e
121+ if ! command -v sshd &> /dev/null; then
122+ apt-get update && apt-get install -y --no-install-recommends openssh-server
123+ fi
124+ mkdir -p /var/run/sshd
125+ cat >> /etc/ssh/sshd_config << EOF
126+ PermitRootLogin yes
127+ PubkeyAuthentication yes
128+ StrictModes no
129+ EOF
130+ ssh-keygen -A
131+ exec /usr/sbin/sshd -D
132+ """ .strip ()
133+
98134 def _get_merged_env_vars (self ) -> dict [str , str | list [str ]]:
99135 final_env_vars = self .system .global_env_vars .copy ()
100136 final_env_vars .update (self .test_run .test .extra_env_vars )
101137 return final_env_vars
102138
103139 def _generate_env_list (self , env_vars : Dict [str , Union [str , List [str ]]]) -> List [Dict [str , str ]]:
104- env_list = [{"name" : "OMPI_ALLOW_RUN_AS_ROOT" , "value" : "1" }]
140+ env_list = [
141+ {"name" : "OMPI_ALLOW_RUN_AS_ROOT" , "value" : "1" },
142+ {"name" : "OMPI_ALLOW_RUN_AS_ROOT_CONFIRM" , "value" : "1" },
143+ ]
105144 for key , value in env_vars .items ():
106145 if isinstance (value , list ):
107146 value = "," .join (value )
108147 env_list .append ({"name" : key , "value" : value })
109148 return env_list
110149
111- def _generate_mpi_args (self , env_vars : Dict [str , Union [str , List [str ]]]) -> List [str ]:
150+ def _generate_mpi_args (self ) -> List [str ]:
151+ k8s_system = cast (KubernetesSystem , self .system )
152+ total_processes = self .test_run .nnodes * k8s_system .gpus_per_node
153+
112154 mpi_args = [
113- "--allow-run-as-root" ,
114- f"--mca plm_rsh_args '-p { self .SSH_PORT } '" ,
115- "-c 2" ,
116- "-bind-to none -map-by slot" ,
117- "-mca btl tcp,self" ,
155+ f"-np { total_processes } " ,
156+ "-bind-to none" ,
157+ # Disable strict host key checking for SSH
158+ "-mca plm_rsh_args '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'" ,
118159 ]
119160
120- if "NCCL_SOCKET_IFNAME" in env_vars :
121- mpi_args .append (f"-mca btl_tcp_if_include { env_vars ['NCCL_SOCKET_IFNAME' ]} " )
122-
123161 return mpi_args
124162
125163 def _generate_nccl_args (self , cmd_args_dict : Dict [str , Any ]) -> List [str ]:
@@ -136,7 +174,7 @@ def _generate_extra_args(self, extra_cmd_args: Dict[str, str]) -> List[str]:
136174 extra_args .append (f"{ key } { value } " if value else key )
137175 return extra_args
138176
139- def _generate_launcher_command (self , env_vars : dict [ str , str | list [ str ]] ) -> str :
177+ def _generate_launcher_command (self ) -> str :
140178 tdef : NCCLTestDefinition = cast (NCCLTestDefinition , self .test_run .test )
141179 tdef_cmd_args = tdef .cmd_args
142180
@@ -146,7 +184,7 @@ def _generate_launcher_command(self, env_vars: dict[str, str | list[str]]) -> st
146184
147185 command_parts = [
148186 "mpirun" ,
149- " " .join (self ._generate_mpi_args (env_vars )),
187+ " " .join (self ._generate_mpi_args ()),
150188 tdef_cmd_args .subtest_name ,
151189 " " .join (self ._generate_nccl_args (cmd_args_dict )),
152190 ]
@@ -163,4 +201,6 @@ def _prepare_launcher_resources(self) -> Dict[str, Dict[str, str]]:
163201 }
164202
165203 def _prepare_worker_resources (self ) -> Dict [str , Dict [str , str ]]:
166- return {"requests" : {"nvidia.com/gpu" : "8" }, "limits" : {"nvidia.com/gpu" : "8" }}
204+ k8s_system = cast (KubernetesSystem , self .system )
205+ gpu_count = str (k8s_system .gpus_per_node )
206+ return {"requests" : {"nvidia.com/gpu" : gpu_count }, "limits" : {"nvidia.com/gpu" : gpu_count }}
0 commit comments