4141from torch .utils .data import Dataset
4242from transformers .trainer_pt_utils import LabelSmoother
4343
44+ from modelopt .torch .utils .loss_mask import get_loss_mask_recovery
45+
4446IGNORE_TOKEN_ID = LabelSmoother .ignore_index
4547
4648
@@ -96,20 +98,27 @@ class OfflineSupervisedDataset(Dataset):
9698 dumped_files (list): A list of file paths to the dumped .pt files.
9799 answer_only_loss (bool): If True, use the ``loss_mask`` stored in each .pt
98100 file so that only assistant-produced tokens contribute to the loss.
99- Raises ``ValueError`` on ``__getitem__`` if the file lacks ``loss_mask``.
101+ If a file lacks ``loss_mask`` and ``tokenizer`` has a registered
102+ model-specific recovery (see ``modelopt.torch.utils.loss_mask``), the
103+ mask is rebuilt from ``input_ids``; otherwise ``__getitem__`` raises
104+ ``ValueError``.
100105 If False (default), a uniform all-ones mask is used regardless of what
101106 is stored in the file (backward compatible).
107+ tokenizer: Optional tokenizer used to recover the assistant mask for dumps
108+ that lack a stored ``loss_mask``.
102109 """
103110
104111 def __init__ (
105112 self ,
106113 dumped_files ,
107114 answer_only_loss : bool = False ,
115+ tokenizer = None ,
108116 ):
109117 """Initialize with a list of .pt file paths."""
110118 super ().__init__ ()
111119 self .dumped_files = dumped_files
112120 self .answer_only_loss = answer_only_loss
121+ self .tokenizer = tokenizer
113122
114123 def __len__ (self ):
115124 return len (self .dumped_files )
@@ -121,13 +130,22 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
121130 labels [..., :- 1 ] = offline_data ["input_ids" ][..., 1 :]
122131
123132 if self .answer_only_loss :
124- if "loss_mask" not in offline_data :
133+ recovery = get_loss_mask_recovery (self .tokenizer ) if self .tokenizer else None
134+ if "loss_mask" in offline_data :
135+ loss_mask = offline_data ["loss_mask" ].to (offline_data ["input_ids" ].dtype )
136+ elif recovery is not None :
137+ # Dumps from tokenizers that cannot emit assistant masks carry no
138+ # loss_mask; rebuild it from the token ids.
139+ loss_mask = recovery .compute (self .tokenizer , offline_data ["input_ids" ]).to (
140+ offline_data ["input_ids" ].dtype
141+ )
142+ else :
125143 raise ValueError (
126144 f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
127145 f".pt file, but { self .dumped_files [i ]} does not have one. Re-dump "
128- f"with --answer-only-loss in compute_hidden_states_*.py."
146+ f"with --answer-only-loss in compute_hidden_states_*.py, or pass a "
147+ f"tokenizer with a registered loss-mask recovery."
129148 )
130- loss_mask = offline_data ["loss_mask" ].to (offline_data ["input_ids" ].dtype )
131149 else :
132150 loss_mask = torch .ones_like (offline_data ["input_ids" ])
133151
0 commit comments