Skip to content

Commit 65b8d8d

Browse files
committed
Use host network by default
+ make ssh port usage more reliable
1 parent 26f5ece commit 65b8d8d

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/cloudai/workloads/nccl_test/kubernetes_json_gen_strategy.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class NcclTestKubernetesJsonGenStrategy(JsonGenStrategy):
3131
This strategy generates an MPIJob configuration for running NCCL tests.
3232
"""
3333

34+
@property
35+
def ssh_port(self) -> int:
36+
return 2222
37+
3438
def gen_json(self) -> dict[Any, Any]:
3539
k8s_system = cast(KubernetesSystem, self.system)
3640
job_name = self.sanitize_k8s_job_name(self.test_run.name)
@@ -68,6 +72,7 @@ def _create_launcher_spec(self) -> dict[str, Any]:
6872
"replicas": 1,
6973
"template": {
7074
"spec": {
75+
"hostNetwork": True,
7176
"containers": [
7277
{
7378
"image": self.container_url,
@@ -90,10 +95,12 @@ def _create_worker_spec(self) -> dict[str, Any]:
9095
"replicas": self.test_run.nnodes,
9196
"template": {
9297
"spec": {
98+
"hostNetwork": True,
9399
"containers": [
94100
{
95101
"image": self.container_url,
96102
"name": "nccl-test-worker",
103+
"ports": [{"containerPort": self.ssh_port, "name": "ssh"}],
97104
"imagePullPolicy": "IfNotPresent",
98105
"securityContext": {"privileged": True},
99106
"env": self._generate_env_list(env_vars),
@@ -116,8 +123,8 @@ def _generate_worker_command(self) -> str:
116123
117124
If the SSH daemon is not installed, it will be installed and the SSH keys will be generated.
118125
"""
119-
return """
120-
set -e
126+
return f"""
127+
set -ex
121128
if ! command -v sshd &> /dev/null; then
122129
apt-get update && apt-get install -y --no-install-recommends openssh-server
123130
fi
@@ -126,9 +133,11 @@ def _generate_worker_command(self) -> str:
126133
PermitRootLogin yes
127134
PubkeyAuthentication yes
128135
StrictModes no
136+
Port {self.ssh_port}
129137
EOF
130138
ssh-keygen -A
131-
exec /usr/sbin/sshd -D
139+
service ssh restart
140+
sleep infinity
132141
""".strip()
133142

134143
def _get_merged_env_vars(self) -> dict[str, str | list[str]]:
@@ -154,8 +163,8 @@ def _generate_mpi_args(self) -> List[str]:
154163
mpi_args = [
155164
f"-np {total_processes}",
156165
"-bind-to none",
157-
# Disable strict host key checking for SSH
158-
"-mca plm_rsh_args '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'",
166+
# Disable strict host key checking for SSH and ensure correct port is used
167+
f"-mca plm_rsh_args '-p {self.ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'",
159168
]
160169

161170
return mpi_args

0 commit comments

Comments
 (0)