@@ -212,6 +212,7 @@ def trace_context():
212212 return (axis_env_state .value , mesh_context_manager .value ,
213213 xla_metadata_context_manager .value ,
214214 abstract_mesh_context_manager .value ,
215+ device_context .value ,
215216 compute_on_context_manager .value , enable_x64 .value ,
216217 numpy_rank_promotion .value , default_matmul_precision .value ,
217218 dynamic_shapes .value ,
@@ -245,6 +246,7 @@ def trace_context():
245246 axis_env_state = ()
246247 mesh_context_manager = ()
247248 abstract_mesh_context_manager = ()
249+ device_context = ()
248250 xla_metadata_context_manager = ()
249251 compute_on_context_manager = ()
250252
@@ -255,12 +257,14 @@ def trace_context():
255257 mesh_context_manager = context .mesh_context_manager
256258 if context and context .abstract_mesh_context_manager :
257259 abstract_mesh_context_manager = context .abstract_mesh_context_manager
260+ if context and context .device_context :
261+ device_context = context .device_context
258262 if context and context .xla_metadata_context_manager :
259263 xla_metadata_context_manager = context .xla_metadata_context_manager
260264 if context and context .compute_on_context_manager :
261265 compute_on_context_manager = context .compute_on_context_manager
262266 return (axis_env_state , mesh_context_manager , abstract_mesh_context_manager ,
263- xla_metadata_context_manager ,
267+ device_context , xla_metadata_context_manager ,
264268 compute_on_context_manager , enable_x64 .value ,
265269 numpy_rank_promotion .value , default_matmul_precision .value ,
266270 dynamic_shapes .value ,
@@ -976,6 +980,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
976980 axis_env_state = config_ext .Config ((), include_in_jit_key = True )
977981 mesh_context_manager = config_ext .Config ((), include_in_jit_key = True )
978982 abstract_mesh_context_manager = config_ext .Config ((), include_in_jit_key = True )
983+ device_context = config_ext .Config ((), include_in_jit_key = True )
979984 compute_on_context_manager = config_ext .Config ((), include_in_jit_key = True )
980985 xla_metadata_context_manager = config_ext .Config ((), include_in_jit_key = True )
981986else :
@@ -1019,6 +1024,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10191024 axis_env_state : Hashable = ()
10201025 mesh_context_manager : Hashable = ()
10211026 abstract_mesh_context_manager : Hashable = ()
1027+ device_context : Hashable = ()
10221028 compute_on_context_manager : Hashable = ()
10231029 xla_metadata_context_manager : Hashable = ()
10241030
@@ -1086,6 +1092,7 @@ def set_local(self, value):
10861092 axis_env_state = JitConfig ('axis_env_state' )
10871093 mesh_context_manager = JitConfig ('mesh_context_manager' )
10881094 abstract_mesh_context_manager = JitConfig ('abstract_mesh_context_manager' )
1095+ device_context = JitConfig ('device_context' )
10891096 compute_on_context_manager = JitConfig ('compute_on_context_manager' )
10901097 xla_metadata_context_manager = JitConfig ('xla_metadata_context_manager' )
10911098
0 commit comments