3535from jax ._src .lib import _profiler
3636from jax ._src .lib import _profile_data
3737
38+ # TODO: remove messy fallback for set_metadata / clear_metadata
39+ if not hasattr (_profiler , "set_metadata" ):
40+ _profiler .set_metadata = lambda key , value : None
41+ if not hasattr (_profiler , "clear_metadata" ):
42+ _profiler .clear_metadata = lambda : None
43+
3844ProfileData = _profile_data .ProfileData
3945ProfileEvent = _profile_data .ProfileEvent
4046ProfilePlane = _profile_data .ProfilePlane
@@ -138,14 +144,24 @@ def start_trace(
138144 # session. Otherwise on Cloud TPU, libtpu may not be initialized before
139145 # creating the tracer, which will cause the TPU tracer initialization to
140146 # fail and no TPU operations will be included in the profile.
141- xla_bridge .get_backend ()
142-
143- if profiler_options is None :
144- _profile_state .profile_session = _profiler .ProfilerSession ()
145- else :
146- _profile_state .profile_session = _profiler .ProfilerSession (
147- profiler_options
148- )
147+ client = xla_bridge .get_backend ()
148+
149+ options = profiler_options
150+ if options is None :
151+ options = ProfileOptions ()
152+ _profiler .clear_metadata ()
153+ try :
154+ from jax ._src .lib import version as version_lib
155+ version = "." .join (map (str , version_lib ))
156+ _profiler .set_metadata ("jax_version" , version )
157+ jaxlib_version = version
158+ if client .platform == "tpu" :
159+ jaxlib_version += f" ({ client .platform_version } )"
160+ _profiler .set_metadata ("jaxlib_version" , jaxlib_version )
161+ except (ImportError , AttributeError ):
162+ _profiler .set_metadata ("jax_version" , "unknown" )
163+ _profiler .set_metadata ("jaxlib_version" , "unknown" )
164+ _profile_state .profile_session = _profiler .ProfilerSession (options )
149165 _profile_state .create_perfetto_link = create_perfetto_link
150166 _profile_state .create_perfetto_trace = (
151167 create_perfetto_trace or create_perfetto_link )
@@ -226,6 +242,7 @@ def stop_trace():
226242 if _profile_state .create_perfetto_link :
227243 _host_perfetto_trace_file (abs_filename )
228244 _profile_state .reset ()
245+ _profiler .clear_metadata ()
229246
230247
231248def stop_and_get_fdo_profile () -> bytes | str :
@@ -240,6 +257,7 @@ def stop_and_get_fdo_profile() -> bytes | str:
240257 xspace = _profile_state .profile_session .stop ()
241258 fdo_profile = _profiler .get_fdo_profile (xspace )
242259 _profile_state .reset ()
260+ _profiler .clear_metadata ()
243261 return fdo_profile
244262
245263
0 commit comments