Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,8 @@ def compare_models_one_step(args) -> None:
torch.distributed.broadcast(megatron_next_token, get_last_rank())


if __name__ == "__main__":
def build_parser() -> argparse.ArgumentParser:
"""Build the CLI parser."""
parser = argparse.ArgumentParser(description="Compare HuggingFace and Megatron models")
parser.add_argument(
"--hf_model_path",
Expand Down Expand Up @@ -898,9 +899,18 @@ def compare_models_one_step(args) -> None:
default=None,
help="Directory where the exported HF model will be saved during round-trip. Defaults to current directory.",
)
parser.add_argument("--trust_remote_code", action="store_true", help="if trust_remote_code")
parser.add_argument(
"--trust_remote_code",
"--trust-remote-code",
dest="trust_remote_code",
action="store_true",
help="Allow custom model code execution.",
)
return parser


args = parser.parse_args()
if __name__ == "__main__":
args = build_parser().parse_args()

compare_models_one_step(args)

Expand Down
1 change: 0 additions & 1 deletion examples/conversion/convert_checkpoints_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def export_megatron_to_hf(
mp_overrides=mp_overrides,
wrap_with_ddp=False,
)
megatron_model = [m.cuda() for m in megatron_model]

print_rank_0(f"Saving HuggingFace model to: {hf_path}")
bridge.save_hf_pretrained(
Expand Down
121 changes: 93 additions & 28 deletions examples/models/deepseek_v4/conversion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,24 @@
# MODEL_VARIANT: one of DeepSeek-V4-Flash, DeepSeek-V4-Flash-Base,
# DeepSeek-V4-Pro, DeepSeek-V4-Pro-Base
# (default: DeepSeek-V4-Flash-Base)
# HF_MODEL_ID: HuggingFace model ID or local path (default: deepseek-ai/${MODEL_VARIANT})
# EP: expert-parallel size (default: 4 for Flash, 8 for Pro)
# PP: pipeline-parallel size (default: 1 for Flash, 4 for Pro)
# NNODES, NPROC_PER_NODE, NODE_RANK, MASTER_ADDR, MASTER_PORT: torchrun launch overrides
# UV_RUN_ARGS: extra arguments passed after `uv run` (for example: "--active --no-sync")
# RUN_COMPARE: set to 1 to run the HF/Megatron logits comparison (default: 0)
# RUN_ROUNDTRIP: set to 1 to run the second import/export round-trip (default: 0)
#
# Defaults below are for GB200 (192 GB). For H100 (80 GB) configs, see README.md.

set -xeuo pipefail

WORKSPACE=${WORKSPACE:-/workspace}
MODEL_VARIANT=${MODEL_VARIANT:-DeepSeek-V4-Flash-Base}
HF_MODEL_ID="deepseek-ai/${MODEL_VARIANT}"
HF_MODEL_ID=${HF_MODEL_ID:-deepseek-ai/${MODEL_VARIANT}}
RUN_COMPARE=${RUN_COMPARE:-0}
RUN_ROUNDTRIP=${RUN_ROUNDTRIP:-0}
read -r -a UV_RUN_ARGS_ARRAY <<< "${UV_RUN_ARGS:-}"

if [[ -z "${EP:-}" ]]; then
case "${MODEL_VARIANT}" in
Expand All @@ -48,13 +56,66 @@ if [[ -z "${PP:-}" ]]; then
esac
fi
TP=1
WORLD_SIZE=$((TP * PP * EP))

_first_slurm_host() {
local nodelist=$1
local prefix entries first

if [[ "${nodelist}" != *"["* ]]; then
echo "${nodelist%%,*}"
return
fi

prefix="${nodelist%%[*}"
entries="${nodelist#*[}"
entries="${entries%%]*}"
first="${entries%%,*}"
first="${first%%-*}"
echo "${prefix}${first}"
}

NNODES=${NNODES:-${SLURM_JOB_NUM_NODES:-1}}
if [[ -z "${NPROC_PER_NODE:-}" ]]; then
if [[ -n "${SLURM_GPUS_ON_NODE:-}" && "${SLURM_GPUS_ON_NODE}" =~ ^[0-9]+$ ]]; then
NPROC_PER_NODE=${SLURM_GPUS_ON_NODE}
elif (( NNODES > 1 )); then
NPROC_PER_NODE=$((WORLD_SIZE / NNODES))
else
NPROC_PER_NODE=${WORLD_SIZE}
fi
fi

if (( NPROC_PER_NODE * NNODES != WORLD_SIZE )); then
echo "NPROC_PER_NODE (${NPROC_PER_NODE}) * NNODES (${NNODES}) must equal TP*PP*EP (${WORLD_SIZE})." >&2
exit 1
fi

MASTER_PORT=${MASTER_PORT:-29500}
TORCHRUN=(uv run "${UV_RUN_ARGS_ARRAY[@]}" python -m torch.distributed.run --nproc_per_node "${NPROC_PER_NODE}")
if (( NNODES > 1 )); then
NODE_RANK=${NODE_RANK:-${SLURM_NODEID:-0}}
if [[ -z "${MASTER_ADDR:-}" ]]; then
if [[ -n "${SLURM_NODELIST:-}" ]]; then
MASTER_ADDR=$(_first_slurm_host "${SLURM_NODELIST}")
else
echo "MASTER_ADDR must be set when NNODES=${NNODES}." >&2
exit 1
fi
fi
TORCHRUN+=(--nnodes "${NNODES}" --node_rank "${NODE_RANK}" --master_addr "${MASTER_ADDR}" --master_port "${MASTER_PORT}")
fi

echo "DeepSeek-V4 conversion: MODEL_VARIANT=${MODEL_VARIANT} HF_MODEL_ID=${HF_MODEL_ID}"
echo "Parallelism: TP=${TP} PP=${PP} EP=${EP} WORLD_SIZE=${WORLD_SIZE}"
echo "Launch: NNODES=${NNODES} NPROC_PER_NODE=${NPROC_PER_NODE} NODE_RANK=${NODE_RANK:-0} MASTER_ADDR=${MASTER_ADDR:-local} MASTER_PORT=${MASTER_PORT}"

MEGATRON_DIR="${WORKSPACE}/models/${MODEL_VARIANT}"
EXPORT_DIR="${WORKSPACE}/models/${MODEL_VARIANT}-hf-export"
ITER=iter_0000000

# 1) Import HF -> Megatron (FP8 / MXFP4 dequantised to bfloat16 in-flight)
uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
"${TORCHRUN[@]}" \
examples/conversion/convert_checkpoints_multi_gpu.py import \
--hf-model "${HF_MODEL_ID}" \
--megatron-path "${MEGATRON_DIR}" \
Expand All @@ -63,16 +124,18 @@ uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
--trust-remote-code

# 2) Compare HF and Megatron logits on a short prompt
uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "${HF_MODEL_ID}" \
--megatron_model_path "${MEGATRON_DIR}" \
--prompt "Hello, how are you?" \
--tp ${TP} --pp ${PP} --ep ${EP} \
--trust-remote-code
if [[ "${RUN_COMPARE}" == "1" ]]; then
"${TORCHRUN[@]}" \
examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "${HF_MODEL_ID}" \
--megatron_model_path "${MEGATRON_DIR}" \
--prompt "Hello, how are you?" \
--tp ${TP} --pp ${PP} --ep ${EP} \
--trust-remote-code
fi

# 3) Export Megatron -> HF (round-trip)
uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
"${TORCHRUN[@]}" \
examples/conversion/convert_checkpoints_multi_gpu.py export \
--hf-model "${HF_MODEL_ID}" \
--megatron-path "${MEGATRON_DIR}/${ITER}" \
Expand All @@ -86,22 +149,24 @@ uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
# DSv4 HF weights are quantized (FP8/MXFP4), so the first import dequantises
# to bfloat16. A true lossless roundtrip re-imports the exported bf16 checkpoint
# and compares against the first export.
ROUNDTRIP_DIR="${WORKSPACE}/models/${MODEL_VARIANT}-roundtrip"
uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
examples/conversion/convert_checkpoints_multi_gpu.py import \
--hf-model "${EXPORT_DIR}" \
--megatron-path "${ROUNDTRIP_DIR}" \
--tp ${TP} --pp ${PP} --ep ${EP} \
--torch-dtype bfloat16 \
--trust-remote-code
if [[ "${RUN_ROUNDTRIP}" == "1" ]]; then
ROUNDTRIP_DIR="${WORKSPACE}/models/${MODEL_VARIANT}-roundtrip"
"${TORCHRUN[@]}" \
examples/conversion/convert_checkpoints_multi_gpu.py import \
--hf-model "${EXPORT_DIR}" \
--megatron-path "${ROUNDTRIP_DIR}" \
--tp ${TP} --pp ${PP} --ep ${EP} \
--torch-dtype bfloat16 \
--trust-remote-code

ROUNDTRIP_EXPORT_DIR="${WORKSPACE}/models/${MODEL_VARIANT}-roundtrip-export"
uv run python -m torch.distributed.run --nproc_per_node=$((PP * EP)) \
examples/conversion/convert_checkpoints_multi_gpu.py export \
--hf-model "${EXPORT_DIR}" \
--megatron-path "${ROUNDTRIP_DIR}" \
--hf-path "${ROUNDTRIP_EXPORT_DIR}" \
--tp ${TP} --pp ${PP} --ep ${EP} \
--torch-dtype bfloat16 \
--distributed-save \
--trust-remote-code
ROUNDTRIP_EXPORT_DIR="${WORKSPACE}/models/${MODEL_VARIANT}-roundtrip-export"
"${TORCHRUN[@]}" \
examples/conversion/convert_checkpoints_multi_gpu.py export \
--hf-model "${EXPORT_DIR}" \
--megatron-path "${ROUNDTRIP_DIR}" \
--hf-path "${ROUNDTRIP_EXPORT_DIR}" \
--tp ${TP} --pp ${PP} --ep ${EP} \
--torch-dtype bfloat16 \
--distributed-save \
--trust-remote-code
fi
161 changes: 161 additions & 0 deletions tests/unit_tests/examples/test_convert_checkpoints_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for examples/conversion/convert_checkpoints_multi_gpu.py."""

from __future__ import annotations

import importlib.util
import pathlib
import sys

import pytest


_REPO_ROOT = pathlib.Path(__file__).resolve().parents[3]
_CLI_PATH = _REPO_ROOT / "examples" / "conversion" / "convert_checkpoints_multi_gpu.py"


@pytest.fixture(scope="module")
def cli():
"""Load the conversion script as a module under a stable test name."""
spec = importlib.util.spec_from_file_location("convert_checkpoints_multi_gpu_under_test", _CLI_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
try:
spec.loader.exec_module(module)
yield module
finally:
sys.modules.pop(spec.name, None)


class _FakeProvider:
def __init__(self, calls):
self.calls = calls
self.pipeline_model_parallel_layout = None

def finalize(self):
self.calls.append(("finalize", (), {}))

def initialize_model_parallel(self, *args, **kwargs):
self.calls.append(("initialize_model_parallel", args, kwargs))

def provide_distributed_model(self, *args, **kwargs):
self.calls.append(("provide_distributed_model", args, kwargs))
return ["megatron-model"]


class _FakeModelBridge:
def get_hf_tokenizer_kwargs(self):
return {"padding_side": "left"}


class _FakeHfPretrained:
config = type("Config", (), {"num_hidden_layers": 1, "num_nextn_predict_layers": 0})()


class TestImportHfToMegatron:
def test_import_saves_megatron_checkpoint_with_tokenizer_metadata(self, cli, monkeypatch):
calls = []

class FakeBridge:
_model_bridge = _FakeModelBridge()
hf_pretrained = _FakeHfPretrained()

def to_megatron_provider(self, *args, **kwargs):
calls.append(("to_megatron_provider", args, kwargs))
return _FakeProvider(calls)

def save_megatron_model(self, *args, **kwargs):
calls.append(("save_megatron_model", args, kwargs))

def fake_from_hf_pretrained(*args, **kwargs):
calls.append(("from_hf_pretrained", args, kwargs))
return FakeBridge()

monkeypatch.setattr(cli, "_ensure_distributed_initialized", lambda timeout_minutes: None)
monkeypatch.setattr(cli, "is_safe_repo", lambda *, trust_remote_code, hf_path: trust_remote_code)
monkeypatch.setattr(cli.AutoBridge, "from_hf_pretrained", fake_from_hf_pretrained)

cli.import_hf_to_megatron.__wrapped__(
hf_model="hf",
megatron_path="/ckpt",
tp=1,
pp=1,
ep=2,
etp=1,
torch_dtype="bfloat16",
trust_remote_code=True,
)

save_call = next(call for call in calls if call[0] == "save_megatron_model")
assert save_call[1] == (["megatron-model"], "/ckpt")
assert "low_memory_save" not in save_call[2]
assert save_call[2]["hf_tokenizer_path"] == "hf"
assert save_call[2]["hf_tokenizer_kwargs"] == {"padding_side": "left", "trust_remote_code": True}


class TestExportMegatronToHf:
def test_export_does_not_move_loaded_model_to_cuda(self, cli, monkeypatch):
calls = []

class FakeModelShard:
def cuda(self):
raise AssertionError("export should not force loaded checkpoint shards to CUDA")

fake_model = [FakeModelShard()]

class FakeBridge:
_model_bridge = object()
hf_pretrained = _FakeHfPretrained()

def to_megatron_provider(self, *args, **kwargs):
calls.append(("to_megatron_provider", args, kwargs))
return _FakeProvider(calls)

def load_megatron_model(self, *args, **kwargs):
calls.append(("load_megatron_model", args, kwargs))
return fake_model

def save_hf_pretrained(self, *args, **kwargs):
calls.append(("save_hf_pretrained", args, kwargs))

def fake_from_hf_pretrained(*args, **kwargs):
calls.append(("from_hf_pretrained", args, kwargs))
return FakeBridge()

monkeypatch.setattr(cli, "_ensure_distributed_initialized", lambda timeout_minutes: None)
monkeypatch.setattr(cli, "is_safe_repo", lambda *, trust_remote_code, hf_path: trust_remote_code)
monkeypatch.setattr(cli.AutoBridge, "from_hf_pretrained", fake_from_hf_pretrained)

cli.export_megatron_to_hf.__wrapped__(
hf_model="hf",
megatron_path="/ckpt/iter_0000000",
hf_path="/hf-export",
tp=1,
pp=1,
ep=2,
etp=1,
torch_dtype="bfloat16",
trust_remote_code=True,
distributed_save=True,
)

load_call = next(call for call in calls if call[0] == "load_megatron_model")
assert load_call[1] == ("/ckpt/iter_0000000",)
assert load_call[2]["mp_overrides"]["expert_model_parallel_size"] == 2

save_call = next(call for call in calls if call[0] == "save_hf_pretrained")
assert save_call[1] == (fake_model, "/hf-export")
assert save_call[2]["distributed_save"] is True
15 changes: 15 additions & 0 deletions tests/unit_tests/test_compare_mask_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,18 @@ def test_hf_path_receives_ones_like_attention_mask(self):
assert call_kwargs["attention_mask"].dtype == torch.bool
assert call_kwargs["attention_mask"].shape == input_ids.shape
assert torch.equal(call_kwargs["attention_mask"], expected_mask)

@pytest.mark.parametrize("flag", ["--trust_remote_code", "--trust-remote-code"])
def test_trust_remote_code_accepts_underscore_and_hyphen_flags(self, flag):
"""Test that compare.py accepts both trust_remote_code flag spellings."""
args = compare.build_parser().parse_args(
[
"--hf_model_path",
"hf",
"--prompt",
"Hello",
flag,
]
)

assert args.trust_remote_code is True
Loading