Skip to content

Commit

Permalink
Fix incorrect xla type for view_as_complex (#8382)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Nov 15, 2024
1 parent 5b3c331 commit bd4006e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
27 changes: 27 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,33 @@ def test_view_as_real_c128(self):
self.assertIn("f64[4,2]",
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])

@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_complex_f32(self):
xla_device = torch_xla.device()
x = torch.randn(4, 2, device=xla_device)
complex = torch.view_as_complex(x)
self.assertEqual(complex.dtype, torch.complex64)
# XLA type of the real needs to be f32 as well
self.assertIn("c64[4]", torch_xla._XLAC._get_xla_tensor_debug_info(complex))
# HLO generated needs to have type f32 as well
self.assertIn(
"c64[4]",
torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3])

@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_complex_f64(self):
xla_device = torch_xla.device()
x = torch.randn(4, 2, dtype=torch.float64, device=xla_device)
complex = torch.view_as_complex(x)
self.assertEqual(complex.dtype, torch.complex128)
# XLA type of the real needs to be f32 as well
self.assertIn("c128[4]",
torch_xla._XLAC._get_xla_tensor_debug_info(complex))
# HLO generated needs to have type f32 as well
self.assertIn(
"c128[4]",
torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3])

def test_index_put(self):
xla_device = xm.xla_device()
a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32)
Expand Down
19 changes: 16 additions & 3 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,11 +981,24 @@ torch::lazy::NodePtr ViewAsComplexCopy(const torch::lazy::Value& input) {
return node.ReturnOp(xla::Complex(zero_dim, first_dim), loctx);
};

xla::Shape result_shape = GetXlaShape(input);
result_shape.DeleteDimension(result_shape.rank() - 1);
xla::Shape input_shape = GetXlaShape(input);
xla::Shape res_shape;
switch (input_shape.element_type()) {
case xla::PrimitiveType::F32:
res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::C64,
input_shape.dimensions());
break;
case xla::PrimitiveType::F64:
res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::C128,
input_shape.dimensions());
break;
default:
XLA_ERROR() << "input shape type not supported: " << input_shape;
}
res_shape.DeleteDimension(res_shape.rank() - 1);

return GenericOp(torch::lazy::OpKind(at::aten::view_as_complex_copy), {input},
result_shape, std::move(lower_fn));
res_shape, std::move(lower_fn));
}

torch::lazy::NodePtr ViewAsRealCopy(const torch::lazy::Value& input) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3562,7 +3562,7 @@ XLATensorPtr view_symint(const XLATensorPtr& input,

XLATensorPtr view_as_complex_copy(const XLATensorPtr& input) {
return input->CreateFrom(ViewAsComplexCopy(input->GetIrValue()),
at::ScalarType::ComplexFloat);
/*logical_element_type=*/std::nullopt);
}

XLATensorPtr view_as_real_copy(const XLATensorPtr& input) {
Expand Down

0 comments on commit bd4006e

Please sign in to comment.