diff --git a/test/test_operations.py b/test/test_operations.py index fc6765d0483..8d772a140f5 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -817,6 +817,33 @@ def test_view_as_real_c128(self): self.assertIn("f64[4,2]", torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3]) + @skipIfFunctionalizationDisabled("view_as_real unsupported") + def test_view_as_complex_f32(self): + xla_device = torch_xla.device() + x = torch.randn(4, 2, device=xla_device) + complex = torch.view_as_complex(x) + self.assertEqual(complex.dtype, torch.complex64) + # XLA type of the real needs to be f32 as well + self.assertIn("c64[4]", torch_xla._XLAC._get_xla_tensor_debug_info(complex)) + # HLO generated needs to have type f32 as well + self.assertIn( + "c64[4]", + torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3]) + + @skipIfFunctionalizationDisabled("view_as_real unsupported") + def test_view_as_complex_f64(self): + xla_device = torch_xla.device() + x = torch.randn(4, 2, dtype=torch.float64, device=xla_device) + complex = torch.view_as_complex(x) + self.assertEqual(complex.dtype, torch.complex128) + # XLA type of the real needs to be f32 as well + self.assertIn("c128[4]", + torch_xla._XLAC._get_xla_tensor_debug_info(complex)) + # HLO generated needs to have type f32 as well + self.assertIn( + "c128[4]", + torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3]) + def test_index_put(self): xla_device = xm.xla_device() a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 84a3f623200..39c4bf54321 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -981,11 +981,24 @@ torch::lazy::NodePtr ViewAsComplexCopy(const torch::lazy::Value& input) { return node.ReturnOp(xla::Complex(zero_dim, first_dim), loctx); }; - xla::Shape result_shape = GetXlaShape(input); - result_shape.DeleteDimension(result_shape.rank() - 1); + xla::Shape input_shape = GetXlaShape(input); + xla::Shape res_shape; + switch (input_shape.element_type()) { + case xla::PrimitiveType::F32: + res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::C64, + input_shape.dimensions()); + break; + case xla::PrimitiveType::F64: + res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::C128, + input_shape.dimensions()); + break; + default: + XLA_ERROR() << "input shape type not supported: " << input_shape; + } + res_shape.DeleteDimension(res_shape.rank() - 1); return GenericOp(torch::lazy::OpKind(at::aten::view_as_complex_copy), {input}, - result_shape, std::move(lower_fn)); + res_shape, std::move(lower_fn)); } torch::lazy::NodePtr ViewAsRealCopy(const torch::lazy::Value& input) { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index f73b5ebb3cf..5c610b4885a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -3562,7 +3562,7 @@ XLATensorPtr view_symint(const XLATensorPtr& input, XLATensorPtr view_as_complex_copy(const XLATensorPtr& input) { return input->CreateFrom(ViewAsComplexCopy(input->GetIrValue()), - at::ScalarType::ComplexFloat); + /*logical_element_type=*/std::nullopt); } XLATensorPtr view_as_real_copy(const XLATensorPtr& input) {