Skip to content

Commit b12170e

Browse files
committed
[torchax] fix einsum op conversion
1 parent e583c2c commit b12170e

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

torchax/test/test_ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ def test_reference_eager(self, device, dtype, op):
227227
run_export_and_compare(self, op, sample_input, check_output,
228228
ignore_indices=ignore_index)
229229

230+
def test_einsum(self):
231+
args = (
232+
"bshd,bthd->bsht",
233+
torch.ones((1, 2, 4, 8), dtype=torch.float32),
234+
torch.ones((1, 2, 4, 8), dtype=torch.float32),
235+
)
236+
kwargs = dict()
237+
run_export_and_compare(self, torch.einsum, args, kwargs, check_dtype=True)
238+
230239

231240
instantiate_device_type_tests(TestOpInfo, globals(), only_for={'cpu'})
232241

torchax/torchax/ops/jtorch.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ def _diag(input, diagonal=0):
8686
def _einsum(equation, *operands):
8787
def get_params(*a):
8888
inner_list = a[0]
89-
if len(inner_list) == 1:
89+
if not isinstance(inner_list, jax.Array):
90+
if len(inner_list) == 1:
9091
A = inner_list
9192
return A
92-
elif len(inner_list) == 2:
93+
elif len(inner_list) == 2:
9394
A, B = inner_list
9495
return A, B
95-
else:
96-
return operands
96+
return operands
9797
assert isinstance(equation, str), 'Only accept str equation'
9898
filtered_operands = get_params(*operands)
9999
return jnp.einsum(equation, *filtered_operands)

0 commit comments

Comments
 (0)