[Nvidia][Gluon] Refactor convolution fprop kernel, add wgrad and dgrad kernels #10030
Open
bingyizh233 wants to merge 8 commits intotriton-lang:mainfrom
Open
[Nvidia][Gluon] Refactor convolution fprop kernel, add wgrad and dgrad kernels #10030bingyizh233 wants to merge 8 commits intotriton-lang:mainfrom
bingyizh233 wants to merge 8 commits intotriton-lang:mainfrom
Conversation
Contributor
Author
|
Merge main. There are unrelated failures in the CI. |
Contributor
Author
|
Fail last time due to the CI issue. Merge main again. |
ThomasRaoux
approved these changes
Apr 16, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR cleans up the Gluon convolution examples and makes the three kernels easier to understand as a family:
02-conv-fprop.pyimplements forward convolution (fprop)02-conv-wgrad.pyimplements weight-gradient computation (wgrad)02-conv-dgrad.pyimplements input-gradient computation (dgrad)02-conv-common.pyholds the shared runtime utilities used by all three kernelsAll three kernels use the same high-level execution pattern:
load,mma, andepiloguepartitionsKernel Inputs And Outputs
Forward (
conv2d_fprop/conv2dfprop_fixed)conv_fpropis the kernel with the config autotuneconv_fprop_fixedis for CI with fixed configInputs
input_tensor: NHWC activation tensor with shape[N, H, W, Ci]weight_tensor: OHWI filter tensor with shape[Co, R, S, Ci]stride: int or(stride_h, stride_w)padding: int or(pad_h, pad_w)Output
[N, out_h, out_w, Co]Logical GEMM
M = N * out_h * out_wN = CoK = R * S * CiWeight Gradient (
conv2d_wgrad/conv2d_wgrad_fixed)conv2d_wgradis the kernel with the config autotuneconv2d_wgrad_fixedis for CI with fixed configInputs
input_nhwc: forward input activation tensor with shape[N, H, W, Ci]grad_output_nhwc: output gradient tensor with shape[N, out_h, out_w, Co]R,S: filter height / widthstride: int or(stride_h, stride_w)padding: int or(pad_h, pad_w)Output
[Co, R, S, Ci]Logical GEMM
grad_W[Co, R*S*Ci] = grad_out[M, Co]^T @ im2col(input)[M, R*S*Ci]M = N * out_h * out_wInput Gradient (
conv2d_dgrad/conv2d_dgrad_fixed)conv2d_dgradis the kernel with the config autotuneconv2d_dgrad_fixedis for CI with fixed configInputs
grad_output_nhwc: output gradient tensor with shape[N, out_h, out_w, Co]weight_nhwc: forward weight tensor with shape[Co, R, S, Ci]H_in,W_in: spatial shape of the original input tensorstride: int or(stride_h, stride_w)padding: int or(pad_h, pad_w)Output
[N, H_in, W_in, Ci]Logical GEMM
grad_outputusing rotated weightsBlock Diagrams
End-to-End Forward / Backward View
Performance
conv2d_fprop
conv2d_wgrad
conv2d_dgrad
Acknowledgement
This PR is collaboration with @shangz-ai and gpt-5.4-high.
@shangz-ai helps a lot for forming the ideas and provide constructive feedback on the functionality and performance
gpt-5.4-high helps a lot of the engineering work