diff --git a/.torch_pin b/.torch_pin new file mode 100644 index 00000000000..9eb602820f8 --- /dev/null +++ b/.torch_pin @@ -0,0 +1 @@ +#138470 diff --git a/test/test_operations.py b/test/test_operations.py index 1af928e6a47..545831acf49 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2912,6 +2912,17 @@ def test_dlpack_xla_to_pytorch_cuda(self): cuda_t1[0] = cuda_t1[0] + 20 self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self): + xla_t1 = torch.arange(5).to(xm.xla_device()) + caps_t1 = torch.utils.dlpack.to_dlpack(xla_t1) + cuda_t1 = torch.utils.dlpack.from_dlpack(caps_t1) + self.assertEqual(cuda_t1.device.type, 'cuda') + self.assertEqual(cuda_t1.device.index, xla_t1.device.index) + cuda_t1[0] = cuda_t1[0] + 20 + self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) + @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_non_default_layout(self):