1515"""Tests for tf2jax."""
1616
1717import contextlib
18+ import dataclasses
19+ from typing import Any
1820
1921from absl .testing import parameterized
2022
2123import chex
2224import jax
25+ from jax import export
2326from jax .experimental import checkify
2427import numpy as np
2528
@@ -36,6 +39,15 @@ def _reorder(vals, inds):
3639 return [vals [idx ] for idx in inds ]
3740
3841
42+ @dataclasses .dataclass
43+ class _PolymorphicInput :
44+ """Wrapper class containing information for polymorphic inputs."""
45+
46+ tf_spec : tf .TensorSpec
47+ jax_spec : jax .ShapeDtypeStruct
48+ concrete_value : Any
49+
50+
3951class OpsTest (test_util .TestCase ):
4052
4153 def test_get_unsupported (self ):
@@ -84,6 +96,84 @@ def _test_convert(
8496
8597 return jax_results , new_jax_params
8698
99+ def _test_convert_polymorphic (
100+ self ,
101+ tf_func ,
102+ inputs ,
103+ * ,
104+ check_shape_only = False ,
105+ functional = True ,
106+ jit_compile = True ,
107+ atol = 1e-5 ,
108+ ):
109+ if not isinstance (inputs , (list , tuple )):
110+ inputs = (inputs ,)
111+
112+ # Call self._test_convert if there is no _PolymorphicInput.
113+ self .assertTrue (any (isinstance (x , _PolymorphicInput ) for x in inputs ))
114+
115+ if not hasattr (tf_func , "get_concrete_function" ):
116+ tf_func = tf .function (tf_func , jit_compile = jit_compile )
117+
118+ def get_poly_attr_or_else (attr , else_fn = None ):
119+ """Returns the attr of a _PolymorphicInput otherwise apply `else_fn`."""
120+
121+ def mapper (x ):
122+ if isinstance (x , _PolymorphicInput ):
123+ return getattr (x , attr )
124+ if else_fn is not None :
125+ return else_fn (x )
126+ return x
127+
128+ return mapper
129+
130+ jax_func , jax_params = tf2jax .convert (
131+ tf_func ,
132+ * tree .map_structure (
133+ get_poly_attr_or_else ("tf_spec" , np .zeros_like ), inputs
134+ ),
135+ )
136+ if functional :
137+ self .assertEmpty (jax_params , "Expected no parameters for pure Ops." )
138+
139+ jax_func = self .variant (jax_func )
140+
141+ concrete_inputs = tree .map_structure (
142+ get_poly_attr_or_else ("concrete_value" ), inputs
143+ )
144+ tf_results = tf_func (* concrete_inputs )
145+
146+ def assert_same (tf_results , jax_results ):
147+ """Compares the results of the TF and JAX functions."""
148+ for tf_res , jax_res in utils .safe_zip (
149+ tree .flatten (tf_results ), tree .flatten (jax_results )
150+ ):
151+ self .assertEqual (tf_res .shape , jax_res .shape )
152+ if not check_shape_only :
153+ self .assertAllClose (
154+ np .asarray (tf_res ), np .asarray (jax_res ), atol = atol
155+ )
156+
157+ # Check the converted JAX function.
158+ rng = jax .random .PRNGKey (42 )
159+ jax_results , new_jax_params = jax_func (
160+ jax_params , * concrete_inputs , rng = rng
161+ )
162+ assert_same (tf_results , jax_results )
163+
164+ # Check exported JAX function.
165+ exp_func = export .export (jax_func )(
166+ jax_params ,
167+ * tree .map_structure (
168+ get_poly_attr_or_else ("jax_spec" , np .zeros_like ), inputs
169+ ),
170+ )
171+ exp_results , new_exp_params = exp_func .call (jax_params , * concrete_inputs )
172+ assert_same (tf_results , exp_results )
173+ assert_same (new_jax_params , new_exp_params )
174+
175+ return jax_results , new_jax_params
176+
87177 @chex .variants (with_jit = True , without_jit = True )
88178 @parameterized .parameters ("log_softmax" , "sigmoid" , "softmax" , "softplus" ,
89179 "tanh" , "relu" , "relu6" , "elu" , "leaky_relu" )
@@ -527,6 +617,54 @@ def raw_func(x):
527617 return tf .raw_ops .Bitcast (input = x , type = tf .float32 )
528618 self ._test_convert (raw_func , inputs )
529619
620+ @chex .variants (with_jit = True , without_jit = True )
621+ @parameterized .parameters (
622+ ([1 , 2 ], [3 , 1 ]),
623+ ([2 , 3 , 1 ], [1 , 5 ]),
624+ ([], [1 ]),
625+ ([1 ], []),
626+ ([], []),
627+ ([3 , 1 , 2 ], [1 , 5 , 1 ]),
628+ )
629+ def test_broadcast_args (self , s0 , s1 ):
630+ x = np .zeros (s0 , dtype = np .float32 )
631+ y = np .zeros (s1 , dtype = np .float32 )
632+
633+ def broadcast_args (x , y ):
634+ return tf .broadcast_to (
635+ 0.0 , tf .broadcast_dynamic_shape (tf .shape (x ), tf .shape (y ))
636+ )
637+
638+ self ._test_convert (broadcast_args , [x , y ])
639+
640+ @chex .variants (with_jit = True , without_jit = False )
641+ def test_broadcast_args_polymorphic (self ):
642+
643+ @tf .function
644+ def broadcast_args (x , y ):
645+ return tf .broadcast_to (
646+ 0.0 , tf .broadcast_dynamic_shape (tf .shape (x ), tf .shape (y ))
647+ )
648+
649+ x = np .zeros ((1 , 2 ), dtype = np .float32 )
650+ y = np .zeros ((3 , 1 ), dtype = np .float32 )
651+ x_spec , y_spec = export .symbolic_args_specs ((x , y ), ("(_, x)" , "(y, _)" ))
652+ self ._test_convert_polymorphic (
653+ broadcast_args ,
654+ [
655+ _PolymorphicInput (
656+ tf_spec = tf .TensorSpec (shape = (1 , None ), dtype = tf .float32 ),
657+ jax_spec = x_spec ,
658+ concrete_value = x ,
659+ ),
660+ _PolymorphicInput (
661+ tf_spec = tf .TensorSpec (shape = (None , 1 ), dtype = tf .float32 ),
662+ jax_spec = y_spec ,
663+ concrete_value = y ,
664+ ),
665+ ],
666+ )
667+
530668 @chex .variants (with_jit = True , without_jit = True )
531669 def test_broadcast_to (self ):
532670 inputs , shape = np .array ([1 , 2 , 3 ]), (3 , 3 )
@@ -932,6 +1070,23 @@ def fill_static():
9321070 return tf .zeros (fill (value ))
9331071 self ._test_convert (fill_static , [])
9341072
1073+ @chex .variants (with_jit = True , without_jit = False )
1074+ def test_fill_polymorphic (self ):
1075+ @tf .function
1076+ def fill (x ):
1077+ return tf .zeros (shape = tf .shape (x ), dtype = tf .float32 )
1078+
1079+ x = np .zeros ((2 , 3 ), dtype = np .float32 )
1080+ x_spec = export .symbolic_args_specs (x , "(a, b)" )
1081+ self ._test_convert_polymorphic (
1082+ fill ,
1083+ _PolymorphicInput (
1084+ tf_spec = tf .TensorSpec (shape = (None , None ), dtype = tf .float32 ),
1085+ jax_spec = x_spec ,
1086+ concrete_value = x ,
1087+ ),
1088+ )
1089+
9351090 @chex .variants (with_jit = True , without_jit = True )
9361091 @parameterized .named_parameters (
9371092 chex .params_product (
0 commit comments