Skip to content

[Nvidia][Gluon] Refactor convolution fprop kernel, add wgrad and dgrad kernels #10030

Open
bingyizh233 wants to merge 8 commits intotriton-lang:mainfrom
bingyizh233:conv-backward
Open

[Nvidia][Gluon] Refactor convolution fprop kernel, add wgrad and dgrad kernels #10030
bingyizh233 wants to merge 8 commits intotriton-lang:mainfrom
bingyizh233:conv-backward

Conversation

@bingyizh233
Copy link
Copy Markdown
Contributor

Summary

This PR cleans up the Gluon convolution examples and makes the three kernels easier to understand as a family:

  • 02-conv-fprop.py implements forward convolution (fprop)
  • 02-conv-wgrad.py implements weight-gradient computation (wgrad)
  • 02-conv-dgrad.py implements input-gradient computation (dgrad)
  • 02-conv-common.py holds the shared runtime utilities used by all three kernels

All three kernels use the same high-level execution pattern:

  • NHWC / OHWI tensor layouts on the host side
  • TMA-based tiled loads, including im2col descriptors where needed
  • warp-specialized load, mma, and epilogue partitions
  • persistent tile scheduling
  • TMEM-backed fp32 accumulation
  • optional padding / stride fixes to satisfy TMA alignment requirements

Kernel Inputs And Outputs

Forward (conv2d_fprop / conv2dfprop_fixed)

  • conv_fprop is the kernel with the config autotune
  • conv_fprop_fixed is for CI with fixed config

Inputs

  • input_tensor: NHWC activation tensor with shape [N, H, W, Ci]
  • weight_tensor: OHWI filter tensor with shape [Co, R, S, Ci]
  • stride: int or (stride_h, stride_w)
  • padding: int or (pad_h, pad_w)

Output

  • output activation tensor with shape [N, out_h, out_w, Co]

Logical GEMM

  • M = N * out_h * out_w
  • N = Co
  • K = R * S * Ci

Weight Gradient (conv2d_wgrad / conv2d_wgrad_fixed)

  • conv2d_wgrad is the kernel with the config autotune
  • conv2d_wgrad_fixed is for CI with fixed config

Inputs

  • input_nhwc: forward input activation tensor with shape [N, H, W, Ci]
  • grad_output_nhwc: output gradient tensor with shape [N, out_h, out_w, Co]
  • R, S: filter height / width
  • stride: int or (stride_h, stride_w)
  • padding: int or (pad_h, pad_w)

Output

  • weight gradient tensor with shape [Co, R, S, Ci]

Logical GEMM

  • grad_W[Co, R*S*Ci] = grad_out[M, Co]^T @ im2col(input)[M, R*S*Ci]
  • M = N * out_h * out_w

Input Gradient (conv2d_dgrad / conv2d_dgrad_fixed)

  • conv2d_dgrad is the kernel with the config autotune
  • conv2d_dgrad_fixed is for CI with fixed config

Inputs

  • grad_output_nhwc: output gradient tensor with shape [N, out_h, out_w, Co]
  • weight_nhwc: forward weight tensor with shape [Co, R, S, Ci]
  • H_in, W_in: spatial shape of the original input tensor
  • stride: int or (stride_h, stride_w)
  • padding: int or (pad_h, pad_w)

Output

  • input gradient tensor with shape [N, H_in, W_in, Ci]

Logical GEMM

  • dgrad is expressed as a convolution / GEMM over grad_output using rotated weights
  • for stride > 1, the host decomposes the problem into stride-specific subproblems

Block Diagrams

End-to-End Forward / Backward View

image

Performance

conv2d_fprop

Conv2d N=128 Ci=384 Co=384 H=64 W=64 R=3 S=3 stride=1 pad=1:
      kernel  Gluon (autotuned) (TFLOPS)  PyTorch (TFLOPS)
0  autotuned                 1062.950991        991.128692

conv2d_wgrad

Wgrad N=128 Ci=384 Co=384 H=64 W=64 R=3 S=3 stride=1 pad=1:
      kernel  Gluon (autotuned) (TFLOPS)  PyTorch (TFLOPS)
0  autotuned                  1016.68141        806.371836

conv2d_dgrad

Dgrad N=128 Ci=384 Co=384 H=64 W=64 R=3 S=3 stride=1 pad=1:
      kernel  Gluon (autotuned) (TFLOPS)  PyTorch (TFLOPS)
0  autotuned                 1015.414674        867.544942

Acknowledgement

This PR is collaboration with @shangz-ai and gpt-5.4-high.
@shangz-ai helps a lot for forming the ideas and provide constructive feedback on the functionality and performance
gpt-5.4-high helps a lot of the engineering work

@bingyizh233 bingyizh233 requested a review from ptillet as a code owner April 14, 2026 21:48
@bingyizh233
Copy link
Copy Markdown
Contributor Author

Merge main. There are unrelated failures in the CI.

@bingyizh233
Copy link
Copy Markdown
Contributor Author

Fail last time due to the CI issue. Merge main again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants