13
13
# limitations under the License.
14
14
"""All modeling components including architecture, training and inference."""
15
15
16
+ import abc
16
17
import collections
17
18
import copy
18
19
import dataclasses
@@ -1120,17 +1121,27 @@ def predict_probs(
1120
1121
## Optimizers.
1121
1122
1122
1123
1124
+
1123
1125
def get_init_steps ():
1124
1126
return jax .lax .with_sharding_constraint (
1125
1127
jnp .array (0 , dtype = jnp .int32 ), mesh_sharding (None ))
1126
1128
1127
1129
1128
- @OptimizerRegistry .register
1129
- class SGD :
1130
- """Stochastic gradient descent optimizer."""
1130
+ class Optimizer (abc .ABC ):
1131
+ """An untra-simplified version of `flax.nn.Module`."""
1131
1132
1132
- def __init__ (self ):
1133
- pass
1133
+ def init (self , params ):
1134
+ """Initializes the state associated with the optimizer."""
1135
+
1136
+ @abc .abstractmethod
1137
+ def apply (self , state , grad ):
1138
+ """Applies the update rule to the optimizer state and the gradient."""
1139
+
1140
+
1141
+ @OptimizerRegistry .register
1142
+ @dataclasses .dataclass (frozen = True )
1143
+ class SGD (Optimizer ):
1144
+ """Stochastic Gradient Descent Optimizer."""
1134
1145
1135
1146
def init (self , params ):
1136
1147
state = {}
@@ -1143,17 +1154,15 @@ def apply(self, state, grad):
1143
1154
1144
1155
1145
1156
@OptimizerRegistry .register
1146
- class Adam :
1147
- """The Adam optimizer."""
1157
+ @dataclasses .dataclass (frozen = True )
1158
+ class Adam (Optimizer ):
1159
+ """Adam Optimizer."""
1148
1160
1149
- def __init__ (self , beta1 : float = 0.9 , beta2 : float = 0.999 ,
1150
- epsilon : float = 1e-6 ):
1151
- self .beta1 = beta1
1152
- self .beta2 = beta2
1153
- self .epsilon = epsilon
1161
+ beta1 : float = 0.9
1162
+ beta2 : float = 0.999
1163
+ epsilon : float = 1e-6
1154
1164
1155
1165
def init (self , params ):
1156
- """Initializes the optimizer state."""
1157
1166
state = {}
1158
1167
state ['params' ] = params
1159
1168
state ['m' ] = jax .tree_util .tree_map (
@@ -1181,14 +1190,13 @@ def apply(self, state, grad):
1181
1190
1182
1191
1183
1192
@OptimizerRegistry .register
1184
- class Lion :
1185
- """The Lion optimizer."""
1186
-
1187
- def __init__ (self , beta1 : float = 0.95 , beta2 : float = 0.98 ,
1188
- momentum_use_bf16 = True ):
1189
- self .momentum_dtype = jnp .bfloat16 if momentum_use_bf16 else jnp .float32
1190
- self .beta1 = jnp .array (beta1 , dtype = self .momentum_dtype )
1191
- self .beta2 = jnp .array (beta2 , dtype = self .momentum_dtype )
1193
+ @dataclasses .dataclass (frozen = True )
1194
+ class Lion (Optimizer ):
1195
+ """Lion Optimizer."""
1196
+
1197
+ beta1 : float = 0.95
1198
+ beta2 : float = 0.98
1199
+ momentum_dtype : jax .typing .DTypeLike = 'bfloat16'
1192
1200
1193
1201
def init (self , params ):
1194
1202
state = {}
0 commit comments