Skip to content

Add Transformer Engine support to Paxml #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
13 changes: 8 additions & 5 deletions paxml/contrib/gpu/scripts_gpu/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paxml.contrib.gpu.scripts_gpu.tasks import LambadaDataset
from paxml.contrib.gpu.scripts_gpu.tasks import PileUnsupervisedDataset
from paxml.tasks.lm.model_params import maybe_setup_moe_params
from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper
from paxml.tasks.lm.params.c4 import TransformerLmSpmdAdam
from paxml.tasks.lm.params.lm_cloud import SyntheticDataset
from praxis import base_layer
Expand Down Expand Up @@ -116,7 +117,7 @@ class GPT126MBase(TransformerLmSpmdAdam):

MAX_SEQ_LEN = 2048
VOCAB_SIZE = 50304
PACKED_INPUT = True
PACKED_INPUT = False
PERCORE_BATCH_SIZE = 4

NUM_LAYERS = 12
Expand Down Expand Up @@ -171,10 +172,12 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated
):
stacked_p = stacked_p.block
transformer_layer_p = stacked_p.transformer_layer_params_tpl
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True

task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True
if not TransformerEngineHelper.is_enabled_te():
transformer_layer_p = stacked_p.transformer_layer_params_tpl
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True

model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
Expand Down Expand Up @@ -239,7 +242,7 @@ class GPT175BBase(GPT126MBase):
# Known as MLP_DIM in t5x
HIDDEN_DIMS = MODEL_DIMS * 4
# Defaults to MODEL_DIMS // NUM_HEADS.
DIMS_PER_HEAD = None
DIMS_PER_HEAD = 128
# Known as NUM_EMBEDDINGS in t5x
VOCAB_SIZE = 50257
USE_REPEATED_LAYER = True
Expand Down
76 changes: 76 additions & 0 deletions paxml/contrib/gpu/scripts_gpu/te_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from contextlib import contextmanager

from praxis import base_layer

try:
import transformer_engine.jax as te
from transformer_engine.common import recipe
_IS_TRANSFORMER_ENGINE_INSTALLED = True
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME]

except ModuleNotFoundError as e:
_IS_TRANSFORMER_ENGINE_INSTALLED = False
DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST


class TransformerEngineHelperBase:

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
raise NotImplementedError


class TENotInstalledHelper(TransformerEngineHelperBase):

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
try:
yield
finally:
pass


class TEInstalledHelper(TransformerEngineHelperBase):

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID,
amax_history_len=1024, amax_compute_algo='max')

enable_fp8 = bool(int((os.environ.get("ENABLE_FP8", False))))
try:
with te.fp8_autocast(enabled=enable_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis,
tp_resource=tp_mesh_axis,
fsdp_resource=fsdp_mesh_axis)):
yield
finally:
pass


class TransformerEngineHelper(TransformerEngineHelperBase):

@staticmethod
def is_enabled_te():
enable_te = bool(int((os.environ.get("ENABLE_TE", False))))
return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te)

@staticmethod
def get_helper():
if TransformerEngineHelper.is_enabled_te():
return TEInstalledHelper
return TENotInstalledHelper

@staticmethod
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
try:
with TransformerEngineHelper.get_helper().fp8_autocast(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis):
yield
finally:
pass
73 changes: 38 additions & 35 deletions paxml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from paxml import trainer_lib
from paxml import tuning_lib
from paxml import ml_monitoring
from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper
from praxis import pax_fiddle
from praxis import py_utils

Expand Down Expand Up @@ -519,37 +520,38 @@ def create_experiment_config():
),
)

if FLAGS.exp is not None:
experiment_config = get_experiment(FLAGS.exp)()
elif absl_flags.fdl_flags_supplied():
# Use the legacy Fiddle flags API to parse command line Fiddle flags.
cfg = absl_flags.create_buildable_from_flags(
module=None, allow_imports=True)
experiment_config = pax_fiddle.build(cfg)
logging.warning(
'Legacy Fiddle flags API usage detected. Please use the new Fiddle'
' command line flag `fdl` with various commands to specify the'
' config and any overrides. Please see'
' `fiddle/docs/flags_code_lab.md` for more'
' documentation on Fiddle flags usage.'
)
elif _FIDDLE_CONFIG.value is not None:
# This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse
# command line Fiddle flags. See
# `fiddle/docs/flags_code_lab.md` for details on the new
# Fiddle flags API.
logging.info(
'Using pax_fiddle_config from the command line: %s',
_FIDDLE_CONFIG.value,
)
experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value)
else:
raise app.UsageError(
'No experiment provided. At least one of --exp, --fdl,'
' --fdl_config, or --fdl_config_file is required.'
)
with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'):
if FLAGS.exp is not None:
experiment_config = get_experiment(FLAGS.exp)()
elif absl_flags.fdl_flags_supplied():
# Use the legacy Fiddle flags API to parse command line Fiddle flags.
cfg = absl_flags.create_buildable_from_flags(
module=None, allow_imports=True)
experiment_config = pax_fiddle.build(cfg)
logging.warning(
'Legacy Fiddle flags API usage detected. Please use the new Fiddle'
' command line flag `fdl` with various commands to specify the'
' config and any overrides. Please see'
' `fiddle/docs/flags_code_lab.md` for more'
' documentation on Fiddle flags usage.'
)
elif _FIDDLE_CONFIG.value is not None:
# This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse
# command line Fiddle flags. See
# `fiddle/docs/flags_code_lab.md` for details on the new
# Fiddle flags API.
logging.info(
'Using pax_fiddle_config from the command line: %s',
_FIDDLE_CONFIG.value,
)
experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value)
else:
raise app.UsageError(
'No experiment provided. At least one of --exp, --fdl,'
' --fdl_config, or --fdl_config_file is required.'
)

experiment_config.validate()
experiment_config.validate()
return experiment_config


Expand All @@ -565,11 +567,12 @@ def _main(argv: Sequence[str]) -> None:
with ml_monitoring.ml_event_logger(ml_monitoring.MlEvent.INITIALIZE_SETUP):
experiment_config = create_experiment_config()

run(
experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving,
startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs,
)
with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'):
run(
experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving,
startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs,
)


_TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+')
Expand Down
3 changes: 2 additions & 1 deletion paxml/tasks_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from paxml import io_utils
from paxml import learners as learners_lib
from paxml import train_states
from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST
from praxis import asserts
from praxis import base_hyperparams
from praxis import base_input
Expand Down Expand Up @@ -1786,7 +1787,7 @@ def _apply_init_checkpoint_rule(
)
# Initialize with a dummy seed
var_weight_hparams = ckpt_task.model.abstract_init_with_metadata(
inputs_shape_dtype)
inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
ckpt_train_state = ckpt_task.create_train_state_padded_shapes(
var_weight_hparams)
train_state_pspecs = ckpt_task.create_train_state_partition_specs(
Expand Down
40 changes: 31 additions & 9 deletions paxml/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from etils import epath
import fiddle as fdl
from flax import struct as flax_struct
from flax.linen.fp8_ops import fm32
import jax
from jax import numpy as jnp
from jax.experimental import pjit
Expand All @@ -35,6 +36,7 @@
from paxml import sgf
from paxml import tasks_lib
from paxml import train_states
from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST
from praxis import asserts
from praxis import base_hyperparams
from praxis import base_input
Expand Down Expand Up @@ -167,8 +169,7 @@ def create_train_state_metadata(
A TrainStateMetadata instance.
"""
var_weight_hparams = jax_task.model.abstract_init_with_metadata(
train_shape_dtype, do_eval=do_eval
)
train_shape_dtype, do_eval=do_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
padded_global_shapes = jax_task.create_train_state_padded_shapes(
var_weight_hparams, discard_opt_states=discard_opt_states
)
Expand Down Expand Up @@ -217,7 +218,8 @@ def write_post_init_model_hparams_file(
logging.info('post_init_model_params: %s', params_fpath)
job_log_dir.mkdir(parents=True, exist_ok=True)
hyper_params = model.abstract_init_with_mdl_config(
train_state_metadata.input_shape_dtype, do_eval=do_eval
train_state_metadata.input_shape_dtype, do_eval=do_eval,
extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST
)
with params_fpath.open('w') as params_file:
hyper_params_dump = base_hyperparams.nested_struct_to_text(hyper_params)
Expand Down Expand Up @@ -379,7 +381,8 @@ def initialize_model_state(
is_eval_for_init = is_eval
if not var_weight_hparams:
var_weight_hparams = model.abstract_init_with_metadata(
inputs_shape_dtype, do_eval=is_eval_for_init
inputs_shape_dtype, do_eval=is_eval_for_init,
extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST
)
logging.info('init_var prng_seed: %s', init_key)
logging.info('var_weight_hparams: %s', var_weight_hparams)
Expand All @@ -396,7 +399,7 @@ def init_fn(init_key):
inputs = jax.tree.map(jnp.zeros_like, inputs_shape_dtype)
if model.hparams.fprop_dtype == jnp.bfloat16:
inputs = jax.tree.map(_maybe_to_bfloat16, inputs)
return model.init(init_key, inputs)
return model.init(init_key, inputs, mutable=DEFAULT_INIT_MUTABLE_LIST)

initial_vars = init_fn(init_key)
logging.info('initial_vars: %s', jax.tree.map(jnp.shape, initial_vars))
Expand Down Expand Up @@ -802,14 +805,28 @@ def _default_apply_fn(
)


def _maybe_to_fm32_vars(mdl_vars, var_weight_hparams):
asserts.assert_same_structure(mdl_vars, var_weight_hparams)

def _maybe_fm32_var_fn(var, var_param):
if base_layer.var_overwrite_with_gradient(var_param):
return jax.lax.convert_element_type(var, fm32)
else:
return var

is_leaf = lambda x: not isinstance(x, (tuple, dict, list))
return jax.tree_util.tree_map(
_maybe_fm32_var_fn, mdl_vars, var_weight_hparams, is_leaf=is_leaf
)


class LossFnProtocol(Protocol):

def __call__(
self, mdl_vars: NestedJTensor, inputs: NestedMap, prng_key: PRNGKey
) -> tuple[JTensor, sgf.GradAuxInfo]:
"""Produces losses and grad info by passing the inputs through a model."""


def _get_default_loss_fn(
jax_task: tasks_lib.SingleTask,
context_p: base_layer.JaxContext.HParams,
Expand All @@ -833,6 +850,8 @@ def _loss_fn(
else:
assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

mdl_vars = _maybe_to_fm32_vars(mdl_vars, var_weight_hparams)

with base_layer.JaxContext.new_context(hparams=context_p):
k1, k2, k3 = jax.random.split(prng_key, 3)
(metrics, per_example_output), updated_vars = apply_fn(
Expand Down Expand Up @@ -994,6 +1013,7 @@ def get_excluded_var_masks(
excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad(
var_weight_hparams, learner
)

_log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad)

# Excluded for optimizer states.
Expand Down Expand Up @@ -1090,7 +1110,7 @@ def train_step_single_learner(

if not var_weight_hparams:
with base_layer.JaxContext.new_context(hparams=context_p):
var_weight_hparams = model.abstract_init_with_metadata(inputs)
var_weight_hparams = model.abstract_init_with_metadata(inputs, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)
updated_model_vars = jax_task.maybe_adjust_train_state( # pytype: disable=wrong-arg-types # jax-ndarray
step=states.step,
mdl_vars=states.mdl_vars,
Expand Down Expand Up @@ -1162,6 +1182,7 @@ def train_step_single_learner(
wps_with_opt = tasks_lib.filter_vars_for_grad_or_opt(
var_weight_hparams, excluded_for_learner
)

transformed_grads, new_opt_states = learner.update_states(
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
Expand Down Expand Up @@ -1197,6 +1218,7 @@ def train_step_single_learner(
states.mdl_vars,
mdl_vars,
)

new_states = states.new_state(
mdl_vars=mdl_vars, opt_states=[new_opt_states], extra_state=()
)
Expand Down Expand Up @@ -1300,7 +1322,7 @@ def eval_step_single_learner(
var_weight_hparams = model.abstract_init_with_metadata(
inputs,
do_eval=not jax_task.hparams.train.always_use_train_for_model_init,
)
extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST)

if fprop_dtype == jnp.float32:
pass
Expand Down Expand Up @@ -1554,7 +1576,7 @@ def initialize_partitioned_model_states(
model = jax_task.model
if not var_weight_hparams:
var_weight_hparams = model.abstract_init_with_metadata(
global_input_shapes, do_eval=is_eval
global_input_shapes, do_eval=is_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST
)

train_state_partition_specs = (
Expand Down