Skip to content

Commit 4688da3

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix jax2tf failure coming from dot_general
PiperOrigin-RevId: 688738110
1 parent f8a1f02 commit 4688da3

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

jax/experimental/jax2tf/impl_no_xla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def _conv_general_dilated(
364364
def _dot_general(lhs, rhs, *, dimension_numbers,
365365
precision: tuple[PrecisionType, PrecisionType] | None,
366366
preferred_element_type: DType | None,
367+
out_type=None,
367368
_in_avals: Sequence[core.ShapedArray],
368369
_out_aval: core.ShapedArray):
369370
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""

jax/experimental/jax2tf/jax2tf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2180,9 +2180,10 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None):
21802180
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
21812181

21822182

2183-
def _dot_general(lhs, rhs, *, dimension_numbers, out_type,
2183+
def _dot_general(lhs, rhs, *, dimension_numbers,
21842184
precision: lax_internal.CanonicalPrecision,
21852185
preferred_element_type: DType | None,
2186+
out_type=None,
21862187
_in_avals: Sequence[core.ShapedArray],
21872188
_out_aval: core.ShapedArray):
21882189
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""

0 commit comments

Comments
 (0)