Skip to content

Commit 8926361

Browse files
qdhackOrbax Authors
authored and
Orbax Authors
committed
Adds support for symbolic shapes (and their constraints) to OEX V1's OBM path
PiperOrigin-RevId: 738427406
1 parent e0b27c8 commit 8926361

File tree

5 files changed

+46
-12
lines changed

5 files changed

+46
-12
lines changed

export/orbax/export/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class OrbaxNativeSerializationType(enum.Enum):
8080
# attribute on the OrbaxModule.
8181
WEIGHTS_NAME = 'weights_name'
8282

83+
# Jax2obm_kwargs key for input polymorphic constraints.
84+
POLYMORPHIC_CONSTRAINTS = 'polymorphic_constraints'
85+
8386
# Default weights name to use if a checkpoint path is provided but a weights_
8487
# name kwarg was not provided in the jax2obm_kwargs.
8588
DEFAULT_WEIGHTS_NAME = 'weights'

export/orbax/export/jax_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ def __init__(
6767
input_polymorphic_shape: the polymorhpic shape for the inputs of
6868
``apply_fn``. If ``apply_fn`` is a mapping, ``input_polymorphic_shape``
6969
must be a mapping of method key to the input polymorphic shape for the
70-
method. Currently input_polymorphic_shape is only relevant for TF
71-
SavedModel export.
70+
method.
7271
jax2tf_kwargs: options passed to jax2tf. ``polymorphic_shape`` is inferred
7372
from ``input_polymorphic_shape`` and should not be set.
7473
``with_gradient``, if set, should be consistent with the ``trainable``
@@ -104,6 +103,7 @@ def __init__(
104103
self._export_module = obm_module.ObmModule(
105104
params=params,
106105
apply_fn=apply_fn,
106+
input_polymorphic_shape=input_polymorphic_shape,
107107
jax2obm_kwargs=jax2obm_kwargs,
108108
)
109109
else:

export/orbax/export/modules/obm_module.py

+40-9
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
from collections.abc import Callable, Mapping, Sequence
1818
import copy
1919
import logging
20-
from typing import Any, Optional, Union
20+
from typing import Any, Optional, Tuple, Union
2121

2222
import jax
2323
from orbax.export import constants
2424
from orbax.export import typing as orbax_export_typing
25+
from orbax.export import utils
2526
from orbax.export.modules import orbax_module_base
2627
from orbax.export.typing import PyTree
27-
from orbax.export import utils
2828
import tensorflow as tf
2929

30+
3031
ApplyFn = orbax_export_typing.ApplyFn
3132

3233

@@ -37,13 +38,16 @@ def __init__(
3738
self,
3839
params: PyTree,
3940
apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]],
41+
*,
42+
input_polymorphic_shape: Any = None,
4043
jax2obm_kwargs: Union[Mapping[str, Any], None] = None,
4144
):
4245
"""Data container for Orbax Model export.
4346
4447
Args:
4548
params: The model parameter specs (e.g. `jax.ShapeDtypeStruct`s).
4649
apply_fn: The apply_fn for the model.
50+
input_polymorphic_shape: polymorhpic shape for the inputs of `apply_fn`.
4751
jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
4852
library. Accepted arguments to jax2obm_kwargs are
4953
'native_serialization_platforms', 'flatten_signature', 'weights_name'and
@@ -54,16 +58,29 @@ def __init__(
5458
if not jax2obm_kwargs:
5559
jax2obm_kwargs = {}
5660

57-
self._apply_fn_map = self._normalize_apply_fn_map(
58-
self._normalize_apply_fn_map(apply_fn)
61+
self._apply_fn_map, self.input_polymorphic_shape_map = (
62+
self._normalize_apply_fn_map(apply_fn, input_polymorphic_shape)
5963
)
6064

61-
if len(self._apply_fn_map) != 1:
65+
if (
66+
len(self._apply_fn_map) != 1
67+
or len(self.input_polymorphic_shape_map) != 1
68+
):
6269
raise NotImplementedError(
6370
'ObmModule: Currently the ObmExport only supports a single method'
64-
f' for export. Received: {self._apply_fn_map}'
71+
f' for export. Received apply_fn_map: {self._apply_fn_map} and'
72+
f' input_polymorphic_shape_map: {self.input_polymorphic_shape_map}'
6573
)
6674

75+
# TODO(qidichen): Consider if `self.polymorphic_constraints` should support
76+
# a map as well like `input_polymorphic_shape`. For now we assume there is
77+
# only one entry in the `apply_fn` or `input_polymorphic_shape`.
78+
self.polymorphic_constraints = jax2obm_kwargs.get(
79+
constants.POLYMORPHIC_CONSTRAINTS, None
80+
)
81+
if self.polymorphic_constraints is None:
82+
self.polymorphic_constraints = ()
83+
6784
self._native_serialization_platforms = utils.get_lowering_platforms(
6885
jax2obm_kwargs
6986
)
@@ -86,18 +103,32 @@ def __init__(
86103
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)
87104

88105
def _normalize_apply_fn_map(
89-
self, apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]]
90-
) -> Mapping[str, ApplyFn]:
106+
self,
107+
apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]],
108+
input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree], None],
109+
) -> Tuple[Mapping[str, ApplyFn], Mapping[str, Union[PyTree, None]]]:
91110
if callable(apply_fn):
92111
apply_fn_map = {constants.DEFAULT_METHOD_KEY: apply_fn}
112+
input_polymorphic_shape_map = {
113+
constants.DEFAULT_METHOD_KEY: input_polymorphic_shape
114+
}
93115
elif len(apply_fn) > 1:
94116
raise NotImplementedError(
95117
'ObmModule: Currently the ObmExport only supports a single method'
96118
f' per module. Received: {apply_fn}'
97119
)
98120
else:
99121
apply_fn_map = apply_fn
100-
return apply_fn_map
122+
if input_polymorphic_shape is None:
123+
input_polymorphic_shape_map = {constants.DEFAULT_METHOD_KEY: None}
124+
elif not isinstance(input_polymorphic_shape, Mapping):
125+
raise TypeError(
126+
'When apply_fn is a mapping, input_polymorphic_shape must also be a'
127+
' mapping.'
128+
)
129+
else:
130+
input_polymorphic_shape_map = input_polymorphic_shape
131+
return apply_fn_map, input_polymorphic_shape_map
101132

102133
def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
103134
if constants.CHECKPOINT_PATH not in jax2obm_kwargs:

export/orbax/export/obm_export.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from absl import logging
2121
import jax
22+
from jax import export as jax_export
2223
from orbax.export import constants
2324
from orbax.export import export_base
2425
from orbax.export import jax_module

export/orbax/export/obm_export_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import contextlib
1616
import os
17-
from typing import cast
1817

1918
from absl.testing import absltest
2019
from absl.testing import parameterized

0 commit comments

Comments
 (0)