Skip to content

rishi-more-2003/vllm-kernel-profiler

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

vLLM Kernel Profiler

A kernel-level profiling and microbenchmarking suite for diagnosing decode latency, KV-cache memory pressure, and GPU bottlenecks in vLLM-style LLM inference.

Python PyTorch CUDA Triton License

Rishi More

Quickstart · Results · Nsight Workflow · Architecture


TL;DR

A kernel-level profiling suite that pinpoints where decode latency comes from in vLLM-style inference and when the bottleneck shifts between compute, memory bandwidth, KV-cache, sampling, and scheduling. Custom Triton kernels (RMSNorm, fused residual+RMSNorm, sampling, decode attention) reach 14–49× speedups over PyTorch baselines, and a bottleneck classifier labels each workload as memory_bound, compute_bound, or scheduler_bound. All figures below are measured on an RTX 4070 Laptop GPU with Qwen2.5-0.5B-Instruct — local results, not universal claims.


Overview

Modern LLM serving has two dominant runtime phases:

  • Prefill — processes prompt tokens in bulk; often compute-heavy.
  • Decode — generates one token at a time; often memory-bound due to repeated KV-cache reads.

This project is built to answer a concrete systems question:

During LLM inference, where does latency come from, and when does the bottleneck shift between compute, memory bandwidth, KV-cache, sampling, and scheduling?

It pairs an end-to-end inference benchmarking harness (Hugging Face + optional vLLM backends) with kernel-level microbenchmarks, Triton autotuning, Nsight integration, and an automated bottleneck classifier that turns raw timings into a readable diagnosis.


Core Capabilities

  • Workload benchmarkingshort_chat, long_prefill, decode_heavy, mixed_length, shared_prefix
  • Backends — Hugging Face (transformers + PyTorch) fallback, plus an optional vLLM backend with a graceful install hint
  • Triton kernels — RMSNorm, fused residual + RMSNorm, sampling (temperature + top-k), and a simplified decode-attention kernel for KV-cache access pressure, each benchmarked against a PyTorch baseline
  • Autotuning — Triton launch-parameter sweep
  • Profiling — Nsight diagnostics + launch wrappers (ncu, nsys) with tool auto-detection
  • Analysis — bottleneck classifier (memory_bound, compute_bound, scheduler_bound, …)
  • Reporting — Markdown + HTML report generation with plot artifacts

Architecture

End-to-end flow from a workload spec to a diagnosis:

   workload spec                backend run                  microkernels
 ┌───────────────┐         ┌───────────────────┐         ┌──────────────────┐
 │ short_chat    │         │  HF (transformers │         │ rmsnorm          │
 │ long_prefill  │  ─────▶ │  + PyTorch)       │  ─────▶ │ fused_rmsnorm    │
 │ decode_heavy  │         │  or vLLM backend  │         │ sampling         │
 │ mixed_length  │         └─────────┬─────────┘         │ decode_attention │
 │ shared_prefix │                   │                   └────────┬─────────┘
 └───────────────┘                   │                            │
                                      ▼                            ▼
                            ┌───────────────────┐       ┌──────────────────┐
                            │ Nsight profiling  │       │ Triton autotune  │
                            │ (ncu / nsys)      │       │ launch-param     │
                            └─────────┬─────────┘       │ sweep            │
                                      │                 └────────┬─────────┘
                                      ▼                          │
                            ┌───────────────────────────────────▼─────────┐
                            │ analysis: bottleneck classifier              │
                            │ memory_bound / compute_bound / scheduler_... │
                            └───────────────────────┬──────────────────────┘
                                                     ▼
                                        Markdown + HTML report + plots

Package layout:

vllm_kernel_profiler/
  workloads/      # workload schemas + generators
  inference/      # hf and optional vllm runners
  kernels/        # microkernel implementations + benchmarks
  autotune/       # Triton parameter sweep
  analysis/       # classifier, plots, report builder
  profiling/      # Nsight tool detection + launch wrappers
  experiments/    # full-sweep orchestration
  utils/          # device/env/io/logging helpers

