Skip to content

[Perf][GEMV] Fix coalescing, add cp.async pipeline, expand test coverage#244

Merged
superAngGao merged 1 commit intotile-ai:mainfrom
superAngGao:gemv_perf
Feb 28, 2026
Merged

[Perf][GEMV] Fix coalescing, add cp.async pipeline, expand test coverage#244
superAngGao merged 1 commit intotile-ai:mainfrom
superAngGao:gemv_perf

Conversation

@superAngGao
Copy link
Collaborator

@superAngGao superAngGao commented Feb 28, 2026

Summary

  • Fix warp-level memory coalescing for GEMV B matrix access (O1)
  • Add cp.async pipeline via T.Pipelined for B tile loads (O3)
  • Remove invalid num_stages=0 sentinel (Fix A); eliminate Python closure overhead in forward() (Fix B)
  • Expand test suite with 4 LLM production-scale shapes (Llama-3 70B, DeepSeek-V3 MoE)

Closes #232

Benchmark Results (H200 SXM, 4.8 TB/s peak HBM)

fp16

Shape (n, k) original TileOps H200 util optimized TileOps H200 util torch baseline H200 util vs original vs baseline
(1024, 1024) 0.27 TB/s 6% 0.54 TB/s 11% 0.26 TB/s 5% +100% +108%
(7168, 16384) 2.24 TB/s 47% 3.47 TB/s 72% 3.34 TB/s 70% +55% +4%
(18432, 7168) 2.06 TB/s 43% 3.80 TB/s 79% 3.34 TB/s 70% +84% +14%
(28672, 8192) 2.35 TB/s 49% 4.02 TB/s 84% 3.80 TB/s 79% +71% +6%
(57344, 7168) 2.62 TB/s 55% 4.26 TB/s 89% 3.92 TB/s 82% +63% +9%

bf16

Shape (n, k) original TileOps H200 util optimized TileOps H200 util torch baseline H200 util vs original vs baseline
(7168, 16384) 2.53 TB/s 53% 3.61 TB/s 75% 3.35 TB/s 70% +43% +8%
(18432, 7168) 1.76 TB/s 37% 3.75 TB/s 78% 3.34 TB/s 70% +113% +12%
(28672, 8192) 2.55 TB/s 53% 4.08 TB/s 85% 3.81 TB/s 79% +60% +7%
(57344, 7168) 2.63 TB/s 55% 4.30 TB/s 90% 3.93 TB/s 82% +63% +9%

Autotune best configs (SM90 / H200)

Shape Best config
(7168, 16384) fp16/bf16 block_n=1, reduce_threads=32, num_stages=3
(18432, 7168) fp16 block_n=1, reduce_threads=32, num_stages=3
(18432, 7168) bf16 block_n=2, reduce_threads=32, num_stages=3
(28672, 8192) fp16/bf16 block_n=2, reduce_threads=32, num_stages=3
(57344, 7168) fp16/bf16 block_n=2, reduce_threads=32, num_stages=3

Test plan

  • python -m pytest tests/ops/test_gemv.py -vvs — all 10 cases pass
  • python -m pytest benchmarks/ops/bench_gemv.py -vvs — benchmark results match tables above

Authors

