Open
Description
Hi,
I am trying to combine both GSPMD + PyTorch Compile, but it doesn't work.
I took a copy of the test script "test_train_spmd_imagenet.py" and test it in colab, and it started normally. However, after I added the compile line :
device = xm.xla_device()
model = get_model_property('model_fn')().to(device)
model = torch.compile(
model, backend='aot_torchxla_trace_once')
It crashed.
Here is a Colab example to reproduce the results:
https://colab.research.google.com/drive/1KNcBydAfZXLATpSo-CXILxtHJkK8JD-2?usp=sharing