|
| 1 | +import re |
| 2 | +import torch |
| 3 | +import torch_xla |
| 4 | +import torch_xla.core.xla_model as xm |
| 5 | +import unittest |
| 6 | + |
| 7 | +device = xm.xla_device() |
| 8 | + |
| 9 | + |
| 10 | +class TestAutocastXla(unittest.TestCase): |
| 11 | + |
| 12 | + def test_cross_entropy_loss(self): |
| 13 | + data = torch.randn(16, 10).to(torch.bfloat16).to(device) |
| 14 | + target = torch.randn(16, 10).to(torch.bfloat16).to(device) |
| 15 | + |
| 16 | + with torch.autocast("xla"): |
| 17 | + loss = torch.nn.CrossEntropyLoss()(data, target) |
| 18 | + hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss]) |
| 19 | + self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16") |
| 20 | + self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32") |
| 21 | + self.assertRegex(hlo, r".*log.*f32.*log.*f32") |
| 22 | + |
| 23 | + def test_einsum(self): |
| 24 | + # irrespective of input dtype, output dtype will depend on autocast policy. |
| 25 | + # Tests for bf16 and f32 given below. |
| 26 | + |
| 27 | + # input data of type bf16 |
| 28 | + data = torch.randn(16, 10).to(torch.bfloat16).to(device) |
| 29 | + target = torch.randn(5, 10).to(torch.bfloat16).to(device) |
| 30 | + |
| 31 | + with torch.autocast("xla"): |
| 32 | + product = torch.einsum("...n,mn->...m", data, target) |
| 33 | + # test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO |
| 34 | + hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) |
| 35 | + # Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast |
| 36 | + self.assertRegex(hlo, r".*dot.*bf16") |
| 37 | + self.assertNotRegex(hlo, r".*dot.*f32") |
| 38 | + |
| 39 | + # input data of type fp32 |
| 40 | + data32 = torch.randn(16, 10).to(torch.float32).to(device) |
| 41 | + target32 = torch.randn(5, 10).to(torch.float32).to(device) |
| 42 | + |
| 43 | + with torch.autocast("xla"): |
| 44 | + product = torch.einsum("...n,mn->...m", data32, target32) |
| 45 | + # test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO |
| 46 | + hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) |
| 47 | + # Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast |
| 48 | + self.assertRegex(hlo, r".*dot.*bf16") |
| 49 | + self.assertNotRegex(hlo, r".*dot.*f32") |
| 50 | + |
| 51 | + |
| 52 | +if __name__ == "__main__": |
| 53 | + unittest.main() |
0 commit comments