Skip to content

Commit 1485d94

Browse files
rryanTF2JAXDev
authored andcommitted
Support tf2jax.convert in TensorFlow tracing contexts.
Adds a new configuration option `skip_variables_evaluation_inside_tf_tracing` which users can set to True to enable `tf2jax.convert` to work in TensorFlow 2 tracing contexts. Fore example, when exporting a JAX model as a TF saved model, where that JAX model uses `tf2jax.convert` internally. PiperOrigin-RevId: 825829137
1 parent 552931e commit 1485d94

File tree

3 files changed

+124
-11
lines changed

3 files changed

+124
-11
lines changed

tf2jax/_src/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
# converting `tensorflow_probability`'s `HiddenMarkovModel.posterior_mode`
4040
# to JAX. Here we allow to disable it.
4141
disable_assert_in_tensor_list_get_item=False,
42+
# If True, avoids evaluating variables in a TF tracing context. Set to False
43+
# when using tf2jax.convert from within a tf.function in TF2 mode (e.g. when
44+
# exporting a JAX function using tf2jax.convert internally as a TensorFlow
45+
# saved model).
46+
skip_variables_evaluation_inside_tf_tracing=False,
4247
)
4348

4449

tf2jax/_src/tf2jax.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def maybe_tensor_to_spec(v):
537537
exp_args, exp_kwargs = structured_inputs
538538
if exp_args:
539539
raise ValueError("If function_spec is None then only keyword arguments "
540-
f"are expectd, found args={exp_args} in structure.")
540+
f"are expected, found args={exp_args} in structure.")
541541
parameters = tuple([
542542
# TODO(b/266552275) Remove temporary fix for TF-Hub.
543543
inspect.Parameter(
@@ -999,7 +999,7 @@ def _convert(
999999
constants: A mapping from tensor names to constant values. The keys are a
10001000
subset of captured_input_names.
10011001
library: A mapping from function names to Callable. This is non-empty on
1002-
recurisve calls if the FunctionDefLibrary in the GraphDef is non-empty.
1002+
recursive calls if the FunctionDefLibrary in the GraphDef is non-empty.
10031003
10041004
Returns:
10051005
A tuple: the first is a Jax functions that takes a flat parameter dict and a
@@ -1068,7 +1068,7 @@ def _convert(
10681068
raise ValueError(err_message)
10691069

10701070
# Extract variables.
1071-
if tf.executing_eagerly():
1071+
if tf.executing_eagerly() or tf.inside_function():
10721072
# Uniqueify variables with identical names.
10731073
variables_tf = {}
10741074
var_name_by_ref = {}
@@ -1082,11 +1082,23 @@ def _convert(
10821082
variables_tf[var_name] = v
10831083
var_name_by_ref[v.ref()] = var_name
10841084

1085-
variables = {
1086-
k: Variable(v.numpy(), v.trainable, v.name)
1087-
for k, v in variables_tf.items()
1088-
}
1089-
else:
1085+
if tf.inside_function():
1086+
# We cannot evaluate variables inside a tf.function.
1087+
variables = {}
1088+
if (
1089+
not config.get_config("skip_variables_evaluation_inside_tf_tracing")
1090+
and variable_map
1091+
):
1092+
raise ValueError(
1093+
"Unable to to evaluate variables inside a TF tracing context, "
1094+
"and `skip_variables_evaluation_inside_tf_tracing` is False."
1095+
)
1096+
else:
1097+
variables = {
1098+
k: Variable(v.numpy(), v.trainable, v.name)
1099+
for k, v in variables_tf.items()
1100+
}
1101+
else: # We are in TF1 mode
10901102
variables_tf = {_parse_input(v.name): v for _, v in variable_map.items()}
10911103
var_name_by_ref = {
10921104
v.ref(): _parse_input(v.name) for v in variable_map.values()
@@ -1103,7 +1115,10 @@ def _convert(
11031115

11041116
assert len(variable_map) == len(variables_tf)
11051117
assert len(variable_map) == len(var_name_by_ref)
1106-
assert len(variable_map) == len(variables)
1118+
if not tf.inside_function() or not config.get_config(
1119+
"skip_variables_evaluation_inside_tf_tracing"
1120+
):
1121+
assert len(variable_map) == len(variables)
11071122

11081123
var_by_node = {k: var_name_by_ref[v.ref()] for k, v in variable_map.items()}
11091124
node_by_var = {v: k for k, v in var_by_node.items()}

tf2jax/_src/tf2jax_test.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import inspect
1818

1919
from absl.testing import parameterized
20-
2120
import chex
21+
import flax
2222
from flax import linen as nn
2323
import jax
24+
from jax.experimental import jax2tf
2425
import jax.numpy as jnp
2526
import numpy as np
26-
2727
import tensorflow as tf
2828
from tf2jax._src import config
2929
from tf2jax._src import tf2jax
@@ -63,6 +63,34 @@ def __call__(self, x):
6363
return x
6464

6565

66+
class FlaxTestDense(nn.Module):
67+
"""A Flax module that wraps a TestDense saved model."""
68+
69+
saved_model_path: str
70+
input_dim: int
71+
72+
def setup(self):
73+
self.m = tf.saved_model.load(self.saved_model_path)
74+
if self.is_initializing():
75+
_, params = tf2jax.convert(
76+
tf.function(self.m.__call__),
77+
tf.TensorSpec([None, self.input_dim], tf.float32),
78+
)
79+
params = flax.traverse_util.unflatten_dict(params, sep="/")
80+
self.params = self.param("saved_model", lambda _: params)
81+
else:
82+
self.params = self.param("saved_model", lambda _: None)
83+
84+
def __call__(self, x: jax.Array) -> jax.Array:
85+
fn, _ = tf2jax.convert(
86+
tf.function(self.m.__call__),
87+
tf.TensorSpec(x.shape, tf.float32),
88+
)
89+
params = flax.traverse_util.flatten_dict(self.params, sep="/")
90+
y, _ = fn(params, x)
91+
return y
92+
93+
6694
class FeaturesTest(tf.test.TestCase, parameterized.TestCase):
6795

6896
def _setup_saved_model(self, *inputs):
@@ -494,6 +522,71 @@ def tf_func(x):
494522
):
495523
self.variant(jax_func)({"aaa": jax_params["aaa"]}, np_inputs)
496524

525+
def test_export_saved_model_export_jax_module(self):
526+
input_dim = 5
527+
l = TestDense(input_dim=input_dim, output_size=5)
528+
l.__call__ = tf.function(
529+
l.__call__, input_signature=[tf.TensorSpec((None, input_dim))]
530+
)
531+
532+
x_tf = tf.ones((2, input_dim))
533+
y_tf = l(x_tf)
534+
kernel_tf, bias_tf = l.variables
535+
tf_export_dir = self.create_tempdir()
536+
jax_export_dir = self.create_tempdir()
537+
tf.saved_model.save(l, tf_export_dir)
538+
539+
module = FlaxTestDense(tf_export_dir, input_dim=input_dim)
540+
x_jax = jnp.ones((2, input_dim))
541+
y_jax, jax_params = module.init_with_output(jax.random.key(0), x_jax)
542+
543+
kernel_jax, bias_jax = jax.tree.leaves(jax_params)
544+
545+
np.testing.assert_array_equal(kernel_tf.numpy(), kernel_jax)
546+
np.testing.assert_array_equal(bias_tf.numpy(), bias_jax)
547+
np.testing.assert_allclose(y_tf.numpy(), y_jax, atol=1e-6, rtol=1e-6)
548+
549+
def model_fn(params, inputs): # The JAX model function to export.
550+
fm = FlaxTestDense(tf_export_dir, input_dim=input_dim)
551+
return fm.apply(params, inputs)
552+
553+
class ExportWrapper(tf.Module):
554+
555+
def __init__(self, params):
556+
super().__init__()
557+
self._tf_fn = jax2tf.convert(model_fn)
558+
self._params = params
559+
560+
@tf.function(
561+
input_signature=[
562+
tf.TensorSpec(shape=[1, input_dim], dtype=tf.float32)
563+
]
564+
)
565+
def __call__(self, x):
566+
return self._tf_fn(self._params, x)
567+
568+
with config.override_config(
569+
"skip_variables_evaluation_inside_tf_tracing", False
570+
), self.assertRaises(ValueError):
571+
tf.saved_model.save(
572+
ExportWrapper(jax_params),
573+
jax_export_dir,
574+
options=tf.saved_model.SaveOptions(
575+
experimental_custom_gradients=False
576+
),
577+
)
578+
579+
with config.override_config(
580+
"skip_variables_evaluation_inside_tf_tracing", True
581+
):
582+
tf.saved_model.save(
583+
ExportWrapper(jax_params),
584+
jax_export_dir,
585+
options=tf.saved_model.SaveOptions(
586+
experimental_custom_gradients=False
587+
),
588+
)
589+
497590

498591
if __name__ == "__main__":
499592
tf.test.main()

0 commit comments

Comments
 (0)