Skip to content

Commit fcaa58f

Browse files
authored
nsys-jax: use pytest tmp_path[_factory] to allow saving profiles (#1748)
This aims to save all of the original `.zip`, but neither the extracted versions of them nor any derived `.zip` files (as those should be deterministically derived from the original `.zip`s)
1 parent b5ad330 commit fcaa58f

File tree

9 files changed

+54
-33
lines changed

9 files changed

+54
-33
lines changed

.github/container/nsys_jax/tests/nsys_jax_test_helpers/__init__.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@
77
import zipfile
88

99

10-
def nsys_jax_with_result(command):
10+
def nsys_jax_with_result(command, *, out_dir):
1111
"""
1212
Helper to run nsys-jax with a unique output file that will be automatically
1313
cleaned up on destruction. Explicitly returns the `subprocess.CompletedProcess`
1414
instance.
1515
"""
16-
output = tempfile.NamedTemporaryFile(suffix=".zip")
16+
output = tempfile.NamedTemporaryFile(delete=False, dir=out_dir, suffix=".zip")
1717
result = subprocess.run(
1818
["nsys-jax", "--force-overwrite", "--output", output.name] + command,
1919
)
2020
return output, result
2121

2222

23-
def nsys_jax(command):
23+
def nsys_jax(command, *, out_dir):
2424
"""
2525
Helper to run nsys-jax with a unique output file that will be automatically
2626
cleaned up on destruction. Throws if running `nsys-jax` does not succeed.
2727
"""
28-
output, result = nsys_jax_with_result(command)
28+
output, result = nsys_jax_with_result(command, out_dir=out_dir)
2929
result.check_returncode()
3030
return output
3131

@@ -42,11 +42,11 @@ def extract(archive):
4242
return tmpdir
4343

4444

45-
def nsys_jax_archive(command):
45+
def nsys_jax_archive(command, *, out_dir):
4646
"""
4747
Helper to run nsys-jax and automatically extract the output, yielding a directory.
4848
"""
49-
archive = nsys_jax(command)
49+
archive = nsys_jax(command, out_dir=out_dir)
5050
tmpdir = extract(archive)
5151
# Make sure the protobuf bindings can be imported, the generated .py will go into
5252
# a temporary directory that is not currently cleaned up. The bindings cannot be
@@ -58,13 +58,17 @@ def nsys_jax_archive(command):
5858

5959

6060
def multi_process_nsys_jax(
61-
num_processes: int, command: typing.Callable[[int], list[str]]
61+
num_processes: int,
62+
command: typing.Callable[[int], list[str]],
63+
*,
64+
out_dir,
6265
):
6366
"""
6467
Helper to run a multi-process test under nsys-jax and yield several .zip
6568
"""
6669
child_outputs = [
67-
tempfile.NamedTemporaryFile(suffix=".zip") for _ in range(num_processes)
70+
tempfile.NamedTemporaryFile(delete=False, dir=out_dir, suffix=".zip")
71+
for _ in range(num_processes)
6872
]
6973
children = [
7074
subprocess.Popen(

.github/container/nsys_jax/tests/test_basics.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
from nsys_jax_test_helpers import nsys_jax # noqa: E402
1010

1111

12-
def test_program_without_gpu_activity():
12+
def test_program_without_gpu_activity(tmp_path):
1313
"""
1414
Profiling a program that doesn't do anything should succeed.
1515
"""
16-
nsys_jax([sys.executable, "-c", "print('Hello world!')"])
16+
nsys_jax([sys.executable, "-c", "print('Hello world!')"], out_dir=tmp_path)
1717

1818

19-
def test_stacktrace_entry_with_file():
19+
def test_stacktrace_entry_with_file(tmp_path):
2020
"""
2121
Test that if a source file appears in the traceback of a JITed JAX function then
2222
the source file is bundled into the nsys-jax output archive.
@@ -27,19 +27,21 @@ def test_stacktrace_entry_with_file():
2727
src_code = "import jax\njax.jit(lambda x: x*2)(4)\n"
2828
with open(src_file, "w") as f:
2929
f.write(src_code)
30-
archive = nsys_jax([sys.executable, src_file])
30+
archive = nsys_jax([sys.executable, src_file], out_dir=tmp_path)
3131
with zipfile.ZipFile(archive) as ifile:
3232
src_file_in_archive = f"sources{src_file}"
3333
assert src_file_in_archive in ifile.namelist()
3434
with ifile.open(src_file_in_archive, "r") as archived_file:
3535
assert archived_file.read().decode() == src_code
3636

3737

38-
def test_stacktrace_entry_without_file():
38+
def test_stacktrace_entry_without_file(tmp_path):
3939
"""
4040
Test that tracing code that does not come from a named file works (bug 4931958).
4141
"""
42-
archive = nsys_jax(["python", "-c", "import jax; jax.jit(lambda x: x*2)(4)"])
42+
archive = nsys_jax(
43+
["python", "-c", "import jax; jax.jit(lambda x: x*2)(4)"], out_dir=tmp_path
44+
)
4345
with zipfile.ZipFile(archive.name) as ifile:
4446
# The combination of -c and JAX suppressing references to its own source code
4547
# should mean that no source code files are gathered.

.github/container/nsys_jax/tests/test_example_program.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212

1313
@pytest.fixture(scope="module")
14-
def profiler_data():
14+
def profiler_data(tmp_path_factory):
1515
"""
1616
Fixture that yields the loaded result of profiling example_program.py with nsys-jax.
1717
"""
1818
outdir = nsys_jax_archive(
19-
[sys.executable, os.path.join(os.path.dirname(__file__), "example_program.py")]
19+
[sys.executable, os.path.join(os.path.dirname(__file__), "example_program.py")],
20+
out_dir=tmp_path_factory.mktemp("test_example_program"),
2021
)
2122
return load_profiler_data(pathlib.Path(outdir.name))
2223

.github/container/nsys_jax/tests/test_exit_code_handling.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
@pytest.mark.parametrize("kill", kill_args.keys())
23-
def test_program_that_is_killed_by_nsys(kill):
23+
def test_program_that_is_killed_by_nsys(kill, tmp_path):
2424
"""
2525
The default --capture-range-end=stop-shutdown behaviour of nsys profile causes nsys
2626
to kill the profiled process after cudaProfilerStop if
@@ -38,12 +38,13 @@ def test_program_that_is_killed_by_nsys(kill):
3838
sys.executable,
3939
cuda_profiler_api,
4040
"sleep",
41-
]
41+
],
42+
out_dir=tmp_path,
4243
)
4344
assert os.path.isfile(output_zip.name)
4445

