Skip to content

Conversation

@h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Oct 25, 2025

What does this PR do?

Type of change: New feature

Overview:

  • Support the nano and nano-VL in eagle3 online mode:
    • Added submodule path detection for base model, lm_head, and embeddings to adapt different base model naming structure;
    • Refactored data loading/preprocessing to support VLM;
  • Attn backend improvement:
    • Added option of sdpa in case flex_attn doesn't work.
    • Added a unified TTT mask function that produce either BlockMask for flex_attn or tensor masks for regular attn.
  • Logging improvements:
    • Added estimated AR validation during training. This is available for both online and offline.
    • Plot estimated AR and training acc to wandb for better training visualization;

Usage

For VLM as base model, pass in extra arguments --vlm_processor <hf_model_path> --vlm_img_dir <path to images> in original launching commands. Other usage unchanged.
E.g.

./launch_train.sh --model $MODEL \
            --output_dir $OUTPUT_DIR \
            --data $DATA \
            --num_gpu 1 \
            --num_epochs 2 \
            --train_bs 2 \
            --lr 3e-5 \
            --eagle_config eagle_config.json \
            --training_seq_len 4096 \
            --vlm_processor $MODEL \
            --vlm_img_dir  <path to images>

Testing

Tested short training with HF Online training on following models:

  • llama-3.2-1b - data: daring-anteater
  • The new nano (Hyrbid LLM) - data: daring-anteater
  • The nano-VL - data: Llama-Nemotron-VLM-Dataset-v1/ocr_1

See loss decreasing and AR > 1.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@h-guo18 h-guo18 self-assigned this Oct 25, 2025
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 25, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

device = self.model.layers[-1].self_attn.q_proj.weight.device
elif hasattr(self.model.layers[-1].self_attn, "qkv_proj"):
device = self.model.layers[-1].self_attn.qkv_proj.weight.device
self.eagle_module.to(self.dtype).to(device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: confirm this device detection with @yeyu-nvidia

@h-guo18 h-guo18 force-pushed the haoguo/support-nano branch from 8eb6abf to a85d473 Compare October 27, 2025 23:40
@h-guo18 h-guo18 changed the title Feat: eagle3 support for nanov3 Feat: eagle3 support for nano2-vlm and nano3 Oct 27, 2025
@codecov
Copy link

codecov bot commented Oct 27, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.38%. Comparing base (41de55f) to head (9c791d9).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #463   +/-   ##
=======================================
  Coverage   73.38%   73.38%           
=======================================
  Files         180      180           
  Lines       18110    18110           
=======================================
  Hits        13290    13290           
  Misses       4820     4820           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 changed the title Feat: eagle3 support for nano2-vlm and nano3 Feat: Eagle3 HF Online - support nano2-vlm and nano3 Oct 27, 2025
@h-guo18 h-guo18 marked this pull request as ready for review October 27, 2025 23:56
@h-guo18 h-guo18 requested a review from a team as a code owner October 27, 2025 23:56
@h-guo18 h-guo18 requested a review from yeyu-nvidia October 27, 2025 23:56
@h-guo18 h-guo18 changed the title Feat: Eagle3 HF Online - support nano2-vlm and nano3 Feat: Eagle3 HF Online - support nemotron models Oct 28, 2025
input_ids = output.input_ids[0]
attention_mask = output.attention_mask[0]
loss_mask = torch.ones_like(input_ids)
labels = torch.full_like(input_ids, IGNORE_TOKEN_ID)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So all labels are IGNORE_TOKEN_ID?

return ret


class OfflineSupervisedDataset(Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this support VLM data?

if wandb and is_master():
wandb.init()

def on_log(self, args, state, control, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how you estimate AR? I'm not sure it's a good idea to expose "estimated AR" as it may mislead users.

metadata={"help": "Path to the d2t cache directory."},
)
vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."})
vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is VLM processor?

for param in self.model.embed_tokens.parameters():
# find base model, lm head, and embeddings paths
self._find_base_model_parts()
self.eagle_module.to(self._base_model.dtype).to(self._base_model_lm_head.weight.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check if ptq/inference fails. We want to make sure eagle_module.device is the same as last base model decoder layer, but this is not necessarily the same as lm_head.device.


dtypemin = torch.finfo(self._base_llm_config.dtype).min
q_len = seq_length
kv_len = seq_length * (2 + ttt_step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 2 + ttt_step?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants