diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index 033eb5ebe3..4e859b4868 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -17,10 +17,12 @@ import warnings from collections import OrderedDict +from contextlib import nullcontext from typing import ClassVar, Unpack import torch import torch.nn as nn +import transformer_engine.common.recipe import transformer_engine.pytorch import transformers from transformer_engine.pytorch.attention import InferenceParams @@ -50,6 +52,7 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None class NVLlamaPreTrainedModel(PreTrainedModel): @@ -159,11 +162,54 @@ def _init_method(x): self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def set_recipes( + self, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ) -> None: + """Attach quantization recipe objects for per-layer autocast. + + Recipes are not serializable and must be set at runtime after model creation + and sharding (FSDP/DDP) but before training. The per-layer precision + assignments are read from ``self.config.layer_precision``. + + Args: + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, input_ids: torch.Tensor | None = None, @@ -240,23 +286,27 @@ def forward( if te_rope_emb.dtype == torch.float32: warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states = (*all_hidden_states, hidden_states) - - hidden_states = decoder_layer( - hidden_states, - attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, - rotary_pos_emb=te_rope_emb, - inference_params=past_key_values, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled + # by get_layer_autocast(), which nests inside this context. + with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe): + for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + with self.get_layer_autocast(layer_number): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) hidden_states = self.norm(hidden_states) diff --git a/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py new file mode 100644 index 0000000000..eb93415d50 --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import subprocess + +import pytest +import torch +from transformer_engine.pytorch.fp8 import check_fp8_support + + +def requires_fp8(func): + """Decorator to skip tests that require FP8 support.""" + fp8_available, reason = check_fp8_support() + return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func) + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"]) +@requires_fp8 +def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): + cmd = [ + "torchrun", + "--nproc_per_node=1", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + os.path.relpath(__file__), + "--strategy", + strategy, + ] + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"]) +@requires_fp8 +@requires_multi_gpu +def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port): + cmd = [ + "torchrun", + "--nproc_per_node=2", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + os.path.relpath(__file__), + "--strategy", + strategy, + ] + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +if __name__ == "__main__": + import argparse + import enum + import os + import sys + from dataclasses import dataclass, field + from pathlib import Path + + # Ensure the model directory is on sys.path for bare module imports. + sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix()) + + import torch.distributed as dist + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + from torch.optim import AdamW + from transformer_engine.pytorch.fp8 import DelayedScaling, Format + + from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM + + def recursive_assert(a, b, path=""): + if isinstance(a, dict) and isinstance(b, dict): + assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}" + for k in a: + recursive_assert(a[k], b[k], path=f"{path}.{k}") + elif isinstance(a, list) and isinstance(b, list): + assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}" + for i in range(len(a)): + recursive_assert(a[i], b[i], path=f"{path}.{i}") + elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}") + else: + assert a == b, f"Value mismatch at {path}: {a} != {b}" + + class Strategy(enum.StrEnum): + DDP = "ddp" + FSDP2 = "fsdp2" + + @dataclass + class DistributedConfig: + """Class to track distributed ranks.""" + + rank: int = field(default_factory=dist.get_rank) + local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"])) + world_size: int = field(default_factory=dist.get_world_size) + + def is_main_process(self) -> bool: + """This is the global rank 0 process, to be used for wandb logging, etc.""" + return self.rank == 0 + + parser = argparse.ArgumentParser() + parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP]) + args = parser.parse_args() + + torch.distributed.init_process_group(backend="nccl") + dist_config = DistributedConfig() + torch.cuda.set_device(dist_config.local_rank) + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size, 1), + mesh_dim_names=("dp", "tp"), + ) + device = f"cuda:{dist_config.local_rank}" + + fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10) + + config = NVLlamaConfig( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=6, + num_attention_heads=8, + num_key_value_heads=4, + vocab_size=100, + dtype=torch.bfloat16, + ) + config.layer_precision = ["fp8"] * config.num_hidden_layers + model = NVLlamaForCausalLM(config) + + if args.strategy is Strategy.FSDP2: + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + model.to(device) + + elif args.strategy is Strategy.DDP: + model.to(device) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["dp"], + ) + + optimizer = AdamW(model.parameters()) + + # Attach FP8 recipes to the model (layer precision is already on config). + llama_model = model.module.model if args.strategy is Strategy.DDP else model.model + llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) + + model.train() + + generator = torch.Generator() + generator.manual_seed(torch.distributed.get_rank()) + + for _ in range(3): + input_data = { + "input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator), + "labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator), + "attention_mask": torch.ones(1, 32), + } + input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()} + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**input_data) + + outputs.loss.backward() + + # Access FP8 extra states directly from modules instead of state_dict() + # since state_dict() now filters them out for HuggingFace compatibility + fp8_extra_states = {} + for name, module in model.named_modules(): + if hasattr(module, "_extra_state") and callable(module._extra_state): + extra_state = module._extra_state() + if extra_state is not None and len(extra_state) > 0: + fp8_extra_states[f"{name}._extra_state"] = extra_state + + # lm_head is BF16, not FP8, so exclude it from FP8 checks + fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key} + + # 2 ranks, test to ensure that both ranks have the same FP8 extra states + if torch.distributed.get_world_size() == 2: + outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None + torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0) + if torch.distributed.get_rank() == 0: + assert outputs_list is not None + + for key in outputs_list[0]: + state_1 = outputs_list[0][key] + state_2 = outputs_list[1][key] + assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0" + assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1" + dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes()) + dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes()) + recursive_assert(dict_1, dict_2) + + # One rank, test to ensure the correct FP8 extra states are saved + if torch.distributed.get_world_size() == 1: + for key, val in fp8_extra_states.items(): + assert len(val) > 0, f"No FP8 extra states for {key}" + fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes()) + assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}" + + torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/models/llama3/tests/test_layer_quantization.py b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py new file mode 100644 index 0000000000..a80ff80f2c --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for NVLlamaModel.set_recipes and get_layer_autocast.""" + +from contextlib import nullcontext +from unittest.mock import patch + +import pytest +import transformer_engine.common.recipe +import transformer_engine.pytorch + +from modeling_llama_te import NVLlamaConfig, NVLlamaModel + + +@pytest.fixture +def model(): + """Create a small NVLlamaModel for testing.""" + config = NVLlamaConfig( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=6, + num_attention_heads=8, + num_key_value_heads=4, + vocab_size=100, + ) + return NVLlamaModel(config) + + +# -- set_recipes -- + + +def test_all_fp8(model): + model.config.layer_precision = ["fp8"] * 6 + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) + assert model._fp8_recipe is fp8_recipe + assert model._fp4_recipe is None + assert all(p == "fp8" for p in model.config.layer_precision) + + +def test_all_fp4(model): + model.config.layer_precision = ["fp4"] * 6 + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe) + assert model._fp8_recipe is None + assert model._fp4_recipe is fp4_recipe + assert all(p == "fp4" for p in model.config.layer_precision) + + +def test_all_bf16(model): + model.config.layer_precision = [None] * 6 + model.set_recipes(fp8_recipe=None, fp4_recipe=None) + assert all(p is None for p in model.config.layer_precision) + + +def test_mixed_fp8_fp4(model): + model.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + assert model.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_mixed_fp8_bf16(model): + model.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) + assert model.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None] + + +def test_mixed_all_three(model): + model.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + assert model.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"] + + +def test_covers_all_layers(model): + model.config.layer_precision = ["fp8"] + [None] * 5 + model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None) + assert len(model.config.layer_precision) == 6 + + +def test_recipes_stored_as_attributes(model): + model.config.layer_precision = ["fp8", "fp4", None, None, None, None] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + assert model._fp8_recipe is fp8_recipe + assert model._fp4_recipe is fp4_recipe + # The precision list only contains strings/None, not recipe objects. + for v in model.config.layer_precision: + assert v is None or isinstance(v, str) + + +# -- get_layer_autocast -- + + +def test_fp8_layer_returns_nullcontext(model): + model.config.layer_precision = ["fp8"] + [None] * 5 + model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None) + ctx = model.get_layer_autocast(0) + assert isinstance(ctx, nullcontext) + + +def test_fp4_layer_returns_te_autocast(model): + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.config.layer_precision = ["fp4"] + [None] * 5 + model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "fp4_context" + ctx = model.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe) + assert ctx == "fp4_context" + + +def test_bf16_layer_returns_te_autocast_disabled(model): + model.config.layer_precision = [None] * 6 + model.set_recipes(fp8_recipe=None, fp4_recipe=None) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + ctx = model.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=False) + assert ctx == "bf16_context" + + +def test_uninitialized_defaults_to_bf16(model): + """When layer_precision is None (default), all layers default to BF16.""" + assert model.config.layer_precision is None + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + ctx = model.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=False) + assert ctx == "bf16_context" + + +def test_mixed_layers_return_correct_contexts(model): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None] + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + + # FP8 layers -> nullcontext + assert isinstance(model.get_layer_autocast(0), nullcontext) + assert isinstance(model.get_layer_autocast(1), nullcontext) + + # FP4 layers -> te.pytorch.autocast + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "fp4_context" + model.get_layer_autocast(2) + mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe) + + # BF16 layers -> te.pytorch.autocast(enabled=False) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + model.get_layer_autocast(4) + mock_autocast.assert_called_with(enabled=False) + + +def test_layer_precision_is_pickleable(model): + """The config.layer_precision list should be trivially pickleable.""" + import pickle + + model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None] + roundtripped = pickle.loads(pickle.dumps(model.config.layer_precision)) + assert roundtripped == model.config.layer_precision diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 7f593157ab..1042855cc9 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -1,8 +1,8 @@ # TransformerEngine-accelerated Llama 3 training with native PyTorch training loop This folder demonstrates how to train TE-accelerated Llama 3 with a native PyTorch training loop, including sequence -packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. This recipe is configured -for genomic sequences using a custom nucleotide tokenizer. +packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed +training. This recipe is configured for genomic sequences using a custom nucleotide tokenizer. ## How to use this recipe @@ -16,9 +16,9 @@ bionemo-framework repository. You can download a zipped directory of this folder ## Supported Models and Training Features -| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism | -| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ | -| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | +| Model | BF16 | FP8[1] | MXFP8[2] | NVFP4[3] | THD Input Format | Context Parallelism | Tensor Parallelism | +| ---------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- | ------------------ | +| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅: Supported
🚧: Under development
@@ -26,6 +26,7 @@ bionemo-framework repository. You can download a zipped directory of this folder \[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+)
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending
+\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+)
### Installing Dependencies @@ -64,11 +65,16 @@ def compute_model_pflops(seq_len, global_batch_size, step_time_s): return model_flops / 1e15 ``` +### Low precision performance benchmarks + +![Performance Benchmarks Low Precision](../../../docs/docs/assets/images/llama3/llama3_8gpu_tflops.png) +In the above plot we can see the performance increases as we lower the precision of our transformer layers across the 1B and 8B variant of LLAMA3. + ### Convergence Benchmarks

