@@ -497,6 +497,61 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
497
497
return (d_input , d_weight , d_bias )
498
498
499
499
500
+ @register_decomposition (aten .native_batch_norm_backward )
501
+ def native_batch_norm_backward (grad_out : Tensor , input : Tensor , weight : Optional [Tensor ], running_mean : Optional [Tensor ], running_var : Optional [Tensor ], save_mean : Optional [Tensor ], save_invstd : Optional [Tensor ], train : bool , eps : float , output_mask : List [bool ]) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
502
+ input_shape = input .shape
503
+ input_rank = input .dim ()
504
+ assert input_rank >= 2 , "rank of the input must be at least 2"
505
+
506
+ axis = 1
507
+ num_features = prod (input_shape ) / input_shape [axis ]
508
+ mean = save_mean
509
+ invstd = save_invstd
510
+ if train :
511
+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
512
+ else :
513
+ assert running_mean is not None and running_var is not None
514
+ mean = running_mean
515
+ invstd = torch .rsqrt (running_var + eps )
516
+
517
+ broadcast_mask = [1 ] * input_rank
518
+ broadcast_mask [axis ] = input_shape [axis ]
519
+
520
+ reduction_axes = []
521
+ for i in range (input_rank ):
522
+ if i != axis :
523
+ reduction_axes .append (i )
524
+
525
+ mean = torch .reshape (mean , broadcast_mask )
526
+ norm = 1.0 / num_features
527
+ grad_output_sum = torch .sum (grad_out , reduction_axes )
528
+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
529
+
530
+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
531
+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
532
+
533
+ grad_scale = None
534
+ if weight is None :
535
+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
536
+ else :
537
+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
538
+ grad_input = None
539
+ if train :
540
+ proj = (input - mean ) * proj_scale
541
+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
542
+ else :
543
+ grad_input = grad_out * grad_scale
544
+
545
+ grad_weight = None
546
+ if output_mask [1 ]:
547
+ grad_weight = dot_p * invstd
548
+
549
+ grad_bias = None
550
+ if output_mask [2 ]:
551
+ grad_bias = grad_output_sum
552
+ return (grad_input , grad_weight , grad_bias )
553
+
554
+
500
555
@register_decomposition (aten .clamp_min )
501
556
def clamp_min (self : Tensor , min : float ):
502
557
return torch .clamp (self , min = min )
0 commit comments