High-performance GPU implementation of fused matrix multiplication + sampling using Triton. This package provides an efficient kernel for sampling from categorical distributions where logits are computed on-the-fly from matrix multiplication, avoiding the need to materialize the full logit tensor in GPU main memory (GMEM). The key insight is that in LLM decode workloads, both the matmul and the sampling are memory-bound (the matmul collapses to a matrix-vector product). By fusing both operations, we avoid round-trips to GPU main memory (GMEM) and speed up the sampling process.
- Bandwidth-Efficient: Fuses matrix multiplication and sampling into a single Triton kernel, avoiding materialization of intermediate logit tensors, and preventing round-trips to GMEM.
- Exact: Uses Gumbel-max trick for efficient categorical sampling. No approximations.
- Flexible: Supports temperature scaling and multiple samples per hidden state vector.
# Clone the repository
git clone https://github.com/tomasruizt/fused-mm-sample.git
cd fused-mm-sample
# Install the package (assumes you're in a virtual environment)
uv pip install -e ".[dev]"
# Verify installation
python examples/basic_usage.pyFor a complete working example, see examples/basic_usage.py.
The basic usage pattern:
from fused_mm_sampling import fused_mm_sample_triton
samples = fused_mm_sample_triton(
weights=weights, # [vocab_size, hidden_size]
hidden_states=hidden_states, # [n_hidden_states, hidden_size]
num_samples=1,
temperature=torch.tensor(1.0, device="cuda"), # scalar (0-d) CUDA tensor
seed=42 # Optional: for reproducibility
)
# Returns: [n_hidden_states, num_samples]Parameters:
weights(Tensor): Weight matrix of shape[vocab_size, hidden_size]hidden_states(Tensor): Hidden states of shape[n_hidden_states, hidden_size]num_samples(int): Number of samples to draw per sequence positiontemperature(Tensor): Scalar (0-d) CUDA tensor for temperature scaling (higher = more random)seed(int, optional): Random seed for reproducibility
Returns: Tensor of shape [n_hidden_states, num_samples] containing sampled indices
Check out the branch feature/fmms-sampler from my vLLM fork and install vLLM from local sources. The code diff is minimal.
git clone https://github.com/tomasruizt/vllm.git
cd vllm
git checkout feature/fmms-sampler
VLLM_USE_PRECOMPILED=1 uv pip install -e .Then launch any model. Use the flags below to activate the FMMS sampler:
VLLM_USE_FMMS_SAMPLER=1 VLLM_FMMS_PROVIDER=fused-triton vllm serve Qwen/Qwen3-1.7BThe FMMS kernel implements the Gumbel-max trick for categorical sampling:
- Matrix Multiplication: Compute a tile of logits = hidden_states @ weights in SRAM
- Temperature Scaling: Scale logits by temperature
- Gumbel Noise: Add Gumbel noise to scaled logits tile
- Argmax: Take argmax within the tile to get samples
The FMMS kernel computes these steps in blocks without materializing the full logit tensor, preventing memory accesses, and relieving the bottleneck on the memory bandwidth.
Kernel microbenchmarks across B300, B200, H200, and H100 GPUs, roofline analysis, and end-to-end vLLM integration results are in the blog post.
# Benchmark all implementations
python speed_test.py
# Compare performance over many batch sizes
make triton-benchmark
# Run all microbenchmarks on Modal (B300, B200, H200, H100)
make modal-triton-benchmark-all-gpusAll profiling scripts are located in the benchmarking/ directory.
cd benchmarking
make profile-memThis will generate a memory snapshot and HTML visualization in benchmarking/memory/.
cd benchmarking
# Profile fused Triton kernel
make ncu-profile-fused-triton
# Profile naive compiled implementation
make ncu-profile-naive-compiledcd benchmarking
# Profile fused Triton kernel
make nsight-profile-fused-triton
# Profile naive compiled implementation
make nsight-profile-naive-compiledThe dev dependencies permit running the scripts in the benchmarking/ directory. To install them, run:
uv pip install -e ".[dev]"The experiments involving many differnt GPUs were run on Modal. To install and login to Modal:
uv pip install modal
modal setupRun the speed-test on modal:
make modal-speed-testMIT License - see LICENSE file for details
Contributions are welcome! Please feel free to create an issue or submit a pull request.
