Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
import torch
from omegaconf import OmegaConf

from megatron.bridge.recipes.nemotron_vl.nemotron_nano_v2_vl import (
nemotron_nano_v2_vl_12b_peft_config,
nemotron_nano_v2_vl_12b_sft_config,
)
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.finetune import finetune
from megatron.bridge.training.llava_step import forward_step
Expand Down Expand Up @@ -105,16 +109,20 @@ def main() -> None:
datefmt="%Y-%m-%d %H:%M:%S",
)

from megatron.bridge.recipes.nemotron_vl.nemotron_nano_v2_vl import nemotron_nano_v2_vl_12b_finetune_config

cfg: ConfigContainer = nemotron_nano_v2_vl_12b_finetune_config(
hf_model_path=args.hf_model_path,
pretrained_checkpoint=args.pretrained_checkpoint,
lora_on_language_model=args.lora_on_language_model,
lora_on_vision_model=args.lora_on_vision_model,
)

logger.info("Loaded base configuration for finetuning")
config_kwargs = {
"hf_model_path": args.hf_model_path,
"pretrained_checkpoint": args.pretrained_checkpoint,
}
if args.lora_on_language_model or args.lora_on_vision_model:
cfg: ConfigContainer = nemotron_nano_v2_vl_12b_peft_config(
**config_kwargs,
lora_on_language_model=args.lora_on_language_model,
lora_on_vision_model=args.lora_on_vision_model,
)
logger.info("Loaded base configuration for PEFT")
else:
cfg = nemotron_nano_v2_vl_12b_sft_config(**config_kwargs)
logger.info("Loaded base configuration for SFT")

if get_rank_safe() == 0:
cfg.print_yaml()
Expand Down
119 changes: 107 additions & 12 deletions src/megatron/bridge/recipes/nemotron_vl/nemotron_nano_v2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Nemotron Nano V2 VL finetuning recipes with parameterless API.
"""Nemotron Nano V2 VL finetuning recipes with parameterless defaults.

This module provides SFT and PEFT configurations for Nemotron Nano V2 VL 12B.
"""
Expand All @@ -21,27 +21,106 @@

from megatron.bridge import AutoBridge
from megatron.bridge.peft.base import PEFT
from megatron.bridge.peft.dora import DoRA
from megatron.bridge.peft.lora import VLMLoRA
from megatron.bridge.recipes.common import _peft_common_vlm, _sft_common_vlm
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.training.config import ConfigContainer


_DEFAULT_HF_MODEL_PATH = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16"
_ALL_COMPONENT_LORA_TARGET_MODULES = ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]
_LANGUAGE_LORA_TARGET_MODULES = [
"*language_model*.linear_qkv",
"*language_model*.linear_proj",
"*language_model*.linear_fc1",
"*language_model*.linear_fc2",
]
_VISION_LORA_TARGET_MODULES = [
"*vision_model*.linear_qkv",
"*vision_model*.linear_proj",
"*vision_model*.linear_fc1",
"*vision_model*.linear_fc2",
"*vision_projection*.linear_fc1",
"*vision_projection*.linear_fc2",
]


def _nemotron_vl_target_modules(
*,
lora_on_language_model: bool = True,
lora_on_vision_model: bool = True,
) -> list[str]:
"""Return adapter target modules for the selected Nemotron VL components."""
if not lora_on_language_model and not lora_on_vision_model:
raise ValueError("At least one of lora_on_language_model or lora_on_vision_model must be True.")

if lora_on_language_model and lora_on_vision_model:
return _ALL_COMPONENT_LORA_TARGET_MODULES.copy()

if lora_on_language_model:
return _LANGUAGE_LORA_TARGET_MODULES.copy()

return _VISION_LORA_TARGET_MODULES.copy()


def _nemotron_vl_lora_config(
*,
lora_on_language_model: bool = True,
lora_on_vision_model: bool = True,
) -> VLMLoRA:
"""Build a Nemotron VL LoRA config that respects component selection flags."""
target_modules = _nemotron_vl_target_modules(
lora_on_language_model=lora_on_language_model,
lora_on_vision_model=lora_on_vision_model,
)

if lora_on_language_model and lora_on_vision_model:
return VLMLoRA(
target_modules=target_modules,
dim=16,
alpha=32,
)

if lora_on_language_model:
return VLMLoRA(
target_modules=target_modules,
dim=16,
alpha=32,
freeze_vision_model=False,
freeze_vision_projection=False,
)

return VLMLoRA(
target_modules=target_modules,
dim=16,
alpha=32,
freeze_language_model=False,
)


# =============================================================================
# Nemotron Nano V2 VL 12B SFT Configuration
# =============================================================================
def nemotron_nano_v2_vl_12b_sft_config() -> ConfigContainer:
def nemotron_nano_v2_vl_12b_sft_config(
hf_model_path: str = _DEFAULT_HF_MODEL_PATH,
pretrained_checkpoint: str | None = None,
) -> ConfigContainer:
"""Return a full SFT config for Nemotron Nano V2 VL 12B.

Default configuration: 1 node, 8 GPUs
- TP=4, PP=1
- LR=1e-5 (finetune default)
- Sequence length: 4096

Args:
hf_model_path: Hugging Face model path used for provider and processor setup.
pretrained_checkpoint: Optional checkpoint path to load before finetuning.
"""
cfg = _sft_common_vlm()

