Skip to content

Commit 70a62a3

Browse files
committed
Improve tokenizer path handling in NeMo Launcher Slurm strategy
1 parent ce8be07 commit 70a62a3

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,18 @@ def gen_exec_command(
7878
full_cmd = f"python {launcher_path}/launcher_scripts/main.py {cmd_args_str}"
7979

8080
if extra_cmd_args:
81-
full_cmd += " " + extra_cmd_args
82-
if "training.model.tokenizer.model" in extra_cmd_args:
83-
tokenizer_path = extra_cmd_args.split("training.model.tokenizer.model=")[1].split(" ")[0]
84-
full_cmd += " " + f"container_mounts=[{tokenizer_path}:{tokenizer_path}]"
81+
full_cmd += f" {extra_cmd_args}"
82+
tokenizer_key = "training.model.tokenizer.model="
83+
if tokenizer_key in extra_cmd_args:
84+
tokenizer_path = extra_cmd_args.split(tokenizer_key, 1)[1].split(" ", 1)[0]
85+
if not os.path.isfile(tokenizer_path):
86+
raise ValueError(
87+
f"The provided tokenizer path '{tokenizer_path}' is not valid. "
88+
"Please review the test schema file to ensure the tokenizer path is correct. "
89+
"If it contains a placeholder value, refer to USER_GUIDE.md to download the tokenizer "
90+
"and update the schema file accordingly."
91+
)
92+
full_cmd += f" container_mounts=[{tokenizer_path}:{tokenizer_path}]"
8593

8694
env_vars_str = " ".join(f"{key}={value}" for key, value in final_env_vars.items())
8795
full_cmd = f"{env_vars_str} {full_cmd}" if env_vars_str else full_cmd

tests/test_slurm_command_gen_strategy.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,54 @@ def test_env_var_escaping(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrateg
159159

160160
assert "TEST_VAR=\\'value,with,commas\\'" in cmd
161161

162-
def test_tokenizer_handled(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
162+
def test_tokenizer_handled(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy, tmp_path: Path):
163163
extra_env_vars = {"TEST_VAR_1": "value1"}
164164
cmd_args = {
165165
"docker_image_url": "fake",
166166
"repository_url": "fake",
167167
"repository_commit_hash": "fake",
168168
}
169+
tokenizer_path = tmp_path / "tokenizer"
170+
tokenizer_path.touch()
171+
169172
cmd = nemo_cmd_gen.gen_exec_command(
170173
env_vars={},
171174
cmd_args=cmd_args,
172175
extra_env_vars=extra_env_vars,
173-
extra_cmd_args="training.model.tokenizer.model=value",
176+
extra_cmd_args=f"training.model.tokenizer.model={tokenizer_path}",
174177
output_path="",
175178
num_nodes=1,
176179
nodes=[],
177180
)
178181

179-
assert "container_mounts=[value:value]" in cmd
182+
assert f"container_mounts=[{tokenizer_path}:{tokenizer_path}]" in cmd
183+
184+
def test_invalid_tokenizer_path(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
185+
extra_env_vars = {"TEST_VAR_1": "value1"}
186+
cmd_args = {
187+
"docker_image_url": "fake",
188+
"repository_url": "fake",
189+
"repository_commit_hash": "fake",
190+
}
191+
invalid_tokenizer_path = "/invalid/path/to/tokenizer"
192+
193+
with pytest.raises(
194+
ValueError,
195+
match=(
196+
r"The provided tokenizer path '/invalid/path/to/tokenizer' is not valid. Please review the test "
197+
r"schema file to ensure the tokenizer path is correct. If it contains a placeholder value, refer to "
198+
r"USER_GUIDE.md to download the tokenizer and update the schema file accordingly."
199+
),
200+
):
201+
nemo_cmd_gen.gen_exec_command(
202+
env_vars={},
203+
cmd_args=cmd_args,
204+
extra_env_vars=extra_env_vars,
205+
extra_cmd_args=f"training.model.tokenizer.model={invalid_tokenizer_path}",
206+
output_path="",
207+
num_nodes=1,
208+
nodes=[],
209+
)
180210

181211

182212
class TestWriteSbatchScript:

0 commit comments

Comments
 (0)