Skip to content

Commit 4516dbb

Browse files
mingxu1067terrykong
authored andcommitted
Support FM32 to OWG parameters.
Signed-off-by: Ming Huang <[email protected]>
1 parent 65b0055 commit 4516dbb

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

paxml/trainer_lib.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from etils import epath
2727
import fiddle as fdl
2828
from flax import struct as flax_struct
29+
from flax.linen.fp8_ops import fm32
2930
import jax
3031
from jax import numpy as jnp
3132
from jax.experimental import pjit
@@ -35,7 +36,7 @@
3536
from paxml import sgf
3637
from paxml import tasks_lib
3738
from paxml import train_states
38-
from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper, DEFAULT_INIT_MUTABLE_LIST
39+
from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST
3940
from praxis import asserts
4041
from praxis import base_hyperparams
4142
from praxis import base_input
@@ -804,6 +805,21 @@ def _default_apply_fn(
804805
)
805806

806807

808+
def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams):
809+
asserts.assert_same_structure(mdl_vars, var_weight_hparams)
810+
811+
def _maybe_fm32_var_fn(var, var_param):
812+
if base_layer.var_overwrite_with_gradient(var_param):
813+
return jax.lax.convert_element_type(var, fm32)
814+
else:
815+
return var
816+
817+
is_leaf = lambda x: not isinstance(x, (tuple, dict, list))
818+
return jax.tree_util.tree_map(
819+
_maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf
820+
)
821+
822+
807823
class LossFnProtocol(Protocol):
808824

809825
def __call__(
@@ -834,6 +850,8 @@ def _loss_fn(
834850
else:
835851
assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')
836852

853+
mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams)
854+
837855
with base_layer.JaxContext.new_context(hparams=context_p):
838856
k1, k2, k3 = jax.random.split(prng_key, 3)
839857
(metrics, per_example_output), updated_vars = apply_fn(

0 commit comments

Comments
 (0)