Skip to content

Commit ff743cf

Browse files
Add JAX version metadata to xspace
PiperOrigin-RevId: 840138208
1 parent 3648d95 commit ff743cf

File tree

3 files changed

+83
-8
lines changed

3 files changed

+83
-8
lines changed

jax/_src/lib/_profiler.pyi

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any
2+
3+
class ProfilerServer: ...
4+
def start_server(port: int) -> ProfilerServer: ...
5+
6+
def register_plugin_profiler(c_api: Any) -> None: ...
7+
8+
def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ...
9+
def get_instructions_profile(tensorboard_dir: str) -> list[tuple[str, float]]: ...
10+
def get_fdo_profile(
11+
xspace: bytes, as_textproto: bool = ...
12+
) -> bytes | str: ...
13+
14+
class ProfilerSession:
15+
def __init__(self, options: ProfileOptions | None = ...) -> None: ...
16+
def stop(self) -> bytes: ...
17+
def stop_and_export(self, tensorboard_dir: str) -> None: ...
18+
def export(self, xspace: bytes, tensorboard_dir: str) -> None:...
19+
20+
class ProfileOptions:
21+
include_dataset_ops: bool
22+
host_tracer_level: int
23+
python_tracer_level: int
24+
enable_hlo_proto: bool
25+
start_timestamp_ns: int
26+
duration_ms: int
27+
repository_path: str
28+
raise_error_on_start_failure: bool
29+
advanced_configuration: dict[str, Any]
30+
session_id: str
31+
32+
def aggregate_profiled_instructions(profiles: list[bytes], percentile: int) -> str: ...
33+
34+
class TraceMe:
35+
def __init__(self, name: str, **kwargs: Any) -> None: ...
36+
def __enter__(self) -> TraceMe: ...
37+
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
38+
def set_metadata(self, **kwargs: Any) -> None: ...
39+
@staticmethod
40+
def is_enabled() -> bool: ...
41+
42+
def set_metadata(key: str, value: str) -> None: ...
43+
def clear_metadata() -> None: ...

jax/_src/profiler.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,23 @@ def start_trace(
138138
# session. Otherwise on Cloud TPU, libtpu may not be initialized before
139139
# creating the tracer, which will cause the TPU tracer initialization to
140140
# fail and no TPU operations will be included in the profile.
141-
xla_bridge.get_backend()
142-
143-
if profiler_options is None:
144-
_profile_state.profile_session = _profiler.ProfilerSession()
145-
else:
146-
_profile_state.profile_session = _profiler.ProfilerSession(
147-
profiler_options
148-
)
141+
client = xla_bridge.get_backend()
142+
143+
options = profiler_options
144+
if options is None:
145+
options = ProfileOptions()
146+
_profiler.clear_metadata() # pytype: disable=module-attr
147+
try:
148+
import jax # type: ignore
149+
_profiler.set_metadata("jax_version", jax.__version__) # pytype: disable=module-attr
150+
jaxlib_version = jax.lib.version.__version__
151+
if client.platform == "tpu":
152+
jaxlib_version += f" ({client.platform_version})"
153+
_profiler.set_metadata("jaxlib_version", jaxlib_version) # pytype: disable=module-attr
154+
except (ImportError, AttributeError):
155+
_profiler.set_metadata("jax_version", "unknown") # pytype: disable=module-attr
156+
_profiler.set_metadata("jaxlib_version", "unknown") # pytype: disable=module-attr
157+
_profile_state.profile_session = _profiler.ProfilerSession(options)
149158
_profile_state.create_perfetto_link = create_perfetto_link
150159
_profile_state.create_perfetto_trace = (
151160
create_perfetto_trace or create_perfetto_link)
@@ -226,6 +235,7 @@ def stop_trace():
226235
if _profile_state.create_perfetto_link:
227236
_host_perfetto_trace_file(abs_filename)
228237
_profile_state.reset()
238+
_profiler.clear_metadata() # pytype: disable=module-attr
229239

230240

231241
def stop_and_get_fdo_profile() -> bytes | str:
@@ -240,6 +250,7 @@ def stop_and_get_fdo_profile() -> bytes | str:
240250
xspace = _profile_state.profile_session.stop()
241251
fdo_profile = _profiler.get_fdo_profile(xspace)
242252
_profile_state.reset()
253+
_profiler.clear_metadata() # pytype: disable=module-attr
243254
return fdo_profile
244255

245256

tests/profiler_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,27 @@ def on_profile(port, logdir, worker_start):
437437
thread_worker.join(120)
438438
self._check_xspace_pb_exist(logdir)
439439

440+
def testDeviceVersionSavedToMetadata(self):
441+
with tempfile.TemporaryDirectory() as tmpdir_string:
442+
tmpdir = pathlib.Path(tmpdir_string)
443+
with jax.profiler.trace(tmpdir):
444+
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
445+
jnp.ones(jax.local_device_count()))
446+
447+
proto_path = tuple(tmpdir.rglob("*.xplane.pb"))
448+
self.assertEqual(len(proto_path), 1)
449+
(proto_file,) = proto_path
450+
proto = proto_file.read_bytes()
451+
452+
self.assertIn(b"jax_version", proto)
453+
self.assertIn(b"jaxlib_version", proto)
454+
if jtu.test_device_matches(["tpu"]):
455+
self.assertIn(b"libtpu_version", proto)
456+
if jtu.test_device_matches(["gpu"]):
457+
self.assertIn(b"cuda_version", proto)
458+
self.assertIn(b"cuda_runtime_version", proto)
459+
self.assertIn(b"cuda_driver_version", proto)
460+
440461
@unittest.skipIf(
441462
not (portpicker and _pywrap_profiler_plugin),
442463
"Test requires xprof and portpicker")

0 commit comments

Comments
 (0)