Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/cloudai/workloads/nccl_test/kubernetes_json_gen_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -31,6 +31,10 @@ class NcclTestKubernetesJsonGenStrategy(JsonGenStrategy):
This strategy generates an MPIJob configuration for running NCCL tests.
"""

@property
def ssh_port(self) -> int:
return 2222

def gen_json(self) -> dict[Any, Any]:
k8s_system = cast(KubernetesSystem, self.system)
job_name = self.sanitize_k8s_job_name(self.test_run.name)
Expand Down Expand Up @@ -68,6 +72,7 @@ def _create_launcher_spec(self) -> dict[str, Any]:
"replicas": 1,
"template": {
"spec": {
"hostNetwork": True,
"containers": [
{
"image": self.container_url,
Expand All @@ -90,10 +95,12 @@ def _create_worker_spec(self) -> dict[str, Any]:
"replicas": self.test_run.nnodes,
"template": {
"spec": {
"hostNetwork": True,
"containers": [
{
"image": self.container_url,
"name": "nccl-test-worker",
"ports": [{"containerPort": self.ssh_port, "name": "ssh"}],
"imagePullPolicy": "IfNotPresent",
"securityContext": {"privileged": True},
"env": self._generate_env_list(env_vars),
Expand All @@ -116,8 +123,8 @@ def _generate_worker_command(self) -> str:

If the SSH daemon is not installed, it will be installed and the SSH keys will be generated.
"""
return """
set -e
return f"""
set -ex
if ! command -v sshd &> /dev/null; then
apt-get update && apt-get install -y --no-install-recommends openssh-server
fi
Expand All @@ -126,9 +133,11 @@ def _generate_worker_command(self) -> str:
PermitRootLogin yes
PubkeyAuthentication yes
StrictModes no
Port {self.ssh_port}
EOF
ssh-keygen -A
exec /usr/sbin/sshd -D
service ssh restart
sleep infinity
""".strip()

def _get_merged_env_vars(self) -> dict[str, str | list[str]]:
Expand All @@ -154,8 +163,8 @@ def _generate_mpi_args(self) -> List[str]:
mpi_args = [
f"-np {total_processes}",
"-bind-to none",
# Disable strict host key checking for SSH
"-mca plm_rsh_args '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'",
# Disable strict host key checking for SSH and ensure correct port is used
f"-mca plm_rsh_args '-p {self.ssh_port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'",
]

return mpi_args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -39,7 +39,9 @@ def _parse_data_rows(file: TextIOWrapper) -> List[List[str]]:
for line in file:
line: str = line.strip()
if re.match(r"^\d", line):
parsed_data_rows.append(re.split(r"\s+", line))
parts = re.split(r"\s+", line)
if len(parts) == 13:
parsed_data_rows.append(parts)
return parsed_data_rows


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -141,7 +141,11 @@ def test_launcher_command_generation(self, test_run_with_extra_args: TestRun, k8
assert "mpirun" in launcher_args
assert f"-np {test_run_with_extra_args.nnodes * k8s_system.gpus_per_node}" in launcher_args
assert "-bind-to none" in launcher_args
assert "-mca plm_rsh_args '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'" in launcher_args
assert (
f"-mca plm_rsh_args '-p {json_gen_strategy.ssh_port}"
+ " -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'"
in launcher_args
)
assert nccl.cmd_args.subtest_name in launcher_args
assert f"--nthreads {nccl.cmd_args.nthreads}" in launcher_args
assert f"--ngpus {nccl.cmd_args.ngpus}" in launcher_args