Skip to content

Commit

Permalink
Fix the fallback op in SPMD (#8386)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Nov 18, 2024
1 parent 91f5c8a commit defd0a9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,32 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self):
):
data, _ = iter(train_device_loader).__next__()

def test_fallback(self):
device = torch_xla.device()

theta: float = 10000
dim = 16
end = 2048

torch_xla.sync()
freqs = 1.0 / (
theta
**(torch.arange(0, dim, 2, device=device)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs, device=device),
freqs) # complex64
# torch.polar will fallback on CPU, the result tensor should not have any sharding spec
self.assertIn("ShardingSpec: None",
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))
# it will be on a CPU tensor, the sharding spec is not specified so it won't be move to device yet
self.assertIn("Tensor on host: with size [2048, 8]",
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))
torch_xla.sync()
# data should be on device and replicated now
self.assertIn("Data Shape: c64[2048,8]\n OpSharding: {replicated}",
torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis))


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
at::Tensor coyped_tensor = torch::lazy::CopyTensor(tensor, dtype());
SetTensorData(coyped_tensor);
data()->handle = nullptr;
data()->sharding = nullptr;
AssignIrValue(torch::lazy::Value());
if (data()->view != nullptr) {
torch::lazy::Value ir_value = GetIrValueForTensor(coyped_tensor, device);
Expand Down

0 comments on commit defd0a9

Please sign in to comment.