Skip to content

Commit 719abab

Browse files
TF2JAXDevTF2JAXDev
authored andcommitted
[tf2jax] Fix BroadcastArgs and Fill for polymorphic shapes.
The `BroadcastArgs` and `Fill` ops failed when input shapes contained symbolic dimensions. This CL fixes the issue of `BroadcastArgs` by using `jnp.broadcast_shapes`, which is designed to handle polymorphic shapes during JAX tracing. It fixes the issue of `Fill` by correctly select `jnp` as the backend when the inputs contain symbolic dimension. PiperOrigin-RevId: 861734224
1 parent f202766 commit 719abab

File tree

3 files changed

+167
-3
lines changed

3 files changed

+167
-3
lines changed

tf2jax/_src/numpy_compat.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,18 @@ def is_poly_dim(x) -> bool:
8989
return False
9090

9191

92+
def _is_np(x):
93+
"""Checks if `x` is a numpy like type."""
94+
# Special case for polymorphic shape tensors. Shape tensors are stored as 1-D
95+
# np.array.
96+
if isinstance(x, np.ndarray) and x.ndim == 1 and any(map(is_poly_dim, x)):
97+
return False
98+
return isinstance(x, _NP_LIKES) or is_poly_dim(x)
99+
100+
92101
def _get_np(*args):
93102
"""Select numpy backend based on input types."""
94-
no_jax = all((isinstance(x, _NP_LIKES) or is_poly_dim(x)) for x in args)
103+
no_jax = all(map(_is_np, args))
95104
return np if no_jax else jnp
96105

97106

@@ -156,7 +165,7 @@ def empty(shape, dtype: tf.DType, init: bool):
156165

157166

158167
def full(shape, fill_value, dtype: tf.DType):
159-
dtype = _get_dtypes(shape)[dtype]
168+
dtype = _get_dtypes(shape, fill_value)[dtype]
160169
return _get_np(shape, fill_value).full(shape, fill_value, dtype=dtype)
161170

162171

tf2jax/_src/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def _bit_cast(proto):
454454
@register_operation("BroadcastArgs")
455455
def _broadcast_args(proto):
456456
_check_attrs(proto, {"T"})
457-
return lambda s0, s1: np.array(np.broadcast(np.zeros(s0), np.zeros(s1)).shape)
457+
return lambda s0, s1: np.array(jnp.broadcast_shapes(s0, s1))
458458

459459

460460
class _CaseOp(_HigherOrderFunction):

tf2jax/_src/ops_test.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
"""Tests for tf2jax."""
1616

1717
import contextlib
18+
import dataclasses
19+
from typing import Any
1820

1921
from absl.testing import parameterized
2022

2123
import chex
2224
import jax
25+
from jax import export
2326
from jax.experimental import checkify
2427
import numpy as np
2528

@@ -36,6 +39,15 @@ def _reorder(vals, inds):
3639
return [vals[idx] for idx in inds]
3740

3841

42+
@dataclasses.dataclass
43+
class _PolymorphicInput:
44+
"""Wrapper class containing information for polymorphic inputs."""
45+
46+
tf_spec: tf.TensorSpec
47+
jax_spec: jax.ShapeDtypeStruct
48+
concrete_value: Any
49+
50+
3951
class OpsTest(test_util.TestCase):
4052

4153
def test_get_unsupported(self):
@@ -84,6 +96,84 @@ def _test_convert(
8496

8597
return jax_results, new_jax_params
8698

99+
def _test_convert_polymorphic(
100+
self,
101+
tf_func,
102+
inputs,
103+
*,
104+
check_shape_only=False,
105+
functional=True,
106+
jit_compile=True,
107+
atol=1e-5,
108+
):
109+
if not isinstance(inputs, (list, tuple)):
110+
inputs = (inputs,)
111+
112+
# Call self._test_convert if there is no _PolymorphicInput.
113+
self.assertTrue(any(isinstance(x, _PolymorphicInput) for x in inputs))
114+
115+
if not hasattr(tf_func, "get_concrete_function"):
116+
tf_func = tf.function(tf_func, jit_compile=jit_compile)
117+
118+
def get_poly_attr_or_else(attr, else_fn=None):
119+
"""Returns the attr of a _PolymorphicInput otherwise apply `else_fn`."""
120+
121+
def mapper(x):
122+
if isinstance(x, _PolymorphicInput):
123+
return getattr(x, attr)
124+
if else_fn is not None:
125+
return else_fn(x)
126+
return x
127+
128+
return mapper
129+
130+
jax_func, jax_params = tf2jax.convert(
131+
tf_func,
132+
*tree.map_structure(
133+
get_poly_attr_or_else("tf_spec", np.zeros_like), inputs
134+
),
135+
)
136+
if functional:
137+
self.assertEmpty(jax_params, "Expected no parameters for pure Ops.")
138+
139+
jax_func = self.variant(jax_func)
140+
141+
concrete_inputs = tree.map_structure(
142+
get_poly_attr_or_else("concrete_value"), inputs
143+
)
144+
tf_results = tf_func(*concrete_inputs)
145+
146+
def assert_same(tf_results, jax_results):
147+
"""Compares the results of the TF and JAX functions."""
148+
for tf_res, jax_res in utils.safe_zip(
149+
tree.flatten(tf_results), tree.flatten(jax_results)
150+
):
151+
self.assertEqual(tf_res.shape, jax_res.shape)
152+
if not check_shape_only:
153+
self.assertAllClose(
154+
np.asarray(tf_res), np.asarray(jax_res), atol=atol
155+
)
156+
157+
# Check the converted JAX function.
158+
rng = jax.random.PRNGKey(42)
159+
jax_results, new_jax_params = jax_func(
160+
jax_params, *concrete_inputs, rng=rng
161+
)
162+
assert_same(tf_results, jax_results)
163+
164+
# Check exported JAX function.
165+
exp_func = export.export(jax_func)(
166+
jax_params,
167+
*tree.map_structure(
168+
get_poly_attr_or_else("jax_spec", np.zeros_like), inputs
169+
),
170+
)
171+
exp_results, new_exp_params = exp_func.call(jax_params, *concrete_inputs)
172+
assert_same(tf_results, exp_results)
173+
assert_same(new_jax_params, new_exp_params)
174+
175+
return jax_results, new_jax_params
176+
87177
@chex.variants(with_jit=True, without_jit=True)
88178
@parameterized.parameters("log_softmax", "sigmoid", "softmax", "softplus",
89179
"tanh", "relu", "relu6", "elu", "leaky_relu")
@@ -527,6 +617,54 @@ def raw_func(x):
527617
return tf.raw_ops.Bitcast(input=x, type=tf.float32)
528618
self._test_convert(raw_func, inputs)
529619

620+
@chex.variants(with_jit=True, without_jit=True)
621+
@parameterized.parameters(
622+
([1, 2], [3, 1]),
623+
([2, 3, 1], [1, 5]),
624+
([], [1]),
625+
([1], []),
626+
([], []),
627+
([3, 1, 2], [1, 5, 1]),
628+
)
629+
def test_broadcast_args(self, s0, s1):
630+
x = np.zeros(s0, dtype=np.float32)
631+
y = np.zeros(s1, dtype=np.float32)
632+
633+
def broadcast_args(x, y):
634+
return tf.broadcast_to(
635+
0.0, tf.broadcast_dynamic_shape(tf.shape(x), tf.shape(y))
636+
)
637+
638+
self._test_convert(broadcast_args, [x, y])
639+
640+
@chex.variants(with_jit=True, without_jit=False)
641+
def test_broadcast_args_polymorphic(self):
642+
643+
@tf.function
644+
def broadcast_args(x, y):
645+
return tf.broadcast_to(
646+
0.0, tf.broadcast_dynamic_shape(tf.shape(x), tf.shape(y))
647+
)
648+
649+
x = np.zeros((1, 2), dtype=np.float32)
650+
y = np.zeros((3, 1), dtype=np.float32)
651+
x_spec, y_spec = export.symbolic_args_specs((x, y), ("(_, x)", "(y, _)"))
652+
self._test_convert_polymorphic(
653+
broadcast_args,
654+
[
655+
_PolymorphicInput(
656+
tf_spec=tf.TensorSpec(shape=(1, None), dtype=tf.float32),
657+
jax_spec=x_spec,
658+
concrete_value=x,
659+
),
660+
_PolymorphicInput(
661+
tf_spec=tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
662+
jax_spec=y_spec,
663+
concrete_value=y,
664+
),
665+
],
666+
)
667+
530668
@chex.variants(with_jit=True, without_jit=True)
531669
def test_broadcast_to(self):
532670
inputs, shape = np.array([1, 2, 3]), (3, 3)
@@ -932,6 +1070,23 @@ def fill_static():
9321070
return tf.zeros(fill(value))
9331071
self._test_convert(fill_static, [])
9341072

1073+
@chex.variants(with_jit=True, without_jit=False)
1074+
def test_fill_polymorphic(self):
1075+
@tf.function
1076+
def fill(x):
1077+
return tf.zeros(shape=tf.shape(x), dtype=tf.float32)
1078+
1079+
x = np.zeros((2, 3), dtype=np.float32)
1080+
x_spec = export.symbolic_args_specs(x, "(a, b)")
1081+
self._test_convert_polymorphic(
1082+
fill,
1083+
_PolymorphicInput(
1084+
tf_spec=tf.TensorSpec(shape=(None, None), dtype=tf.float32),
1085+
jax_spec=x_spec,
1086+
concrete_value=x,
1087+
),
1088+
)
1089+
9351090
@chex.variants(with_jit=True, without_jit=True)
9361091
@parameterized.named_parameters(
9371092
chex.params_product(

0 commit comments

Comments
 (0)