From 72ad00634d31f5b0dffd36e872a725b5a94f3a78 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Fri, 31 Jan 2025 00:24:32 +0000 Subject: [PATCH] [torchax] fix einsum op conversion --- torchax/test/test_core_aten_ops.py | 16 +++++++++++++++- torchax/torchax/ops/jtorch.py | 16 ++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/torchax/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py index ca0c5c15a9f..3f8caf8c1c4 100644 --- a/torchax/test/test_core_aten_ops.py +++ b/torchax/test/test_core_aten_ops.py @@ -1,10 +1,15 @@ import math +import os +import sys import unittest import torch from torchax import tensor -from . import test_base +test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) +sys.path.append(test_folder) + +import test_base from torch.utils import _pytree as pytree @@ -4405,6 +4410,15 @@ def test_aten_rand_like(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.rand_like, args, kwargs, atol=math.inf, check_dtype=True) + def test_einsum(self): + args = ( + "bshd,bthd->bsht", + torch.ones((1, 2, 4, 8), dtype=torch.float32), + torch.ones((1, 2, 4, 8), dtype=torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.einsum, args, kwargs, check_dtype=True) + if __name__ == "__main__": test_base.main() diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 42e26cc98b8..e6b529664ab 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -86,14 +86,14 @@ def _diag(input, diagonal=0): def _einsum(equation, *operands): def get_params(*a): inner_list = a[0] - if len(inner_list) == 1: - A = inner_list - return A - elif len(inner_list) == 2: - A, B = inner_list - return A, B - else: - return operands + if not isinstance(inner_list, jax.Array): + if len(inner_list) == 1: + A = inner_list + return A + elif len(inner_list) == 2: + A, B = inner_list + return A, B + return operands assert isinstance(equation, str), 'Only accept str equation' filtered_operands = get_params(*operands) return jnp.einsum(equation, *filtered_operands)