@@ -497,6 +497,60 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
497
497
return (d_input , d_weight , d_bias )
498
498
499
499
500
+ @register_decomposition (aten .native_batch_norm_backward )
501
+ def native_batch_norm_backward (grad_out : Tensor , input : Tensor , weight : Optional [Tensor ], running_mean : Optional [Tensor ], running_var : Optional [Tensor ], save_mean : Optional [Tensor ], save_invstd : Optional [Tensor ], train : bool , eps : float , output_mask : List [bool ]) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
502
+ input_shape = input .shape
503
+ input_rank = input .dim ()
504
+ assert input_rank >= 2 , "rank of the input must be at least 2"
505
+
506
+ axis = 1
507
+ num_features = prod (input_shape ) / input_shape [axis ]
508
+ mean = save_mean
509
+ invstd = save_invstd
510
+ if train :
511
+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
512
+ else :
513
+ mean = running_mean
514
+ invstd = torch .rsqrt (running_var + eps )
515
+
516
+ broadcast_mask = [1 ] * input_rank
517
+ broadcast_mask [axis ] = input_shape [axis ]
518
+
519
+ reduction_axes = []
520
+ for i in range (input_rank ):
521
+ if i != axis :
522
+ reduction_axes .append (i )
523
+
524
+ mean = torch .reshape (mean , broadcast_mask )
525
+ norm = 1.0 / num_features
526
+ grad_output_sum = torch .sum (grad_out , reduction_axes )
527
+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
528
+
529
+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
530
+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
531
+
532
+ grad_scale = None
533
+ if weight is None :
534
+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
535
+ else :
536
+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
537
+ grad_input = None
538
+ if train :
539
+ proj = (input - mean ) * proj_scale
540
+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
541
+ else :
542
+ grad_input = grad_out * grad_scale
543
+
544
+ grad_weight = None
545
+ if output_mask [1 ]:
546
+ grad_weight = dot_p * invstd
547
+
548
+ grad_bias = None
549
+ if output_mask [2 ]:
550
+ grad_bias = grad_output_sum
551
+ return (grad_input , grad_weight , grad_bias )
552
+
553
+
500
554
@register_decomposition (aten .clamp_min )
501
555
def clamp_min (self : Tensor , min : float ):
502
556
return torch .clamp (self , min = min )
0 commit comments