4546

46-
def test_program_that_fails_after_cuda_profiler_stop():
47+
def test_program_that_fails_after_cuda_profiler_stop(tmp_path):
4748
"""
4849
With --capture-range=stop then nsys-jax should still propagate a failure code from
4950
the application.
@@ -56,14 +57,17 @@ def test_program_that_fails_after_cuda_profiler_stop():
5657
sys.executable,
5758
cuda_profiler_api,
5859
"exit42",
59-
]
60+
],
61+
out_dir=tmp_path,
6062
)
6163
assert result.returncode == 42, result
6264
assert os.path.isfile(output_zip.name)
6365

6466

6567
@pytest.mark.parametrize("kill", kill_args.keys())
66-
def test_program_that_fails_after_cuda_profiler_stop_as_nsys_tries_to_kill_it(kill):
68+
def test_program_that_fails_after_cuda_profiler_stop_as_nsys_tries_to_kill_it(
69+
tmp_path, kill
70+
):
6771
"""
6872
The racy case, where either nsys sends SIGTERM or the application exits with 42.
6973
Also cover the case where --capture-range-end is not passed explicitly.
@@ -76,7 +80,8 @@ def test_program_that_fails_after_cuda_profiler_stop_as_nsys_tries_to_kill_it(ki
7680
sys.executable,
7781
cuda_profiler_api,
7882
"exit42",
79-
]
83+
],
84+
out_dir=tmp_path,
8085
)
8186
# If nsys sends SIGTERM fast enough, the child process will exit due to that and
8287
# nsys-jax will return 0 (because that is the expected result). The application

.github/container/nsys_jax/tests/test_hlo_program.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@pytest.fixture(scope="module")
21-
def profiler_data():
21+
def profiler_data(tmp_path_factory):
2222
"""
2323
Fixture that yields the loaded result of profiling example_program.py with nsys-jax.
2424
"""
@@ -27,7 +27,8 @@ def profiler_data():
2727
hlo_runner_main,
2828
f"--num_repeats={num_repeats}",
2929
os.path.join(os.path.dirname(__file__), "example.hlo"),
30-
]
30+
],
31+
out_dir=tmp_path_factory.mktemp("test_hlo_program"),
3132
)
3233
return load_profiler_data(pathlib.Path(outdir.name))
3334

