diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index e4560df6c70..1946ae05a52 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -253,6 +253,15 @@ def use_spmd(auto: Optional[bool] = False): torch_xla._XLAC._xla_set_auto_sharding() os.environ["XLA_AUTO_SPMD"] = "1" + if device_type() == 'NEURON': + # In case of Neuron, reset the initialization environment to accommodate SPMD. + try: + from torch_neuronx.initialization import initialize + + initialize() + except ImportError: + pass + def is_spmd(): """Returns if SPMD is set for execution."""