|
26 | 26 | from etils import epath
|
27 | 27 | import fiddle as fdl
|
28 | 28 | from flax import struct as flax_struct
|
| 29 | +from flax.linen.fp8_ops import fm32 |
29 | 30 | import jax
|
30 | 31 | from jax import numpy as jnp
|
31 | 32 | from jax.experimental import pjit
|
|
35 | 36 | from paxml import sgf
|
36 | 37 | from paxml import tasks_lib
|
37 | 38 | from paxml import train_states
|
38 |
| -from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper, DEFAULT_INIT_MUTABLE_LIST |
| 39 | +from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST |
39 | 40 | from praxis import asserts
|
40 | 41 | from praxis import base_hyperparams
|
41 | 42 | from praxis import base_input
|
@@ -804,6 +805,21 @@ def _default_apply_fn(
|
804 | 805 | )
|
805 | 806 |
|
806 | 807 |
|
| 808 | +def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams): |
| 809 | + asserts.assert_same_structure(mdl_vars, var_weight_hparams) |
| 810 | + |
| 811 | + def _maybe_fm32_var_fn(var, var_param): |
| 812 | + if base_layer.var_overwrite_with_gradient(var_param): |
| 813 | + return jax.lax.convert_element_type(var, fm32) |
| 814 | + else: |
| 815 | + return var |
| 816 | + |
| 817 | + is_leaf = lambda x: not isinstance(x, (tuple, dict, list)) |
| 818 | + return jax.tree_util.tree_map( |
| 819 | + _maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf |
| 820 | + ) |
| 821 | + |
| 822 | + |
807 | 823 | class LossFnProtocol(Protocol):
|
808 | 824 |
|
809 | 825 | def __call__(
|
@@ -834,6 +850,8 @@ def _loss_fn(
|
834 | 850 | else:
|
835 | 851 | assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')
|
836 | 852 |
|
| 853 | + mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams) |
| 854 | + |
837 | 855 | with base_layer.JaxContext.new_context(hparams=context_p):
|
838 | 856 | k1, k2, k3 = jax.random.split(prng_key, 3)
|
839 | 857 | (metrics, per_example_output), updated_vars = apply_fn(
|
|
0 commit comments