Skip to content

Commit 7b5aca6

Browse files
authored
Add autocast support for einsum (#8420)
1 parent 5c062ea commit 7b5aca6

File tree

3 files changed

+54
-29
lines changed

3 files changed

+54
-29
lines changed

test/test_autocast_xla.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import re
2+
import torch
3+
import torch_xla
4+
import torch_xla.core.xla_model as xm
5+
import unittest
6+
7+
device = xm.xla_device()
8+
9+
10+
class TestAutocastXla(unittest.TestCase):
11+
12+
def test_cross_entropy_loss(self):
13+
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
14+
target = torch.randn(16, 10).to(torch.bfloat16).to(device)
15+
16+
with torch.autocast("xla"):
17+
loss = torch.nn.CrossEntropyLoss()(data, target)
18+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss])
19+
self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16")
20+
self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32")
21+
self.assertRegex(hlo, r".*log.*f32.*log.*f32")
22+
23+
def test_einsum(self):
24+
# irrespective of input dtype, output dtype will depend on autocast policy.
25+
# Tests for bf16 and f32 given below.
26+
27+
# input data of type bf16
28+
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
29+
target = torch.randn(5, 10).to(torch.bfloat16).to(device)
30+
31+
with torch.autocast("xla"):
32+
product = torch.einsum("...n,mn->...m", data, target)
33+
# test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO
34+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([product])
35+
# Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast
36+
self.assertRegex(hlo, r".*dot.*bf16")
37+
self.assertNotRegex(hlo, r".*dot.*f32")
38+
39+
# input data of type fp32
40+
data32 = torch.randn(16, 10).to(torch.float32).to(device)
41+
target32 = torch.randn(5, 10).to(torch.float32).to(device)
42+
43+
with torch.autocast("xla"):
44+
product = torch.einsum("...n,mn->...m", data32, target32)
45+
# test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO
46+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([product])
47+
# Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast
48+
self.assertRegex(hlo, r".*dot.*bf16")
49+
self.assertNotRegex(hlo, r".*dot.*f32")
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

test/test_bf16_autocast.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

torch_xla/csrc/autocast_mode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
4848
KERNEL_XLA(prelu, lower_precision_fp)
4949
KERNEL_XLA(relu, lower_precision_fp)
5050
KERNEL_XLA(max_pool2d, lower_precision_fp)
51+
KERNEL_XLA(einsum, lower_precision_fp)
5152
// Disable `scaled_dot_product_attention` for now since it causes
5253
// undefined symbol with official torch whl.
5354
// KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)

0 commit comments

Comments
 (0)