diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py
index 033eb5ebe3..4e859b4868 100644
--- a/bionemo-recipes/models/llama3/modeling_llama_te.py
+++ b/bionemo-recipes/models/llama3/modeling_llama_te.py
@@ -17,10 +17,12 @@
import warnings
from collections import OrderedDict
+from contextlib import nullcontext
from typing import ClassVar, Unpack
import torch
import torch.nn as nn
+import transformer_engine.common.recipe
import transformer_engine.pytorch
import transformers
from transformer_engine.pytorch.attention import InferenceParams
@@ -50,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ layer_precision: list[str | None] | None = None
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -159,11 +162,54 @@ def _init_method(x):
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
+ self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
+ self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
+
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
+ def set_recipes(
+ self,
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ ) -> None:
+ """Attach quantization recipe objects for per-layer autocast.
+
+ Recipes are not serializable and must be set at runtime after model creation
+ and sharding (FSDP/DDP) but before training. The per-layer precision
+ assignments are read from ``self.config.layer_precision``.
+
+ Args:
+ fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
+ fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
+ """
+ self._fp8_recipe = fp8_recipe
+ self._fp4_recipe = fp4_recipe
+
+ def get_layer_autocast(self, layer_number: int):
+ """Return the appropriate TE autocast context manager for a given layer.
+
+ The context interacts with the outer FP8 autocast in the training script:
+ - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
+ - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
+ - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
+
+ Args:
+ layer_number: The 0-indexed layer number.
+
+ Returns:
+ A context manager for the layer's quantization mode.
+ """
+ precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
+ if precision == "fp8":
+ return nullcontext()
+ elif precision == "fp4":
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
+ else:
+ return transformer_engine.pytorch.autocast(enabled=False)
+
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -240,23 +286,27 @@ def forward(
if te_rope_emb.dtype == torch.float32:
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- if output_hidden_states:
- all_hidden_states = (*all_hidden_states, hidden_states)
-
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
- rotary_pos_emb=te_rope_emb,
- inference_params=past_key_values,
- cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
- cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
- cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
- cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
- max_seqlen_q=kwargs.get("max_length_q", None),
- max_seqlen_kv=kwargs.get("max_length_k", None),
- pad_between_seqs=kwargs.get("pad_between_seqs", None),
- )
+ # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
+ # by get_layer_autocast(), which nests inside this context.
+ with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
+ for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ if output_hidden_states:
+ all_hidden_states = (*all_hidden_states, hidden_states)
+
+ with self.get_layer_autocast(layer_number):
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
+ rotary_pos_emb=te_rope_emb,
+ inference_params=past_key_values,
+ cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
+ cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
+ cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
+ cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
+ max_seqlen_q=kwargs.get("max_length_q", None),
+ max_seqlen_kv=kwargs.get("max_length_k", None),
+ pad_between_seqs=kwargs.get("pad_between_seqs", None),
+ )
hidden_states = self.norm(hidden_states)
diff --git a/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
new file mode 100644
index 0000000000..eb93415d50
--- /dev/null
+++ b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
@@ -0,0 +1,243 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pickle
+import subprocess
+
+import pytest
+import torch
+from transformer_engine.pytorch.fp8 import check_fp8_support
+
+
+def requires_fp8(func):
+ """Decorator to skip tests that require FP8 support."""
+ fp8_available, reason = check_fp8_support()
+ return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func)
+
+
+requires_multi_gpu = pytest.mark.skipif(
+ not torch.cuda.is_available() or torch.cuda.device_count() < 2,
+ reason="Test requires at least 2 GPUs",
+)
+
+
+@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
+@requires_fp8
+def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
+ cmd = [
+ "torchrun",
+ "--nproc_per_node=1",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{unused_tcp_port}",
+ os.path.relpath(__file__),
+ "--strategy",
+ strategy,
+ ]
+
+ result = subprocess.run(
+ cmd,
+ check=False,
+ text=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ timeout=240,
+ )
+ if result.returncode != 0:
+ print(f"STDOUT:\n{result.stdout}")
+ print(f"STDERR:\n{result.stderr}")
+ pytest.fail(f"Command failed with exit code {result.returncode}")
+
+
+@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
+@requires_fp8
+@requires_multi_gpu
+def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
+ cmd = [
+ "torchrun",
+ "--nproc_per_node=2",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{unused_tcp_port}",
+ os.path.relpath(__file__),
+ "--strategy",
+ strategy,
+ ]
+
+ result = subprocess.run(
+ cmd,
+ check=False,
+ text=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ timeout=240,
+ )
+ if result.returncode != 0:
+ print(f"STDOUT:\n{result.stdout}")
+ print(f"STDERR:\n{result.stderr}")
+ pytest.fail(f"Command failed with exit code {result.returncode}")
+
+
+if __name__ == "__main__":
+ import argparse
+ import enum
+ import os
+ import sys
+ from dataclasses import dataclass, field
+ from pathlib import Path
+
+ # Ensure the model directory is on sys.path for bare module imports.
+ sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix())
+
+ import torch.distributed as dist
+ from torch.distributed.device_mesh import init_device_mesh
+ from torch.distributed.fsdp import fully_shard
+ from torch.optim import AdamW
+ from transformer_engine.pytorch.fp8 import DelayedScaling, Format
+
+ from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
+
+ def recursive_assert(a, b, path=""):
+ if isinstance(a, dict) and isinstance(b, dict):
+ assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}"
+ for k in a:
+ recursive_assert(a[k], b[k], path=f"{path}.{k}")
+ elif isinstance(a, list) and isinstance(b, list):
+ assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}"
+ for i in range(len(a)):
+ recursive_assert(a[i], b[i], path=f"{path}.{i}")
+ elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
+ torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}")
+ else:
+ assert a == b, f"Value mismatch at {path}: {a} != {b}"
+
+ class Strategy(enum.StrEnum):
+ DDP = "ddp"
+ FSDP2 = "fsdp2"
+
+ @dataclass
+ class DistributedConfig:
+ """Class to track distributed ranks."""
+
+ rank: int = field(default_factory=dist.get_rank)
+ local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"]))
+ world_size: int = field(default_factory=dist.get_world_size)
+
+ def is_main_process(self) -> bool:
+ """This is the global rank 0 process, to be used for wandb logging, etc."""
+ return self.rank == 0
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP])
+ args = parser.parse_args()
+
+ torch.distributed.init_process_group(backend="nccl")
+ dist_config = DistributedConfig()
+ torch.cuda.set_device(dist_config.local_rank)
+ device_mesh = init_device_mesh(
+ "cuda",
+ mesh_shape=(dist_config.world_size, 1),
+ mesh_dim_names=("dp", "tp"),
+ )
+ device = f"cuda:{dist_config.local_rank}"
+
+ fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
+
+ config = NVLlamaConfig(
+ hidden_size=256,
+ intermediate_size=512,
+ num_hidden_layers=6,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ vocab_size=100,
+ dtype=torch.bfloat16,
+ )
+ config.layer_precision = ["fp8"] * config.num_hidden_layers
+ model = NVLlamaForCausalLM(config)
+
+ if args.strategy is Strategy.FSDP2:
+ for layer in model.model.layers:
+ fully_shard(layer, mesh=device_mesh["dp"])
+ fully_shard(model, mesh=device_mesh["dp"])
+ model.to(device)
+
+ elif args.strategy is Strategy.DDP:
+ model.to(device)
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[dist_config.local_rank],
+ output_device=dist_config.local_rank,
+ device_mesh=device_mesh["dp"],
+ )
+
+ optimizer = AdamW(model.parameters())
+
+ # Attach FP8 recipes to the model (layer precision is already on config).
+ llama_model = model.module.model if args.strategy is Strategy.DDP else model.model
+ llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
+
+ model.train()
+
+ generator = torch.Generator()
+ generator.manual_seed(torch.distributed.get_rank())
+
+ for _ in range(3):
+ input_data = {
+ "input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
+ "labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
+ "attention_mask": torch.ones(1, 32),
+ }
+ input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()}
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ outputs = model(**input_data)
+
+ outputs.loss.backward()
+
+ # Access FP8 extra states directly from modules instead of state_dict()
+ # since state_dict() now filters them out for HuggingFace compatibility
+ fp8_extra_states = {}
+ for name, module in model.named_modules():
+ if hasattr(module, "_extra_state") and callable(module._extra_state):
+ extra_state = module._extra_state()
+ if extra_state is not None and len(extra_state) > 0:
+ fp8_extra_states[f"{name}._extra_state"] = extra_state
+
+ # lm_head is BF16, not FP8, so exclude it from FP8 checks
+ fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key}
+
+ # 2 ranks, test to ensure that both ranks have the same FP8 extra states
+ if torch.distributed.get_world_size() == 2:
+ outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None
+ torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0)
+ if torch.distributed.get_rank() == 0:
+ assert outputs_list is not None
+
+ for key in outputs_list[0]:
+ state_1 = outputs_list[0][key]
+ state_2 = outputs_list[1][key]
+ assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0"
+ assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1"
+ dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes())
+ dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes())
+ recursive_assert(dict_1, dict_2)
+
+ # One rank, test to ensure the correct FP8 extra states are saved
+ if torch.distributed.get_world_size() == 1:
+ for key, val in fp8_extra_states.items():
+ assert len(val) > 0, f"No FP8 extra states for {key}"
+ fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes())
+ assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}"
+
+ torch.distributed.destroy_process_group()
diff --git a/bionemo-recipes/models/llama3/tests/test_layer_quantization.py b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py
new file mode 100644
index 0000000000..a80ff80f2c
--- /dev/null
+++ b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py
@@ -0,0 +1,180 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for NVLlamaModel.set_recipes and get_layer_autocast."""
+
+from contextlib import nullcontext
+from unittest.mock import patch
+
+import pytest
+import transformer_engine.common.recipe
+import transformer_engine.pytorch
+
+from modeling_llama_te import NVLlamaConfig, NVLlamaModel
+
+
+@pytest.fixture
+def model():
+ """Create a small NVLlamaModel for testing."""
+ config = NVLlamaConfig(
+ hidden_size=256,
+ intermediate_size=512,
+ num_hidden_layers=6,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ vocab_size=100,
+ )
+ return NVLlamaModel(config)
+
+
+# -- set_recipes --
+
+
+def test_all_fp8(model):
+ model.config.layer_precision = ["fp8"] * 6
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
+ assert model._fp8_recipe is fp8_recipe
+ assert model._fp4_recipe is None
+ assert all(p == "fp8" for p in model.config.layer_precision)
+
+
+def test_all_fp4(model):
+ model.config.layer_precision = ["fp4"] * 6
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
+ assert model._fp8_recipe is None
+ assert model._fp4_recipe is fp4_recipe
+ assert all(p == "fp4" for p in model.config.layer_precision)
+
+
+def test_all_bf16(model):
+ model.config.layer_precision = [None] * 6
+ model.set_recipes(fp8_recipe=None, fp4_recipe=None)
+ assert all(p is None for p in model.config.layer_precision)
+
+
+def test_mixed_fp8_fp4(model):
+ model.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ assert model.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+
+
+def test_mixed_fp8_bf16(model):
+ model.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
+ assert model.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None]
+
+
+def test_mixed_all_three(model):
+ model.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ assert model.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"]
+
+
+def test_covers_all_layers(model):
+ model.config.layer_precision = ["fp8"] + [None] * 5
+ model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
+ assert len(model.config.layer_precision) == 6
+
+
+def test_recipes_stored_as_attributes(model):
+ model.config.layer_precision = ["fp8", "fp4", None, None, None, None]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ assert model._fp8_recipe is fp8_recipe
+ assert model._fp4_recipe is fp4_recipe
+ # The precision list only contains strings/None, not recipe objects.
+ for v in model.config.layer_precision:
+ assert v is None or isinstance(v, str)
+
+
+# -- get_layer_autocast --
+
+
+def test_fp8_layer_returns_nullcontext(model):
+ model.config.layer_precision = ["fp8"] + [None] * 5
+ model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
+ ctx = model.get_layer_autocast(0)
+ assert isinstance(ctx, nullcontext)
+
+
+def test_fp4_layer_returns_te_autocast(model):
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.config.layer_precision = ["fp4"] + [None] * 5
+ model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "fp4_context"
+ ctx = model.get_layer_autocast(0)
+ mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe)
+ assert ctx == "fp4_context"
+
+
+def test_bf16_layer_returns_te_autocast_disabled(model):
+ model.config.layer_precision = [None] * 6
+ model.set_recipes(fp8_recipe=None, fp4_recipe=None)
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "bf16_context"
+ ctx = model.get_layer_autocast(0)
+ mock_autocast.assert_called_once_with(enabled=False)
+ assert ctx == "bf16_context"
+
+
+def test_uninitialized_defaults_to_bf16(model):
+ """When layer_precision is None (default), all layers default to BF16."""
+ assert model.config.layer_precision is None
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "bf16_context"
+ ctx = model.get_layer_autocast(0)
+ mock_autocast.assert_called_once_with(enabled=False)
+ assert ctx == "bf16_context"
+
+
+def test_mixed_layers_return_correct_contexts(model):
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
+ # FP8 layers -> nullcontext
+ assert isinstance(model.get_layer_autocast(0), nullcontext)
+ assert isinstance(model.get_layer_autocast(1), nullcontext)
+
+ # FP4 layers -> te.pytorch.autocast
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "fp4_context"
+ model.get_layer_autocast(2)
+ mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe)
+
+ # BF16 layers -> te.pytorch.autocast(enabled=False)
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "bf16_context"
+ model.get_layer_autocast(4)
+ mock_autocast.assert_called_with(enabled=False)
+
+
+def test_layer_precision_is_pickleable(model):
+ """The config.layer_precision list should be trivially pickleable."""
+ import pickle
+
+ model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
+ roundtripped = pickle.loads(pickle.dumps(model.config.layer_precision))
+ assert roundtripped == model.config.layer_precision
diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md
index 7f593157ab..1042855cc9 100644
--- a/bionemo-recipes/recipes/llama3_native_te/README.md
+++ b/bionemo-recipes/recipes/llama3_native_te/README.md
@@ -1,8 +1,8 @@
# TransformerEngine-accelerated Llama 3 training with native PyTorch training loop
This folder demonstrates how to train TE-accelerated Llama 3 with a native PyTorch training loop, including sequence
-packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. This recipe is configured
-for genomic sequences using a custom nucleotide tokenizer.
+packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed
+training. This recipe is configured for genomic sequences using a custom nucleotide tokenizer.
## How to use this recipe
@@ -16,9 +16,9 @@ bionemo-framework repository. You can download a zipped directory of this folder
## Supported Models and Training Features
-| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism |
-| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
-| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
+| Model | BF16 | FP8[1] | MXFP8[2] | NVFP4[3] | THD Input Format | Context Parallelism | Tensor Parallelism |
+| ---------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- | ------------------ |
+| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
✅: Supported
🚧: Under development
@@ -26,6 +26,7 @@ bionemo-framework repository. You can download a zipped directory of this folder
\[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+)
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending
+\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+)
### Installing Dependencies
@@ -64,11 +65,16 @@ def compute_model_pflops(seq_len, global_batch_size, step_time_s):
return model_flops / 1e15
```
+### Low precision performance benchmarks
+
+
+In the above plot we can see the performance increases as we lower the precision of our transformer layers across the 1B and 8B variant of LLAMA3.
+
### Convergence Benchmarks
-
-
+
+
We compared the convergence of this Llama3 recipe (with FSDP2) against NeMo 2.0
@@ -88,6 +94,10 @@ are due checkpointing, further work will be done to improve training step time s
Models were trained on 64 NVIDIA H100 GPUs with a micro batch size of 4 and a context length of 4096 for 60,000 steps.
Training was performed with BF16 precision.
+### Low Precision convergence benchmarks
+
+
+
### Distributed Training
This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Context Parallelism, shown in three separate training entrypoints:
@@ -127,35 +137,71 @@ batch size while running on a smaller number of GPUs.
python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2
```
-### FP8 Training
+### Quantized Training (FP8 / MXFP8 / NVFP4)
To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8
-configuration parameters, including switching to `MXFP8BlockScaling`, can be set via the hydra configuration.
+configuration parameters, including switching to `MXFP8BlockScaling`, can be set using the hydra configuration.
```bash
python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true
```
-#### FP8 Debugging
+Similarly, to train with NVFP4 quantization:
+
+```bash
+python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true
+```
+
+Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet
+supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today.
+
+Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration.
+
+#### Layer-Wise Precision
+
+You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and
+`fp4_layers`. Layers not assigned to either format will run in BF16.
+
+For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers:
+
+```bash
+python train_fsdp2.py --config-name L0_sanity \
+ fp8_config.enabled=true \
+ fp4_config.enabled=true \
+ 'fp8_layers=[1,2,3]' \
+ 'fp4_layers=[4,5,6]'
+```
+
+When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically
+claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no
+`fp4_layers`, then layers 4 through N will default to FP4.
+
+#### Quantization Stats Debugging
-We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients.
+We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training.
+When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are
+tracked.
-To enable this please select the following config options.
+To enable stats logging:
```bash
python train_fsdp2.py \
- fp8_stats_config.enabled=True \
- fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \
- fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \
- fp8_config.enabled=True
+ quant_stats_config.enabled=true \
+ quant_stats_config.quant_log_dir=./logs/quant_stats \
+ quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \
+ fp8_config.enabled=true
```
-Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts.
+Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet
+supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today.
-The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure.
+The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the
+[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html)
+in more detail.
-This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our
-experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit.
+Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats
+on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We
+recommend using `freq>=10` to reduce this performance hit.
### Sequence Packing (THD input format)
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml
new file mode 100644
index 0000000000..9046d44caf
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml
@@ -0,0 +1,33 @@
+example_fp4_tensor_stat_collection:
+ enabled: True
+ layers:
+ # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming)
+ # This matches: model.model.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2)
+ layer_name_regex_pattern: 'model\.model\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)'
+ transformer_engine:
+ LogNvfp4TensorStats:
+ enabled: True
+ tensors_struct:
+ - tensor: activation
+ stats: [underflows%, mse]
+ freq: 100
+ - tensor: gradient
+ stats: [underflows%, mse]
+ freq: 100
+
+example_fp8_tensor_stat_collection:
+ enabled: True
+ layers:
+ # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming)
+ # This matches: model.model.layers.[6-10].*.(layernorm_qkv|proj|fc1|fc2)
+ layer_name_regex_pattern: 'model\.model\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)'
+ transformer_engine:
+ LogFp8TensorStats:
+ enabled: True
+ tensors_struct:
+ - tensor: activation
+ stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
+ freq: 100
+ - tensor: gradient
+ stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
+ freq: 100
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
deleted file mode 100644
index d01024f04c..0000000000
--- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-from pathlib import Path
-
-import nvdlfw_inspect.api as debug_api
-import transformer_engine
-
-from distributed_config import DistributedConfig
-
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-def initialize_fp8_debugging(
- dist_config: DistributedConfig,
- enabled: bool,
- fp8_stats_file: str,
- fp8_log_dir: str | os.PathLike,
- fp8_enabled: bool,
-) -> None:
- """Initialize FP8 debugging.
-
- Args:
- dist_config: The distributed configuration.
- enabled: Whether to enable FP8 debugging.
- fp8_stats_file: The file containing the FP8 stats.
- fp8_log_dir: The directory to log the FP8 stats to.
- fp8_enabled: Whether FP8 autocast is enabled.
- """
- if not enabled:
- return
-
- if not fp8_enabled:
- raise ValueError(
- "fp8_stats_config.enabled is true but fp8_config.enabled is false, "
- "please set fp8_config.enabled to true in the config if you wish to collect FP8 stats"
- )
-
- fp8_log_dir = Path(fp8_log_dir) / f"rank_{dist_config.rank}"
- fp8_log_dir.mkdir(parents=True, exist_ok=True)
- logger.info(f"Logging FP8 stats to {fp8_log_dir}")
- te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
- debug_api.initialize(
- config_file=fp8_stats_file,
- feature_dirs=[te_features_dir],
- log_dir=fp8_log_dir.as_posix(),
- default_logging_enabled=True,
- )
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
index 7544bbedcf..ba640a6cbb 100644
--- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
+++ b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
@@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
- layer_types: [layernorm_qkv]
+ layer_types: [layernorm_qkv, proj, fc1, fc2]
transformer_engine:
LogFp8TensorStats:
enabled: True
@@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection:
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 10
+ LogTensorStats:
+ enabled: True
+ stats: [max, min, mean, std, l1_norm]
+ tensors: [dgrad, wgrad]
+ freq: 1
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
index d6c181598f..be71c57e78 100644
--- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
@@ -44,6 +44,12 @@ fp8_config:
quantized_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.
+fp4_config:
+ enabled: false
+ fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
+ fp4_format: "E2M1"
+ fp4_recipe_kwargs: {}
+
# Optimizer config
adamw_kwargs:
lr: 3e-3
@@ -70,10 +76,15 @@ checkpoint:
logger:
frequency: 100
-fp8_stats_config:
+quant_stats_config:
enabled: false
- fp8_stats_file: ./fp8_debugging_stats.yaml
- fp8_log_dir: ./log_fp8_stats
+ quant_stats_file: ./fp8_debugging_stats.yaml
+ quant_log_dir: ./log_quant_stats
+
+# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
+fp8_layers: null
+fp4_layers: null
+use_fp32_master_weights: null
profiler:
enabled: false
diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
index 033eb5ebe3..4e859b4868 100644
--- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
+++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
@@ -17,10 +17,12 @@
import warnings
from collections import OrderedDict
+from contextlib import nullcontext
from typing import ClassVar, Unpack
import torch
import torch.nn as nn
+import transformer_engine.common.recipe
import transformer_engine.pytorch
import transformers
from transformer_engine.pytorch.attention import InferenceParams
@@ -50,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ layer_precision: list[str | None] | None = None
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -159,11 +162,54 @@ def _init_method(x):
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
+ self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
+ self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
+
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
+ def set_recipes(
+ self,
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ ) -> None:
+ """Attach quantization recipe objects for per-layer autocast.
+
+ Recipes are not serializable and must be set at runtime after model creation
+ and sharding (FSDP/DDP) but before training. The per-layer precision
+ assignments are read from ``self.config.layer_precision``.
+
+ Args:
+ fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
+ fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
+ """
+ self._fp8_recipe = fp8_recipe
+ self._fp4_recipe = fp4_recipe
+
+ def get_layer_autocast(self, layer_number: int):
+ """Return the appropriate TE autocast context manager for a given layer.
+
+ The context interacts with the outer FP8 autocast in the training script:
+ - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
+ - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
+ - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
+
+ Args:
+ layer_number: The 0-indexed layer number.
+
+ Returns:
+ A context manager for the layer's quantization mode.
+ """
+ precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
+ if precision == "fp8":
+ return nullcontext()
+ elif precision == "fp4":
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
+ else:
+ return transformer_engine.pytorch.autocast(enabled=False)
+
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -240,23 +286,27 @@ def forward(
if te_rope_emb.dtype == torch.float32:
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- if output_hidden_states:
- all_hidden_states = (*all_hidden_states, hidden_states)
-
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
- rotary_pos_emb=te_rope_emb,
- inference_params=past_key_values,
- cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
- cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
- cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
- cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
- max_seqlen_q=kwargs.get("max_length_q", None),
- max_seqlen_kv=kwargs.get("max_length_k", None),
- pad_between_seqs=kwargs.get("pad_between_seqs", None),
- )
+ # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
+ # by get_layer_autocast(), which nests inside this context.
+ with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
+ for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ if output_hidden_states:
+ all_hidden_states = (*all_hidden_states, hidden_states)
+
+ with self.get_layer_autocast(layer_number):
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
+ rotary_pos_emb=te_rope_emb,
+ inference_params=past_key_values,
+ cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
+ cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
+ cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
+ cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
+ max_seqlen_q=kwargs.get("max_length_q", None),
+ max_seqlen_kv=kwargs.get("max_length_k", None),
+ pad_between_seqs=kwargs.get("pad_between_seqs", None),
+ )
hidden_states = self.norm(hidden_states)
diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py
index 726eb19e8e..4b1a8d4ec7 100644
--- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py
+++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py
@@ -91,7 +91,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step:
self.grad_acc_step_count = 0
# Whether to step debug_api.step() after each step
- self.fp8_stats_enabled = args.fp8_stats_config.enabled
+ self.quant_stats_config = args.quant_stats_config.enabled
@nvtx.annotate("PerfLogger.log_micro_step", color="pink")
def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast):
@@ -150,7 +150,7 @@ def log_step(
if self._profiler is not None:
self._profiler.step(step)
- if self.fp8_stats_enabled:
+ if self.quant_stats_config:
debug_api.step()
if step % self.logging_frequency == 0 and step > 0:
@@ -201,15 +201,15 @@ def log_step(
def finish(self):
"""Finish the logger and close the progress bar."""
+ if self.quant_stats_config:
+ debug_api.end_debug()
+
if not self._dist_config.is_main_process():
return
wandb.finish()
self._progress_bar.close()
- if self.fp8_stats_enabled:
- debug_api.end_debug()
-
class NsightProfiler:
"""Nsight Systems profiler wrapper for performance analysis.
diff --git a/bionemo-recipes/recipes/llama3_native_te/quantization.py b/bionemo-recipes/recipes/llama3_native_te/quantization.py
new file mode 100644
index 0000000000..e479b13c02
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/quantization.py
@@ -0,0 +1,223 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for layer-wise quantization configuration (FP8/FP4)."""
+
+import logging
+import tempfile
+from pathlib import Path
+
+import yaml
+
+
+logger = logging.getLogger(__name__)
+
+
+def generate_layer_regex(layer_numbers: list[int] | None) -> str:
+ """Generate a regex pattern to match specific layer numbers (1-indexed).
+
+ The debug API (nvdlfw_inspect) uses 1-indexed layer names after ``infer_and_assign_layer_names``.
+
+ Args:
+ layer_numbers: List of layer numbers (1-indexed, as shown in debug logs).
+ If empty or None, returns a pattern that matches nothing.
+
+ Returns:
+ Regex pattern string for matching those layers' linear sublayers.
+ """
+ if not layer_numbers:
+ return r"model\.model\.layers\.DISABLED_NO_LAYERS_SPECIFIED"
+ layer_pattern = "|".join(str(n) for n in sorted(layer_numbers))
+ return rf"model\.model\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)"
+
+
+def update_quant_stats_config(
+ config_file: str,
+ fp4_layers: list[int] | None,
+ fp8_layers: list[int] | None,
+) -> str:
+ """Update the quant stats YAML config with layer-specific regex patterns.
+
+ Args:
+ config_file: Path to the original YAML config file.
+ fp4_layers: List of layer numbers for FP4 (1-indexed).
+ fp8_layers: List of layer numbers for FP8 (1-indexed).
+
+ Returns:
+ Path to the updated config file (a temp file).
+ """
+ with open(config_file, "r") as f:
+ config = yaml.safe_load(f)
+
+ if "example_fp4_tensor_stat_collection" in config:
+ # TODO: Remove this block and replace with FP8-style regex update once a TransformerEngine
+ # release with LogNvfp4TensorStats support is available. At that point, this becomes:
+ # fp4_regex = generate_layer_regex(fp4_layers)
+ # config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex
+ config["example_fp4_tensor_stat_collection"]["enabled"] = False
+ if fp4_layers:
+ logger.warning(
+ "NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). "
+ f"Disabling FP4 stats collection for layers {fp4_layers}. FP8 stats will still be collected."
+ )
+ else:
+ logger.info("FP4 stats section disabled (no FP4 layers and feature not yet supported)")
+
+ if "example_fp8_tensor_stat_collection" in config:
+ fp8_regex = generate_layer_regex(fp8_layers)
+ config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex
+ if fp8_layers:
+ logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}")
+ else:
+ logger.info("FP8 layers empty - regex set to match nothing")
+
+ temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
+ yaml.dump(config, temp_file, default_flow_style=False)
+ temp_file.close()
+
+ config_str = yaml.dump(config, default_flow_style=False)
+ logger.info(f"Created updated quant stats config at: {temp_file.name}")
+ logger.info(f"Updated quant stats config contents:\n{config_str}")
+
+ return temp_file.name
+
+
+def initialize_quant_stats_logging(
+ quant_stats_file: str,
+ quant_log_dir: str,
+ rank: int,
+ layer_precision: list[str | None],
+) -> None:
+ """Set up quantization stats logging via nvdlfw_inspect.
+
+ Updates the quant stats YAML config with resolved layer regex patterns, creates
+ the per-rank log directory, and initializes the debug API.
+
+ Args:
+ quant_stats_file: Path to the base quant stats YAML config file.
+ quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created).
+ rank: The global rank of this process.
+ layer_precision: Per-layer precision list (0-indexed by position). Each element is
+ ``"fp8"``, ``"fp4"``, or ``None``.
+ """
+ import nvdlfw_inspect.api as debug_api
+ import transformer_engine
+
+ # Derive 1-indexed layer lists for the debug API, which uses 1-indexed layer names.
+ fp8_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp8"] or None
+ fp4_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp4"] or None
+ updated_config = update_quant_stats_config(
+ config_file=quant_stats_file,
+ fp4_layers=fp4_layers_1indexed,
+ fp8_layers=fp8_layers_1indexed,
+ )
+
+ rank_log_dir = Path(quant_log_dir) / f"rank_{rank}"
+ rank_log_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Logging quant stats to {rank_log_dir}")
+
+ te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
+ debug_api.initialize(
+ config_file=updated_config,
+ feature_dirs=[te_features_dir],
+ log_dir=rank_log_dir,
+ default_logging_enabled=True,
+ )
+
+
+def resolve_layer_precision(
+ num_layers: int,
+ fp8_enabled: bool,
+ fp4_enabled: bool,
+ fp8_layers: list[int] | None,
+ fp4_layers: list[int] | None,
+) -> list[str | None]:
+ """Resolve layer-wise quantization assignments from user config.
+
+ Takes 1-indexed layer lists (as specified by the user in YAML config) and returns a per-layer
+ precision list (0-indexed by position). When a quantization format is enabled but no layer list
+ is provided, all layers default to that format. When one format has explicit layers and the other
+ is enabled without a layer list, the unspecified format defaults to the remaining (unclaimed) layers.
+
+ Args:
+ num_layers: Total number of transformer layers in the model.
+ fp8_enabled: Whether FP8 quantization is enabled.
+ fp4_enabled: Whether FP4 quantization is enabled.
+ fp8_layers: 1-indexed list of layers for FP8, or None if not specified.
+ fp4_layers: 1-indexed list of layers for FP4, or None if not specified.
+
+ Returns:
+ A list of length ``num_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None``
+ (BF16 fallback), indexed by layer position (0-indexed).
+
+ Raises:
+ ValueError: If both formats are enabled with no layer lists, or if layer lists overlap.
+ """
+ all_layers = set(range(1, num_layers + 1))
+
+ if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None:
+ raise ValueError(
+ "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. "
+ "When both are enabled, you must explicitly provide layer lists to indicate which layers use which format."
+ )
+
+ # When one format has explicit layers and the other defaults, fill in the remaining layers.
+ if fp8_enabled and fp8_layers is None:
+ claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set()
+ fp8_layers = sorted(all_layers - claimed_by_fp4)
+ if claimed_by_fp4:
+ logger.warning(
+ f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} "
+ f"are already claimed by FP4. Defaulting FP8 to the remaining layers: {fp8_layers}"
+ )
+ else:
+ logger.info(
+ f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} layers to FP8"
+ )
+
+ if fp4_enabled and fp4_layers is None:
+ claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set()
+ fp4_layers = sorted(all_layers - claimed_by_fp8)
+ if claimed_by_fp8:
+ logger.warning(
+ f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} "
+ f"are already claimed by FP8. Defaulting FP4 to the remaining layers: {fp4_layers}"
+ )
+ else:
+ logger.info(
+ f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} layers to FP4"
+ )
+
+ # Disable layer lists when corresponding config is not enabled.
+ if not fp8_enabled:
+ fp8_layers = None
+ if not fp4_enabled:
+ fp4_layers = None
+
+ # Validate no overlap between FP8 and FP4 layer assignments.
+ if fp8_layers is not None and fp4_layers is not None:
+ overlap = set(fp8_layers) & set(fp4_layers)
+ if overlap:
+ raise ValueError(
+ f"fp8_layers and fp4_layers cannot have overlapping layer numbers. Found overlap: {sorted(overlap)}"
+ )
+
+ # Build per-layer precision list (0-indexed by position, 1-indexed for lookup).
+ fp8_set = set(fp8_layers) if fp8_layers is not None else set()
+ fp4_set = set(fp4_layers) if fp4_layers is not None else set()
+ return [
+ "fp8" if layer_1indexed in fp8_set else "fp4" if layer_1indexed in fp4_set else None
+ for layer_1indexed in range(1, num_layers + 1)
+ ]
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
index 08330b12f7..bb7a2d8ed6 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
@@ -56,6 +56,8 @@ def pytest_collection_modifyitems(items):
stats_test_names = {
"test_sanity_ddp_fp8_stats_logging",
"test_sanity_fsdp2_fp8_stats_logging",
+ "test_sanity_ddp_fp8_partial_layers_stats_logging",
+ "test_sanity_fsdp2_fp8_partial_layers_stats_logging",
}
stats_tests = [item for item in items if item.name in stats_test_names]
other_tests = [item for item in items if item.name not in stats_test_names]
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
index aebdfe17ef..d919278d4a 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
@@ -34,7 +34,7 @@ def _make_args(logging_frequency=1, num_train_steps=100):
"wandb": {"project": "test", "mode": "disabled"},
"num_train_steps": num_train_steps,
"profiler": {"enabled": False},
- "fp8_stats_config": {"enabled": False},
+ "quant_stats_config": {"enabled": False},
}
)
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py
new file mode 100644
index 0000000000..2d6e02b050
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py
@@ -0,0 +1,332 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import sys
+from pathlib import Path
+
+import pytest
+import yaml
+
+
+sys.path.append(Path(__file__).parent.parent.as_posix())
+
+from quantization import generate_layer_regex, resolve_layer_precision, update_quant_stats_config
+
+
+# -- resolve_layer_precision --
+
+
+def test_fp8_enabled_no_layers_defaults_all():
+ """When fp8 is enabled with no explicit layers, all layers should default to FP8."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert result == ["fp8", "fp8", "fp8", "fp8", "fp8", "fp8"]
+
+
+def test_fp4_enabled_no_layers_defaults_all():
+ """When fp4 is enabled with no explicit layers, all layers should default to FP4."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None
+ )
+ assert result == ["fp4", "fp4", "fp4", "fp4", "fp4", "fp4"]
+
+
+def test_fp8_explicit_layers():
+ """Explicit 1-indexed fp8_layers should produce fp8 at those positions."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None
+ )
+ assert result == ["fp8", None, "fp8", None, "fp8", None]
+
+
+def test_fp4_explicit_layers():
+ """Explicit 1-indexed fp4_layers should produce fp4 at those positions."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6]
+ )
+ assert result == [None, "fp4", None, "fp4", None, "fp4"]
+
+
+def test_mixed_fp8_fp4_explicit():
+ """Both enabled with explicit non-overlapping layers should work correctly."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5]
+ )
+ assert result == ["fp8", "fp4", "fp8", "fp8", "fp4", None]
+
+
+def test_both_enabled_no_layers_raises():
+ """Both enabled with no layer lists should raise ValueError."""
+ with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"):
+ resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None)
+
+
+def test_overlapping_layers_raises():
+ """Overlapping layer assignments should raise ValueError."""
+ with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"):
+ resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5]
+ )
+
+
+def test_disabled_ignores_layers():
+ """When a format is disabled, its layers should be ignored."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6]
+ )
+ assert result == [None, None, None, None, None, None]
+
+
+def test_both_disabled():
+ """Both disabled with no layers should return all None."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert result == [None, None, None, None, None, None]
+
+
+def test_large_model_defaults_all():
+ """Auto-population should work correctly for larger models (e.g. 36 layers)."""
+ result = resolve_layer_precision(
+ num_layers=36, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert result == ["fp8"] * 36
+
+
+def test_fp8_enabled_empty_list():
+ """An explicit empty list should remain empty (not default to all)."""
+ result = resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None)
+ assert result == [None, None, None, None, None, None]
+
+
+def test_both_enabled_fp8_specified_fp4_defaults_to_remaining():
+ """When both enabled, FP8 has explicit layers, FP4 should default to the remaining layers."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None
+ )
+ assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+
+
+def test_both_enabled_fp4_specified_fp8_defaults_to_remaining():
+ """When both enabled, FP4 has explicit layers, FP8 should default to the remaining layers."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=[4, 5, 6]
+ )
+ assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+
+
+def test_returns_correct_length():
+ """Result list length should always equal num_layers."""
+ for n in [1, 6, 48]:
+ result = resolve_layer_precision(
+ num_layers=n, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert len(result) == n
+
+
+# -- generate_layer_regex --
+
+
+def test_single_layer():
+ """Single layer should produce a simple regex."""
+ regex = generate_layer_regex([3])
+ assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv")
+ assert not re.search(regex, "model.model.layers.2.self_attention.layernorm_qkv")
+
+
+def test_multiple_layers():
+ """Multiple layers should match any of them."""
+ regex = generate_layer_regex([1, 2, 3])
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+ assert re.search(regex, "model.model.layers.2.layernorm_mlp.fc1")
+ assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2")
+ assert not re.search(regex, "model.model.layers.4.self_attention.proj")
+
+
+def test_matches_correct_sublayers():
+ """Regex should only match layernorm_qkv, proj, fc1, fc2."""
+ regex = generate_layer_regex([1])
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv_something")
+ assert re.search(regex, "model.model.layers.1.self_attention.proj_something")
+ assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc1_something")
+ assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc2_something")
+ # Should not match unrelated sublayer names
+ assert not re.search(regex, "model.model.layers.1.self_attention.some_other_thing")
+
+
+def test_none_returns_disabled_pattern():
+ """None should return a pattern that matches nothing."""
+ regex = generate_layer_regex(None)
+ assert "DISABLED" in regex
+ assert not re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+
+
+def test_empty_list_returns_disabled_pattern():
+ """Empty list should return a pattern that matches nothing."""
+ regex = generate_layer_regex([])
+ assert "DISABLED" in regex
+
+
+def test_1indexed_layer_names():
+ """Regex should use 1-indexed layer numbers (matching debug API naming)."""
+ regex = generate_layer_regex([1])
+ # Should match layers.1 (1-indexed first layer)
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+ # Should NOT match layers.0 (0-indexed first layer)
+ assert not re.search(regex, "model.model.layers.0.self_attention.layernorm_qkv")
+
+
+# -- update_quant_stats_config --
+
+
+@pytest.fixture
+def fp8_only_config(tmp_path):
+ """Create an FP8-only stats config file."""
+ config = {
+ "example_fp8_tensor_stat_collection": {
+ "enabled": True,
+ "layers": {
+ "layer_name_regex_pattern": "PLACEHOLDER",
+ },
+ "transformer_engine": {
+ "LogFp8TensorStats": {
+ "enabled": True,
+ "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}],
+ }
+ },
+ }
+ }
+ config_path = tmp_path / "fp8_stats.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+ return str(config_path)
+
+
+@pytest.fixture
+def fp4_fp8_config(tmp_path):
+ """Create a combined FP4+FP8 stats config file."""
+ config = {
+ "example_fp4_tensor_stat_collection": {
+ "enabled": True,
+ "layers": {
+ "layer_name_regex_pattern": "PLACEHOLDER",
+ },
+ "transformer_engine": {
+ "LogNvfp4TensorStats": {"enabled": True},
+ },
+ },
+ "example_fp8_tensor_stat_collection": {
+ "enabled": True,
+ "layers": {
+ "layer_name_regex_pattern": "PLACEHOLDER",
+ },
+ "transformer_engine": {
+ "LogFp8TensorStats": {"enabled": True},
+ },
+ },
+ }
+ config_path = tmp_path / "fp4_fp8_stats.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+ return str(config_path)
+
+
+def test_fp8_layers_updates_regex(fp8_only_config):
+ """FP8 layer list should update the regex in the output config."""
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+ assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2")
+ assert not re.search(regex, "model.model.layers.4.self_attention.proj")
+
+
+def test_none_layers_disables_matching(fp8_only_config):
+ """None layers should set regex to match nothing."""
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=None)
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert "DISABLED" in regex
+
+
+def test_fp4_section_disabled_fp8_still_updated(fp4_fp8_config):
+ """FP4 stats section should be disabled (not yet supported), FP8 should still be updated."""
+ output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+
+ # FP4 section should be disabled
+ assert result["example_fp4_tensor_stat_collection"]["enabled"] is False
+
+ # FP8 regex should still match layers 4-6
+ fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj")
+ assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj")
+
+
+def test_original_file_not_modified(fp8_only_config):
+ """update_quant_stats_config should write to a temp file, not modify the original."""
+ with open(fp8_only_config) as f:
+ original_content = f.read()
+
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2])
+
+ assert output_path != fp8_only_config
+ with open(fp8_only_config) as f:
+ assert f.read() == original_content
+
+
+def test_preserves_other_config_fields(fp8_only_config):
+ """Non-layer fields in the config should be preserved."""
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ # The transformer_engine section should still be there
+ assert result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True
+
+
+def test_missing_section_is_skipped(fp8_only_config):
+ """If fp4 section doesn't exist in config, it should be silently skipped."""
+ # fp8_only_config has no fp4 section -- passing fp4_layers should not error
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ # Only FP8 section should exist and be updated
+ assert "example_fp4_tensor_stat_collection" not in result
+ regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv")
+
+
+def test_with_real_fp4_config():
+ """Test with the actual fp4_debugging_stats.yaml file."""
+ config_path = Path(__file__).parent.parent / "fp4_debugging_stats.yaml"
+ if not config_path.exists():
+ pytest.skip("fp4_debugging_stats.yaml not found")
+
+ output_path = update_quant_stats_config(config_file=str(config_path), fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+
+ # FP4 section should be disabled (not yet supported in current TE release)
+ assert result["example_fp4_tensor_stat_collection"]["enabled"] is False
+
+ # FP8 section should still be updated and working
+ fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj")
+ assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj")
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
index 89e85068de..be8bd48fe3 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
@@ -452,8 +452,8 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
f"checkpoint.ckpt_dir={tmp_path}",
"+dataset.pad_sequences_to_be_divisible_by=16",
"fp8_config.enabled=true",
- "fp8_stats_config.enabled=true",
- f"fp8_stats_config.fp8_log_dir={fp8_log_dir}",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={fp8_log_dir}",
"num_train_steps=4",
],
)
@@ -493,8 +493,8 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
f"checkpoint.ckpt_dir={tmp_path}",
"fp8_config.enabled=true",
"+dataset.pad_sequences_to_be_divisible_by=16",
- "fp8_stats_config.enabled=true",
- f"fp8_stats_config.fp8_log_dir={fp8_log_dir}",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={fp8_log_dir}",
"num_train_steps=4",
],
)
@@ -507,6 +507,65 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+@requires_fp8
+def test_sanity_ddp_fp8_partial_layers_stats_logging(tmp_path, recipe_path):
+ """Test DDP training with layer-wise FP8 stats (layers 1-3 only)."""
+ quant_log_dir = tmp_path / "quant_stats_logs"
+
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity",
+ overrides=[
+ f"+wandb_init_args.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "+dataset.pad_sequences_to_be_divisible_by=16",
+ "fp8_config.enabled=true",
+ "fp8_layers=[1,2,3]",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={quant_log_dir}",
+ "num_train_steps=4",
+ ],
+ )
+
+ main_ddp(sanity_config)
+
+ # Verify the log directory structure was created
+ assert quant_log_dir.exists(), "Quant log directory was not created"
+ assert (quant_log_dir / "rank_0").exists(), "rank_0 directory was not created"
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs").exists(), "nvdlfw_inspect_logs directory was not created"
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs").exists(), (
+ "nvdlfw_inspect_statistics_logs directory was not created"
+ )
+
+
+@requires_fp8
+def test_sanity_fsdp2_fp8_partial_layers_stats_logging(tmp_path, recipe_path):
+ """Test FSDP2 training with layer-wise FP8 stats (layers 1-3 only)."""
+ quant_log_dir = tmp_path / "quant_stats_logs"
+
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity",
+ overrides=[
+ f"+wandb_init_args.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "+dataset.pad_sequences_to_be_divisible_by=16",
+ "fp8_config.enabled=true",
+ "fp8_layers=[1,2,3]",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={quant_log_dir}",
+ "num_train_steps=4",
+ ],
+ )
+
+ main_fsdp2(sanity_config)
+
+ # Verify log structure
+ assert quant_log_dir.exists()
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+
+
def run_train_cmd(cmd, recipe_path):
"""Run a training command and check for errors.
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
index 0a25c02940..4ae9c751aa 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
@@ -43,9 +43,9 @@
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
from dataset import create_bshd_dataloader, create_thd_dataloader
from distributed_config import DistributedConfig
-from fp8_debugging import initialize_fp8_debugging
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
from perf_logger import PerfLogger
+from quantization import initialize_quant_stats_logging, resolve_layer_precision
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -67,18 +67,12 @@ def main(args: DictConfig) -> float | None:
torch.distributed.init_process_group(backend="nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
- # TE Debug feature logging
- if args.fp8_stats_config.enabled:
- initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
+ if args.use_fp32_master_weights:
+ raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.")
# Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2.
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
- # --- Model Configuration ---
- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
- fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
- )
-
if args.use_te:
config_class = NVLlamaConfig
model_class = NVLlamaForCausalLM
@@ -86,9 +80,40 @@ def main(args: DictConfig) -> float | None:
config_class = LlamaConfig
model_class = LlamaForCausalLM
- # --- Model Initialization ---
+ # --- Model Configuration ---
config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
+ # Resolve layer-wise quantization assignments and store on config.
+ layer_precision = resolve_layer_precision(
+ num_layers=config.num_hidden_layers,
+ fp8_enabled=args.fp8_config.enabled,
+ fp4_enabled=args.fp4_config.enabled,
+ fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
+ fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
+ )
+ config.layer_precision = layer_precision
+
+ if args.quant_stats_config.enabled:
+ initialize_quant_stats_logging(
+ quant_stats_file=args.quant_stats_config.quant_stats_file,
+ quant_log_dir=args.quant_stats_config.quant_log_dir,
+ rank=dist_config.rank,
+ layer_precision=layer_precision,
+ )
+
+ # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
+ fp8_recipe = None
+ fp4_recipe = None
+ if args.fp8_config.enabled:
+ fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+ fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+ )
+ if args.fp4_config.enabled:
+ fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
+ fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
+ )
+
+ # --- Model Initialization ---
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
# `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
# and fp8 versions of weights are kept.
@@ -99,8 +124,12 @@ def main(args: DictConfig) -> float | None:
logger.info("Initialized Model:\n%s", model)
+ # Attach quantization recipes to the model (layer precision is already on config).
+ if isinstance(model, NVLlamaForCausalLM):
+ model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
# --- Distributed Wrapping (DDP) ---
- if args.fp8_stats_config.enabled:
+ if args.quant_stats_config.enabled:
debug_api.infer_and_assign_layer_names(model)
model = model.to(device=device)
@@ -161,9 +190,8 @@ def main(args: DictConfig) -> float | None:
micro_step += 1
# DDP requires no_sync to skip all-reduce until the last microbatch in the accumulation window.
with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext():
- # Forward pass with mixed precision.
- with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
- outputs = model(**batch)
+ # Forward pass - quantization autocast is handled inside the model via set_recipes().
+ outputs = model(**batch)
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
index 4d88f2e0c0..3588c207f1 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
@@ -34,7 +34,7 @@
import transformer_engine.pytorch
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
-from torch.distributed.fsdp import fully_shard
+from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
from transformers.models.llama.configuration_llama import LlamaConfig
@@ -49,9 +49,9 @@
)
from dataset import create_bshd_dataloader, create_thd_dataloader
from distributed_config import DistributedConfig
-from fp8_debugging import initialize_fp8_debugging
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
from perf_logger import PerfLogger
+from quantization import initialize_quant_stats_logging, resolve_layer_precision
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -73,17 +73,8 @@ def main(args: DictConfig) -> float | None:
torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
- # TE Debug feature logging - MUST be done BEFORE FSDP wrapping
- if args.fp8_stats_config.enabled:
- initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
-
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
- # --- Model Configuration ---
- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
- fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
- )
-
if args.use_te:
config_class = NVLlamaConfig
model_class = NVLlamaForCausalLM
@@ -91,9 +82,44 @@ def main(args: DictConfig) -> float | None:
config_class = LlamaConfig
model_class = LlamaForCausalLM
- # --- Model Initialization ---
- config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
+ # --- Model Configuration ---
+ config = config_class.from_pretrained(
+ args.config_name_or_path,
+ dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16,
+ **args.config_kwargs,
+ )
+
+ # Resolve layer-wise quantization assignments and store on config.
+ layer_precision = resolve_layer_precision(
+ num_layers=config.num_hidden_layers,
+ fp8_enabled=args.fp8_config.enabled,
+ fp4_enabled=args.fp4_config.enabled,
+ fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
+ fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
+ )
+ config.layer_precision = layer_precision
+
+ if args.quant_stats_config.enabled:
+ initialize_quant_stats_logging(
+ quant_stats_file=args.quant_stats_config.quant_stats_file,
+ quant_log_dir=args.quant_stats_config.quant_log_dir,
+ rank=dist_config.rank,
+ layer_precision=layer_precision,
+ )
+ # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
+ fp8_recipe = None
+ fp4_recipe = None
+ if args.fp8_config.enabled:
+ fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+ fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+ )
+ if args.fp4_config.enabled:
+ fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
+ fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
+ )
+
+ # --- Model Initialization ---
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
# `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
# and fp8 versions of weights are kept.
@@ -108,10 +134,24 @@ def main(args: DictConfig) -> float | None:
logger.info("Initialized Model:\n%s", model)
# --- Distributed Wrapping (FSDP2) ---
+ if args.use_fp32_master_weights:
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ output_dtype=torch.bfloat16,
+ cast_forward_inputs=False,
+ )
+ else:
+ mp_policy = MixedPrecisionPolicy()
+
# Each decoder layer should be individually sharded before sharding the full model.
for layer in model.model.layers:
- fully_shard(layer, mesh=device_mesh["dp"])
- fully_shard(model, mesh=device_mesh["dp"])
+ fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy)
+ fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy)
+
+ # Attach quantization recipes to the model (layer precision is already on config).
+ if isinstance(model, NVLlamaForCausalLM):
+ model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
# If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters.
if args.use_meta_device and isinstance(model, NVLlamaForCausalLM):
@@ -123,7 +163,7 @@ def main(args: DictConfig) -> float | None:
model.apply(model._init_weights)
# Assign names to layers so debug API can identify them
- if args.fp8_stats_config.enabled:
+ if args.quant_stats_config.enabled:
debug_api.infer_and_assign_layer_names(model)
# --- Optimizer & Scheduler ---
@@ -175,9 +215,8 @@ def main(args: DictConfig) -> float | None:
micro_step += 1
- # Forward pass with mixed precision.
- with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
- outputs = model(**batch)
+ # Forward pass - quantization autocast is handled inside the model via set_recipes().
+ outputs = model(**batch)
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
index 06fb6630ba..d67a0b05cb 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
@@ -29,12 +29,13 @@
from pathlib import Path
import hydra
+import nvdlfw_inspect.api as debug_api
import nvtx
import torch
import transformer_engine.pytorch
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
-from torch.distributed.fsdp import fully_shard
+from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
@@ -50,6 +51,7 @@
from distributed_config import DistributedConfig
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
from perf_logger import PerfLogger
+from quantization import initialize_quant_stats_logging, resolve_layer_precision
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -79,13 +81,43 @@ def main(args: DictConfig) -> float | None:
logger.info("Created device mesh: %s", device_mesh)
# --- Model Configuration ---
- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
- fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+ config = NVLlamaConfig.from_pretrained(
+ args.config_name_or_path,
+ dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16,
+ **args.config_kwargs,
)
- # --- Model Initialization ---
- config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
+ # Resolve layer-wise quantization assignments and store on config.
+ layer_precision = resolve_layer_precision(
+ num_layers=config.num_hidden_layers,
+ fp8_enabled=args.fp8_config.enabled,
+ fp4_enabled=args.fp4_config.enabled,
+ fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
+ fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
+ )
+ config.layer_precision = layer_precision
+
+ if args.quant_stats_config.enabled:
+ initialize_quant_stats_logging(
+ quant_stats_file=args.quant_stats_config.quant_stats_file,
+ quant_log_dir=args.quant_stats_config.quant_log_dir,
+ rank=dist_config.rank,
+ layer_precision=layer_precision,
+ )
+
+ # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
+ fp8_recipe = None
+ fp4_recipe = None
+ if args.fp8_config.enabled:
+ fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+ fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+ )
+ if args.fp4_config.enabled:
+ fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
+ fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
+ )
+ # --- Model Initialization ---
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
# `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
# and fp8 versions of weights are kept.
@@ -102,11 +134,21 @@ def main(args: DictConfig) -> float | None:
# --- Distributed Wrapping (FSDP2 + CP) ---
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
+ if args.use_fp32_master_weights:
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ output_dtype=torch.bfloat16,
+ cast_forward_inputs=False,
+ )
+ else:
+ mp_policy = MixedPrecisionPolicy()
+
# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
# Each decoder layer should be individually sharded before sharding the full model.
for layer in model.model.layers:
- fully_shard(layer, mesh=cp_dp_mesh)
- fully_shard(model, mesh=cp_dp_mesh)
+ fully_shard(layer, mesh=cp_dp_mesh, mp_policy=mp_policy)
+ fully_shard(model, mesh=cp_dp_mesh, mp_policy=mp_policy)
# Attach the CP group to the model.
for layer in model.model.layers:
@@ -116,10 +158,17 @@ def main(args: DictConfig) -> float | None:
torch.cuda.Stream(),
)
+ # Attach quantization recipes to the model (layer precision is already on config).
+ model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
if args.use_meta_device:
# TE layers require special handling to initialize the weights from the meta device.
model.init_empty_weights()
+ # Assign names to layers so debug API can identify them
+ if args.quant_stats_config.enabled:
+ debug_api.infer_and_assign_layer_names(model)
+
# --- Optimizer & Scheduler ---
# Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
@@ -193,10 +242,9 @@ def main(args: DictConfig) -> float | None:
micro_step += 1
- # Forward pass with mixed precision.
+ # Forward pass - quantization autocast is handled inside the model via set_recipes().
with nvtx.annotate("Forward pass", color="green"):
- with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
- outputs = model(**batch)
+ outputs = model(**batch)
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
diff --git a/docs/docs/assets/images/recipes/lingua-1b-loss-curve.png b/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png
similarity index 100%
rename from docs/docs/assets/images/recipes/lingua-1b-loss-curve.png
rename to docs/docs/assets/images/llama3/lingua-1b-loss-curve.png
diff --git a/docs/docs/assets/images/recipes/lingua-1b-step-time.png b/docs/docs/assets/images/llama3/lingua-1b-step-time.png
similarity index 100%
rename from docs/docs/assets/images/recipes/lingua-1b-step-time.png
rename to docs/docs/assets/images/llama3/lingua-1b-step-time.png
diff --git a/docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png b/docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png
new file mode 100644
index 0000000000..daed0ce0ba
Binary files /dev/null and b/docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png differ
diff --git a/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png b/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png
new file mode 100644
index 0000000000..a304fa29b7
Binary files /dev/null and b/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png differ