Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/liger_kernel/ops/group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,17 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
)
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
# `triton_dtype` is the dtype the dW/dB accumulators are cast to before being
# atomic-added into the DW/DB buffers, so it should match those buffers'
# dtype (W.dtype / B.dtype). Mapping every non-fp32 dtype to bfloat16 meant
# that for fp16 inputs a bfloat16-rounded gradient was atomic-added into an
# fp16 buffer, which loses precision unnecessarily. Map fp16 -> fp16.
if X.dtype == torch.float32:
triton_dtype = tl.float32
elif X.dtype == torch.float16:
triton_dtype = tl.float16
else:
triton_dtype = tl.bfloat16

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
_group_norm_backward_kernel[(batch_size, num_groups)](
Expand Down