# Model configuration
hf_path = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16"
hf_path = hf_model_path
cfg.model = AutoBridge.from_hf_pretrained(hf_path, trust_remote_code=True).to_megatron_provider(load_weights=False)
cfg.model.seq_length = 4096

Expand Down Expand Up @@ -121,6 +200,8 @@ def nemotron_nano_v2_vl_12b_sft_config() -> ConfigContainer:

# Checkpoint config - override save_interval from common
cfg.checkpoint.save_interval = 200
if pretrained_checkpoint is not None:
cfg.checkpoint.pretrained_checkpoint = pretrained_checkpoint

# FP8 and MXFP8 settings (disabled by default)
cfg.mixed_precision = "bf16_mixed"
Expand All @@ -141,7 +222,14 @@ def nemotron_nano_v2_vl_12b_sft_config() -> ConfigContainer:
# =============================================================================
# Nemotron Nano V2 VL 12B PEFT Configuration
# =============================================================================
def nemotron_nano_v2_vl_12b_peft_config(peft_scheme: str | PEFT = "lora") -> ConfigContainer:
def nemotron_nano_v2_vl_12b_peft_config(
peft_scheme: str | PEFT = "lora",
*,
hf_model_path: str = _DEFAULT_HF_MODEL_PATH,
pretrained_checkpoint: str | None = None,
lora_on_language_model: bool = True,
lora_on_vision_model: bool = True,
) -> ConfigContainer:
"""Return a PEFT config for Nemotron Nano V2 VL 12B.

Default configuration: 1 node, 8 GPUs
Expand All @@ -152,28 +240,33 @@ def nemotron_nano_v2_vl_12b_peft_config(peft_scheme: str | PEFT = "lora") -> Con
Args:
peft_scheme: PEFT scheme - "lora", "dora", or a custom PEFT instance.
Note: Default uses VLMLoRA targeting all model components.
hf_model_path: Hugging Face model path used for provider and processor setup.
pretrained_checkpoint: Optional checkpoint path to load before finetuning.
lora_on_language_model: Whether LoRA targets the language model when PEFT is "lora" or "dora".
lora_on_vision_model: Whether LoRA targets the vision model when PEFT is "lora" or "dora".
"""
cfg = _peft_common_vlm()

# PEFT scheme - Nemotron uses VLMLoRA by default
if isinstance(peft_scheme, str) and peft_scheme.lower() == "lora":
cfg.peft = VLMLoRA(
target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"],
dim=16,
alpha=32,
cfg.peft = _nemotron_vl_lora_config(
lora_on_language_model=lora_on_language_model,
lora_on_vision_model=lora_on_vision_model,
)
elif isinstance(peft_scheme, str) and peft_scheme.lower() == "dora":
cfg.peft = VLMLoRA(
target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"],
cfg.peft = DoRA(
target_modules=_nemotron_vl_target_modules(
lora_on_language_model=lora_on_language_model,
lora_on_vision_model=lora_on_vision_model,
),
dim=16,
alpha=32,
dora=True,
)
else:
cfg.peft = peft_scheme

# Model configuration
hf_path = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16"
hf_path = hf_model_path
cfg.model = AutoBridge.from_hf_pretrained(hf_path, trust_remote_code=True).to_megatron_provider(load_weights=False)
cfg.model.seq_length = 4096

Expand Down Expand Up @@ -253,6 +346,8 @@ def nemotron_nano_v2_vl_12b_peft_config(peft_scheme: str | PEFT = "lora") -> Con

# Checkpoint config - override save_interval from common
cfg.checkpoint.save_interval = 200
if pretrained_checkpoint is not None:
cfg.checkpoint.pretrained_checkpoint = pretrained_checkpoint

# FP8 and MXFP8 settings (disabled by default)
cfg.mixed_precision = "bf16_mixed"
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/examples/test_nemotron_vl_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Import smoke test for the Nemotron Nano V2 VL finetune example."""

from __future__ import annotations

import importlib.util
import pathlib
import sys


_REPO_ROOT = pathlib.Path(__file__).resolve().parents[3]
_SCRIPT_PATH = _REPO_ROOT / "examples" / "models" / "nemotron" / "nemotron_vl" / "finetune_nemotron_nano_v2_vl.py"


def test_nemotron_nano_v2_vl_finetune_example_imports():
"""Test that the finetune example does not import removed recipe symbols."""
spec = importlib.util.spec_from_file_location("nemotron_vl_finetune_under_test", _SCRIPT_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
try:
spec.loader.exec_module(module)
finally:
sys.modules.pop(spec.name, None)
Loading
Loading