Skip to content

Commit afaf0d0

Browse files
authored
Added pow to autocast policy and unit test (#8684)
1 parent 8b45e59 commit afaf0d0

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/test_autocast_xla.py

+10
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ def test_einsum(self):
9393
self.assertRegex(hlo, r".*dot.*bf16")
9494
self.assertNotRegex(hlo, r".*dot.*f32")
9595

96+
def test_pow(self):
97+
data = torch.randn(16, 20).to(torch.bfloat16).to(device)
98+
99+
with torch.autocast("xla"):
100+
output = data.pow(2)
101+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
102+
103+
self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16")
104+
self.assertRegex(hlo, r".*power.*f32.*power.*f32")
105+
96106

97107
if __name__ == "__main__":
98108
unittest.main()

torch_xla/csrc/autocast_mode.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
6565
KERNEL_XLA(binary_cross_entropy, fp32)
6666
// KERNEL_XLA(grid_sampler, fp32)
6767
// KERNEL_XLA(polar, fp32)
68+
KERNEL_XLA2(pow, Tensor_Scalar, fp32)
6869
KERNEL_XLA(prod, fp32)
6970
KERNEL_XLA2(prod, dim_int, fp32)
7071
KERNEL_XLA2(prod, dim_Dimname, fp32)

0 commit comments

Comments
 (0)