Getting Started

Prerequisites

  • Python 3.10+
  • An NVIDIA GPU with a working CUDA toolchain (results below use CUDA 12.4 / PyTorch 2.6)
  • (Optional) vLLM for the vLLM backend — the harness prints an install hint and falls back to Hugging Face if it is missing
  • (Optional) NVIDIA Nsight Compute (ncu) and Nsight Systems (nsys) for deep profiling

Installation

git clone https://github.com/rishi-more-2003/vllm-kernel-profiler.git
cd vllm-kernel-profiler

python -m venv .venv
source .venv/bin/activate      # Linux / macOS
# .venv\Scripts\activate       # Windows

pip install -e .

WSL vLLM Setup (Recommended on Windows)

For vllm runs on Windows hosts, use WSL Ubuntu:

wsl -d Ubuntu -- bash -lc "
  cd /mnt/c/Data/GithubRepository/vllm-kernel-profiler &&
  uv python install 3.12 &&
  uv venv --python 3.12 .venv-wsl &&
  source .venv-wsl/bin/activate &&
  uv pip install --index-url https://download.pytorch.org/whl/cu124 torch==2.6.0 &&
  uv pip install -e . &&
  uv pip install vllm
"

Quickstart

# Microbenchmark a single kernel
python -m vllm_kernel_profiler.kernels.benchmark --kernel rmsnorm

# Run an inference workload on the Hugging Face backend
python -m vllm_kernel_profiler.run --backend hf --workload decode_heavy

# Build a report from raw benchmark logs
python -m vllm_kernel_profiler.report --input results/raw/benchmark_runs.jsonl

One-command full sweep:

python -m vllm_kernel_profiler.full_profile --model Qwen/Qwen2.5-0.5B-Instruct

Nsight environment diagnostics:

python scripts/nsight_doctor.py

Usage

CLI reference (click to expand)
# Inference sweeps
python -m vllm_kernel_profiler.run --backend hf --workload decode_heavy \
    --model Qwen/Qwen2.5-0.5B-Instruct --batch-size 4 --prompt-len 2048 --decode-len 256

python -m vllm_kernel_profiler.run --backend vllm --workload mixed_length \
    --model Qwen/Qwen2.5-0.5B-Instruct --batch-size 8 \
    --vllm-gpu-memory-utilization 0.70 --vllm-max-model-len 4096

# WSL-safe vLLM run (disables FlashInfer sampler JIT path)
VLLM_USE_FLASHINFER_SAMPLER=0 python -m vllm_kernel_profiler.run --backend vllm \
    --workload decode_heavy --model Qwen/Qwen2.5-0.5B-Instruct --batch-size 1 \
    --prompt-len 256 --decode-len 64 --vllm-gpu-memory-utilization 0.70 \
    --vllm-max-model-len 4096 --output results/raw/benchmark_runs_wsl_vllm.jsonl

# Kernel microbenchmarks
python -m vllm_kernel_profiler.kernels.benchmark --kernel rmsnorm \
    --batch-size 32 --hidden-size 4096

# Triton launch-parameter autotuning
python -m vllm_kernel_profiler.autotune --kernel rmsnorm

# Report generation
python -m vllm_kernel_profiler.report --input results/raw/benchmark_runs.jsonl \
    --output results/reports/summary.md

Results (This Machine)

Runtime

GPU NVIDIA GeForce RTX 4070 Laptop GPU
Torch 2.6.0+cu124
Model Qwen/Qwen2.5-0.5B-Instruct

Inference sweep

