Skip to content

Commit cbeaf7a

Browse files
committed
Add decomposition for aten.native_batch_norm_backward op
This commit adds decomposition for the `aten.native_batch_norm_backward` op. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent fb6f749 commit cbeaf7a

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

functorch/_src/decompositions.py

+55
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,61 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
497497
return (d_input, d_weight, d_bias)
498498

499499

500+
@register_decomposition(aten.native_batch_norm_backward)
501+
def native_batch_norm_backward(grad_out: Tensor, input: Tensor, weight: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_invstd: Optional[Tensor], train: bool, eps: float, output_mask: List[bool]) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
502+
input_shape = input.shape
503+
input_rank = input.dim()
504+
assert input_rank >= 2, "rank of the input must be at least 2"
505+
506+
axis = 1
507+
num_features = prod(input_shape) / input_shape[axis]
508+
mean = save_mean
509+
invstd = save_invstd
510+
if train:
511+
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"
512+
else:
513+
assert running_mean is not None and running_var is not None
514+
mean = running_mean
515+
invstd = torch.rsqrt(running_var + eps)
516+
517+
broadcast_mask = [1] * input_rank
518+
broadcast_mask[axis] = input_shape[axis]
519+
520+
reduction_axes = []
521+
for i in range(input_rank):
522+
if i != axis:
523+
reduction_axes.append(i)
524+
525+
mean = torch.reshape(mean, broadcast_mask)
526+
norm = 1.0 / num_features
527+
grad_output_sum = torch.sum(grad_out, reduction_axes)
528+
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
529+
530+
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
531+
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
532+
533+
grad_scale = None
534+
if weight is None:
535+
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
536+
else:
537+
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
538+
grad_input = None
539+
if train:
540+
proj = (input - mean) * proj_scale
541+
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
542+
else:
543+
grad_input = grad_out * grad_scale
544+
545+
grad_weight = None
546+
if output_mask[1]:
547+
grad_weight = dot_p * invstd
548+
549+
grad_bias = None
550+
if output_mask[2]:
551+
grad_bias = grad_output_sum
552+
return (grad_input, grad_weight, grad_bias)
553+
554+
500555
@register_decomposition(aten.clamp_min)
501556
def clamp_min(self: Tensor, min: float):
502557
return torch.clamp(self, min=min)

0 commit comments

Comments
 (0)