Skip to content

Commit

Permalink
Add autocast support for einsum (#8420)
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-nm9 authored Dec 9, 2024
1 parent 5c062ea commit 7b5aca6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 29 deletions.
53 changes: 53 additions & 0 deletions test/test_autocast_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import re
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest

device = xm.xla_device()


class TestAutocastXla(unittest.TestCase):

def test_cross_entropy_loss(self):
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(16, 10).to(torch.bfloat16).to(device)

with torch.autocast("xla"):
loss = torch.nn.CrossEntropyLoss()(data, target)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss])
self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16")
self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32")
self.assertRegex(hlo, r".*log.*f32.*log.*f32")

def test_einsum(self):
# irrespective of input dtype, output dtype will depend on autocast policy.
# Tests for bf16 and f32 given below.

# input data of type bf16
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(5, 10).to(torch.bfloat16).to(device)

with torch.autocast("xla"):
product = torch.einsum("...n,mn->...m", data, target)
# test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO
hlo = torch_xla._XLAC._get_xla_tensors_hlo([product])
# Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast
self.assertRegex(hlo, r".*dot.*bf16")
self.assertNotRegex(hlo, r".*dot.*f32")

# input data of type fp32
data32 = torch.randn(16, 10).to(torch.float32).to(device)
target32 = torch.randn(5, 10).to(torch.float32).to(device)

with torch.autocast("xla"):
product = torch.einsum("...n,mn->...m", data32, target32)
# test the HLO to see if autocast works for einsum op, which would show up as a dot op in the HLO
hlo = torch_xla._XLAC._get_xla_tensors_hlo([product])
# Verify that dot op has bf16 output and not f32, i.e. the computation is performed in bfloat16 precision by autocast
self.assertRegex(hlo, r".*dot.*bf16")
self.assertNotRegex(hlo, r".*dot.*f32")


if __name__ == "__main__":
unittest.main()
29 changes: 0 additions & 29 deletions test/test_bf16_autocast.py

This file was deleted.

1 change: 1 addition & 0 deletions torch_xla/csrc/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
KERNEL_XLA(prelu, lower_precision_fp)
KERNEL_XLA(relu, lower_precision_fp)
KERNEL_XLA(max_pool2d, lower_precision_fp)
KERNEL_XLA(einsum, lower_precision_fp)
// Disable `scaled_dot_product_attention` for now since it causes
// undefined symbol with official torch whl.
// KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
Expand Down

0 comments on commit 7b5aca6

Please sign in to comment.