workload batch prompt_len decode_len total_latency_ms ttft_ms gen_tok/s
short_chat 2 64 32 2199.93 706.51 41.67
long_prefill 2 2048 32 1628.90 528.27 56.20
decode_heavy 2 512 256 9642.71 391.65 55.12
mixed_length 3 512 128 5634.55 451.53 73.51
shared_prefix 3 768 128 5287.72 470.29 78.71

vLLM backend runs (WSL)

workload backend batch prompt_len decode_len total_latency_ms ttft_ms gen_tok/s
decode_heavy vllm 1 256 64 3214.25 1335.92 33.19
mixed_length vllm 2 256 64 3288.28 1341.00 64.88

Kernel microbenchmarks (Triton vs. PyTorch)

kernel latency_ms (torch) latency_ms (triton) speedup
rmsnorm 2.0526 0.1284 15.99×
fused_rmsnorm 2.2295 0.1618 13.78×
sampling 6.7986 0.4332 15.69×
decode_attention 13.6281 0.2806 48.57×

Note. These are measured local results on a single laptop GPU, not universal claims. The decode_attention kernel is a simplified benchmark of KV-cache access pressure in isolation, so its speedup reflects the microbenchmark, not end-to-end serving throughput.


Nsight Workflow

See docs/profiling_with_nsight.md for deep guidance. Typical commands:

ncu --set full python scripts/run_kernel_bench.py --kernel rmsnorm
nsys profile python scripts/run_inference_bench.py --backend vllm --workload decode_heavy

Also supported:

  • python scripts/nsight_doctor.py — tool detection + remediation hints
  • auto-resolved absolute paths for Nsight tools on Windows installs

Reproducibility Artifacts

Artifact Path
Inference benchmark log results/raw/benchmark_runs.jsonl
Kernel benchmark log results/raw/kernel_runs.jsonl
Nsight Compute report results/raw/ncu_rmsnorm.ncu-rep
Nsight Systems report results/raw/nsys_decode_heavy.nsys-rep
Autotune configs results/autotune/best_configs.json
Markdown report results/reports/summary.md
HTML report results/reports/summary.html
Plots results/plots/

Project Structure

Full tree (click to expand)
vllm-kernel-profiler/
├── vllm_kernel_profiler/
│   ├── __init__.py
│   ├── run.py                 # inference benchmark entry point
│   ├── full_profile.py        # one-command full sweep
│   ├── workloads/             # workload schemas + generators
│   ├── inference/             # hf and optional vllm runners
│   ├── kernels/               # microkernel implementations + benchmarks
│   ├── autotune/              # Triton parameter sweep
│   ├── analysis/              # classifier, plots, report builder
│   ├── profiling/             # Nsight tool detection + launch wrappers
│   ├── experiments/           # full-sweep orchestration
│   └── utils/                 # device/env/io/logging helpers
│
├── scripts/
│   ├── run_inference_bench.py
│   ├── run_kernel_bench.py
│   ├── run_full_profile.py
│   ├── generate_report.py
│   └── nsight_doctor.py
│
├── configs/
│   ├── default.yaml
│   ├── workloads.yaml
│   ├── kernels.yaml
│   └── experiments.yaml
│
├── docs/
├── tests/
└── results/

Citation

If you find this work useful, please cite:

@misc{more2026vllmkernelprofiler,
  title  = {vLLM Kernel Profiler: Kernel-Level Profiling for Decode Latency,
            KV-Cache Pressure, and GPU Bottlenecks in LLM Inference},
  author = {More, Rishi},
  year   = {2026},
  url    = {https://github.com/rishi-more-2003/vllm-kernel-profiler}
}

Acknowledgements

Built with Triton, PyTorch, Hugging Face Transformers, vLLM, and NVIDIA Nsight Compute / Systems. Benchmarks were run on an NVIDIA GeForce RTX 4070 Laptop GPU.

Released under the MIT License.

Made with care · 2026

About

Kernel-level profiling suite for vLLM inference, focused on decode-time latency, KV-cache access patterns, Triton kernels, and CUDA performance metrics.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages