Skip to content

Commit b1103a0

Browse files
authored
nsys-jax: bugfix and expanded testing (#1132)
- Fix for profiling traced code that is not attributed to a named file. - More test coverage. - Cleanup `nsys-jax` handling of `-o` and `-f` options.
1 parent 1dad010 commit b1103a0

File tree

10 files changed

+264
-89
lines changed

10 files changed

+264
-89
lines changed

.github/container/Dockerfile.base

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,20 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
200200
## Add helper scripts for profiling with Nsight Systems
201201
##
202202
## The scripts saved to /opt/jax_nsys are embedded in the output archives
203-
## written by nsys-jax, while the nsys-jax wrapper is used inside the container.
203+
## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are
204+
## only used inside the containers.
204205
###############################################################################
205-
206206
ADD nsys-jax nsys-jax-combine /usr/local/bin/
207207
ADD jax_nsys/ /opt/jax_nsys
208+
# The jax_nsys package should be installed inside the containers, so nsys-jax
209+
# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container
210+
# environment, without an extra layer of virtual environment indirection.
208211
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
212+
# This should be embedded in output archives and be runnable inside containers
209213
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/
214+
# Should be available for execution inside the containers, should not be
215+
# embedded in the output archives.
216+
ADD jax_nsys_tests/ /opt/jax_nsys_tests
210217

211218
###############################################################################
212219
## Copy manifest file to the container

.github/container/jax_nsys/python/jax_nsys/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ dependencies = [
1111
"uncertainties", # communication analysis recipe
1212
]
1313
requires-python = ">= 3.10"
14+
[project.optional-dependencies]
15+
test = [
16+
"pytest"
17+
]

.github/container/jax_nsys/python/jax_nsys_analysis/summary.py

100755100644
File mode changed.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import jax
2+
3+
4+
@jax.jit
5+
def distinctively_named_function(x):
6+
return x @ x.T
7+
8+
9+
square = jax.random.normal(jax.random.key(1), (32, 32))
10+
for _ in range(5):
11+
square = distinctively_named_function(square)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import subprocess
2+
import tempfile
3+
4+
5+
def nsys_jax(command):
6+
"""
7+
Helper to run nsys-jax with a unique output file that will be automatically
8+
cleaned up on destruction.
9+
"""
10+
output = tempfile.NamedTemporaryFile(suffix=".zip")
11+
subprocess.run(
12+
["nsys-jax", "--force-overwrite", "--output", output.name] + command, check=True
13+
)
14+
return output
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import subprocess
3+
import sys
4+
import tempfile
5+
import zipfile
6+
7+
helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers")
8+
if helper_dir not in sys.path:
9+
sys.path.insert(0, helper_dir)
10+
from jax_nsys_test_helpers import nsys_jax # noqa: E402
11+
12+
13+
def test_program_without_gpu_activity():
14+
"""
15+
Profiling a program that doesn't do anything should succeed.
16+
"""
17+
nsys_jax([sys.executable, "-c", "print('Hello world!')"])
18+
19+
20+
def test_stacktrace_entry_with_file():
21+
"""
22+
Test that if a source file appears in the traceback of a JITed JAX function then
23+
the source file is bundled into the nsys-jax output archive.
24+
"""
25+
with tempfile.TemporaryDirectory() as tmpdir:
26+
archive = f"{tmpdir}/out.zip"
27+
src_file = f"{tmpdir}/test.py"
28+
assert os.path.isabs(src_file), src_file
29+
src_code = "import jax\njax.jit(lambda x: x*2)(4)\n"
30+
with open(src_file, "w") as f:
31+
f.write(src_code)
32+
subprocess.run(
33+
["nsys-jax", "--output", archive, sys.executable, src_file], check=True
34+
)
35+
with zipfile.ZipFile(archive) as ifile:
36+
src_file_in_archive = f"sources{src_file}"
37+
assert src_file_in_archive in ifile.namelist()
38+
with ifile.open(src_file_in_archive, "r") as archived_file:
39+
assert archived_file.read().decode() == src_code
40+
41+
42+
def test_stacktrace_entry_without_file():
43+
"""
44+
Test that tracing code that does not come from a named file works (bug 4931958).
45+
"""
46+
archive = nsys_jax(["python", "-c", "import jax; jax.jit(lambda x: x*2)(4)"])
47+
with zipfile.ZipFile(archive.name) as ifile:
48+
# The combination of -c and JAX suppressing references to its own source code
49+
# should mean that no source code files are gathered.
50+
assert not any(x.startswith("sources/") for x in ifile.namelist())
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from jax_nsys import (
2+
ensure_compiled_protos_are_importable,
3+
load_profiler_data,
4+
)
5+
import os
6+
import pathlib
7+
import pytest # type: ignore
8+
import sys
9+
import tempfile
10+
import zipfile
11+
12+
helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers")
13+
if helper_dir not in sys.path:
14+
sys.path.insert(0, helper_dir)
15+
from jax_nsys_test_helpers import nsys_jax # noqa: E402
16+
17+
18+
@pytest.fixture(scope="module")
19+
def example_program():
20+
"""
21+
Fixture that yields an extracted archive of the result of profiling
22+
example_program.py with nsys-jax.
23+
"""
24+
tmpdir = tempfile.TemporaryDirectory()
25+
archive = nsys_jax(
26+
[sys.executable, os.path.join(os.path.dirname(__file__), "example_program.py")]
27+
)
28+
old_dir = os.getcwd()
29+
os.chdir(tmpdir.name)
30+
try:
31+
with zipfile.ZipFile(archive) as zf:
32+
zf.extractall()
33+
finally:
34+
os.chdir(old_dir)
35+
# Make sure the protobuf bindings can be imported, the generated .py will go into
36+
# a temporary directory that is not currently cleaned up. The bindings cannot be
37+
# un-imported from the test process, so there is a tacit assumption that in a given
38+
# test session there will only be one set of .proto files and it doesn't matter
39+
# which ones are picked up.
40+
ensure_compiled_protos_are_importable(prefix=pathlib.Path(tmpdir.name))
41+
return tmpdir
42+
43+
44+
@pytest.fixture(scope="module")
45+
def profiler_data(example_program):
46+
return load_profiler_data(pathlib.Path(example_program.name))
47+
48+
49+
def test_comms(profiler_data):
50+
# example_program.py should contain no communication
51+
assert len(profiler_data.communication) == 0
52+
53+
54+
def test_modules(profiler_data):
55+
test_func_mask = profiler_data.module["Name"] == "jit_distinctively_named_function"
56+
assert sum(test_func_mask) == 5
57+
test_func_data = profiler_data.module[test_func_mask]
58+
assert test_func_data.index.names == ["ProgramId", "ProgramExecution", "Device"]
59+
# All executions should have the same program id
60+
program_ids = test_func_data.index.get_level_values("ProgramId")
61+
assert all(program_ids == program_ids[0])
62+
# All executions should be on device 0
63+
execution_devices = test_func_data.index.get_level_values("Device")
64+
assert all(execution_devices == 0)
65+
# Execution indices should count from 0..n-1
66+
execution_indices = test_func_data.index.get_level_values("ProgramExecution")
67+
assert all(execution_indices == range(len(test_func_data)))

0 commit comments

Comments
 (0)