- Llama 3 Lingua 1B Loss Curve - Llama 3 Lingua 1B Step Time + Llama 3 Lingua 1B Loss Curve + Llama 3 Lingua 1B Step Time

We compared the convergence of this Llama3 recipe (with FSDP2) against NeMo 2.0 @@ -88,6 +94,10 @@ are due checkpointing, further work will be done to improve training step time s Models were trained on 64 NVIDIA H100 GPUs with a micro batch size of 4 and a context length of 4096 for 60,000 steps. Training was performed with BF16 precision. +### Low Precision convergence benchmarks + + + ### Distributed Training This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Context Parallelism, shown in three separate training entrypoints: @@ -127,35 +137,71 @@ batch size while running on a smaller number of GPUs. python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2 ``` -### FP8 Training +### Quantized Training (FP8 / MXFP8 / NVFP4) To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8 -configuration parameters, including switching to `MXFP8BlockScaling`, can be set via the hydra configuration. +configuration parameters, including switching to `MXFP8BlockScaling`, can be set using the hydra configuration. ```bash python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true ``` -#### FP8 Debugging +Similarly, to train with NVFP4 quantization: + +```bash +python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true +``` + +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet +supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today. + +Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration. + +#### Layer-Wise Precision + +You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and +`fp4_layers`. Layers not assigned to either format will run in BF16. + +For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers: + +```bash +python train_fsdp2.py --config-name L0_sanity \ + fp8_config.enabled=true \ + fp4_config.enabled=true \ + 'fp8_layers=[1,2,3]' \ + 'fp4_layers=[4,5,6]' +``` + +When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically +claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no +`fp4_layers`, then layers 4 through N will default to FP4. + +#### Quantization Stats Debugging -We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients. +We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training. +When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are +tracked. -To enable this please select the following config options. +To enable stats logging: ```bash python train_fsdp2.py \ - fp8_stats_config.enabled=True \ - fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \ - fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \ - fp8_config.enabled=True + quant_stats_config.enabled=true \ + quant_stats_config.quant_log_dir=./logs/quant_stats \ + quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \ + fp8_config.enabled=true ``` -Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet +supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today. -The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure. +The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the +[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) +in more detail. -This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our -experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit. +Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats +on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We +recommend using `freq>=10` to reduce this performance hit. ### Sequence Packing (THD input format) diff --git a/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml new file mode 100644 index 0000000000..9046d44caf --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,33 @@ +example_fp4_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming) + # This matches: model.model.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.model\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 100 + - tensor: gradient + stats: [underflows%, mse] + freq: 100 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming) + # This matches: model.model.layers.[6-10].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.model\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 + - tensor: gradient + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py deleted file mode 100644 index d01024f04c..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -from pathlib import Path - -import nvdlfw_inspect.api as debug_api -import transformer_engine - -from distributed_config import DistributedConfig - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -def initialize_fp8_debugging( - dist_config: DistributedConfig, - enabled: bool, - fp8_stats_file: str, - fp8_log_dir: str | os.PathLike, - fp8_enabled: bool, -) -> None: - """Initialize FP8 debugging. - - Args: - dist_config: The distributed configuration. - enabled: Whether to enable FP8 debugging. - fp8_stats_file: The file containing the FP8 stats. - fp8_log_dir: The directory to log the FP8 stats to. - fp8_enabled: Whether FP8 autocast is enabled. - """ - if not enabled: - return - - if not fp8_enabled: - raise ValueError( - "fp8_stats_config.enabled is true but fp8_config.enabled is false, " - "please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" - ) - - fp8_log_dir = Path(fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") - te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") - debug_api.initialize( - config_file=fp8_stats_file, - feature_dirs=[te_features_dir], - log_dir=fp8_log_dir.as_posix(), - default_logging_enabled=True, - ) diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml index 7544bbedcf..ba640a6cbb 100644 --- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad] + freq: 1 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index d6c181598f..be71c57e78 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -44,6 +44,12 @@ fp8_config: quantized_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + # Optimizer config adamw_kwargs: lr: 3e-3 @@ -70,10 +76,15 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +fp8_layers: null +fp4_layers: null +use_fp32_master_weights: null profiler: enabled: false diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 033eb5ebe3..4e859b4868 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -17,10 +17,12 @@ import warnings from collections import OrderedDict +from contextlib import nullcontext from typing import ClassVar, Unpack import torch import torch.nn as nn +import transformer_engine.common.recipe import transformer_engine.pytorch import transformers from transformer_engine.pytorch.attention import InferenceParams @@ -50,6 +52,7 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None class NVLlamaPreTrainedModel(PreTrainedModel): @@ -159,11 +162,54 @@ def _init_method(x): self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def set_recipes( + self, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ) -> None: + """Attach quantization recipe objects for per-layer autocast. + + Recipes are not serializable and must be set at runtime after model creation + and sharding (FSDP/DDP) but before training. The per-layer precision + assignments are read from ``self.config.layer_precision``. + + Args: + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, input_ids: torch.Tensor | None = None, @@ -240,23 +286,27 @@ def forward( if te_rope_emb.dtype == torch.float32: warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states = (*all_hidden_states, hidden_states) - - hidden_states = decoder_layer( - hidden_states, - attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, - rotary_pos_emb=te_rope_emb, - inference_params=past_key_values, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled + # by get_layer_autocast(), which nests inside this context. + with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe): + for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + with self.get_layer_autocast(layer_number): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) hidden_states = self.norm(hidden_states) diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 726eb19e8e..4b1a8d4ec7 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -91,7 +91,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: self.grad_acc_step_count = 0 # Whether to step debug_api.step() after each step - self.fp8_stats_enabled = args.fp8_stats_config.enabled + self.quant_stats_config = args.quant_stats_config.enabled @nvtx.annotate("PerfLogger.log_micro_step", color="pink") def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast): @@ -150,7 +150,7 @@ def log_step( if self._profiler is not None: self._profiler.step(step) - if self.fp8_stats_enabled: + if self.quant_stats_config: debug_api.step() if step % self.logging_frequency == 0 and step > 0: @@ -201,15 +201,15 @@ def log_step( def finish(self): """Finish the logger and close the progress bar.""" + if self.quant_stats_config: + debug_api.end_debug() + if not self._dist_config.is_main_process(): return wandb.finish() self._progress_bar.close() - if self.fp8_stats_enabled: - debug_api.end_debug() - class NsightProfiler: """Nsight Systems profiler wrapper for performance analysis. diff --git a/bionemo-recipes/recipes/llama3_native_te/quantization.py b/bionemo-recipes/recipes/llama3_native_te/quantization.py new file mode 100644 index 0000000000..e479b13c02 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/quantization.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for layer-wise quantization configuration (FP8/FP4).""" + +import logging +import tempfile +from pathlib import Path + +import yaml + + +logger = logging.getLogger(__name__) + + +def generate_layer_regex(layer_numbers: list[int] | None) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + The debug API (nvdlfw_inspect) uses 1-indexed layer names after ``infer_and_assign_layer_names``. + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in debug logs). + If empty or None, returns a pattern that matches nothing. + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + return r"model\.model\.layers\.DISABLED_NO_LAYERS_SPECIFIED" + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.model\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + if "example_fp4_tensor_stat_collection" in config: + # TODO: Remove this block and replace with FP8-style regex update once a TransformerEngine + # release with LogNvfp4TensorStats support is available. At that point, this becomes: + # fp4_regex = generate_layer_regex(fp4_layers) + # config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + config["example_fp4_tensor_stat_collection"]["enabled"] = False + if fp4_layers: + logger.warning( + "NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). " + f"Disabling FP4 stats collection for layers {fp4_layers}. FP8 stats will still be collected." + ) + else: + logger.info("FP4 stats section disabled (no FP4 layers and feature not yet supported)") + + if "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") + + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + +def initialize_quant_stats_logging( + quant_stats_file: str, + quant_log_dir: str, + rank: int, + layer_precision: list[str | None], +) -> None: + """Set up quantization stats logging via nvdlfw_inspect. + + Updates the quant stats YAML config with resolved layer regex patterns, creates + the per-rank log directory, and initializes the debug API. + + Args: + quant_stats_file: Path to the base quant stats YAML config file. + quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created). + rank: The global rank of this process. + layer_precision: Per-layer precision list (0-indexed by position). Each element is + ``"fp8"``, ``"fp4"``, or ``None``. + """ + import nvdlfw_inspect.api as debug_api + import transformer_engine + + # Derive 1-indexed layer lists for the debug API, which uses 1-indexed layer names. + fp8_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp8"] or None + fp4_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp4"] or None + updated_config = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=fp4_layers_1indexed, + fp8_layers=fp8_layers_1indexed, + ) + + rank_log_dir = Path(quant_log_dir) / f"rank_{rank}" + rank_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {rank_log_dir}") + + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=updated_config, + feature_dirs=[te_features_dir], + log_dir=rank_log_dir, + default_logging_enabled=True, + ) + + +def resolve_layer_precision( + num_layers: int, + fp8_enabled: bool, + fp4_enabled: bool, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, +) -> list[str | None]: + """Resolve layer-wise quantization assignments from user config. + + Takes 1-indexed layer lists (as specified by the user in YAML config) and returns a per-layer + precision list (0-indexed by position). When a quantization format is enabled but no layer list + is provided, all layers default to that format. When one format has explicit layers and the other + is enabled without a layer list, the unspecified format defaults to the remaining (unclaimed) layers. + + Args: + num_layers: Total number of transformer layers in the model. + fp8_enabled: Whether FP8 quantization is enabled. + fp4_enabled: Whether FP4 quantization is enabled. + fp8_layers: 1-indexed list of layers for FP8, or None if not specified. + fp4_layers: 1-indexed list of layers for FP4, or None if not specified. + + Returns: + A list of length ``num_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` + (BF16 fallback), indexed by layer position (0-indexed). + + Raises: + ValueError: If both formats are enabled with no layer lists, or if layer lists overlap. + """ + all_layers = set(range(1, num_layers + 1)) + + if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None: + raise ValueError( + "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. " + "When both are enabled, you must explicitly provide layer lists to indicate which layers use which format." + ) + + # When one format has explicit layers and the other defaults, fill in the remaining layers. + if fp8_enabled and fp8_layers is None: + claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set() + fp8_layers = sorted(all_layers - claimed_by_fp4) + if claimed_by_fp4: + logger.warning( + f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} " + f"are already claimed by FP4. Defaulting FP8 to the remaining layers: {fp8_layers}" + ) + else: + logger.info( + f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} layers to FP8" + ) + + if fp4_enabled and fp4_layers is None: + claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set() + fp4_layers = sorted(all_layers - claimed_by_fp8) + if claimed_by_fp8: + logger.warning( + f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} " + f"are already claimed by FP8. Defaulting FP4 to the remaining layers: {fp4_layers}" + ) + else: + logger.info( + f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} layers to FP4" + ) + + # Disable layer lists when corresponding config is not enabled. + if not fp8_enabled: + fp8_layers = None + if not fp4_enabled: + fp4_layers = None + + # Validate no overlap between FP8 and FP4 layer assignments. + if fp8_layers is not None and fp4_layers is not None: + overlap = set(fp8_layers) & set(fp4_layers) + if overlap: + raise ValueError( + f"fp8_layers and fp4_layers cannot have overlapping layer numbers. Found overlap: {sorted(overlap)}" + ) + + # Build per-layer precision list (0-indexed by position, 1-indexed for lookup). + fp8_set = set(fp8_layers) if fp8_layers is not None else set() + fp4_set = set(fp4_layers) if fp4_layers is not None else set() + return [ + "fp8" if layer_1indexed in fp8_set else "fp4" if layer_1indexed in fp4_set else None + for layer_1indexed in range(1, num_layers + 1) + ] diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py index 08330b12f7..bb7a2d8ed6 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py @@ -56,6 +56,8 @@ def pytest_collection_modifyitems(items): stats_test_names = { "test_sanity_ddp_fp8_stats_logging", "test_sanity_fsdp2_fp8_stats_logging", + "test_sanity_ddp_fp8_partial_layers_stats_logging", + "test_sanity_fsdp2_fp8_partial_layers_stats_logging", } stats_tests = [item for item in items if item.name in stats_test_names] other_tests = [item for item in items if item.name not in stats_test_names] diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index aebdfe17ef..d919278d4a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -34,7 +34,7 @@ def _make_args(logging_frequency=1, num_train_steps=100): "wandb": {"project": "test", "mode": "disabled"}, "num_train_steps": num_train_steps, "profiler": {"enabled": False}, - "fp8_stats_config": {"enabled": False}, + "quant_stats_config": {"enabled": False}, } ) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py new file mode 100644 index 0000000000..2d6e02b050 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py @@ -0,0 +1,332 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from pathlib import Path + +import pytest +import yaml + + +sys.path.append(Path(__file__).parent.parent.as_posix()) + +from quantization import generate_layer_regex, resolve_layer_precision, update_quant_stats_config + + +# -- resolve_layer_precision -- + + +def test_fp8_enabled_no_layers_defaults_all(): + """When fp8 is enabled with no explicit layers, all layers should default to FP8.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp8", "fp8", "fp8", "fp8", "fp8", "fp8"] + + +def test_fp4_enabled_no_layers_defaults_all(): + """When fp4 is enabled with no explicit layers, all layers should default to FP4.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp4", "fp4", "fp4", "fp4", "fp4", "fp4"] + + +def test_fp8_explicit_layers(): + """Explicit 1-indexed fp8_layers should produce fp8 at those positions.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None + ) + assert result == ["fp8", None, "fp8", None, "fp8", None] + + +def test_fp4_explicit_layers(): + """Explicit 1-indexed fp4_layers should produce fp4 at those positions.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6] + ) + assert result == [None, "fp4", None, "fp4", None, "fp4"] + + +def test_mixed_fp8_fp4_explicit(): + """Both enabled with explicit non-overlapping layers should work correctly.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5] + ) + assert result == ["fp8", "fp4", "fp8", "fp8", "fp4", None] + + +def test_both_enabled_no_layers_raises(): + """Both enabled with no layer lists should raise ValueError.""" + with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"): + resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None) + + +def test_overlapping_layers_raises(): + """Overlapping layer assignments should raise ValueError.""" + with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"): + resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5] + ) + + +def test_disabled_ignores_layers(): + """When a format is disabled, its layers should be ignored.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6] + ) + assert result == [None, None, None, None, None, None] + + +def test_both_disabled(): + """Both disabled with no layers should return all None.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == [None, None, None, None, None, None] + + +def test_large_model_defaults_all(): + """Auto-population should work correctly for larger models (e.g. 36 layers).""" + result = resolve_layer_precision( + num_layers=36, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp8"] * 36 + + +def test_fp8_enabled_empty_list(): + """An explicit empty list should remain empty (not default to all).""" + result = resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None) + assert result == [None, None, None, None, None, None] + + +def test_both_enabled_fp8_specified_fp4_defaults_to_remaining(): + """When both enabled, FP8 has explicit layers, FP4 should default to the remaining layers.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None + ) + assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_both_enabled_fp4_specified_fp8_defaults_to_remaining(): + """When both enabled, FP4 has explicit layers, FP8 should default to the remaining layers.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=[4, 5, 6] + ) + assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_returns_correct_length(): + """Result list length should always equal num_layers.""" + for n in [1, 6, 48]: + result = resolve_layer_precision( + num_layers=n, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert len(result) == n + + +# -- generate_layer_regex -- + + +def test_single_layer(): + """Single layer should produce a simple regex.""" + regex = generate_layer_regex([3]) + assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv") + assert not re.search(regex, "model.model.layers.2.self_attention.layernorm_qkv") + + +def test_multiple_layers(): + """Multiple layers should match any of them.""" + regex = generate_layer_regex([1, 2, 3]) + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.model.layers.2.layernorm_mlp.fc1") + assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.model.layers.4.self_attention.proj") + + +def test_matches_correct_sublayers(): + """Regex should only match layernorm_qkv, proj, fc1, fc2.""" + regex = generate_layer_regex([1]) + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv_something") + assert re.search(regex, "model.model.layers.1.self_attention.proj_something") + assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc1_something") + assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc2_something") + # Should not match unrelated sublayer names + assert not re.search(regex, "model.model.layers.1.self_attention.some_other_thing") + + +def test_none_returns_disabled_pattern(): + """None should return a pattern that matches nothing.""" + regex = generate_layer_regex(None) + assert "DISABLED" in regex + assert not re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + + +def test_empty_list_returns_disabled_pattern(): + """Empty list should return a pattern that matches nothing.""" + regex = generate_layer_regex([]) + assert "DISABLED" in regex + + +def test_1indexed_layer_names(): + """Regex should use 1-indexed layer numbers (matching debug API naming).""" + regex = generate_layer_regex([1]) + # Should match layers.1 (1-indexed first layer) + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + # Should NOT match layers.0 (0-indexed first layer) + assert not re.search(regex, "model.model.layers.0.self_attention.layernorm_qkv") + + +# -- update_quant_stats_config -- + + +@pytest.fixture +def fp8_only_config(tmp_path): + """Create an FP8-only stats config file.""" + config = { + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": { + "enabled": True, + "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}], + } + }, + } + } + config_path = tmp_path / "fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + +@pytest.fixture +def fp4_fp8_config(tmp_path): + """Create a combined FP4+FP8 stats config file.""" + config = { + "example_fp4_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogNvfp4TensorStats": {"enabled": True}, + }, + }, + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": {"enabled": True}, + }, + }, + } + config_path = tmp_path / "fp4_fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + +def test_fp8_layers_updates_regex(fp8_only_config): + """FP8 layer list should update the regex in the output config.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3]) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.model.layers.4.self_attention.proj") + + +def test_none_layers_disables_matching(fp8_only_config): + """None layers should set regex to match nothing.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=None) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert "DISABLED" in regex + + +def test_fp4_section_disabled_fp8_still_updated(fp4_fp8_config): + """FP4 stats section should be disabled (not yet supported), FP8 should still be updated.""" + output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6]) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 regex should still match layers 4-6 + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj") + + +def test_original_file_not_modified(fp8_only_config): + """update_quant_stats_config should write to a temp file, not modify the original.""" + with open(fp8_only_config) as f: + original_content = f.read() + + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2]) + + assert output_path != fp8_only_config + with open(fp8_only_config) as f: + assert f.read() == original_content + + +def test_preserves_other_config_fields(fp8_only_config): + """Non-layer fields in the config should be preserved.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1]) + with open(output_path) as f: + result = yaml.safe_load(f) + # The transformer_engine section should still be there + assert result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True + + +def test_missing_section_is_skipped(fp8_only_config): + """If fp4 section doesn't exist in config, it should be silently skipped.""" + # fp8_only_config has no fp4 section -- passing fp4_layers should not error + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4]) + with open(output_path) as f: + result = yaml.safe_load(f) + # Only FP8 section should exist and be updated + assert "example_fp4_tensor_stat_collection" not in result + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv") + + +def test_with_real_fp4_config(): + """Test with the actual fp4_debugging_stats.yaml file.""" + config_path = Path(__file__).parent.parent / "fp4_debugging_stats.yaml" + if not config_path.exists(): + pytest.skip("fp4_debugging_stats.yaml not found") + + output_path = update_quant_stats_config(config_file=str(config_path), fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6]) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled (not yet supported in current TE release) + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 section should still be updated and working + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj") diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index 89e85068de..be8bd48fe3 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -452,8 +452,8 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): f"checkpoint.ckpt_dir={tmp_path}", "+dataset.pad_sequences_to_be_divisible_by=16", "fp8_config.enabled=true", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) @@ -493,8 +493,8 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", "+dataset.pad_sequences_to_be_divisible_by=16", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) @@ -507,6 +507,65 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() +@requires_fp8 +def test_sanity_ddp_fp8_partial_layers_stats_logging(tmp_path, recipe_path): + """Test DDP training with layer-wise FP8 stats (layers 1-3 only).""" + quant_log_dir = tmp_path / "quant_stats_logs" + + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "+dataset.pad_sequences_to_be_divisible_by=16", + "fp8_config.enabled=true", + "fp8_layers=[1,2,3]", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={quant_log_dir}", + "num_train_steps=4", + ], + ) + + main_ddp(sanity_config) + + # Verify the log directory structure was created + assert quant_log_dir.exists(), "Quant log directory was not created" + assert (quant_log_dir / "rank_0").exists(), "rank_0 directory was not created" + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs").exists(), "nvdlfw_inspect_logs directory was not created" + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs").exists(), ( + "nvdlfw_inspect_statistics_logs directory was not created" + ) + + +@requires_fp8 +def test_sanity_fsdp2_fp8_partial_layers_stats_logging(tmp_path, recipe_path): + """Test FSDP2 training with layer-wise FP8 stats (layers 1-3 only).""" + quant_log_dir = tmp_path / "quant_stats_logs" + + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "+dataset.pad_sequences_to_be_divisible_by=16", + "fp8_config.enabled=true", + "fp8_layers=[1,2,3]", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={quant_log_dir}", + "num_train_steps=4", + ], + ) + + main_fsdp2(sanity_config) + + # Verify log structure + assert quant_log_dir.exists() + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + + def run_train_cmd(cmd, recipe_path): """Run a training command and check for errors. diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 0a25c02940..4ae9c751aa 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -43,9 +43,9 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_layer_precision from scheduler import get_cosine_annealing_schedule_with_warmup @@ -67,18 +67,12 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - if args.fp8_stats_config.enabled: - initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.") # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2. device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) - # --- Model Configuration --- - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) - if args.use_te: config_class = NVLlamaConfig model_class = NVLlamaForCausalLM @@ -86,9 +80,40 @@ def main(args: DictConfig) -> float | None: config_class = LlamaConfig model_class = LlamaForCausalLM - # --- Model Initialization --- + # --- Model Configuration --- config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + # Resolve layer-wise quantization assignments and store on config. + layer_precision = resolve_layer_precision( + num_layers=config.num_hidden_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + config.layer_precision = layer_precision + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + ) + + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + # --- Model Initialization --- # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 # and fp8 versions of weights are kept. @@ -99,8 +124,12 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # Attach quantization recipes to the model (layer precision is already on config). + if isinstance(model, NVLlamaForCausalLM): + model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + # --- Distributed Wrapping (DDP) --- - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) model = model.to(device=device) @@ -161,9 +190,8 @@ def main(args: DictConfig) -> float | None: micro_step += 1 # DDP requires no_sync to skip all-reduce until the last microbatch in the accumulation window. with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext(): - # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + # Forward pass - quantization autocast is handled inside the model via set_recipes(). + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 4d88f2e0c0..3588c207f1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -34,7 +34,7 @@ import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.optim import AdamW from transformer_engine.common.recipe import Format from transformers.models.llama.configuration_llama import LlamaConfig @@ -49,9 +49,9 @@ ) from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_layer_precision from scheduler import get_cosine_annealing_schedule_with_warmup @@ -73,17 +73,8 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled: - initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) - device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) - # --- Model Configuration --- - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) - if args.use_te: config_class = NVLlamaConfig model_class = NVLlamaForCausalLM @@ -91,9 +82,44 @@ def main(args: DictConfig) -> float | None: config_class = LlamaConfig model_class = LlamaForCausalLM - # --- Model Initialization --- - config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + # --- Model Configuration --- + config = config_class.from_pretrained( + args.config_name_or_path, + dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16, + **args.config_kwargs, + ) + + # Resolve layer-wise quantization assignments and store on config. + layer_precision = resolve_layer_precision( + num_layers=config.num_hidden_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + config.layer_precision = layer_precision + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + ) + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + # --- Model Initialization --- # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 # and fp8 versions of weights are kept. @@ -108,10 +134,24 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # --- Distributed Wrapping (FSDP2) --- + if args.use_fp32_master_weights: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=False, + ) + else: + mp_policy = MixedPrecisionPolicy() + # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + + # Attach quantization recipes to the model (layer precision is already on config). + if isinstance(model, NVLlamaForCausalLM): + model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. if args.use_meta_device and isinstance(model, NVLlamaForCausalLM): @@ -123,7 +163,7 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # --- Optimizer & Scheduler --- @@ -175,9 +215,8 @@ def main(args: DictConfig) -> float | None: micro_step += 1 - # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + # Forward pass - quantization autocast is handled inside the model via set_recipes(). + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 06fb6630ba..d67a0b05cb 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -29,12 +29,13 @@ from pathlib import Path import hydra +import nvdlfw_inspect.api as debug_api import nvtx import torch import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.optim import AdamW from transformer_engine.common.recipe import Format @@ -50,6 +51,7 @@ from distributed_config import DistributedConfig from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_layer_precision from scheduler import get_cosine_annealing_schedule_with_warmup @@ -79,13 +81,43 @@ def main(args: DictConfig) -> float | None: logger.info("Created device mesh: %s", device_mesh) # --- Model Configuration --- - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + config = NVLlamaConfig.from_pretrained( + args.config_name_or_path, + dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16, + **args.config_kwargs, ) - # --- Model Initialization --- - config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + # Resolve layer-wise quantization assignments and store on config. + layer_precision = resolve_layer_precision( + num_layers=config.num_hidden_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + config.layer_precision = layer_precision + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + ) + + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + # --- Model Initialization --- # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 # and fp8 versions of weights are kept. @@ -102,11 +134,21 @@ def main(args: DictConfig) -> float | None: # --- Distributed Wrapping (FSDP2 + CP) --- cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") + if args.use_fp32_master_weights: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=False, + ) + else: + mp_policy = MixedPrecisionPolicy() + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: - fully_shard(layer, mesh=cp_dp_mesh) - fully_shard(model, mesh=cp_dp_mesh) + fully_shard(layer, mesh=cp_dp_mesh, mp_policy=mp_policy) + fully_shard(model, mesh=cp_dp_mesh, mp_policy=mp_policy) # Attach the CP group to the model. for layer in model.model.layers: @@ -116,10 +158,17 @@ def main(args: DictConfig) -> float | None: torch.cuda.Stream(), ) + # Attach quantization recipes to the model (layer precision is already on config). + model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + if args.use_meta_device: # TE layers require special handling to initialize the weights from the meta device. model.init_empty_weights() + # Assign names to layers so debug API can identify them + if args.quant_stats_config.enabled: + debug_api.infer_and_assign_layer_names(model) + # --- Optimizer & Scheduler --- # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore @@ -193,10 +242,9 @@ def main(args: DictConfig) -> float | None: micro_step += 1 - # Forward pass with mixed precision. + # Forward pass - quantization autocast is handled inside the model via set_recipes(). with nvtx.annotate("Forward pass", color="green"): - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps diff --git a/docs/docs/assets/images/recipes/lingua-1b-loss-curve.png b/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png similarity index 100% rename from docs/docs/assets/images/recipes/lingua-1b-loss-curve.png rename to docs/docs/assets/images/llama3/lingua-1b-loss-curve.png diff --git a/docs/docs/assets/images/recipes/lingua-1b-step-time.png b/docs/docs/assets/images/llama3/lingua-1b-step-time.png similarity index 100% rename from docs/docs/assets/images/recipes/lingua-1b-step-time.png rename to docs/docs/assets/images/llama3/lingua-1b-step-time.png diff --git a/docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png b/docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png new file mode 100644 index 0000000000..daed0ce0ba Binary files /dev/null and b/docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png differ diff --git a/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png b/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png new file mode 100644 index 0000000000..a304fa29b7 Binary files /dev/null and b/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png differ