1717from collections .abc import Callable , Mapping , Sequence
1818import copy
1919import logging
20- from typing import Any , Optional , Tuple , Union
20+ from typing import Any , Optional , Union
2121import warnings
2222
2323import jax
3232ApplyFn = orbax_export_typing .ApplyFn
3333
3434
35+ def _get_shared_value (
36+ m : Mapping [str , obm_configs .Jax2ObmOptions ] | obm_configs .Jax2ObmOptions ,
37+ keys : Sequence [str ] | None ,
38+ field_name : str ,
39+ ) -> Any :
40+ """Returns attribute `field_name` from `m` or checks if it is same for all values in mapping `m`.
41+
42+ If `m` is a Jax2ObmOptions object, returns `getattr(m, field_name)`. If `m`
43+ is a mapping, it checks if all keys in `keys` are present in `m`, and if
44+ `field_name` has the same value across all values in `m`.
45+
46+ Args:
47+ m: A Jax2ObmOptions object or a mapping where values are Jax2ObmOptions
48+ objects.
49+ keys: A sequence of keys that must be present in `m` if `m` is a mapping.
50+ field_name: The attribute name to get or check for shared value.
51+
52+ Returns:
53+ The attribute value if `m` is a Jax2ObmOptions object, or the shared
54+ attribute value if `m` is a mapping and all values for `field_name` are
55+ the same.
56+
57+ Raises:
58+ ValueError: If `m` is empty, or if `m` is a mapping and any key in `keys`
59+ is not found in `m`, or if values for `field_name` are not all same.
60+ AttributeError: If `field_name` is not an attribute of `Jax2ObmOptions`.
61+ """
62+ if not m :
63+ raise ValueError ('Input mapping is empty.' )
64+ if isinstance (m , obm_configs .Jax2ObmOptions ):
65+ return getattr (m , field_name )
66+ if keys :
67+ for key in keys :
68+ if key not in m :
69+ raise ValueError (f'Key { key } is not found in mapping { m } .' )
70+ value = getattr (next (iter (m .values ())), field_name )
71+ if not all (getattr (v , field_name ) == value for v in m .values ()):
72+ raise ValueError (
73+ f'Not all values for `{ field_name } ` in the mapping are the same.'
74+ )
75+ return value
76+
77+
3578class ObmModule (orbax_module_base .OrbaxModuleBase ):
3679 """A data module for encapsulating the data for a Jax model to be serialized through the Orbax Model export flow."""
3780
@@ -50,7 +93,11 @@ def __init__(
5093 input_polymorphic_shape_symbol_values : Union [
5194 Mapping [str , PyTree ], Mapping [str , Mapping [str , PyTree ]], None
5295 ] = None ,
53- jax2obm_options : obm_configs .Jax2ObmOptions | None = None ,
96+ jax2obm_options : (
97+ obm_configs .Jax2ObmOptions
98+ | Mapping [str , obm_configs .Jax2ObmOptions ]
99+ | None
100+ ) = None ,
54101 jax2obm_kwargs : Union [Mapping [str , Any ], None ] = None ,
55102 ):
56103 """Data container for Orbax Model export.
@@ -70,12 +117,25 @@ def __init__(
70117 same keys (e.g. { 'serving_default': { 'b': (1, 2), 'l': (128, 512)}).
71118 When this argument is set, the polymorphic shape will be concretized to
72119 a set of all possible concretized input shape combinations.
73- jax2obm_options: Options for jax2obm conversion.
120+ jax2obm_options: Options for jax2obm conversion. If `apply_fn` is a
121+ mapping, this can also be a mapping from method keys to
122+ `Jax2ObmOptions`. When it is a mapping, all options must be shared
123+ across different apply functions, except for `enable_auto_layout` and
124+ `native_serialization_disabled_checks`.
74125 jax2obm_kwargs: DEPRECATED. Use `jax2obm_options` instead. A dictionary of
75126 kwargs to pass to the jax2obm conversion library. Accepted arguments to
76127 jax2obm_kwargs are 'native_serialization_platforms', 'weights_name',
77128 'checkpoint_path' and 'polymorphic_constraints'.
78129 """
130+ if (
131+ input_polymorphic_shape is None
132+ and input_polymorphic_shape_symbol_values is not None
133+ ):
134+ raise ValueError (
135+ '`input_polymorphic_shape` is required when'
136+ ' `input_polymorphic_shape_symbol_values` is provided.'
137+ )
138+
79139 if jax2obm_kwargs :
80140 if jax2obm_options is not None :
81141 raise ValueError (
@@ -91,18 +151,22 @@ def __init__(
91151 jax2obm_options = obm_configs .Jax2ObmOptions ()
92152 self ._jax2obm_options = jax2obm_options
93153
94- if (
95- input_polymorphic_shape is None
96- and input_polymorphic_shape_symbol_values is not None
97- ):
98- raise ValueError (
99- '`input_polymorphic_shape` is required when'
100- ' `input_polymorphic_shape_symbol_values` is provided.'
101- )
102-
103- enable_bf16_optimization = self ._jax2obm_options .enable_bf16_optimization
154+ if isinstance (self ._jax2obm_options , Mapping ):
155+ if not isinstance (apply_fn , Mapping ):
156+ raise ValueError (
157+ 'If `jax2obm_options` is a mapping, `apply_fn` must be a mapping.'
158+ )
159+ self ._apply_fn_keys = list (apply_fn .keys ())
160+ else :
161+ self ._apply_fn_keys = None
162+ # self._process_jax2obm_options(self._apply_fn_keys)
104163
105- if enable_bf16_optimization :
164+ self ._enable_bf16_optimization = _get_shared_value (
165+ self ._jax2obm_options ,
166+ self ._apply_fn_keys ,
167+ constants .ENABLE_BF16_OPTIMIZATION ,
168+ )
169+ if self ._enable_bf16_optimization :
106170 mapped_apply_fn = utils .to_bfloat16 (apply_fn )
107171 self ._params_args_spec = utils .to_bfloat16 (params )
108172 else :
@@ -117,19 +181,37 @@ def __init__(
117181 input_polymorphic_shape ,
118182 input_polymorphic_shape_symbol_values ,
119183 )
120-
121- self ._jax_mesh = self ._jax2obm_options .jax_mesh
122-
184+ self ._jax_mesh = _get_shared_value (
185+ self ._jax2obm_options ,
186+ self ._apply_fn_keys ,
187+ constants .JAX_MESH ,
188+ )
123189 self .polymorphic_constraints = self ._maybe_set_polymorphic_constraints ()
124190 self ._native_serialization_platforms = utils .get_lowering_platforms (
125- self ._jax2obm_options .native_serialization_platforms
191+ _get_shared_value (
192+ self ._jax2obm_options ,
193+ self ._apply_fn_keys ,
194+ constants .NATIVE_SERIALIZATION_PLATFORMS ,
195+ )
196+ )
197+ xla_flags_per_platform = _get_shared_value (
198+ self ._jax2obm_options ,
199+ self ._apply_fn_keys ,
200+ constants .XLA_FLAGS_PER_PLATFORM ,
201+ )
202+ persist_xla_flags = _get_shared_value (
203+ self ._jax2obm_options ,
204+ self ._apply_fn_keys ,
205+ constants .PERSIST_XLA_FLAGS ,
126206 )
127207
128208 self ._checkpoint_path : str = None
129209 # Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
130210 self ._maybe_set_orbax_checkpoint_path ()
131- self ._load_all_checkpoint_weights = (
132- self ._jax2obm_options .load_all_checkpoint_weights
211+ self ._load_all_checkpoint_weights = _get_shared_value (
212+ self ._jax2obm_options ,
213+ self ._apply_fn_keys ,
214+ constants .LOAD_ALL_CHECKPOINT_WEIGHTS ,
133215 )
134216
135217 def _jax2obm_kwargs_to_options (
@@ -171,7 +253,7 @@ def _normalize_apply_fn_map(
171253 input_polymorphic_shape_symbol_values : Union [
172254 PyTree , Mapping [str , PyTree ], None
173255 ],
174- ) -> Tuple [
256+ ) -> tuple [
175257 Mapping [
176258 str , orbax_export_typing .ApplyFn | orbax_export_typing .ApplyFnInfo
177259 ],
@@ -255,17 +337,25 @@ def _normalize_apply_fn_map(
255337 )
256338
257339 def _maybe_set_orbax_checkpoint_path (self ):
258- if self ._jax2obm_options .checkpoint_path is None :
340+ if (
341+ _get_shared_value (
342+ self ._jax2obm_options ,
343+ self ._apply_fn_keys ,
344+ constants .CHECKPOINT_PATH ,
345+ )
346+ is None
347+ ):
259348 self ._weights_name = None
260349 return
261350
262351 # TODO: b/374195447 - Add a version check for the Orbax checkpointer.
263- self ._checkpoint_path = self ._jax2obm_options .checkpoint_path
264- self ._weights_name = (
265- self ._jax2obm_options .weights_name
266- if self ._jax2obm_options .weights_name is not None
267- else constants .DEFAULT_WEIGHTS_NAME
352+ self ._checkpoint_path = _get_shared_value (
353+ self ._jax2obm_options , self ._apply_fn_keys , constants .CHECKPOINT_PATH
268354 )
355+ weights_name = _get_shared_value (
356+ self ._jax2obm_options , self ._apply_fn_keys , constants .WEIGHTS_NAME
357+ )
358+ self ._weights_name = weights_name or constants .DEFAULT_WEIGHTS_NAME
269359
270360 def _maybe_set_polymorphic_constraints (self ) -> Mapping [str , Sequence [str ]]:
271361 """Sets the polymorphic constraints for the model.
@@ -279,10 +369,13 @@ def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[str]]:
279369 size of the apply_fn_map or if a key in apply_fn_map is not found in
280370 polymorphic_constraints.
281371 """
282- if isinstance (self ._jax2obm_options .polymorphic_constraints , Mapping ):
283- polymorphic_constraints_mapping = (
284- self ._jax2obm_options .polymorphic_constraints
285- )
372+ polymorphic_constraints = _get_shared_value (
373+ self ._jax2obm_options ,
374+ self ._apply_fn_keys ,
375+ constants .POLYMORPHIC_CONSTRAINTS ,
376+ )
377+ if isinstance (polymorphic_constraints , Mapping ):
378+ polymorphic_constraints_mapping = polymorphic_constraints
286379 if len (polymorphic_constraints_mapping ) != len (self ._apply_fn_map ):
287380 raise ValueError (
288381 'The size of'
@@ -291,17 +384,17 @@ def _maybe_set_polymorphic_constraints(self) -> Mapping[str, Sequence[str]]:
291384 f' apply_fn_map:{ len (self ._apply_fn_map )} .'
292385 )
293386 for key in self ._apply_fn_map :
294- if key not in self . _jax2obm_options . polymorphic_constraints :
387+ if key not in polymorphic_constraints :
295388 raise ValueError (
296389 f'The key { key } is not found in polymorphic_constraints:'
297- f' { self . _jax2obm_options . polymorphic_constraints } '
390+ f' { polymorphic_constraints } '
298391 )
299392 else :
300393 polymorphic_constraints_mapping = {}
301- if self . _jax2obm_options . polymorphic_constraints is None :
394+ if polymorphic_constraints is None :
302395 polymorphic_constraints = ()
303396 else :
304- polymorphic_constraints = self . _jax2obm_options . polymorphic_constraints
397+ polymorphic_constraints = polymorphic_constraints
305398 if not isinstance (polymorphic_constraints , Sequence ):
306399 raise TypeError (
307400 'If not a Mapping, polymorphic_constraints needs to be a'
@@ -366,6 +459,8 @@ def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
366459 raise NotImplementedError ('apply_fn_map is not implemented for ObmModule.' )
367460
368461 @property
369- def jax2obm_options (self ) -> obm_configs .Jax2ObmOptions :
462+ def jax2obm_options (
463+ self ,
464+ ) -> obm_configs .Jax2ObmOptions | Mapping [str , obm_configs .Jax2ObmOptions ]:
370465 """Returns the jax2obm_options."""
371466 return self ._jax2obm_options
0 commit comments