@@ -1377,6 +1377,32 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self):
1377
1377
):
1378
1378
data , _ = iter (train_device_loader ).__next__ ()
1379
1379
1380
+ def test_fallback (self ):
1381
+ device = torch_xla .device ()
1382
+
1383
+ theta : float = 10000
1384
+ dim = 16
1385
+ end = 2048
1386
+
1387
+ torch_xla .sync ()
1388
+ freqs = 1.0 / (
1389
+ theta
1390
+ ** (torch .arange (0 , dim , 2 , device = device )[:(dim // 2 )].float () / dim ))
1391
+ t = torch .arange (end , device = freqs .device )
1392
+ freqs = torch .outer (t , freqs ).float ()
1393
+ freqs_cis = torch .polar (torch .ones_like (freqs , device = device ),
1394
+ freqs ) # complex64
1395
+ # torch.polar will fallback on CPU, the result tensor should not have any sharding spec
1396
+ self .assertIn ("ShardingSpec: None" ,
1397
+ torch_xla ._XLAC ._get_xla_tensor_debug_info (freqs_cis ))
1398
+ # it will be on a CPU tensor, the sharding spec is not specified so it won't be move to device yet
1399
+ self .assertIn ("Tensor on host: with size [2048, 8]" ,
1400
+ torch_xla ._XLAC ._get_xla_tensor_debug_info (freqs_cis ))
1401
+ torch_xla .sync ()
1402
+ # data should be on device and replicated now
1403
+ self .assertIn ("Data Shape: c64[2048,8]\n OpSharding: {replicated}" ,
1404
+ torch_xla ._XLAC ._get_xla_tensor_debug_info (freqs_cis ))
1405
+
1380
1406
1381
1407
if __name__ == '__main__' :
1382
1408
test = unittest .main ()
0 commit comments