Skip to content

Commit

Permalink
[torchax] fix einsum op conversion (#8657)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji authored Jan 31, 2025
1 parent 39dd795 commit 8572e75
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 33 deletions.
119 changes: 92 additions & 27 deletions torchax/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,35 @@
from torch.utils import _pytree as pytree


def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_dtype=False):
def diff_output(testcase,
output1,
output2,
rtol,
atol,
equal_nan=True,
check_dtype=False):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
torch.testing.assert_close(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan, check_dtype=check_dtype)
output1,
output2_cpu,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
check_dtype=check_dtype)
elif isinstance(output1, (tuple, list)):
testcase.assertIsInstance(output2, (tuple, list))
testcase.assertEqual(len(output1), len(output2))
for o1, o2 in zip(output1, output2):
diff_output(testcase, o1, o2, rtol, atol, equal_nan=equal_nan, check_dtype=check_dtype)
diff_output(
testcase,
o1,
o2,
rtol,
atol,
equal_nan=equal_nan,
check_dtype=check_dtype)
else:
testcase.assertEqual(output1, output2)

Expand Down Expand Up @@ -53,7 +71,13 @@ def run_export_and_compare(testcase,
check_dtype=check_dtype)
else:
diff_output(
testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan, check_dtype=check_dtype)
testcase,
res,
res2,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
check_dtype=check_dtype)


class TestCoreAtenOps(unittest.TestCase):
Expand Down Expand Up @@ -1066,19 +1090,22 @@ def test_aten_convolution_2(self):
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs)

def test_aten_copy_0(self):
args = (torch.randn((10, 10)).to(torch.float32), torch.randn((10, 10)).to(torch.float32))
args = (torch.randn((10, 10)).to(torch.float32), torch.randn(
(10, 10)).to(torch.float32))
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.copy, args, kwargs)

def test_aten_copy_broadcast(self):
args = (torch.randn((10, 10)).to(torch.float32), torch.tensor(1.0, dtype=torch.float32))
args = (torch.randn(
(10, 10)).to(torch.float32), torch.tensor(1.0, dtype=torch.float32))
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.copy, args, kwargs)

def test_aten_copy_cast_dtype(self):
args = (torch.randn((10, 10)).to(torch.float32), torch.randn((10, 10)).to(torch.int64))
args = (torch.randn((10, 10)).to(torch.float32), torch.randn(
(10, 10)).to(torch.int64))
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.copy, args, kwargs)

Expand Down Expand Up @@ -1835,7 +1862,7 @@ def test_aten_index_select_2(self):
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)

def test_aten_index_select_int32_index(self):
args = (
torch.randint(0, 10, (2, 10)).to(torch.int32),
Expand Down Expand Up @@ -2151,7 +2178,13 @@ def test_aten_logit_0(self):
def test_aten_logit_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.logit, args, kwargs, atol=0.01,)
run_export_and_compare(
self,
torch.ops.aten.logit,
args,
kwargs,
atol=0.01,
)

def test_aten_logit_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
Expand Down Expand Up @@ -2738,7 +2771,7 @@ def test_aten_native_layer_norm_0(self):
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs)

def test_aten_native_layer_norm_1(self):
args = (
torch.randn((1, 10, 10, 10)).to(torch.float32),
Expand All @@ -2754,7 +2787,7 @@ def test_aten_native_batch_norm_legit(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,2,2)).to(torch.float32),
torch.randn((batch, channel, 2, 2)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
Expand All @@ -2764,13 +2797,14 @@ def test_aten_native_batch_norm_legit(self):
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args,
kwargs)

def test_aten_native_batch_norm_legit_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,4)).to(torch.float32),
torch.randn((batch, channel, 4, 4)).to(torch.float32),
None,
None,
torch.ones(channel),
Expand All @@ -2780,13 +2814,14 @@ def test_aten_native_batch_norm_legit_none(self):
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args,
kwargs)

def test_aten_native_batch_norm_legit_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.randn((batch, channel, 4, 3)).to(torch.float32),
None,
None,
torch.zeros(channel),
Expand All @@ -2796,13 +2831,14 @@ def test_aten_native_batch_norm_legit_training_none(self):
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args,
kwargs)

def test_aten_native_batch_norm_legit_no_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.randn((batch, channel, 4, 3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
Expand All @@ -2811,13 +2847,15 @@ def test_aten_native_batch_norm_legit_no_training(self):
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs)
run_export_and_compare(self,
torch.ops.aten._native_batch_norm_legit_no_training,
args, kwargs)

def test_aten_native_batch_norm_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.randn((batch, channel, 4, 3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
Expand All @@ -2833,7 +2871,7 @@ def test_aten_native_batch_norm_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.randn((batch, channel, 4, 3)).to(torch.float32),
None,
None,
torch.zeros(channel),
Expand All @@ -2849,7 +2887,7 @@ def test_aten_native_batch_norm_eval(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.randn((batch, channel, 4, 3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
Expand Down Expand Up @@ -3810,6 +3848,14 @@ def test_aten__softmax_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._softmax, args, kwargs)

def test_aten_softmax(self):
args = (
torch.randn((10, 10)).to(torch.float32),
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.softmax, args, kwargs)

def _compare_sorted_result(self, args):
res = torch.ops.aten.sort(*args)
with self.subTest("torchax_eval"):
Expand Down Expand Up @@ -4390,20 +4436,39 @@ def test_aten_where_self_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs)

def test_aten_where(self):
args = (torch.randn((10, 10)).to(torch.bool),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.where, args, kwargs)

def test_aten_copy_dtype(self):
args = (
torch.ones((3, 3), dtype=torch.int32),
torch.zeros((3, 3), dtype=torch.float32),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.copy_, args, kwargs, check_dtype=True)
run_export_and_compare(
self, torch.ops.aten.copy_, args, kwargs, check_dtype=True)

def test_aten_rand_like(self):
args = (torch.ones((3, 3), dtype=torch.bfloat16),)
kwargs = dict()
run_export_and_compare(
self,
torch.ops.aten.rand_like,
args,
kwargs,
atol=math.inf,
check_dtype=True)

def test_einsum(self):
args = (
torch.ones((3, 3), dtype=torch.bfloat16),
"bshd,bthd->bsht",
torch.ones((1, 2, 4, 8), dtype=torch.float32),
torch.ones((1, 2, 4, 8), dtype=torch.float32),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.rand_like, args, kwargs, atol=math.inf, check_dtype=True)
run_export_and_compare(self, torch.einsum, args, kwargs, check_dtype=True)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ def _aten_stack(tensors, dim=0):


@op(torch.ops.aten._softmax)
def _aten_softmax(x, dim, halftofloat):
@op(torch.ops.aten.softmax)
def _aten_softmax(x, dim, halftofloat = False):
if x.shape == ():
return jax.nn.softmax(x.reshape([1]), axis=0).reshape([])
return jax.nn.softmax(x, dim)
Expand Down Expand Up @@ -3281,11 +3282,12 @@ def _aten_unique_consecutive(input_tensor,


# aten.where
@op(torch.ops.aten.where)
@op(torch.ops.aten.where.self)
@op(torch.ops.aten.where.ScalarSelf)
@op(torch.ops.aten.where.ScalarOther)
@op(torch.ops.aten.where.Scalar)
def _aten_where(condition, x, y):
def _aten_where(condition, x = None, y = None):
return jnp.where(condition, x, y)


Expand Down
8 changes: 4 additions & 4 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def _diag(input, diagonal=0):
def _einsum(equation, *operands):
def get_params(*a):
inner_list = a[0]
if len(inner_list) == 1:
if not isinstance(inner_list, jax.Array):
if len(inner_list) == 1:
A = inner_list
return A
elif len(inner_list) == 2:
elif len(inner_list) == 2:
A, B = inner_list
return A, B
else:
return operands
return operands
assert isinstance(equation, str), 'Only accept str equation'
filtered_operands = get_params(*operands)
return jnp.einsum(equation, *filtered_operands)
Expand Down

0 comments on commit 8572e75

Please sign in to comment.