Skip to content

[Feat][Conv2d] Add conv2d op with 1x1 optimization and benchmarks#464

Draft
RMLYC wants to merge 2 commits intotile-ai:mainfrom
RMLYC:feat/conv2d-op-434
Draft

[Feat][Conv2d] Add conv2d op with 1x1 optimization and benchmarks#464
RMLYC wants to merge 2 commits intotile-ai:mainfrom
RMLYC:feat/conv2d-op-434

Conversation

@RMLYC
Copy link
Collaborator

@RMLYC RMLYC commented Mar 12, 2026

Closes #434

Summary

  • add the Conv2d op with dedicated 1x1 handling, benchmark coverage, and correctness tests
  • route 1x1, stride=1, padding=0 through a GemmKernel fast path while keeping non-1x1 on the existing im2col + gemm path
  • add a dedicated conv2d benchmark workdir and benchmark ordering/reporting updates for performance analysis

Test plan

  • pre-commit passed during commit creation
  • python -m compileall tileops/ops/conv2d.py tileops/kernels/conv2d/pointwise.py tests/ops/test_conv2d.py tileops/kernels/conv2d/__init__.py tileops/kernels/__init__.py
  • PYTHONPATH="$PWD" python -m pytest -v tests/ops/test_conv2d.py -k 'dispatches_to_pointwise_kernel or dispatches_to_gemm_kernel or 1x1 and test_conv2d'
  • BENCHMARK_REPORT_PATH=workdir_conv2d/profile_run.log python -m pytest -q benchmarks/ops/bench_conv2d.py

Benchmark

Measured on NVIDIA H200, torch 2.9.0+cu128, CUDA 12.8.

Case TileOps Baseline Summary
32x28x28 -> 64, k=1x1, s=2, p=1, bias, fp16 0.03 ms, 0.03 TF, 0.00 TB/s 0.05 ms, 0.02 TF, 0.00 TB/s TileOps faster
64x56x56 -> 256, k=1x1, s=1, p=0, fp16 0.02 ms, 4.72 TF, 0.09 TB/s 0.00 ms, 24.10 TF, 0.48 TB/s Baseline still better
64x56x56 -> 256, k=1x1, s=2, p=0, fp16 0.04 ms, 0.65 TF, 0.02 TB/s 0.02 ms, 1.24 TF, 0.04 TB/s TileOps slower
256x56x56 -> 512, k=1x1, s=1, p=0, bias, fp16 0.04 ms, 18.57 TF, 0.11 TB/s 0.04 ms, 19.52 TF, 0.12 TB/s Nearly tied
128x112x112 -> 512, k=1x1, s=1, p=0, bf16 0.08 ms, 21.38 TF, 0.21 TB/s 0.03 ms, 48.75 TF, 0.48 TB/s TileOps slower
64x56x56 -> 64, k=3x3, s=1, p=1, fp16 0.02 ms, 9.43 TF, 0.04 TB/s 0.04 ms, 5.38 TF, 0.02 TB/s TileOps faster
512x56x56 -> 512, k=3x3, s=1, p=1, bf16 0.12 ms, 119.15 TF, 0.09 TB/s 0.23 ms, 64.17 TF, 0.05 TB/s TileOps faster
128x112x112 -> 256, k=3x3, s=2, p=1, bf16 0.04 ms, 43.01 TF, 0.13 TB/s 0.08 ms, 24.59 TF, 0.07 TB/s TileOps faster
64x224x224 -> 128, k=5x5, s=2, p=2, fp16 0.09 ms, 60.23 TF, 0.12 TB/s 0.16 ms, 32.21 TF, 0.06 TB/s TileOps faster

Additional context

  • benchmark artifacts and NCU notes are stored under workdir_conv2d/
  • current sandbox session later hit a CUDA initialization Error 304 while trying to rerun GPU validation, so the benchmark table above reflects the last successful CUDA-visible run
  • scripts/validate.sh is not present in this checkout, so the repository PR validation script could not be run here

@github-actions github-actions bot added the feature New feature or new operator label Mar 12, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the tileops library by introducing a new Conv2d operator. The primary goal is to improve performance for convolutional neural networks, particularly by optimizing 1x1 convolutions through a dedicated GEMM-based fast path. For more general convolution types, it leverages an im2col approach. The changes include new kernel implementations, comprehensive testing, and a robust benchmarking framework to validate performance gains and ensure numerical correctness.

Highlights

  • New Conv2d Operation: Introduced a Conv2dOp with specialized handling for 1x1 convolutions.
  • Optimized 1x1 Path: Implemented a fast path using GemmKernel for 1x1 convolutions with stride 1 and no padding.
  • General Conv2d Path: Non-1x1 convolutions are routed through the existing im2col + gemm approach.
  • Comprehensive Benchmarking: Added dedicated benchmarks for Conv2d to analyze performance and compare against baselines.
  • Correctness Testing: Included extensive correctness tests to validate the new Conv2d operation and its kernel dispatch logic.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/conftest.py
    • Imported os and pathlib.Path.
    • Modified pytest_sessionfinish to dynamically set the benchmark report path using an environment variable, creating parent directories if needed.
  • benchmarks/ops/bench_conv2d.py
    • Added a new file for Conv2d benchmarks.
    • Defined Conv2dBenchmarkFixture with various convolution parameters (1x1, 3x3, 5x5, different strides, paddings, dtypes).
    • Implemented Conv2dBenchmark to calculate FLOPs and memory usage for Conv2d operations.
    • Included test_conv2d_bench to profile both the TileOps Conv2dOp and a PyTorch baseline.
  • tests/ops/test_conv2d.py
    • Added a new file for Conv2d correctness tests.
    • Defined Conv2dFixture with various convolution parameters for testing.
    • Implemented Conv2dTest to generate inputs and define a reference PyTorch conv2d program.
    • Included test_conv2d to check numerical correctness of Conv2dOp against the reference.
    • Added specific tests (test_conv2d_1x1_dispatches_to_pointwise_kernel, test_conv2d_1x1_s1_p0_dispatches_to_gemm_kernel, test_conv2d_non_1x1_dispatches_to_im2col_kernel) to verify kernel dispatch logic.
  • tileops/kernels/init.py
    • Imported Conv2dIm2ColKernel and PointwiseConvKernel from tileops.kernels.conv2d.
    • Added these new kernels to the __all__ export list.
  • tileops/kernels/conv2d/init.py
    • Added a new directory and __init__.py file.
    • Imported Conv2dIm2ColKernel and PointwiseConvKernel to expose them.
  • tileops/kernels/conv2d/im2col.py
    • Added a new file implementing Conv2dIm2ColKernel.
    • Contains _conv2d_im2col_kernel using tilelang.jit for the im2col transformation.
    • Includes _conv2d_im2col_wrapped_kernel as a custom PyTorch operator and its fake registration.
    • Defines Conv2dIm2ColKernel class with configuration and autotuning options.
  • tileops/kernels/conv2d/pointwise.py
    • Added a new file implementing PointwiseConvKernel.
    • Contains _pointwise_conv_kernel using tilelang.jit for optimized 1x1 pointwise convolution.
    • Includes _pointwise_conv_wrapped_kernel as a custom PyTorch operator and its fake registration.
    • Defines PointwiseConvKernel class with configuration and autotuning options.
  • tileops/ops/init.py
    • Imported Conv2dOp from tileops.ops.conv2d.
    • Added Conv2dOp to the __all__ export list.
  • tileops/ops/conv2d.py
    • Added a new file defining the Conv2dOp class.
    • Includes helper function _pair for kernel, stride, and padding sizes.
    • Implements __init__ to determine convolution parameters and dispatch to appropriate kernels (GemmKernel, PointwiseConvKernel, Conv2dIm2ColKernel).
    • Defines default_kernel_map to associate kernel names with their classes.
    • Implements the forward method, handling input validation, data reshaping, and calling the selected kernel.
