Skip to content

Commit 5d06c10

Browse files
committed
Add decomposition for aten.native_batch_norm_backward op
This commit adds decomposition for the `aten.native_batch_norm_backward` op. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 6a52d17 commit 5d06c10

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

functorch/_src/decompositions.py

+54
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,60 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
497497
return (d_input, d_weight, d_bias)
498498

499499

500+
@register_decomposition(aten.native_batch_norm_backward)
501+
def native_batch_norm_backward(grad_out: Tensor, input: Tensor, weight: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_invstd: Optional[Tensor], train: bool, eps: float, output_mask: List[bool]) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
502+
input_shape = input.shape
503+
input_rank = input.dim()
504+
assert input_rank >= 2, "rank of the input must be at least 2"
505+
506+
axis = 1
507+
num_features = prod(input_shape) / input_shape[axis]
508+
mean = save_mean
509+
invstd = save_invstd
510+
if train:
511+
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"
512+
else:
513+
mean = running_mean
514+
invstd = torch.rsqrt(running_var + eps)
515+
516+
broadcast_mask = [1] * input_rank
517+
broadcast_mask[axis] = input_shape[axis]
518+
519+
reduction_axes = []
520+
for i in range(input_rank):
521+
if i != axis:
522+
reduction_axes.append(i)
523+
524+
mean = torch.reshape(mean, broadcast_mask)
525+
norm = 1.0 / num_features
526+
grad_output_sum = torch.sum(grad_out, reduction_axes)
527+
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
528+
529+
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
530+
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
531+
532+
grad_scale = None
533+
if weight is None:
534+
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
535+
else:
536+
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
537+
grad_input = None
538+
if train:
539+
proj = (input - mean) * proj_scale
540+
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
541+
else:
542+
grad_input = grad_out * grad_scale
543+
544+
grad_weight = None
545+
if output_mask[1]:
546+
grad_weight = dot_p * invstd
547+
548+
grad_bias = None
549+
if output_mask[2]:
550+
grad_bias = grad_output_sum
551+
return (grad_input, grad_weight, grad_bias)
552+
553+
500554
@register_decomposition(aten.clamp_min)
501555
def clamp_min(self: Tensor, min: float):
502556
return torch.clamp(self, min=min)

0 commit comments

Comments
 (0)