@@ -60,64 +60,89 @@ def prepare_for_loss(
6060 self ,
6161 logits : torch .Tensor ,
6262 labels : torch .Tensor ,
63- loss_weight : torch .Tensor ,
63+ soft_masked : torch .Tensor ,
6464 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
65- """Prepare logits, labels, and weights for loss computation.
65+ """Prepare logits, labels, and soft_masked for loss computation.
6666
6767 Override in subclasses to implement MLM vs CLM-specific slicing/filtering.
6868
6969 Args:
7070 logits: Logits of shape (batch, seq_len, vocab_size)
7171 labels: Target labels of shape (batch, seq_len)
72- loss_weight: Loss weights of shape (batch, seq_len)
72+ soft_masked: Boolean mask of shape (batch, seq_len)
7373
7474 Returns:
75- Tuple of (logits, labels, loss_weight ) ready for loss computation
75+ Tuple of (logits, labels, soft_masked ) ready for loss computation
7676 """
7777 raise NotImplementedError ("Subclasses must implement prepare_for_loss" )
7878
7979 def compute_loss (
8080 self ,
8181 logits : torch .Tensor ,
8282 labels : torch .Tensor ,
83- loss_weight : torch .Tensor ,
84- ) -> torch .Tensor :
85- """Compute weighted cross-entropy loss.
83+ soft_masked : torch .Tensor ,
84+ soft_masked_weight : float ,
85+ ) -> dict [str , torch .Tensor ]:
86+ """Compute weighted cross-entropy loss with three variants.
8687
87- Shared loss computation logic for MLM and CLM.
88+ Computes three loss values:
89+ 1. loss_full: All tokens weighted equally (baseline)
90+ 2. loss_non_soft_masked: Only non-soft-masked tokens
91+ 3. loss: Training loss with soft_masked_weight applied
8892
8993 Args:
9094 logits: Logits (1D or 2D)
9195 labels: Target labels (1D)
92- loss_weight: Loss weights (1D)
96+ soft_masked: Boolean mask (1D), True for soft-masked positions
97+ soft_masked_weight: Weight for soft-masked positions in training loss
9398
9499 Returns:
95- Scalar loss value
100+ Dictionary with keys: loss, loss_full, loss_non_soft_masked
96101 """
97- loss = F .cross_entropy (logits , labels , reduction = "none" )
98- loss = (loss * loss_weight / loss_weight .sum ()).sum ()
99- return loss
102+ # Single cross-entropy computation (efficient)
103+ loss_per_token = F .cross_entropy (logits , labels , reduction = "none" )
104+
105+ # Create three weight masks
106+ weight_full = torch .ones_like (loss_per_token )
107+ weight_non_soft_masked = (~ soft_masked ).float ()
108+ weight_training = torch .where (soft_masked , soft_masked_weight , 1.0 )
109+
110+ # Compute normalized losses
111+ def normalize_and_sum (loss : torch .Tensor , weight : torch .Tensor ) -> torch .Tensor :
112+ weight_sum = weight .sum ()
113+ if weight_sum > 0 :
114+ return (loss * weight / weight_sum ).sum ()
115+ else :
116+ # Handle edge case: no tokens with weight
117+ return torch .tensor (0.0 , device = loss .device , dtype = loss .dtype )
118+
119+ return {
120+ "loss" : normalize_and_sum (loss_per_token , weight_training ),
121+ "loss_full" : normalize_and_sum (loss_per_token , weight_full ),
122+ "loss_non_soft_masked" : normalize_and_sum (loss_per_token , weight_non_soft_masked ),
123+ }
100124
101125 def forward (
102126 self ,
103127 input_ids : torch .Tensor ,
104128 labels : torch .Tensor ,
105- loss_weight : torch .Tensor ,
106- ) -> torch .Tensor :
129+ soft_masked : torch .Tensor ,
130+ soft_masked_weight : float ,
131+ ) -> dict [str , torch .Tensor ]:
107132 """Forward pass with loss calculation.
108133
109134 Args:
110135 input_ids: Input token IDs of shape (batch, seq_len), int8 or long
111136 labels: True token IDs of shape (batch, seq_len)
112- loss_weight: Per-token loss weights of shape (batch, seq_len)
137+ soft_masked: Boolean mask of shape (batch, seq_len), True for soft-masked positions
138+ soft_masked_weight: Weight for soft-masked positions in training loss
113139
114140 Returns:
115- Weighted cross-entropy loss (scalar )
141+ Dictionary with loss components (loss, loss_full, loss_non_soft_masked )
116142 """
117143 logits = self .get_logits (input_ids )
118- logits , labels , loss_weight = self .prepare_for_loss (logits , labels , loss_weight )
119- loss = self .compute_loss (logits , labels , loss_weight )
120- return loss
144+ logits , labels , soft_masked = self .prepare_for_loss (logits , labels , soft_masked )
145+ return self .compute_loss (logits , labels , soft_masked , soft_masked_weight )
121146
122147
123148class MLM (LM ):
@@ -130,30 +155,30 @@ def prepare_for_loss(
130155 self ,
131156 logits : torch .Tensor ,
132157 labels : torch .Tensor ,
133- loss_weight : torch .Tensor ,
158+ soft_masked : torch .Tensor ,
134159 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
135160 """Filter to masked positions only.
136161
137162 Args:
138163 logits: Logits of shape (batch, seq_len, vocab_size)
139164 labels: Target labels of shape (batch, seq_len), -100 for ignored positions
140- loss_weight: Loss weights of shape (batch, seq_len)
165+ soft_masked: Boolean mask of shape (batch, seq_len)
141166
142167 Returns:
143- Filtered (logits, labels, loss_weight ) for masked positions only
168+ Filtered (logits, labels, soft_masked ) for masked positions only
144169 """
145170 # Reshape to 1D
146171 logits = logits .view (- 1 , logits .size (- 1 ))
147172 labels = labels .view (- 1 ).long ()
148- loss_weight = loss_weight .view (- 1 )
173+ soft_masked = soft_masked .view (- 1 )
149174
150175 # Filter to masked positions (labels != -100)
151176 mask = labels != - 100
152177 logits = logits [mask ]
153178 labels = labels [mask ]
154- loss_weight = loss_weight [mask ]
179+ soft_masked = soft_masked [mask ]
155180
156- return logits , labels , loss_weight
181+ return logits , labels , soft_masked
157182
158183
159184class CLM (LM ):
@@ -182,21 +207,21 @@ def prepare_for_loss(
182207 self ,
183208 logits : torch .Tensor ,
184209 labels : torch .Tensor ,
185- loss_weight : torch .Tensor ,
210+ soft_masked : torch .Tensor ,
186211 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
187212 """Slice for next-token prediction.
188213
189214 Args:
190215 logits: Logits of shape (batch, seq_len, vocab_size)
191216 labels: Target labels of shape (batch, seq_len) (same as input_ids)
192- loss_weight: Loss weights of shape (batch, seq_len)
217+ soft_masked: Boolean mask of shape (batch, seq_len)
193218
194219 Returns:
195- Sliced (logits, labels, loss_weight ) for next-token prediction
220+ Sliced (logits, labels, soft_masked ) for next-token prediction
196221 """
197222 # Slice: logits[:, :-1] predicts labels[:, 1:]
198223 logits = logits [:, :- 1 ].reshape (- 1 , logits .size (- 1 ))
199224 labels = labels [:, 1 :].reshape (- 1 ).long ()
200- loss_weight = loss_weight [:, 1 :].reshape (- 1 )
225+ soft_masked = soft_masked [:, 1 :].reshape (- 1 )
201226
202- return logits , labels , loss_weight
227+ return logits , labels , soft_masked
0 commit comments