Skip to content

Commit 4067915

Browse files
committed
adds llama3 MXFP8 NVFP4
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 470e10d commit 4067915

File tree

12 files changed

+335
-148
lines changed

12 files changed

+335
-148
lines changed

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
import warnings
1919
from collections import OrderedDict
20+
from contextlib import nullcontext
2021
from typing import ClassVar, Unpack
2122

2223
import torch
2324
import torch.nn as nn
25+
import transformer_engine.common.recipe
2426
import transformer_engine.pytorch
2527
import transformers
2628
from transformer_engine.pytorch.attention import InferenceParams
@@ -50,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
5052
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5153
attn_input_format: str = "thd"
5254
self_attn_mask_type: str = "padding_causal"
55+
layer_precision: list[str | None] | None = None
5356

5457

5558
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -159,11 +162,54 @@ def _init_method(x):
159162
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
160163
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
161164

165+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
166+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
167+
162168
self.gradient_checkpointing = False
163169

164170
# Initialize weights and apply final processing
165171
self.post_init()
166172

173+
def set_recipes(
174+
self,
175+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
176+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
177+
) -> None:
178+
"""Attach quantization recipe objects for per-layer autocast.
179+
180+
Recipes are not serializable and must be set at runtime after model creation
181+
and sharding (FSDP/DDP) but before training. The per-layer precision
182+
assignments are read from ``self.config.layer_precision``.
183+
184+
Args:
185+
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
186+
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
187+
"""
188+
self._fp8_recipe = fp8_recipe
189+
self._fp4_recipe = fp4_recipe
190+
191+
def get_layer_autocast(self, layer_number: int):
192+
"""Return the appropriate TE autocast context manager for a given layer.
193+
194+
The context interacts with the outer FP8 autocast in the training script:
195+
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
196+
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
197+
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
198+
199+
Args:
200+
layer_number: The 0-indexed layer number.
201+
202+
Returns:
203+
A context manager for the layer's quantization mode.
204+
"""
205+
precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
206+
if precision == "fp8":
207+
return nullcontext()
208+
elif precision == "fp4":
209+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
210+
else:
211+
return transformer_engine.pytorch.autocast(enabled=False)
212+
167213
def forward(
168214
self,
169215
input_ids: torch.Tensor | None = None,
@@ -240,23 +286,27 @@ def forward(
240286
if te_rope_emb.dtype == torch.float32:
241287
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
242288

243-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
244-
if output_hidden_states:
245-
all_hidden_states = (*all_hidden_states, hidden_states)
246-
247-
hidden_states = decoder_layer(
248-
hidden_states,
249-
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
250-
rotary_pos_emb=te_rope_emb,
251-
inference_params=past_key_values,
252-
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
253-
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
254-
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
255-
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
256-
max_seqlen_q=kwargs.get("max_length_q", None),
257-
max_seqlen_kv=kwargs.get("max_length_k", None),
258-
pad_between_seqs=kwargs.get("pad_between_seqs", None),
259-
)
289+
# Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
290+
# by get_layer_autocast(), which nests inside this context.
291+
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
292+
for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
293+
if output_hidden_states:
294+
all_hidden_states = (*all_hidden_states, hidden_states)
295+
296+
with self.get_layer_autocast(layer_number):
297+
hidden_states = decoder_layer(
298+
hidden_states,
299+
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
300+
rotary_pos_emb=te_rope_emb,
301+
inference_params=past_key_values,
302+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
303+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
304+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
305+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
306+
max_seqlen_q=kwargs.get("max_length_q", None),
307+
max_seqlen_kv=kwargs.get("max_length_k", None),
308+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
309+
)
260310

261311
hidden_states = self.norm(hidden_states)
262312

bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection:
22
enabled: True
33
layers:
44
# Match the actual linear layers within attention that support FP8 stats
5-
layer_types: [layernorm_qkv]
5+
layer_types: [layernorm_qkv, proj, fc1, fc2]
66
transformer_engine:
77
LogFp8TensorStats:
88
enabled: True
@@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection:
1616
- tensor: weight
1717
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
1818
freq: 10
19+
LogTensorStats:
20+
enabled: True
21+
stats: [max, min, mean, std, l1_norm]
22+
tensors: [dgrad, wgrad]
23+
freq: 1

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ fp8_config:
4444
quantized_model_init_kwargs:
4545
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.
4646

47+
fp4_config:
48+
enabled: false
49+
fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
50+
fp4_format: "E2M1"
51+
fp4_recipe_kwargs: {}
52+
4753
# Optimizer config
4854
adamw_kwargs:
4955
lr: 3e-3
@@ -70,10 +76,15 @@ checkpoint:
7076
logger:
7177
frequency: 100
7278

73-
fp8_stats_config:
79+
quant_stats_config:
7480
enabled: false
75-
fp8_stats_file: ./fp8_debugging_stats.yaml
76-
fp8_log_dir: ./log_fp8_stats
81+
quant_stats_file: ./fp8_debugging_stats.yaml
82+
quant_log_dir: ./log_quant_stats
83+
84+
# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
85+
fp8_layers: null
86+
fp4_layers: null
87+
use_fp32_master_weights: null
7788

7889
profiler:
7990
enabled: false

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
import warnings
1919
from collections import OrderedDict
20+
from contextlib import nullcontext
2021
from typing import ClassVar, Unpack
2122

2223
import torch
2324
import torch.nn as nn
25+
import transformer_engine.common.recipe
2426
import transformer_engine.pytorch
2527
import transformers
2628
from transformer_engine.pytorch.attention import InferenceParams
@@ -50,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
5052
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5153
attn_input_format: str = "thd"
5254
self_attn_mask_type: str = "padding_causal"
55+
layer_precision: list[str | None] | None = None
5356

5457

5558
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -159,11 +162,54 @@ def _init_method(x):
159162
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
160163
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
161164

165+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
166+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
167+
162168
self.gradient_checkpointing = False
163169

164170
# Initialize weights and apply final processing
165171
self.post_init()
166172

173+
def set_recipes(
174+
self,
175+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
176+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
177+
) -> None:
178+
"""Attach quantization recipe objects for per-layer autocast.
179+
180+
Recipes are not serializable and must be set at runtime after model creation
181+
and sharding (FSDP/DDP) but before training. The per-layer precision
182+
assignments are read from ``self.config.layer_precision``.
183+
184+
Args:
185+
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
186+
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
187+
"""
188+
self._fp8_recipe = fp8_recipe
189+
self._fp4_recipe = fp4_recipe
190+
191+
def get_layer_autocast(self, layer_number: int):
192+
"""Return the appropriate TE autocast context manager for a given layer.
193+
194+
The context interacts with the outer FP8 autocast in the training script:
195+
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
196+
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
197+
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
198+
199+
Args:
200+
layer_number: The 0-indexed layer number.
201+
202+
Returns:
203+
A context manager for the layer's quantization mode.
204+
"""
205+
precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
206+
if precision == "fp8":
207+
return nullcontext()
208+
elif precision == "fp4":
209+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
210+
else:
211+
return transformer_engine.pytorch.autocast(enabled=False)
212+
167213
def forward(
168214
self,
169215
input_ids: torch.Tensor | None = None,
@@ -240,23 +286,27 @@ def forward(
240286
if te_rope_emb.dtype == torch.float32:
241287
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
242288

243-
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
244-
if output_hidden_states:
245-
all_hidden_states = (*all_hidden_states, hidden_states)
246-
247-
hidden_states = decoder_layer(
248-
hidden_states,
249-
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
250-
rotary_pos_emb=te_rope_emb,
251-
inference_params=past_key_values,
252-
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
253-
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
254-
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
255-
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
256-
max_seqlen_q=kwargs.get("max_length_q", None),
257-
max_seqlen_kv=kwargs.get("max_length_k", None),
258-
pad_between_seqs=kwargs.get("pad_between_seqs", None),
259-
)
289+
# Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
290+
# by get_layer_autocast(), which nests inside this context.
291+
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
292+
for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
293+
if output_hidden_states:
294+
all_hidden_states = (*all_hidden_states, hidden_states)
295+
296+
with self.get_layer_autocast(layer_number):
297+
hidden_states = decoder_layer(
298+
hidden_states,
299+
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
300+
rotary_pos_emb=te_rope_emb,
301+
inference_params=past_key_values,
302+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
303+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
304+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
305+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
306+
max_seqlen_q=kwargs.get("max_length_q", None),
307+
max_seqlen_kv=kwargs.get("max_length_k", None),
308+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
309+
)
260310

261311
hidden_states = self.norm(hidden_states)
262312

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step:
9191
self.grad_acc_step_count = 0
9292

9393
# Whether to step debug_api.step() after each step
94-
self.fp8_stats_enabled = args.fp8_stats_config.enabled
94+
self.quant_stats_config = args.quant_stats_config.enabled
9595

9696
@nvtx.annotate("PerfLogger.log_micro_step", color="pink")
9797
def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast):
@@ -150,7 +150,7 @@ def log_step(
150150
if self._profiler is not None:
151151
self._profiler.step(step)
152152

153-
if self.fp8_stats_enabled:
153+
if self.quant_stats_config:
154154
debug_api.step()
155155

156156
if step % self.logging_frequency == 0 and step > 0:
@@ -201,15 +201,15 @@ def log_step(
201201

202202
def finish(self):
203203
"""Finish the logger and close the progress bar."""
204+
if self.quant_stats_config:
205+
debug_api.end_debug()
206+
204207
if not self._dist_config.is_main_process():
205208
return
206209

207210
wandb.finish()
208211
self._progress_bar.close()
209212

210-
if self.fp8_stats_enabled:
211-
debug_api.end_debug()
212-
213213

214214
class NsightProfiler:
215215
"""Nsight Systems profiler wrapper for performance analysis.

0 commit comments

Comments
 (0)