Skip to content

Commit defd0a9

Browse files
authored
Fix the fallback op in SPMD (#8386)
1 parent 91f5c8a commit defd0a9

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,32 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self):
13771377
):
13781378
data, _ = iter(train_device_loader).__next__()
13791379

1380+
def test_fallback(self):
1381+
device = torch_xla.device()
1382+
1383+
theta: float = 10000
1384+
dim = 16
1385+
end = 2048
1386+
1387+
torch_xla.sync()
1388+
freqs = 1.0 / (
1389+
theta
1390+
**(torch.arange(0, dim, 2, device=device)[:(dim // 2)].float() / dim))
1391+
t = torch.arange(end, device=freqs.device)
1392+
freqs = torch.outer(t, freqs).float()
1393+
freqs_cis = torch.polar(torch.ones_like(freqs, device=device),
1394+
freqs) # complex64
1395+
# torch.polar will fallback on CPU, the result tensor should not have any sharding spec
1396+
self.assertIn("ShardingSpec: None",
1397+
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))
1398+
# it will be on a CPU tensor, the sharding spec is not specified so it won't be move to device yet
1399+
self.assertIn("Tensor on host: with size [2048, 8]",
1400+
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))
1401+
torch_xla.sync()
1402+
# data should be on device and replicated now
1403+
self.assertIn("Data Shape: c64[2048,8]\n OpSharding: {replicated}",
1404+
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))
1405+
13801406

13811407
if __name__ == '__main__':
13821408
test = unittest.main()

torch_xla/csrc/tensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
562562
at::Tensor coyped_tensor = torch::lazy::CopyTensor(tensor, dtype());
563563
SetTensorData(coyped_tensor);
564564
data()->handle = nullptr;
565+
data()->sharding = nullptr;
565566
AssignIrValue(torch::lazy::Value());
566567
if (data()->view != nullptr) {
567568
torch::lazy::Value ir_value = GetIrValueForTensor(coyped_tensor, device);

0 commit comments

Comments
 (0)