@@ -138,14 +138,23 @@ def start_trace(
138138 # session. Otherwise on Cloud TPU, libtpu may not be initialized before
139139 # creating the tracer, which will cause the TPU tracer initialization to
140140 # fail and no TPU operations will be included in the profile.
141- xla_bridge .get_backend ()
142-
143- if profiler_options is None :
144- _profile_state .profile_session = _profiler .ProfilerSession ()
145- else :
146- _profile_state .profile_session = _profiler .ProfilerSession (
147- profiler_options
148- )
141+ client = xla_bridge .get_backend ()
142+
143+ options = profiler_options
144+ if options is None :
145+ options = ProfileOptions ()
146+ _profiler .clear_metadata () # pytype: disable=module-attr
147+ try :
148+ import jax # type: ignore
149+ _profiler .set_metadata ("jax_version" , jax .__version__ ) # pytype: disable=module-attr
150+ jaxlib_version = jax .lib .version .__version__
151+ if client .platform == "tpu" :
152+ jaxlib_version += f" ({ client .platform_version } )"
153+ _profiler .set_metadata ("jaxlib_version" , jaxlib_version ) # pytype: disable=module-attr
154+ except (ImportError , AttributeError ):
155+ _profiler .set_metadata ("jax_version" , "unknown" ) # pytype: disable=module-attr
156+ _profiler .set_metadata ("jaxlib_version" , "unknown" ) # pytype: disable=module-attr
157+ _profile_state .profile_session = _profiler .ProfilerSession (options )
149158 _profile_state .create_perfetto_link = create_perfetto_link
150159 _profile_state .create_perfetto_trace = (
151160 create_perfetto_trace or create_perfetto_link )
@@ -226,6 +235,7 @@ def stop_trace():
226235 if _profile_state .create_perfetto_link :
227236 _host_perfetto_trace_file (abs_filename )
228237 _profile_state .reset ()
238+ _profiler .clear_metadata () # pytype: disable=module-attr
229239
230240
231241def stop_and_get_fdo_profile () -> bytes | str :
@@ -240,6 +250,7 @@ def stop_and_get_fdo_profile() -> bytes | str:
240250 xspace = _profile_state .profile_session .stop ()
241251 fdo_profile = _profiler .get_fdo_profile (xspace )
242252 _profile_state .reset ()
253+ _profiler .clear_metadata () # pytype: disable=module-attr
243254 return fdo_profile
244255
245256
0 commit comments