Skip to content

Commit d0a738c

Browse files
committed
starting on dcp tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 470e10d commit d0a738c

File tree

11 files changed

+950
-9
lines changed

11 files changed

+950
-9
lines changed

.devcontainer/recipes/Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ FROM nvcr.io/nvidia/pytorch:26.02-py3
66
# Remove once bug has been addressed in the nvidia/pytorch container.
77
RUN rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json
88

9+
RUN --mount=type=cache,target=/var/cache/apt \
10+
--mount=type=cache,target=/var/lib/apt \
11+
apt-get update && \
12+
DEBIAN_FRONTEND=noninteractive apt-get install -y tmux && \
13+
rm -rf /var/lib/apt/lists/*
14+
915
RUN --mount=type=cache,target=/root/.cache/pip \
1016
--mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \
1117
PIP_CONSTRAINT= pip install -r /workspace/requirements.txt

bionemo-recipes/models/esm2/tests/common/fixtures.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,26 @@ def fp8_recipe(request):
102102
return request.param
103103

104104

105+
RECIPE_NAME_TO_FACTORY = {
106+
"DelayedScaling": recipe_module.DelayedScaling,
107+
"Float8CurrentScaling": recipe_module.Float8CurrentScaling,
108+
"Float8BlockScaling": recipe_module.Float8BlockScaling,
109+
"MXFP8BlockScaling": recipe_module.MXFP8BlockScaling,
110+
"NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True),
111+
}
112+
113+
114+
def recipe_to_name(recipe):
115+
"""Convert a recipe instance to its CLI-passable string name."""
116+
return type(recipe).__name__
117+
118+
119+
def recipe_from_name(name):
120+
"""Reconstruct a recipe instance from its CLI-passable string name."""
121+
factory = RECIPE_NAME_TO_FACTORY[name]
122+
return factory()
123+
124+
105125
@pytest.fixture(params=["bshd", "thd"])
106126
def input_format(request):
107127
"""Fixture to parametrize the input format."""
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Worker script for distributed DCP (Distributed Checkpoint) tests.
17+
18+
Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init.
19+
Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip.
20+
"""
21+
22+
import argparse
23+
import importlib.util
24+
import os
25+
import shutil
26+
import sys
27+
import tempfile
28+
from pathlib import Path
29+
30+
import torch
31+
import torch.distributed as dist
32+
import torch.distributed.checkpoint as dcp
33+
import transformer_engine.pytorch
34+
from torch.distributed.device_mesh import init_device_mesh
35+
from torch.distributed.fsdp import fully_shard
36+
from transformers import set_seed
37+
38+
39+
def _setup_sys_path():
40+
"""Add model root and tests directory to sys.path so model/test imports work."""
41+
script_dir = Path(__file__).resolve().parent # tests/common/
42+
tests_dir = script_dir.parent # tests/
43+
model_root = tests_dir.parent # model root (e.g., models/esm2/)
44+
for p in [str(model_root), str(tests_dir)]:
45+
if p not in sys.path:
46+
sys.path.insert(0, p)
47+
48+
49+
def _load_tester_class(tester_file, class_name):
50+
"""Dynamically load a tester class from a file path."""
51+
# Ensure the tester file's directory tree is importable
52+
tester_dir = str(Path(tester_file).parent)
53+
tester_parent = str(Path(tester_file).parent.parent)
54+
for p in [tester_parent, tester_dir]:
55+
if p not in sys.path:
56+
sys.path.insert(0, p)
57+
58+
spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file)
59+
module = importlib.util.module_from_spec(spec)
60+
spec.loader.exec_module(module)
61+
return getattr(module, class_name)
62+
63+
64+
def _build_and_shard_model(tester, config, recipe, device, device_mesh):
65+
"""Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device."""
66+
model_class = tester.get_model_class()
67+
68+
if recipe is not None:
69+
with transformer_engine.pytorch.quantized_model_init(recipe=recipe):
70+
model = model_class(config)
71+
else:
72+
model = model_class(config)
73+
74+
# Shard each transformer layer, then the root model
75+
for layer in tester.get_layer_path(model):
76+
fully_shard(layer, mesh=device_mesh)
77+
fully_shard(model, mesh=device_mesh)
78+
79+
model.to(device)
80+
return model
81+
82+
83+
def _forward(model, input_data, recipe):
84+
"""Run a forward pass and return the model outputs."""
85+
if recipe is not None:
86+
# torch.autocast is needed when model was built with quantized_model_init
87+
# (weights are FP8, non-quantized ops need bf16 casting)
88+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
89+
with transformer_engine.pytorch.autocast(recipe=recipe):
90+
return model(**input_data)
91+
else:
92+
return model(**input_data)
93+
94+
95+
def _train_one_step(model, input_data, recipe, lr=1e-4):
96+
"""Run a single training step (forward + backward + optimizer step) and return detached logits."""
97+
model.train()
98+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
99+
optimizer.zero_grad()
100+
101+
outputs = _forward(model, input_data, recipe)
102+
loss = outputs.logits.sum()
103+
loss.backward()
104+
optimizer.step()
105+
106+
return outputs.logits.detach().clone()
107+
108+
109+
def _run_eval_forward(model, input_data, recipe):
110+
"""Run an eval forward pass and return detached logits."""
111+
model.eval()
112+
with torch.no_grad():
113+
outputs = _forward(model, input_data, recipe)
114+
return outputs.logits.detach().clone()
115+
116+
117+
def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42):
118+
"""Core DCP round-trip test: build → train → save → rebuild → load → eval → compare."""
119+
from tests.common.fixtures import recipe_from_name
120+
121+
rank = dist.get_rank()
122+
local_rank = int(os.environ["LOCAL_RANK"])
123+
world_size = dist.get_world_size()
124+
device = f"cuda:{local_rank}"
125+
torch.cuda.set_device(local_rank)
126+
127+
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,))
128+
129+
# Resolve FP8 recipe
130+
recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None
131+
132+
# Build config
133+
set_seed(seed)
134+
config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd")
135+
136+
# Prepare input data
137+
input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32)
138+
139+
# --- Model A: build, shard, train one step, then eval ---
140+
set_seed(seed)
141+
model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh)
142+
_train_one_step(model_a, input_data, recipe)
143+
logits_a = _run_eval_forward(model_a, input_data, recipe)
144+
145+
# --- DCP Save ---
146+
# Rank 0 creates temp dir, broadcast path to all ranks
147+
if rank == 0:
148+
tmp_dir = tempfile.mkdtemp(prefix="dcp_test_")
149+
else:
150+
tmp_dir = None
151+
tmp_dir_list = [tmp_dir]
152+
dist.broadcast_object_list(tmp_dir_list, src=0)
153+
tmp_dir = tmp_dir_list[0]
154+
155+
checkpoint_path = os.path.join(tmp_dir, "checkpoint")
156+
157+
state_dict_a = {"model": model_a.state_dict()}
158+
dcp.save(state_dict_a, checkpoint_id=checkpoint_path)
159+
160+
dist.barrier()
161+
162+
# Free model_a
163+
del model_a, state_dict_a
164+
torch.cuda.empty_cache()
165+
166+
# --- Model B: build fresh, shard, load, eval ---
167+
set_seed(seed)
168+
model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh)
169+
170+
state_dict_b = {"model": model_b.state_dict()}
171+
dcp.load(state_dict_b, checkpoint_id=checkpoint_path)
172+
model_b.load_state_dict(state_dict_b["model"], strict=False)
173+
174+
logits_b = _run_eval_forward(model_b, input_data, recipe)
175+
176+
# --- Compare ---
177+
tolerances = tester.get_tolerances()
178+
torch.testing.assert_close(
179+
logits_a,
180+
logits_b,
181+
atol=tolerances.dcp_logits_atol,
182+
rtol=tolerances.dcp_logits_rtol,
183+
msg=lambda x: f"DCP round-trip logits mismatch: {x}",
184+
)
185+
186+
# Cleanup
187+
del model_b, state_dict_b
188+
torch.cuda.empty_cache()
189+
dist.barrier()
190+
191+
if rank == 0:
192+
shutil.rmtree(tmp_dir, ignore_errors=True)
193+
194+
print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})")
195+
196+
197+
if __name__ == "__main__":
198+
_setup_sys_path()
199+
200+
parser = argparse.ArgumentParser(description="DCP distributed test worker")
201+
parser.add_argument(
202+
"--tester-file", required=True, help="Absolute path to the test file containing the tester class"
203+
)
204+
parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)")
205+
parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)")
206+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
207+
args = parser.parse_args()
208+
209+
dist.init_process_group(backend="nccl")
210+
211+
try:
212+
tester_cls = _load_tester_class(args.tester_file, args.tester_class)
213+
tester = tester_cls()
214+
run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed)
215+
finally:
216+
dist.destroy_process_group()

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import fnmatch
1919
import gc
20+
import os
21+
import subprocess
2022
from abc import ABC, abstractmethod
2123
from dataclasses import dataclass
2224
from pathlib import Path
@@ -39,6 +41,12 @@
3941
HAS_DATA_CENTER_GPU = False
4042

4143

44+
_requires_multi_gpu = pytest.mark.skipif(
45+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
46+
reason="Test requires at least 2 GPUs",
47+
)
48+
49+
4250
@dataclass
4351
class TestTolerances:
4452
"""Model-specific test tolerances for numerical comparisons."""
@@ -65,6 +73,10 @@ class TestTolerances:
6573
fp8_logits_atol: float = 5.0
6674
fp8_logits_rtol: float = 0.1
6775

76+
# DCP (distributed checkpoint) round-trip tolerances
77+
dcp_logits_atol: float = 0.0
78+
dcp_logits_rtol: float = 0.0
79+
6880
# Meta device initialization tolerances
6981
init_mean_atol: float = 1e-3
7082
init_mean_rtol: float = 1e-4
@@ -1099,4 +1111,69 @@ def test_generate_with_cache_beam_search(self):
10991111
assert output_ids.shape[0] == 2
11001112
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
11011113

1102-
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.
1114+
# ==================== Distributed Checkpoint (DCP) Tests ====================
1115+
1116+
def _get_dcp_worker_script_path(self) -> str:
1117+
"""Return the absolute path to the run_distributed_dcp.py worker script."""
1118+
return str(Path(__file__).resolve().parent / "run_distributed_dcp.py")
1119+
1120+
def _get_tester_file_and_class(self):
1121+
"""Return (file_path, class_name) for dynamic loading in the worker subprocess."""
1122+
import inspect
1123+
1124+
return os.path.abspath(inspect.getfile(type(self))), type(self).__name__
1125+
1126+
def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2):
1127+
"""Launch the DCP worker script via torchrun and assert it succeeds."""
1128+
tester_file, class_name = self._get_tester_file_and_class()
1129+
worker_script = self._get_dcp_worker_script_path()
1130+
1131+
cmd = [
1132+
"torchrun",
1133+
f"--nproc_per_node={nproc_per_node}",
1134+
"--rdzv-backend=c10d",
1135+
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
1136+
worker_script,
1137+
"--tester-file",
1138+
tester_file,
1139+
"--tester-class",
1140+
class_name,
1141+
]
1142+
1143+
if fp8_recipe_name is not None:
1144+
cmd.extend(["--fp8-recipe", fp8_recipe_name])
1145+
1146+
result = subprocess.run(
1147+
cmd,
1148+
check=False,
1149+
text=True,
1150+
stdout=subprocess.PIPE,
1151+
stderr=subprocess.PIPE,
1152+
timeout=300,
1153+
)
1154+
if result.returncode != 0:
1155+
print(f"STDOUT:\n{result.stdout}")
1156+
print(f"STDERR:\n{result.stderr}")
1157+
pytest.fail(f"DCP worker failed with exit code {result.returncode}")
1158+
1159+
def test_dcp_output_parity_single_gpu(self, unused_tcp_port):
1160+
"""Test FSDP2 + DCP save/load round-trip on a single GPU."""
1161+
self._run_dcp_worker(unused_tcp_port, nproc_per_node=1)
1162+
1163+
def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port):
1164+
"""Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU."""
1165+
from .fixtures import recipe_to_name
1166+
1167+
self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1)
1168+
1169+
@_requires_multi_gpu
1170+
def test_dcp_output_parity(self, unused_tcp_port):
1171+
"""Test that a model sharded with FSDP2 produces identical outputs after DCP save/load."""
1172+
self._run_dcp_worker(unused_tcp_port)
1173+
1174+
@_requires_multi_gpu
1175+
def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port):
1176+
"""Test DCP save/load with FP8 quantized_model_init."""
1177+
from .fixtures import recipe_to_name
1178+
1179+
self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe))

bionemo-recipes/models/esm2/tests/test_distributed_fp8.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def requires_fp8(func):
3434
)
3535

3636

37-
@pytest.mark.parametrize(
38-
"strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))]
39-
)
37+
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"])
4038
@requires_fp8
4139
def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
4240
cmd = [
@@ -63,9 +61,7 @@ def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
6361
pytest.fail(f"Command failed with exit code {result.returncode}")
6462

6563

66-
@pytest.mark.parametrize(
67-
"strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))]
68-
)
64+
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"])
6965
@requires_fp8
7066
@requires_multi_gpu
7167
def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):

0 commit comments

Comments
 (0)