Skip to content

Commit 2404c85

Browse files
Update model_lib.py
Update optimizers.
1 parent f8f2aec commit 2404c85

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

hero/model_lib.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""All modeling components including architecture, training and inference."""
1515

16+
import abc
1617
import collections
1718
import copy
1819
import dataclasses
@@ -1120,17 +1121,27 @@ def predict_probs(
11201121
## Optimizers.
11211122

11221123

1124+
11231125
def get_init_steps():
11241126
return jax.lax.with_sharding_constraint(
11251127
jnp.array(0, dtype=jnp.int32), mesh_sharding(None))
11261128

11271129

1128-
@OptimizerRegistry.register
1129-
class SGD:
1130-
"""Stochastic gradient descent optimizer."""
1130+
class Optimizer(abc.ABC):
1131+
"""An untra-simplified version of `flax.nn.Module`."""
11311132

1132-
def __init__(self):
1133-
pass
1133+
def init(self, params):
1134+
"""Initializes the state associated with the optimizer."""
1135+
1136+
@abc.abstractmethod
1137+
def apply(self, state, grad):
1138+
"""Applies the update rule to the optimizer state and the gradient."""
1139+
1140+
1141+
@OptimizerRegistry.register
1142+
@dataclasses.dataclass(frozen=True)
1143+
class SGD(Optimizer):
1144+
"""Stochastic Gradient Descent Optimizer."""
11341145

11351146
def init(self, params):
11361147
state = {}
@@ -1143,17 +1154,15 @@ def apply(self, state, grad):
11431154

11441155

11451156
@OptimizerRegistry.register
1146-
class Adam:
1147-
"""The Adam optimizer."""
1157+
@dataclasses.dataclass(frozen=True)
1158+
class Adam(Optimizer):
1159+
"""Adam Optimizer."""
11481160

1149-
def __init__(self, beta1: float = 0.9, beta2: float = 0.999,
1150-
epsilon: float = 1e-6):
1151-
self.beta1 = beta1
1152-
self.beta2 = beta2
1153-
self.epsilon = epsilon
1161+
beta1: float = 0.9
1162+
beta2: float = 0.999
1163+
epsilon: float = 1e-6
11541164

11551165
def init(self, params):
1156-
"""Initializes the optimizer state."""
11571166
state = {}
11581167
state['params'] = params
11591168
state['m'] = jax.tree_util.tree_map(
@@ -1181,14 +1190,13 @@ def apply(self, state, grad):
11811190

11821191

11831192
@OptimizerRegistry.register
1184-
class Lion:
1185-
"""The Lion optimizer."""
1186-
1187-
def __init__(self, beta1: float = 0.95, beta2: float = 0.98,
1188-
momentum_use_bf16=True):
1189-
self.momentum_dtype = jnp.bfloat16 if momentum_use_bf16 else jnp.float32
1190-
self.beta1 = jnp.array(beta1, dtype=self.momentum_dtype)
1191-
self.beta2 = jnp.array(beta2, dtype=self.momentum_dtype)
1193+
@dataclasses.dataclass(frozen=True)
1194+
class Lion(Optimizer):
1195+
"""Lion Optimizer."""
1196+
1197+
beta1: float = 0.95
1198+
beta2: float = 0.98
1199+
momentum_dtype: jax.typing.DTypeLike = 'bfloat16'
11921200

11931201
def init(self, params):
11941202
state = {}

0 commit comments

Comments
 (0)