Skip to content

Commit ce54b79

Browse files
Add JAX version metadata to xspace
PiperOrigin-RevId: 840138208
1 parent 49561f6 commit ce54b79

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

jax/_src/lib/_profiler.pyi

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any
2+
3+
class ProfilerServer: ...
4+
def start_server(port: int) -> ProfilerServer: ...
5+
6+
def register_plugin_profiler(c_api: Any) -> None: ...
7+
8+
def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ...
9+
def get_instructions_profile(tensorboard_dir: str) -> list[tuple[str, float]]: ...
10+
def get_fdo_profile(
11+
xspace: bytes, as_textproto: bool = ...
12+
) -> bytes | str: ...
13+
14+
class ProfilerSession:
15+
def __init__(self, options: ProfileOptions | None = ...) -> None: ...
16+
def stop(self) -> bytes: ...
17+
def stop_and_export(self, tensorboard_dir: str) -> None: ...
18+
def export(self, xspace: bytes, tensorboard_dir: str) -> None:...
19+
20+
class ProfileOptions:
21+
include_dataset_ops: bool
22+
host_tracer_level: int
23+
python_tracer_level: int
24+
enable_hlo_proto: bool
25+
start_timestamp_ns: int
26+
duration_ms: int
27+
repository_path: str
28+
raise_error_on_start_failure: bool
29+
advanced_configuration: dict[str, Any]
30+
session_id: str
31+
32+
def aggregate_profiled_instructions(profiles: list[bytes], percentile: int) -> str: ...
33+
34+
class TraceMe:
35+
def __init__(self, name: str, **kwargs: Any) -> None: ...
36+
def __enter__(self) -> TraceMe: ...
37+
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
38+
def set_metadata(self, **kwargs: Any) -> None: ...
39+
@staticmethod
40+
def is_enabled() -> bool: ...
41+
42+
def set_metadata(key: str, value: str) -> None: ...
43+
def clear_metadata() -> None: ...

jax/_src/profiler.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
from jax._src.lib import _profiler
3636
from jax._src.lib import _profile_data
3737

38+
# TODO: remove messy fallback for set_metadata / clear_metadata
39+
if not hasattr(_profiler, "set_metadata"):
40+
_profiler.set_metadata = lambda key, value: None
41+
if not hasattr(_profiler, "clear_metadata"):
42+
_profiler.clear_metadata = lambda: None
43+
3844
ProfileData = _profile_data.ProfileData
3945
ProfileEvent = _profile_data.ProfileEvent
4046
ProfilePlane = _profile_data.ProfilePlane
@@ -138,14 +144,24 @@ def start_trace(
138144
# session. Otherwise on Cloud TPU, libtpu may not be initialized before
139145
# creating the tracer, which will cause the TPU tracer initialization to
140146
# fail and no TPU operations will be included in the profile.
141-
xla_bridge.get_backend()
142-
143-
if profiler_options is None:
144-
_profile_state.profile_session = _profiler.ProfilerSession()
145-
else:
146-
_profile_state.profile_session = _profiler.ProfilerSession(
147-
profiler_options
148-
)
147+
client = xla_bridge.get_backend()
148+
149+
options = profiler_options
150+
if options is None:
151+
options = ProfileOptions()
152+
_profiler.clear_metadata()
153+
try:
154+
from jax._src.lib import version as version_lib
155+
version = ".".join(map(str, version_lib))
156+
_profiler.set_metadata("jax_version", version)
157+
jaxlib_version = version
158+
if client.platform == "tpu":
159+
jaxlib_version += f" ({client.platform_version})"
160+
_profiler.set_metadata("jaxlib_version", jaxlib_version)
161+
except (ImportError, AttributeError):
162+
_profiler.set_metadata("jax_version", "unknown")
163+
_profiler.set_metadata("jaxlib_version", "unknown")
164+
_profile_state.profile_session = _profiler.ProfilerSession(options)
149165
_profile_state.create_perfetto_link = create_perfetto_link
150166
_profile_state.create_perfetto_trace = (
151167
create_perfetto_trace or create_perfetto_link)
@@ -226,6 +242,7 @@ def stop_trace():
226242
if _profile_state.create_perfetto_link:
227243
_host_perfetto_trace_file(abs_filename)
228244
_profile_state.reset()
245+
_profiler.clear_metadata()
229246

230247

231248
def stop_and_get_fdo_profile() -> bytes | str:
@@ -240,6 +257,7 @@ def stop_and_get_fdo_profile() -> bytes | str:
240257
xspace = _profile_state.profile_session.stop()
241258
fdo_profile = _profiler.get_fdo_profile(xspace)
242259
_profile_state.reset()
260+
_profiler.clear_metadata()
243261
return fdo_profile
244262

245263

0 commit comments

Comments
 (0)