17
17
from collections .abc import Callable , Mapping , Sequence
18
18
import copy
19
19
import logging
20
- from typing import Any , Optional , Union
20
+ from typing import Any , Optional , Tuple , Union
21
21
22
22
import jax
23
23
from orbax .export import constants
24
24
from orbax .export import typing as orbax_export_typing
25
+ from orbax .export import utils
25
26
from orbax .export .modules import orbax_module_base
26
27
from orbax .export .typing import PyTree
27
- from orbax .export import utils
28
28
import tensorflow as tf
29
29
30
+
30
31
ApplyFn = orbax_export_typing .ApplyFn
31
32
32
33
@@ -37,13 +38,16 @@ def __init__(
37
38
self ,
38
39
params : PyTree ,
39
40
apply_fn : Union [ApplyFn , Mapping [str , ApplyFn ]],
41
+ * ,
42
+ input_polymorphic_shape : Any = None ,
40
43
jax2obm_kwargs : Union [Mapping [str , Any ], None ] = None ,
41
44
):
42
45
"""Data container for Orbax Model export.
43
46
44
47
Args:
45
48
params: The model parameter specs (e.g. `jax.ShapeDtypeStruct`s).
46
49
apply_fn: The apply_fn for the model.
50
+ input_polymorphic_shape: polymorhpic shape for the inputs of `apply_fn`.
47
51
jax2obm_kwargs: A dictionary of kwargs to pass to the jax2obm conversion
48
52
library. Accepted arguments to jax2obm_kwargs are
49
53
'native_serialization_platforms', 'flatten_signature', 'weights_name'and
@@ -54,16 +58,29 @@ def __init__(
54
58
if not jax2obm_kwargs :
55
59
jax2obm_kwargs = {}
56
60
57
- self ._apply_fn_map = self ._normalize_apply_fn_map (
58
- self ._normalize_apply_fn_map (apply_fn )
61
+ self ._apply_fn_map , self .input_polymorphic_shape_map = (
62
+ self ._normalize_apply_fn_map (apply_fn , input_polymorphic_shape )
59
63
)
60
64
61
- if len (self ._apply_fn_map ) != 1 :
65
+ if (
66
+ len (self ._apply_fn_map ) != 1
67
+ or len (self .input_polymorphic_shape_map ) != 1
68
+ ):
62
69
raise NotImplementedError (
63
70
'ObmModule: Currently the ObmExport only supports a single method'
64
- f' for export. Received: { self ._apply_fn_map } '
71
+ f' for export. Received apply_fn_map: { self ._apply_fn_map } and'
72
+ f' input_polymorphic_shape_map: { self .input_polymorphic_shape_map } '
65
73
)
66
74
75
+ # TODO(qidichen): Consider if `self.polymorphic_constraints` should support
76
+ # a map as well like `input_polymorphic_shape`. For now we assume there is
77
+ # only one entry in the `apply_fn` or `input_polymorphic_shape`.
78
+ self .polymorphic_constraints = jax2obm_kwargs .get (
79
+ constants .POLYMORPHIC_CONSTRAINTS , None
80
+ )
81
+ if self .polymorphic_constraints is None :
82
+ self .polymorphic_constraints = ()
83
+
67
84
self ._native_serialization_platforms = utils .get_lowering_platforms (
68
85
jax2obm_kwargs
69
86
)
@@ -86,18 +103,32 @@ def __init__(
86
103
self ._maybe_set_orbax_checkpoint_path (jax2obm_kwargs )
87
104
88
105
def _normalize_apply_fn_map (
89
- self , apply_fn : Union [ApplyFn , Mapping [str , ApplyFn ]]
90
- ) -> Mapping [str , ApplyFn ]:
106
+ self ,
107
+ apply_fn : Union [ApplyFn , Mapping [str , ApplyFn ]],
108
+ input_polymorphic_shape : Union [PyTree , Mapping [str , PyTree ], None ],
109
+ ) -> Tuple [Mapping [str , ApplyFn ], Mapping [str , Union [PyTree , None ]]]:
91
110
if callable (apply_fn ):
92
111
apply_fn_map = {constants .DEFAULT_METHOD_KEY : apply_fn }
112
+ input_polymorphic_shape_map = {
113
+ constants .DEFAULT_METHOD_KEY : input_polymorphic_shape
114
+ }
93
115
elif len (apply_fn ) > 1 :
94
116
raise NotImplementedError (
95
117
'ObmModule: Currently the ObmExport only supports a single method'
96
118
f' per module. Received: { apply_fn } '
97
119
)
98
120
else :
99
121
apply_fn_map = apply_fn
100
- return apply_fn_map
122
+ if input_polymorphic_shape is None :
123
+ input_polymorphic_shape_map = {constants .DEFAULT_METHOD_KEY : None }
124
+ elif not isinstance (input_polymorphic_shape , Mapping ):
125
+ raise TypeError (
126
+ 'When apply_fn is a mapping, input_polymorphic_shape must also be a'
127
+ ' mapping.'
128
+ )
129
+ else :
130
+ input_polymorphic_shape_map = input_polymorphic_shape
131
+ return apply_fn_map , input_polymorphic_shape_map
101
132
102
133
def _maybe_set_orbax_checkpoint_path (self , jax2obm_kwargs ):
103
134
if constants .CHECKPOINT_PATH not in jax2obm_kwargs :
0 commit comments