Co-authored-by: @RMLYC (https://github.com/RMLYC)

🤖 Generated with Claude Code

- O1: Swap thread dims to (reduce_threads, block_n) for coalesced B access
- O3: Pipeline B tile loads via T.Pipelined + cp.async (disable_tma=True);
  autotune space: block_n in [1,2,4,8,16], num_stages in [1,2,3]
- Fix A: Remove num_stages=0 sentinel; T.Pipelined requires num_stages>=1
- Fix B: Call self.kernel directly in forward() to eliminate Python closure overhead
- Default config (SM90): block_n=8, reduce_threads=32, num_stages=2
- Tests: add 4 LLM production-scale shapes (Llama-3 70B, DeepSeek-V3 MoE)
- Skills: add tune, tune-multiplication, kernel-debug skill docs

Co-authored-by: RMLYC <RMLYC@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers substantial performance enhancements for GEMV operations by addressing critical memory access patterns and leveraging advanced GPU features like asynchronous data transfers. It also refines the kernel's internal logic and Python integration for efficiency. Furthermore, the changes significantly broaden the testing scope to cover real-world LLM scenarios and introduce detailed documentation to guide future kernel development and optimization efforts.

Highlights

  • Performance Optimization for GEMV: Implemented significant performance improvements for the General Matrix-Vector (GEMV) multiplication kernel, achieving up to 113% speedup over original TileOps and up to 14% over PyTorch baseline for various shapes on H200 SXM GPUs.
  • Memory Coalescing Fix: Corrected warp-level memory coalescing for the B matrix access in GEMV, by adjusting thread binding to ensure consecutive warp threads access consecutive memory addresses, which was a major bottleneck (P1 in tuning guide).
  • Asynchronous Pipelined Loads: Integrated a cp.async pipeline via T.Pipelined for B tile loads, utilizing shared memory to hide HBM3e latency and improve memory utilization (O3 optimization).
  • Refactored Pipeline Sentinel and Python Overhead: Eliminated the semantically invalid num_stages=0 sentinel for T.Pipelined and removed Python closure overhead in the forward() method by directly calling the JIT-compiled kernel, improving wall-clock performance (P2 and P5 in tuning guide).
  • Expanded Test Coverage: Extended the GEMV test suite to include large-scale LLM production shapes (Llama-3 70B, DeepSeek-V3 MoE) to validate performance and correctness on realistic workloads.
  • New Documentation for Kernel Debugging and Tuning: Added comprehensive guides for TileLang GPU kernel debugging, matrix multiplication kernel tuning, and general GPU kernel profiling methodology, documenting lessons learned and best practices.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • .claude/skills/kernel-debug/skill.md
    • Added a new skill document detailing debugging patterns and pitfalls for TileLang GPU kernels, including correctness triage, JIT cache issues, and pipeline considerations.
  • .claude/skills/tune-multiplication/skill.md
    • Added a new skill document outlining optimization patterns and lessons learned for matrix-related operators in TileOps, with a detailed case study on GEMV tuning.
  • .claude/skills/tune/skill.md
    • Added a new skill document describing the GPU kernel profiling and performance tuning methodology for TileOps, focusing on benchmarking, nsys analysis, and bandwidth calculation.
  • benchmarks/ops/bench_gemv.py
    • Updated imports to include torch, pytest, and GemvOp.
    • Added test_gemv_bench function to run benchmarks for GEMV operations, profiling both TileOps and baseline implementations.
  • tests/ops/test_gemv.py
    • Expanded the GemvFixture with new float16 and bfloat16 test cases for LLM production-scale shapes (Llama-3 70B, DeepSeek-V3 MoE).
  • tileops/kernels/gemv/gemv.py
    • Removed the itertools import.
    • Modified the _gemv_func signature to include a num_stages parameter.
    • Refactored _gemv_main to use T.Pipelined for B matrix loads with cp.async and shared memory (b_shared), and adjusted thread binding for improved memory coalescing.
    • Updated the _gemv_wrapped_kernel signature to accept the new num_stages parameter.
    • Adjusted the GemvOp's default_config and autotune_configs to incorporate num_stages and reflect the new tuning strategies for SM90.
    • Modified the GemvOp.forward method to directly call the JIT-compiled kernel, bypassing Python closure overhead.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@superAngGao superAngGao requested a review from zhen8838 February 28, 2026 07:51
@superAngGao superAngGao added enhancement Improvement to existing functionality documentation Documentation improvements all ai powered Issue or Pull request all finished by AI labels Feb 28, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request delivers substantial performance enhancements for the GEMV kernel by addressing memory coalescing issues, introducing a cp.async pipeline, and optimizing host-side dispatch overhead. The changes are well-supported by an expanded test suite with relevant large-scale shapes and a more effective autotuning configuration. A standout feature of this PR is the excellent new documentation that captures the debugging and tuning journey, which will be highly valuable for the team. My review is positive, with only one minor suggestion to correct a unit typo (GB/s instead of TB/s) in some of the new documentation's benchmark tables to ensure clarity.

zhen8838
zhen8838 previously approved these changes Feb 28, 2026
@superAngGao superAngGao merged commit fd3af14 into tile-ai:main Feb 28, 2026
25 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

all ai powered Issue or Pull request all finished by AI documentation Documentation improvements enhancement Improvement to existing functionality

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PERF][GEMV] fix uncoalesced memory access and add H200-specific optimizations

3 participants