Activity
  • The author, RMLYC, has implemented a new Conv2d operator with specific optimizations for 1x1 convolutions.
  • The pull request includes new benchmark files (benchmarks/ops/bench_conv2d.py) and test files (tests/ops/test_conv2d.py) to ensure the functionality and performance of the new operator.
  • Benchmark results on NVIDIA H200 are provided, demonstrating performance comparisons between TileOps and a baseline for various convolution configurations.
  • The author noted a CUDA initialization error during re-validation but confirmed the provided benchmark table reflects successful runs.
  • The PR addresses and closes issue [New Op] Conv2d #434.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Conv2dOp with an optimized path for 1x1 convolutions, along with corresponding correctness tests and performance benchmarks. The implementation correctly dispatches to a faster GEMM-based kernel for 1x1 convolutions with stride 1 and no padding, while using a dedicated implicit-GEMM kernel for other 1x1 cases and an im2col approach for general convolutions. The changes are well-structured and include thorough testing. My feedback includes a suggestion to refactor some duplicated code in the new benchmark file to improve maintainability.

Comment on lines +72 to +101
def calculate_flops(self) -> Optional[float]:
test = self.test
kernel_h, kernel_w = test.kernel_size if isinstance(test.kernel_size, tuple) else (
test.kernel_size, test.kernel_size)
stride_h, stride_w = test.stride if isinstance(test.stride, tuple) else (
test.stride, test.stride)
pad_h, pad_w = test.padding if isinstance(test.padding, tuple) else (
test.padding, test.padding)
out_h = (test.h + 2 * pad_h - kernel_h) // stride_h + 1
out_w = (test.w + 2 * pad_w - kernel_w) // stride_w + 1
return 2.0 * test.n * test.c_out * out_h * out_w * test.c_in * kernel_h * kernel_w

def calculate_memory(self) -> Optional[float]:
test = self.test
kernel_h, kernel_w = test.kernel_size if isinstance(test.kernel_size, tuple) else (
test.kernel_size, test.kernel_size)
stride_h, stride_w = test.stride if isinstance(test.stride, tuple) else (
test.stride, test.stride)
pad_h, pad_w = test.padding if isinstance(test.padding, tuple) else (
test.padding, test.padding)
out_h = (test.h + 2 * pad_h - kernel_h) // stride_h + 1
out_w = (test.w + 2 * pad_w - kernel_w) // stride_w + 1
bias_elems = test.c_out if test.bias else 0
total_elems = (
test.n * test.c_in * test.h * test.w
+ test.c_out * test.c_in * kernel_h * kernel_w
+ bias_elems
+ test.n * test.c_out * out_h * out_w
)
return total_elems * test.dtype.itemsize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's duplicated logic for calculating convolution parameters like kernel_h, kernel_w, out_h, and out_w in both calculate_flops and calculate_memory. This can be refactored into a helper method to improve maintainability and reduce redundancy.

    def _get_derived_params(self) -> Tuple[int, int, int, int]:
        test = self.test
        kernel_h, kernel_w = test.kernel_size if isinstance(test.kernel_size, tuple) else (
            test.kernel_size, test.kernel_size)
        stride_h, stride_w = test.stride if isinstance(test.stride, tuple) else (
            test.stride, test.stride)
        pad_h, pad_w = test.padding if isinstance(test.padding, tuple) else (
            test.padding, test.padding)
        out_h = (test.h + 2 * pad_h - kernel_h) // stride_h + 1
        out_w = (test.w + 2 * pad_w - kernel_w) // stride_w + 1
        return kernel_h, kernel_w, out_h, out_w

    def calculate_flops(self) -> Optional[float]:
        test = self.test
        kernel_h, kernel_w, out_h, out_w = self._get_derived_params()
        return 2.0 * test.n * test.c_out * out_h * out_w * test.c_in * kernel_h * kernel_w

    def calculate_memory(self) -> Optional[float]:
        test = self.test
        kernel_h, kernel_w, out_h, out_w = self._get_derived_params()
        bias_elems = test.c_out if test.bias else 0
        total_elems = (
            test.n * test.c_in * test.h * test.w
            + test.c_out * test.c_in * kernel_h * kernel_w
            + bias_elems
            + test.n * test.c_out * out_h * out_w
        )
        return total_elems * test.dtype.itemsize

@RMLYC RMLYC added all-ai-powered Produced entirely by automated contributors perf Performance improvements labels Mar 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

all-ai-powered Produced entirely by automated contributors feature New feature or new operator perf Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[New Op] Conv2d

1 participant