Skip to content

Conversation

@yashaswikarnati
Copy link

@yashaswikarnati yashaswikarnati commented Nov 3, 2025

Qwen3VL Verification

Dense Model (8B)

Model: Qwen/Qwen3-VL-8B-Instruct

HF Logits Matching

Megatron Top 5:

[('\n\n', 25.625), ('?\n\n', 24.125), ('?', 22.0), ('?\n', 21.5), ('\n', 21.25)]

HF Top 5:

[('\n\n', 26.5), ('?\n\n', 25.25), ('?', 23.125), ('?\n', 22.75), ('\n', 22.25)]

Fine-tuning on cord-v2 Dataset

Train vs validation loss curves
WandB Link: Qwen3VL 8B Fine-tune Run

image

Inference on Sample of cord-v2 Dataset

Command:

uv run python -m torch.distributed.run --nproc_per_node=8 \
examples/conversion/hf_to_megatron_generate_vlm.py \
--hf_model_path="Qwen/Qwen3-VL-8B-Instruct" \
--image_path=./examples/recipes/qwen_vl/image.png \
--prompt="Describe this items and process in this image." \
--megatron_model_path ./logs/checkpoints/qwen3vl8b/ \
--max_new_tokens 150

Before Fine-tune:

<|im_start|>assistant
This is a receipt from a receipt. The receipt.
<|im_end|>

After Fine-tune:

<s_total><s_total_price>302,016</s_total_price></s_total>
<s_sub_total>
  <s_tax_price>52,416</s_tax_price>
  <s_discount_price>19,000</s_discount_price>
  <s_subtotal_price>259,000</s_subtotal_price>
</s_sub_total>
<s_menu>
  <s_price>59,000</s_price>
  <s_nm>Bintang Bremer</s_nm>
  <s_cnt>1</s_cnt>
  <sep/>
  <s_price>190,000</s_price>
  <s_nm>Chicken H-H</s_nm>
  <s_cnt>1</s_cnt>
</s_menu>

MoE Model (30B - A3B)

Model: Qwen/Qwen3-VL-30B-A3B-Instruct
WandB Link: Qwen3VL MoE Fine-tune Run

HF Logits Matching

Command:

uv run python -m torch.distributed.run --nproc_per_node=8 \
examples/conversion/compare_hf_and_megatron/compare.py \
--model_class Qwen3VLMoeForConditionalGeneration \
--hf_model_path="Qwen/Qwen3-VL-30B-A3B-Instruct" \
--prompt="What is the capital of California" \
--ep 8

Megatron Top 5:

[('\n\n', 23.375), ('?\n\n', 21.75), ('\n', 21.375), ('?', 20.75), (' The', 20.625)]

HF Top 5:

[('?\n\n', 22.875), ('\n\n', 22.5), ('?\n', 22.375), ('?', 21.875), ('\n', 21.75)]

Cosine Similarity: 0.973547

Fine-tuning on cord-v2 Dataset

Train vs validation loss curves
image

Inference on Sample of cord-v2 Dataset

Command:

uv run python -m torch.distributed.run --nproc_per_node=8 \
examples/conversion/hf_to_megatron_generate_vlm.py \
--hf_model_path="Qwen/Qwen3-VL-30B-A3B-Instruct" \
--image_path=/path/to/image \
--prompt="Describe this image." \
--ep 8 \
--megatron_model_path=/path/to/ckpt

Before Fine-tune:

<|im_start|>assistant
Of course
<|im_end|>

After Fine-tune:

<|im_start|>assistant
<s_total><s_total_price>302,016</s_total_price></s_total>
<s_sub_total>
  <s_tax_price>52,416</s_tax_price>
  <s_subtotal_price>259,000</s_subtotal_price>
  <s_service_price>9,600</s_service_price>
  <s_discount_price>19,000</s_discount_price>
</s_sub_total>
<s_menu>
  <s_price>59,000</s_price>
  <s_nm>Bintang Bremer</s_nm>
  <s_cnt>1</s_cnt>
  <sep/>
  <s_price>190,000</s_price>
  <s_nm>Chicken H-H</s_nm>
  <s_cnt>1</s_cnt>
  <sep/>
  <s_price>10,000</s_price>
  <s_nm>Ades</s_nm>
  <s_cnt>1</s_cnt>
</s_menu>

@yashaswikarnati
Copy link
Author

/ok to test c061222


def process_inputs(tokenizer, processor, image_path: Optional[str], prompt: str, is_vl_model: bool):
def pad_input_ids_to_tp_multiple(input_ids, tp_size: int, pad_token_id: int = 0):
"""Pad input_ids so sequence length is divisible by tp_size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this is required when sequence parallel is on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Loading pretrained weights (recommended for finetune):
1) Import HF checkpoint to Megatron format:
$ python examples/conversion/convert_checkpoints.py import \
$ torchrun --nproc_per_node=1 examples/conversion/convert_checkpoints.py import \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the process hangs if I dont use torchrun, so updated

@@ -0,0 +1,213 @@
#!/usr/bin/env python3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merge the 2 finetune vl scripts? you can rename one of them and remove the other one.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, just have one common script now

@@ -0,0 +1,181 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated w/ qwen3_vl bridge?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed duplicate

)

# rebuild the transformer block
self.decoder = Qwen3VLTransformerBlock(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there shouldn't be a rebuit, should just update layerspec

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py#L202

I think the layer spec override is for TransformerLayer, not for TransformerBlock. We might still need to override?

)

# rebuild the transformer block
self.decoder = Qwen3VLTransformerBlock(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need rebuild blocks. It's extra overhead, just update the layer spec to use Qwen3VLTransformerBlock

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add in comment why this need to be overriden - deepstack_visual_embeds

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@@ -0,0 +1,179 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this file to model.py as well? seems single file is enough easier to understand.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer to have shorter self contained files for easier maintenance than one long file, can change it if have strong preference

@@ -0,0 +1,154 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated with qwen3vl_provider?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the duplicate

pass


def extract_expert_number_from_param(param_name: str) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import this from src/megatron/bridge/utils/common_utils.py

Copy link
Author

@yashaswikarnati yashaswikarnati Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@yashaswikarnati
Copy link
Author

/ok to test 3e7f918

@yashaswikarnati
Copy link
Author

@yaoyu-33 addressed all the comments, ptal when you get chance

Signed-off-by: ykarnati <[email protected]>
@yashaswikarnati
Copy link
Author

/ok to test 69233a5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants