Skip to content

Commit d45294a

Browse files
hejiang0116Orbax Authors
authored andcommitted
Allow per-function Jax2ObmOptions and add native serialization disabled checks.
PiperOrigin-RevId: 875863458
1 parent 6dd0254 commit d45294a

File tree

5 files changed

+197
-39
lines changed

5 files changed

+197
-39
lines changed

export/orbax/export/jax_module.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def __init__(
7575
allow_multi_axis_sharding_consolidation: Optional[bool] = None,
7676
export_version: constants.ExportModelType = constants.ExportModelType.TF_SAVEDMODEL,
7777
jax2obm_kwargs: Optional[Mapping[str, Any]] = None,
78-
jax2obm_options: obm_configs.Jax2ObmOptions | None = None,
78+
jax2obm_options: (
79+
obm_configs.Jax2ObmOptions
80+
| Mapping[str, obm_configs.Jax2ObmOptions]
81+
| None
82+
) = None,
7983
):
8084
"""JaxModule constructor.
8185
@@ -137,7 +141,11 @@ def __init__(
137141
to the Orbax Model export. Accepted arguments are
138142
'native_serialization_platforms' which must be a tuple of
139143
OrbaxNativeSerializationType.
140-
jax2obm_options: Options for jax2obm conversion.
144+
jax2obm_options: Options for jax2obm conversion. If `apply_fn` is a
145+
mapping, this can also be a mapping from method keys to
146+
`Jax2ObmOptions`.Currently, when it is a mapping, most options must be
147+
shared across different apply functions, except for `enable_auto_layout`
148+
and `native_serialization_disabled_checks`.
141149
142150
Raises:
143151
ValueError: If `jax2obm_kwargs` and `jax2obm_options` are both provided,

export/orbax/export/modules/obm_module.py

Lines changed: 131 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections.abc import Callable, Mapping, Sequence
1818
import copy
1919
import logging
20-
from typing import Any, Optional, Tuple, Union
20+
from typing import Any, Optional, Union
2121
import warnings
2222

2323
import jax
@@ -32,6 +32,49 @@
3232
ApplyFn = orbax_export_typing.ApplyFn
3333

3434

35+
def _get_shared_value(
36+
m: Mapping[str, obm_configs.Jax2ObmOptions] | obm_configs.Jax2ObmOptions,
37+
keys: Sequence[str] | None,
38+
field_name: str,
39+
) -> Any:
40+
"""Returns attribute `field_name` from `m` or checks if it is same for all values in mapping `m`.
41+
42+
If `m` is a Jax2ObmOptions object, returns `getattr(m, field_name)`. If `m`
43+
is a mapping, it checks if all keys in `keys` are present in `m`, and if
44+
`field_name` has the same value across all values in `m`.
45+
46+
Args:
47+
m: A Jax2ObmOptions object or a mapping where values are Jax2ObmOptions
48+
objects.
49+
keys: A sequence of keys that must be present in `m` if `m` is a mapping.
50+
field_name: The attribute name to get or check for shared value.
51+
52+
Returns:
53+
The attribute value if `m` is a Jax2ObmOptions object, or the shared
54+
attribute value if `m` is a mapping and all values for `field_name` are
55+
the same.
56+
57+
Raises:
58+
ValueError: If `m` is empty, or if `m` is a mapping and any key in `keys`
59+
is not found in `m`, or if values for `field_name` are not all same.
60+
AttributeError: If `field_name` is not an attribute of `Jax2ObmOptions`.
61+
"""
62+
if not m:
63+
raise ValueError('Input mapping is empty.')
64+
if isinstance(m, obm_configs.Jax2ObmOptions):
65+
return getattr(m, field_name)
66+
if keys:
67+
for key in keys:
68+
if key not in m:
69+
raise ValueError(f'Key {key} is not found in mapping {m}.')
70+
value = getattr(next(iter(m.values())), field_name)
71+
if not all(getattr(v, field_name) == value for v in m.values()):
72+
raise ValueError(
73+
f'Not all values for `{field_name}` in the mapping are the same.'
74+
)
75+
return value
76+
77+
3578
class ObmModule(orbax_module_base.OrbaxModuleBase):
3679
"""A data module for encapsulating the data for a Jax model to be serialized through the Orbax Model export flow."""
3780

@@ -50,7 +93,11 @@ def __init__(
5093
input_polymorphic_shape_symbol_values: Union[
5194
Mapping[str, PyTree], Mapping[str, Mapping[str, PyTree]], None
5295
] = None,
53-
jax2obm_options: obm_configs.Jax2ObmOptions | None = None,
96+
jax2obm_options: (
97+
obm_configs.Jax2ObmOptions
98+
| Mapping[str, obm_configs.Jax2ObmOptions]
99+
| None
100+
) = None,
54101
jax2obm_kwargs: Union[Mapping[str, Any], None] = None,
55102
):
56103
"""Data container for Orbax Model export.
@@ -70,12 +117,25 @@ def __init__(
70117
same keys (e.g. { 'serving_default': { 'b': (1, 2), 'l': (128, 512)}).
71118
When this argument is set, the polymorphic shape will be concretized to
72119
a set of all possible concretized input shape combinations.
73-
jax2obm_options: Options for jax2obm conversion.
120+
jax2obm_options: Options for jax2obm conversion. If `apply_fn` is a
121+
mapping, this can also be a mapping from method keys to
122+
`Jax2ObmOptions`. When it is a mapping, all options must be shared
123+
across different apply functions, except for `enable_auto_layout` and
124+
`native_serialization_disabled_checks`.
74125
jax2obm_kwargs: DEPRECATED. Use `jax2obm_options` instead. A dictionary of
75126
kwargs to pass to the jax2obm conversion library. Accepted arguments to
76127
jax2obm_kwargs are 'native_serialization_platforms', 'weights_name',
77128
'checkpoint_path' and 'polymorphic_constraints'.
78129
"""
130+
if (
131+
input_polymorphic_shape is None
132+
and input_polymorphic_shape_symbol_values is not None
133+
):
134+
raise ValueError(
135+
'`input_polymorphic_shape` is required when'
136+
' `input_polymorphic_shape_symbol_values` is provided.'
137+
)
138+
79139
if jax2obm_kwargs:
80140
if jax2obm_options is not None:
81141
raise ValueError(
@@ -91,18 +151,22 @@ def __init__(
91151
jax2obm_options = obm_configs.Jax2ObmOptions()
92152
self._jax2obm_options = jax2obm_options
93153

94-
if (
95-
input_polymorphic_shape is None
96-
and input_polymorphic_shape_symbol_values is not None
97-
):
98-
raise ValueError(
99-
'`input_polymorphic_shape` is required when'
100-
' `input_polymorphic_shape_symbol_values` is provided.'
101-
)
102-
103-
enable_bf16_optimization = self._jax2obm_options.enable_bf16_optimization
154+
if isinstance(self._jax2obm_options, Mapping):
155+
if not isinstance(apply_fn, Mapping):
156+
raise ValueError(
157+
'If `jax2obm_options` is a mapping, `apply_fn` must be a mapping.'
158+
)
159+
self._apply_fn_keys = list(apply_fn.keys())
160+
else:
161+
self._apply_fn_keys = None
162+
# self._process_jax2obm_options(self._apply_fn_keys)
104163

105-
if enable_bf16_optimization:
164+
self._enable_bf16_optimization = _get_shared_value(
165+
self._jax2obm_options,
166+
self._apply_fn_keys,
167+
constants.ENABLE_BF16_OPTIMIZATION,
168+
)
169+
if self._enable_bf16_optimization:
106170
mapped_apply_fn = utils.to_bfloat16(apply_fn)
107171
self._params_args_spec = utils.to_bfloat16(params)
108172
else:
@@ -117,19 +181,37 @@ def __init__(
117181
input_polymorphic_shape,
118182
input_polymorphic_shape_symbol_values,
119183
)
120-
121-
self._jax_mesh = self._jax2obm_options.jax_mesh
122-
184+
self._jax_mesh = _get_shared_value(
185+
self._jax2obm_options,
186+
self._apply_fn_keys,
187+
constants.JAX_MESH,
188+
)
123189
self.polymorphic_constraints = self._maybe_set_polymorphic_constraints()
124190
self._native_serialization_platforms = utils.get_lowering_platforms(
125-
self._jax2obm_options.native_serialization_platforms
191+
_get_shared_value(
192+
self._jax2obm_options,
193+
self._apply_fn_keys,
194+
constants.NATIVE_SERIALIZATION_PLATFORMS,
195+
)
196+
)
197+
xla_flags_per_platform = _get_shared_value(
198+
self._jax2obm_options,
199+
self._apply_fn_keys,
200+
constants.XLA_FLAGS_PER_PLATFORM,
201+
)
202+
persist_xla_flags = _get_shared_value(
203+
self._jax2obm_options,
204+
self._apply_fn_keys,
205+
constants.PERSIST_XLA_FLAGS,
126206
)
127207

128208
self._checkpoint_path: str = None
129209
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
130210
self._maybe_set_orbax_checkpoint_path()
131-
self._load_all_checkpoint_weights = (
132-
self._jax2obm_options.load_all_checkpoint_weights
211+
self._load_all_checkpoint_weights = _get_shared_value(
212+
self._jax2obm_options,
213+
self._apply_fn_keys,
214+
constants.LOAD_ALL_CHECKPOINT_WEIGHTS,
133215
)
134216

135217
def _jax2obm_kwargs_to_options(
@@ -171,7 +253,7 @@ def _normalize_apply_fn_map(
171253
input_polymorphic_shape_symbol_values: Union[
172254
PyTree, Mapping[str, PyTree], None
173255
],
174-
) -> Tuple[
256+
) -> tuple[
175257
Mapping[
176258
str, orbax_export_typing.ApplyFn | orbax_export_typing.ApplyFnInfo
177259
],
@@ -255,17 +337,25 @@ def _normalize_apply_fn_map(
255337
)
256338

257339
def _maybe_set_orbax_checkpoint_path(self):
258-
if self._jax2obm_options.checkpoint_path is None:
340+
if (
341+
_get_shared_value(
342+
self._jax2obm_options,
343+
self._apply_fn_keys,
344+
constants.CHECKPOINT_PATH,
345+
)
346+
is None
347+
):
259348
self._weights_name = None
260349
return
261350

262351
# TODO: b/374195447 - Add a version check for the Orbax checkpointer.
263-
self._checkpoint_path = self._jax2obm_options.checkpoint_path
264-
self._weights_name = (
265-
self._jax2obm_options.weights_name
266-
if self._jax2obm_options.weights_name is not None
267-
else constants.DEFAULT_WEIGHTS_NAME
352+
self._checkpoint_path = _get_shared_value(
353+
self._jax2obm_options, self._apply_fn_keys, constants.CHECKPOINT_PATH
268354
)
355+
weights_name = _get_shared_value(
356+
self._jax2obm_options, self._apply_fn_keys, constants.WEIGHTS_NAME
357+
)
358+
self._weights_name = weights_name or constants.DEFAULT_WEIGHTS_NAME
269359

270360
def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[str]]:
271361
"""Sets the polymorphic constraints for the model.
@@ -279,10 +369,13 @@ def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[str]]:
279369
size of the apply_fn_map or if a key in apply_fn_map is not found in
280370
polymorphic_constraints.
281371
"""
282-
if isinstance(self._jax2obm_options.polymorphic_constraints, Mapping):
283-
polymorphic_constraints_mapping = (
284-
self._jax2obm_options.polymorphic_constraints
285-
)
372+
polymorphic_constraints = _get_shared_value(
373+
self._jax2obm_options,
374+
self._apply_fn_keys,
375+
constants.POLYMORPHIC_CONSTRAINTS,
376+
)
377+
if isinstance(polymorphic_constraints, Mapping):
378+
polymorphic_constraints_mapping = polymorphic_constraints
286379
if len(polymorphic_constraints_mapping) != len(self._apply_fn_map):
287380
raise ValueError(
288381
'The size of'
@@ -291,17 +384,17 @@ def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[str]]:
291384
f' apply_fn_map:{len(self._apply_fn_map)}.'
292385
)
293386
for key in self._apply_fn_map:
294-
if key not in self._jax2obm_options.polymorphic_constraints:
387+
if key not in polymorphic_constraints:
295388
raise ValueError(
296389
f'The key {key} is not found in polymorphic_constraints:'
297-
f' {self._jax2obm_options.polymorphic_constraints}'
390+
f' {polymorphic_constraints}'
298391
)
299392
else:
300393
polymorphic_constraints_mapping = {}
301-
if self._jax2obm_options.polymorphic_constraints is None:
394+
if polymorphic_constraints is None:
302395
polymorphic_constraints = ()
303396
else:
304-
polymorphic_constraints = self._jax2obm_options.polymorphic_constraints
397+
polymorphic_constraints = polymorphic_constraints
305398
if not isinstance(polymorphic_constraints, Sequence):
306399
raise TypeError(
307400
'If not a Mapping, polymorphic_constraints needs to be a'
@@ -366,6 +459,8 @@ def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
366459
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')
367460

368461
@property
369-
def jax2obm_options(self) -> obm_configs.Jax2ObmOptions:
462+
def jax2obm_options(
463+
self,
464+
) -> obm_configs.Jax2ObmOptions | Mapping[str, obm_configs.Jax2ObmOptions]:
370465
"""Returns the jax2obm_options."""
371466
return self._jax2obm_options

export/orbax/export/modules/obm_module_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jax.numpy as jnp
1919
import numpy as np
2020
from orbax.export import constants
21+
from orbax.export import obm_configs
2122
from orbax.export.modules import obm_module
2223

2324
@jax.jit
@@ -384,5 +385,54 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
384385
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
385386

386387

388+
class GetSharedValueTest(parameterized.TestCase):
389+
390+
def test_get_shared_value_single_object(self):
391+
options = obm_configs.Jax2ObmOptions(enable_bf16_optimization=True)
392+
self.assertTrue(
393+
obm_module._get_shared_value(options, None, 'enable_bf16_optimization')
394+
)
395+
396+
def test_get_shared_value_mapping_same_values(self):
397+
options_map = {
398+
'a': obm_configs.Jax2ObmOptions(enable_bf16_optimization=True),
399+
'b': obm_configs.Jax2ObmOptions(enable_bf16_optimization=True),
400+
}
401+
self.assertTrue(
402+
obm_module._get_shared_value(
403+
options_map, ['a', 'b'], 'enable_bf16_optimization'
404+
)
405+
)
406+
407+
def test_get_shared_value_mapping_different_values(self):
408+
options_map = {
409+
'a': obm_configs.Jax2ObmOptions(enable_bf16_optimization=True),
410+
'b': obm_configs.Jax2ObmOptions(enable_bf16_optimization=False),
411+
}
412+
with self.assertRaisesRegex(
413+
ValueError,
414+
r'Not all values for `enable_bf16_optimization` in the mapping are the'
415+
r' same.',
416+
):
417+
obm_module._get_shared_value(
418+
options_map, ['a', 'b'], 'enable_bf16_optimization'
419+
)
420+
421+
def test_get_shared_value_mapping_missing_key(self):
422+
options_map = {
423+
'a': obm_configs.Jax2ObmOptions(enable_bf16_optimization=True),
424+
}
425+
with self.assertRaisesRegex(
426+
ValueError, r'Key b is not found in mapping {\'a\':*'
427+
):
428+
obm_module._get_shared_value(
429+
options_map, ['a', 'b'], 'enable_bf16_optimization'
430+
)
431+
432+
def test_get_shared_value_empty_mapping(self):
433+
with self.assertRaisesRegex(ValueError, r'Input mapping is empty.'):
434+
obm_module._get_shared_value({}, ['a'], 'enable_bf16_optimization')
435+
436+
387437
if __name__ == '__main__':
388438
absltest.main()

export/orbax/export/obm_configs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import itertools
2121
import logging
2222
import jax
23+
from jax import export as jax_export
2324

2425

2526
# LINT.IfChange
@@ -309,6 +310,8 @@ class Jax2ObmOptions:
309310
bfloat16 to save memory.
310311
enable_auto_layout: If set to True, automatically generate the optimal
311312
layout for model parameters to improve serving performance.
313+
native_serialization_disabled_checks: A sequence of
314+
`jax_export.DisabledSafetyCheck` to disable when exporting.
312315
"""
313316

314317
# TODO: b/448900820 - Consider constraint the type to the proto enums.
@@ -326,6 +329,9 @@ class Jax2ObmOptions:
326329
persist_xla_flags: bool = True
327330
enable_bf16_optimization: bool = False
328331
enable_auto_layout: bool = False
332+
native_serialization_disabled_checks: Sequence[
333+
jax_export.DisabledSafetyCheck
334+
] = ()
329335

330336

331337
@dataclasses.dataclass(kw_only=True)

export/orbax/export/obm_export_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import jax
2626
import jax.numpy as jnp
2727
from jaxtyping import PyTree
28-
from orbax.export import config
2928
from orbax.export import constants
3029
from orbax.export import export_testing_utils
3130
from orbax.export import jax_module

0 commit comments

Comments
 (0)