Skip to content

Commit ce8be07

Browse files
authored
Merge pull request #134 from TaekyungHeo/nemo-env-var-fix
Pass all env vars to final command in NeMo launcher test template
2 parents d8acc30 + 166741d commit ce8be07

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@
2020

2121
from .slurm_install_strategy import NeMoLauncherSlurmInstallStrategy
2222

23-
REQUIRE_ENV_VARS = [
24-
"NCCL_SOCKET_IFNAME",
25-
"NCCL_IB_GID_INDEX",
26-
"NCCL_IB_TC",
27-
"NCCL_IB_QPS_PER_CONNECTION",
28-
"UCX_IB_GID_INDEX",
29-
"NCCL_IB_ADAPTIVE_ROUTING",
30-
"NCCL_IB_SPLIT_DATA_ON_QPS",
31-
"NCCL_IBEXT_DISABLE",
32-
]
33-
3423

3524
class NeMoLauncherSlurmCommandGenStrategy(SlurmCommandGenStrategy):
3625
"""
@@ -50,10 +39,8 @@ def gen_exec_command(
5039
num_nodes: int,
5140
nodes: List[str],
5241
) -> str:
53-
# Ensure required environment variables are included
54-
for key in REQUIRE_ENV_VARS:
55-
if key not in extra_env_vars:
56-
extra_env_vars[key] = self.slurm_system.global_env_vars[key]
42+
final_env_vars = self._override_env_vars(self.default_env_vars, env_vars)
43+
final_env_vars = self._override_env_vars(final_env_vars, extra_env_vars)
5744

5845
launcher_path = os.path.join(
5946
self.install_path,
@@ -67,7 +54,7 @@ def gen_exec_command(
6754
self.final_cmd_args["base_results_dir"] = output_path
6855
self.final_cmd_args["training.model.data.index_mapping_dir"] = output_path
6956
self.final_cmd_args["launcher_scripts_path"] = os.path.join(launcher_path, "launcher_scripts")
70-
for key, value in extra_env_vars.items():
57+
for key, value in final_env_vars.items():
7158
self.final_cmd_args[f"env_vars.{key}"] = value
7259
self.final_cmd_args["cluster.partition"] = self.slurm_system.default_partition
7360
nodes = self.slurm_system.parse_nodes(nodes)
@@ -96,7 +83,7 @@ def gen_exec_command(
9683
tokenizer_path = extra_cmd_args.split("training.model.tokenizer.model=")[1].split(" ")[0]
9784
full_cmd += " " + f"container_mounts=[{tokenizer_path}:{tokenizer_path}]"
9885

99-
env_vars_str = " ".join(f"{key}={value}" for key, value in extra_env_vars.items())
86+
env_vars_str = " ".join(f"{key}={value}" for key, value in final_env_vars.items())
10087
full_cmd = f"{env_vars_str} {full_cmd}" if env_vars_str else full_cmd
10188

10289
return full_cmd.strip()
@@ -130,13 +117,19 @@ def _generate_cmd_args_str(self, args: Dict[str, str], nodes: List[str]) -> str:
130117
Returns:
131118
str: A string of command-line arguments.
132119
"""
133-
arg_str_parts = []
120+
cmd_arg_str_parts = []
121+
env_var_str_parts = []
122+
134123
for key, value in args.items():
135-
formatted_key = f"+{key}" if key.startswith("env_vars.") else key
136-
arg_str_parts.append(f"{formatted_key}={value}")
124+
if key.startswith("env_vars."):
125+
if isinstance(value, str) and "," in value:
126+
value = f"\\'{value}\\'"
127+
env_var_str_parts.append(f"+{key}={value}")
128+
else:
129+
cmd_arg_str_parts.append(f"{key}={value}")
137130

138131
if nodes:
139132
nodes_str = ",".join(nodes)
140-
arg_str_parts.append(f"+cluster.nodelist=\\'{nodes_str}\\'")
133+
cmd_arg_str_parts.append(f"+cluster.nodelist=\\'{nodes_str}\\'")
141134

142-
return " ".join(arg_str_parts)
135+
return " ".join(cmd_arg_str_parts + env_var_str_parts)

tests/test_slurm_command_gen_strategy.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import pytest
1919
from cloudai.schema.test_template.nemo_launcher.slurm_command_gen_strategy import (
20-
REQUIRE_ENV_VARS,
2120
NeMoLauncherSlurmCommandGenStrategy,
2221
)
2322
from cloudai.systems import SlurmSystem
@@ -121,15 +120,8 @@ def nemo_cmd_gen(self, slurm_system: SlurmSystem) -> NeMoLauncherSlurmCommandGen
121120
strategy = NeMoLauncherSlurmCommandGenStrategy(slurm_system, env_vars, cmd_args)
122121
return strategy
123122

124-
def test_raises_if_required_env_var_missed(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
125-
with pytest.raises(KeyError) as exc_info:
126-
nemo_cmd_gen.gen_exec_command(
127-
env_vars={}, cmd_args={}, extra_env_vars={}, extra_cmd_args="", output_path="", num_nodes=1, nodes=[]
128-
)
129-
assert REQUIRE_ENV_VARS[0] in str(exc_info.value)
130-
131123
def test_extra_env_vars_added(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
132-
extra_env_vars = {v: "fake" for v in REQUIRE_ENV_VARS}
124+
extra_env_vars = {"TEST_VAR_1": "value1", "TEST_VAR_2": "value2"}
133125
cmd_args = {
134126
"docker_image_url": "fake",
135127
"repository_url": "fake",
@@ -148,8 +140,27 @@ def test_extra_env_vars_added(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStr
148140
for k, v in extra_env_vars.items():
149141
assert f"{k}={v}" in cmd
150142

143+
def test_env_var_escaping(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
144+
extra_env_vars = {"TEST_VAR": "value,with,commas"}
145+
cmd_args = {
146+
"docker_image_url": "fake",
147+
"repository_url": "fake",
148+
"repository_commit_hash": "fake",
149+
}
150+
cmd = nemo_cmd_gen.gen_exec_command(
151+
env_vars={},
152+
cmd_args=cmd_args,
153+
extra_env_vars=extra_env_vars,
154+
extra_cmd_args="",
155+
output_path="",
156+
num_nodes=1,
157+
nodes=[],
158+
)
159+
160+
assert "TEST_VAR=\\'value,with,commas\\'" in cmd
161+
151162
def test_tokenizer_handled(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
152-
extra_env_vars = {v: "fake" for v in REQUIRE_ENV_VARS}
163+
extra_env_vars = {"TEST_VAR_1": "value1"}
153164
cmd_args = {
154165
"docker_image_url": "fake",
155166
"repository_url": "fake",

0 commit comments

Comments
 (0)