Skip to content

Commit 9c4e7a1

Browse files
committed
Keep old nsys logic for Jax tests
1 parent 816a706 commit 9c4e7a1

File tree

4 files changed

+20
-32
lines changed

4 files changed

+20
-32
lines changed

src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pathlib import Path
1818
from typing import Any, Dict, List, Union, cast
1919

20-
from cloudai import NsysConfiguration, TestRun
20+
from cloudai import TestRun
2121
from cloudai.systems import SlurmSystem
2222
from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy
2323
from cloudai.test_definitions.gpt import GPTTestDefinition
@@ -43,10 +43,6 @@ def _container_mounts(self, tr: TestRun) -> list[str]:
4343

4444
return mounts
4545

46-
def gen_nsys_command(self, tr: TestRun) -> list[str]:
47-
"""NSYS profiling is in the run.sh script and disabled for srun level."""
48-
return []
49-
5046
def gen_exec_command(self, tr: TestRun) -> str:
5147
self.test_name = self._extract_test_name(tr.test.cmd_args)
5248
self._update_env_vars(tr)
@@ -286,22 +282,19 @@ def _generate_python_command(
286282
if stage == "profile":
287283
python_command += " >> /opt/paxml/workspace/profile_stderr_${SLURM_PROCID}.txt 2>&1"
288284

289-
nsys = NsysConfiguration(
290-
enable=True,
291-
nsys_binary="nsys",
292-
task="profile",
293-
sample="none",
294-
output=f"/opt/paxml/workspace/nsys_profile_{stage}",
295-
force_overwrite=True,
296-
capture_range="cudaProfilerApi",
297-
capture_range_end="stop",
298-
cuda_graph_trace="node",
285+
nsys_command = (
286+
"nsys profile \\\n"
287+
" -s none \\\n"
288+
f" -o /opt/paxml/workspace/nsys_profile_{stage} \\\n"
289+
" --force-overwrite true \\\n"
290+
" --capture-range=cudaProfilerApi \\\n"
291+
" --capture-range-end=stop \\\n"
292+
" --cuda-graph-trace=node \\\n"
299293
)
300-
nsys_command = " \\\n ".join(nsys.cmd_args) + " "
301294

302295
slurm_check = (
303296
'if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then\n'
304-
f" {nsys_command}\\\n {python_command}\n"
297+
f" {nsys_command} {python_command}\n"
305298
"else\n"
306299
f" {python_command}\n"
307300
"fi"

tests/ref_data/gpt.run

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ export PER_GPU_COMBINE_THRESHOLD=0
55
export XLA_FLAGS="--xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD"
66

77
if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then
8-
nsys \
9-
profile \
8+
nsys profile \
109
-s none \
1110
-o /opt/paxml/workspace/nsys_profile_profile \
12-
--force-overwrite=true \
11+
--force-overwrite true \
1312
--capture-range=cudaProfilerApi \
1413
--capture-range-end=stop \
1514
--cuda-graph-trace=node \
@@ -73,11 +72,10 @@ export PER_GPU_COMBINE_THRESHOLD=0
7372
export XLA_FLAGS="--xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD"
7473

7574
if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then
76-
nsys \
77-
profile \
75+
nsys profile \
7876
-s none \
7977
-o /opt/paxml/workspace/nsys_profile_perf \
80-
--force-overwrite=true \
78+
--force-overwrite true \
8179
--capture-range=cudaProfilerApi \
8280
--capture-range-end=stop \
8381
--cuda-graph-trace=node \

tests/ref_data/grok.run

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ export PER_GPU_COMBINE_THRESHOLD=0
55
export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization --xla_dump_hlo_pass_re=.* --xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_disable_async_collectives=ALLREDUCE,ALLGATHER,REDUCESCATTER,COLLECTIVEBROADCAST,ALLTOALL,COLLECTIVEPERMUTE --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_graph_level=0 --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD --xla_gpu_run_post_layout_collective_pipeliner=false"
66

77
if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then
8-
nsys \
9-
profile \
8+
nsys profile \
109
-s none \
1110
-o /opt/paxml/workspace/nsys_profile_profile \
12-
--force-overwrite=true \
11+
--force-overwrite true \
1312
--capture-range=cudaProfilerApi \
1413
--capture-range-end=stop \
1514
--cuda-graph-trace=node \
@@ -97,11 +96,10 @@ export PER_GPU_COMBINE_THRESHOLD=0
9796
export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization --xla_dump_hlo_pass_re=.* --xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_graph_level=0 --xla_gpu_pgle_profile_file_or_directory_path=/opt/paxml/workspace/pgle_output_profile.pbtxt --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD --xla_gpu_run_post_layout_collective_pipeliner=false --xla_gpu_use_memcpy_local_p2p=false"
9897

9998
if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then
100-
nsys \
101-
profile \
99+
nsys profile \
102100
-s none \
103101
-o /opt/paxml/workspace/nsys_profile_perf \
104-
--force-overwrite=true \
102+
--force-overwrite true \
105103
--capture-range=cudaProfilerApi \
106104
--capture-range-end=stop \
107105
--cuda-graph-trace=node \

tests/slurm_command_gen_strategy/test_jax_toolbox_slurm_command_gen_strategy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,10 @@ def test_generate_python_command(
196196

197197
assert python_cli == [
198198
'if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then',
199-
" nsys \\",
200-
" profile \\",
199+
" nsys profile \\",
201200
" -s none \\",
202201
f" -o /opt/paxml/workspace/nsys_profile_{stage} \\",
203-
" --force-overwrite=true \\",
202+
" --force-overwrite true \\",
204203
" --capture-range=cudaProfilerApi \\",
205204
" --capture-range-end=stop \\",
206205
" --cuda-graph-trace=node \\",

0 commit comments

Comments
 (0)