.github/container/nsys_jax/tests/test_jax_nccl_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def set_env(monkeypatch):
3838

3939

4040
@pytest.mark.parametrize("collection", ["full", "partial"])
41-
def test_jax_nccl_single_process(monkeypatch, collection):
41+
def test_jax_nccl_single_process(monkeypatch, tmp_path, collection):
4242
set_env(monkeypatch)
4343
nsys_jax(
4444
capture_args(collection)
@@ -49,7 +49,8 @@ def test_jax_nccl_single_process(monkeypatch, collection):
4949
"summary",
5050
"--",
5151
"jax-nccl-test",
52-
]
52+
],
53+
out_dir=tmp_path,
5354
)
5455

5556

@@ -65,7 +66,7 @@ def test_jax_nccl_single_process(monkeypatch, collection):
6566

6667
@pytest.mark.parametrize("process_count", process_counts_to_test)
6768
@pytest.mark.parametrize("collection", ["full", "partial"])
68-
def test_jax_nccl_multi_process(monkeypatch, process_count, collection):
69+
def test_jax_nccl_multi_process(monkeypatch, tmp_path, process_count, collection):
6970
assert device_count % process_count == 0, (device_count, process_count)
7071
gpus_per_process = device_count // process_count
7172
set_env(monkeypatch)
@@ -85,6 +86,7 @@ def test_jax_nccl_multi_process(monkeypatch, process_count, collection):
8586
str(gpus_per_process),
8687
"--distributed",
8788
],
89+
out_dir=tmp_path,
8890
)
8991
combined_output = tempfile.NamedTemporaryFile(suffix=".zip")
9092
subprocess.run(

.github/container/nsys_jax/tests/test_multi_process_program.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@pytest.fixture(scope="module")
21-
def individual_results():
21+
def individual_results(tmp_path_factory):
2222
"""
2323
Fixture that yields the .zip files from individual subprocesses.
2424
"""
@@ -33,6 +33,7 @@ def individual_results():
3333
"--rank",
3434
str(rank),
3535
],
36+
out_dir=tmp_path_factory.mktemp("test_multi_progress_program"),
3637
)
3738

3839

.github/container/nsys_jax/tests/test_overlap_program.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414

1515

1616
@pytest.fixture(scope="module")
17-
def profiler_data_full():
17+
def profiler_data_full(tmp_path_factory):
1818
"""
1919
Fixture that yields the loaded result of profiling overlap_program.py with nsys-jax.
2020
"""
21-
outdir = nsys_jax_archive([sys.executable, overlap_program])
21+
outdir = nsys_jax_archive(
22+
[sys.executable, overlap_program],
23+
out_dir=tmp_path_factory.mktemp("test_overlap_program_full"),
24+
)
2225
return load_profiler_data(pathlib.Path(outdir.name))
2326

2427

2528
@pytest.fixture(scope="module")
26-
def profiler_data_narrow():
29+
def profiler_data_narrow(tmp_path_factory):
2730
"""
2831
Fixture that yields the loaded result of profiling overlap_program.py with nsys-jax.
2932
"""
@@ -34,7 +37,8 @@ def profiler_data_narrow():
3437
"--",
3538
sys.executable,
3639
overlap_program,
37-
]
40+
],
41+
out_dir=tmp_path_factory.mktemp("test_overlap_program_narrow"),
3842
)
3943
return load_profiler_data(pathlib.Path(outdir.name))
4044

.github/workflows/_ci.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ jobs:
343343
pip install pytest-reportlog nsys-jax[test]
344344
# abuse knowledge that nsys-jax is installed editable, so the tests exist
345345
test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())')
346-
pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}"
346+
pytest --basetemp=/opt/output/pytest-tmp --report-log=/opt/output/pytest-report.jsonl "${test_path}"
347347
EOF
348348
STATISTICS_SCRIPT: |
349349
summary_line=$(tail -n1 test-nsys-jax.log)
@@ -359,6 +359,7 @@ jobs:
359359
# pytest-driven part
360360
test-nsys-jax.log
361361
pytest-report.jsonl
362+
pytest-tmp/
362363
secrets: inherit
363364

364365
# test-nsys-jax generates several fresh .zip archive outputs by running nsys-jax with real GPU hardware; this test

0 commit comments

Comments
 (0)