- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
Feat: Eagle3 HF Online - support nemotron models #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
           Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here.  | 
    
| device = self.model.layers[-1].self_attn.q_proj.weight.device | ||
| elif hasattr(self.model.layers[-1].self_attn, "qkv_proj"): | ||
| device = self.model.layers[-1].self_attn.qkv_proj.weight.device | ||
| self.eagle_module.to(self.dtype).to(device) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: confirm this device detection with @yeyu-nvidia
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
8eb6abf    to
    a85d473      
    Compare
  
    
          Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@           Coverage Diff           @@
##             main     #463   +/-   ##
=======================================
  Coverage   73.38%   73.38%           
=======================================
  Files         180      180           
  Lines       18110    18110           
=======================================
  Hits        13290    13290           
  Misses       4820     4820           ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
  | 
    
Signed-off-by: h-guo18 <[email protected]>
| input_ids = output.input_ids[0] | ||
| attention_mask = output.attention_mask[0] | ||
| loss_mask = torch.ones_like(input_ids) | ||
| labels = torch.full_like(input_ids, IGNORE_TOKEN_ID) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So all labels are IGNORE_TOKEN_ID?
| return ret | ||
| 
               | 
          ||
| 
               | 
          ||
| class OfflineSupervisedDataset(Dataset): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this support VLM data?
| if wandb and is_master(): | ||
| wandb.init() | ||
| 
               | 
          ||
| def on_log(self, args, state, control, **kwargs): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how you estimate AR? I'm not sure it's a good idea to expose "estimated AR" as it may mislead users.
| metadata={"help": "Path to the d2t cache directory."}, | ||
| ) | ||
| vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) | ||
| vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is VLM processor?
| for param in self.model.embed_tokens.parameters(): | ||
| # find base model, lm head, and embeddings paths | ||
| self._find_base_model_parts() | ||
| self.eagle_module.to(self._base_model.dtype).to(self._base_model_lm_head.weight.device) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check if ptq/inference fails. We want to make sure eagle_module.device is the same as last base model decoder layer, but this is not necessarily the same as lm_head.device.
| 
               | 
          ||
| dtypemin = torch.finfo(self._base_llm_config.dtype).min | ||
| q_len = seq_length | ||
| kv_len = seq_length * (2 + ttt_step) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 2 + ttt_step?
What does this PR do?
Type of change: New feature
Overview:
base model,lm_head, andembeddingsto adapt different base model naming structure;sdpain caseflex_attndoesn't work.BlockMaskfor flex_attn or tensor masks for regular attn.Usage
For VLM as base model, pass in extra arguments
--vlm_processor <hf_model_path> --vlm_img_dir <path to images>in original launching commands. Other usage unchanged.E.g.
Testing
Tested short training with HF Online training on following models:
llama-3.2-1b- data: daring-anteaterSee loss decreasing and AR > 1.
Before your PR is "Ready for review"
Additional Information