@@ -173,91 +173,6 @@ def fused_linear_cross_entropy_forward_megatron_chunked(
173
173
174
174
return loss , None , grad_input , grad_weight , grad_bias
175
175
176
- def fused_linear_cross_entropy_forward_megatron (
177
- _input ,
178
- weight ,
179
- target ,
180
- bias = None ,
181
- reduction = "none" ,
182
- ):
183
- device = _input .device
184
- BT , H = _input .shape
185
- V = weight .shape [0 ]
186
-
187
- grad_weight = torch .zeros_like (weight , device = device ) if weight .requires_grad else None
188
- grad_input = torch .zeros_like (_input , device = device )
189
- grad_bias = torch .zeros_like (bias , device = device ) if bias is not None else None
190
- # we use fp32 for loss accumulator
191
- loss_1d = torch .zeros (BT , dtype = torch .float32 , device = device )
192
-
193
- # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
194
- rank = get_tensor_model_parallel_rank ()
195
- world_size = get_tensor_model_parallel_world_size ()
196
- vocab_start , vocab_end = VocabUtility .vocab_range_from_per_partition_vocab_size (V , rank , world_size )
197
-
198
- target_mask = (target < vocab_start ) | (target >= vocab_end )
199
- adjusted_target = target .clone () - vocab_start # relative id
200
- adjusted_target [target_mask ] = 0
201
- adjusted_target_1d = adjusted_target .view (- 1 )
202
-
203
- # input
204
- # when doing matmul, use the original precision
205
- logits = (_input @ weight .t ()).float () # chunk_size x V
206
- if bias is not None :
207
- logits = logits + bias
208
-
209
- # # ensure _input and target are contiguous
210
- # logits_chunk = logits_chunk.contiguous() # [chunk_size, vocab_size]
211
- # target_chunk = target_chunk.contiguous() # [chunk_size]
212
-
213
- max_logits = torch .max (logits , dim = - 1 )[0 ]
214
- torch .distributed .all_reduce (max_logits , op = torch .distributed .ReduceOp .MAX , group = get_tensor_model_parallel_group (), async_op = False )
215
- logits = logits - max_logits .unsqueeze (- 1 )
216
-
217
- sum_exp_logits = torch .sum (torch .exp (logits ), dim = - 1 )
218
- torch .distributed .all_reduce (sum_exp_logits , op = torch .distributed .ReduceOp .SUM , group = get_tensor_model_parallel_group (), async_op = False )
219
-
220
-
221
- predicted_logits = logits [torch .arange (BT , device = logits .device ), adjusted_target_1d ]
222
- predicted_logits [target_mask ] = 0.0
223
- handle_predicted_logits = torch .distributed .all_reduce (predicted_logits , op = torch .distributed .ReduceOp .SUM , group = get_tensor_model_parallel_group (), async_op = True )
224
-
225
- # Compute gradient
226
- grad_logits = torch .exp (logits ).div_ (sum_exp_logits .unsqueeze (- 1 ))
227
- grad_logits [torch .arange (BT , device = grad_logits .device ), adjusted_target_1d ] -= 1.0 - target_mask .float () # chunk_size x V
228
- grad_input = grad_logits .to (dtype = torch .half ) @ weight
229
- torch .distributed .all_reduce (grad_input , group = get_tensor_model_parallel_group (), async_op = False )
230
-
231
- if grad_weight is not None :
232
- torch .addmm (
233
- input = grad_weight ,
234
- mat1 = grad_logits .t ().to (
235
- _input .dtype
236
- ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
237
- mat2 = _input ,
238
- out = grad_weight ,
239
- alpha = 1.0 ,
240
- beta = 1.0 ,
241
- )
242
- if bias is not None :
243
- torch .add (
244
- input = grad_bias ,
245
- other = grad_logits .sum (dim = 0 ),
246
- out = grad_bias ,
247
- alpha = 1.0 ,
248
- )
249
- handle_predicted_logits .wait ()
250
- loss_chunk = torch .log (sum_exp_logits ) - predicted_logits
251
- loss_1d = loss_chunk
252
-
253
- if reduction == "none" :
254
- loss = loss_1d
255
- else :
256
- loss = torch .sum (loss_1d )
257
-
258
- return loss , None , grad_input , grad_weight , grad_bias
259
-
260
-
261
176
def fused_linear_cross_entropy_backward (grad_output , grad_input , grad_weight , grad_bias ):
262
177
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
263
178
if not torch .equal (grad_output , torch .tensor (1.0 , device = grad_output .device )):
0 commit comments