Skip to content

Commit

Permalink
[torchax] fix einsum op conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Jan 31, 2025
1 parent e583c2c commit 72ad006
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
16 changes: 15 additions & 1 deletion torchax/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import math
import os
import sys
import unittest

import torch
from torchax import tensor

from . import test_base
test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(test_folder)

import test_base
from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -4405,6 +4410,15 @@ def test_aten_rand_like(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.rand_like, args, kwargs, atol=math.inf, check_dtype=True)

def test_einsum(self):
args = (
"bshd,bthd->bsht",
torch.ones((1, 2, 4, 8), dtype=torch.float32),
torch.ones((1, 2, 4, 8), dtype=torch.float32),
)
kwargs = dict()
run_export_and_compare(self, torch.einsum, args, kwargs, check_dtype=True)


if __name__ == "__main__":
test_base.main()
16 changes: 8 additions & 8 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def _diag(input, diagonal=0):
def _einsum(equation, *operands):
def get_params(*a):
inner_list = a[0]
if len(inner_list) == 1:
A = inner_list
return A
elif len(inner_list) == 2:
A, B = inner_list
return A, B
else:
return operands
if not isinstance(inner_list, jax.Array):
if len(inner_list) == 1:
A = inner_list
return A
elif len(inner_list) == 2:
A, B = inner_list
return A, B
return operands
assert isinstance(equation, str), 'Only accept str equation'
filtered_operands = get_params(*operands)
return jnp.einsum(equation, *filtered_operands)
Expand Down

0 comments on commit 72ad006

Please sign in to comment.