diff --git a/.claude-plugin/marketplace.json b/.claude-plugin/marketplace.json index 646ce2a..4d74b2c 100644 --- a/.claude-plugin/marketplace.json +++ b/.claude-plugin/marketplace.json @@ -8,36 +8,12 @@ "version": "0.1.0" }, "plugins": [ - { - "name": "aiter-reflection", - "source": "./skills/aiter-reflection", - "skills": "./", - "description": "This skill should be used when optimizing AMD GPU kernels on MI300 using the aiter project, including running op tests, benchmarking, iterating on kernel changes, and recording results in the kernel experiment database." - }, { "name": "apu-memory-tuner", "source": "./skills/apu-memory-tuner", "skills": "./", "description": "Inspect and tune the shared-vs-dedicated memory split (GTT / UMA Frame Buffer) on AMD Ryzen APUs so larger LLMs and image models fit on the iGPU." }, - { - "name": "gpu-architecture-fundamentals", - "source": "./skills/gpu-architecture-fundamentals", - "skills": "./", - "description": "This skill should be used when reasoning about GPU architecture fundamentals to guide kernel optimization choices such as memory hierarchy usage, execution model mapping, block sizing, and latency-aware tuning across HIP, Triton, and PyTorch." - }, - { - "name": "hip-kernel-optimization", - "source": "./skills/hip-kernel-optimization", - "skills": "./", - "description": "This skill should be used when writing or tuning HIP kernels on AMD/NVIDIA GPUs, covering memory coalescing, shared-memory tiling, bank conflict avoidance, warp primitives, occupancy, vectorization, async ops, loop unrolling, and profiling." - }, - { - "name": "kernel-exp-history", - "source": "./skills/kernel-exp-history", - "skills": "./", - "description": "This skill should be used when optimizing kernels in this repo and needing to consult past optimization experiments, or when recording the current optimization iteration back into the kernel experiment database." - }, { "name": "local-ai-app-integration", "source": "./skills/local-ai-app-integration", @@ -56,47 +32,11 @@ "skills": "./", "description": "Performs GPU kernel correctness and performance evaluation and LLM inference benchmarking with Magpie. Analyzes single or multiple kernels (HIP/CUDA/PyTorch), compares kernel implementations, runs vLLM/SGLang benchmarks with profiling and TraceLens, and runs gap analysis on torch traces." }, - { - "name": "mi300-hip-programming-insights", - "source": "./skills/mi300-hip-programming-insights", - "skills": "./", - "description": "CDNA3/MI300 HIP programming insights—chiplet/cache model, Infinity Cache, memory coherency, matrix cores, sparsity, and best practices." - }, - { - "name": "pytorch-kernel-optimization", - "source": "./skills/pytorch-kernel-optimization", - "skills": "./", - "description": "This skill should be used when optimizing PyTorch models and kernels, including efficient tensor operations, torch.compile, custom autograd/CUDA/Triton extensions, mixed precision, memory and data pipeline tuning, model optimization techniques, CUDA graphs, and profiling." - }, { "name": "rocm-doctor", "source": "./skills/rocm-doctor", "skills": "./", "description": "Diagnose why ROCm, PyTorch, or llama.cpp isn't working on an AMD GPU. Matches the symptom against a fixed list of twelve known misconfigurations and proposes the next step." - }, - { - "name": "rocprof-compute", - "source": "./skills/rocprof-compute", - "skills": "./", - "description": "This skill should be used when profiling AMD GPU kernels with rocprof-compute to collect metrics, roofline data, and analyze bottlenecks for HIP kernels." - }, - { - "name": "triton-hip-reference-kernel-search", - "source": "./skills/triton-hip-reference-kernel-search", - "skills": "./", - "description": "Search and adapt Triton/HIP kernel patterns from a corpus to optimize AMD GPUs; use to find similar ops and reuse tiling/occupancy strategies." - }, - { - "name": "triton-kernel-optimization", - "source": "./skills/triton-kernel-optimization", - "skills": "./", - "description": "This skill should be used when writing or tuning Triton GPU kernels, including autotuning block sizes, coalesced accesses, tiled matmul, fused ops, reductions, flash-attention style kernels, quantization, custom gradients, and profiling." - }, - { - "name": "triton-kernel-reflection-prompts", - "source": "./skills/triton-kernel-reflection-prompts", - "skills": "./", - "description": "Reflection/self-critique prompts for reviewing and fixing AMD-targeted Triton kernels after generation or test failures." } ] } diff --git a/README.md b/README.md index 0a03cb9..714e79c 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Skills earn their keep on repeated, opinionated workflows, exactly where the AMD > > **Target: ready for testing by June 12.** Until then, treat anything below as a preview. -The initial catalog is organized into five focus areas. +The initial catalog is organized into four focus areas. ### Application integration @@ -80,22 +80,6 @@ Diagnose, configure, and ready AMD systems for AI workloads: drivers, BIOS, memo | `gfx-target-chooser` | Pick the right `gfx942` / `gfx90a` / `gfx1100` target and matching compiler flags. | _planned_ | | `pytorch-rocm-setup` | Get a known-good PyTorch + ROCm stack running on a target node, end to end. | _planned_ | -### Kernel engineering - -Author, tune, and reason about GPU kernels for AMD targets. - -| Skill | What it does | Source | -| --- | --- | --- | -| [`aiter-reflection`](skills/aiter-reflection/SKILL.md) | Optimize AMD GPU kernels on MI300 using the aiter project: op tests, benchmarks, iteration, experiment database. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`gpu-architecture-fundamentals`](skills/gpu-architecture-fundamentals/SKILL.md) | Reason about memory hierarchy, execution model, block sizing, and latency across HIP, Triton, and PyTorch. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`hip-kernel-optimization`](skills/hip-kernel-optimization/SKILL.md) | Write and tune HIP kernels: coalescing, shared-memory tiling, bank conflicts, warp primitives, occupancy, vectorization. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`kernel-exp-history`](skills/kernel-exp-history/SKILL.md) | Consult past kernel optimization experiments and record the current iteration back into the experiment database. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`mi300-hip-programming-insights`](skills/mi300-hip-programming-insights/SKILL.md) | CDNA3 / MI300 HIP programming insights: chiplet and cache model, Infinity Cache, coherency, matrix cores, sparsity. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`pytorch-kernel-optimization`](skills/pytorch-kernel-optimization/SKILL.md) | Optimize PyTorch models and kernels: `torch.compile`, custom extensions, mixed precision, CUDA graphs, profiling. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`triton-hip-reference-kernel-search`](skills/triton-hip-reference-kernel-search/SKILL.md) | Search and adapt Triton / HIP kernel patterns from a corpus to reuse tiling and occupancy strategies. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`triton-kernel-optimization`](skills/triton-kernel-optimization/SKILL.md) | Write and tune Triton kernels: autotune block sizes, tiled matmul, fused ops, reductions, flash-attention, quantization. | [Apex](https://github.com/AMD-AGI/Apex) | -| [`triton-kernel-reflection-prompts`](skills/triton-kernel-reflection-prompts/SKILL.md) | Reflection / self-critique prompts for reviewing and fixing AMD-targeted Triton kernels. | [Apex](https://github.com/AMD-AGI/Apex) | - ### Cross-stack porting Bring existing workloads onto AMD. @@ -113,7 +97,7 @@ Close the loop from trace to fix to ship. | Skill | What it does | Source | | --- | --- | --- | | [`magpie`](skills/magpie/SKILL.md) | Evaluate GPU kernel correctness and performance, compare kernel implementations, and benchmark vLLM / SGLang inference with profiling, TraceLens, and torch-trace gap analysis. | [Magpie](https://github.com/AMD-AGI/Magpie) | -| [`rocprof-compute`](skills/rocprof-compute/SKILL.md) | Profile AMD GPU kernels with `rocprof-compute` to collect metrics, roofline data, and bottleneck analysis. | [Apex](https://github.com/AMD-AGI/Apex) | +| `hyperloom` | Autonomously optimizes LLM inference on AMD GPUs. | _planned_ | | `omniperf-tune` | Run `omniperf`, locate the bottleneck, and suggest the fix. | _planned_ | | `quark-quantize` | Quantize PyTorch / ONNX models with [AMD Quark](https://github.com/amd/Quark) and export for AMD deployment. | _planned_ | diff --git a/scripts/sources.yml b/scripts/sources.yml index b5fbe5b..3d88a09 100644 --- a/scripts/sources.yml +++ b/scripts/sources.yml @@ -23,25 +23,6 @@ # the resulting changes for human review. sources: - - name: amd-agi-apex - repo: AMD-AGI/Apex - ref: main - path: tools/skills - license: MIT - # `skill-creator` is intentionally excluded; this catalog already has - # its own `create-skill` story via CONTRIBUTING.md. - skills: - - aiter-reflection - - gpu-architecture-fundamentals - - hip-kernel-optimization - - kernel-exp-history - - mi300-hip-programming-insights - - pytorch-kernel-optimization - - rocprof-compute - - triton-hip-reference-kernel-search - - triton-kernel-optimization - - triton-kernel-reflection-prompts - - name: amd-agi-magpie repo: AMD-AGI/Magpie ref: main diff --git a/skills/aiter-reflection/.federated.json b/skills/aiter-reflection/.federated.json deleted file mode 100644 index e6def93..0000000 --- a/skills/aiter-reflection/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/aiter-reflection", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/aiter-reflection/SKILL.md b/skills/aiter-reflection/SKILL.md deleted file mode 100644 index 3322a85..0000000 --- a/skills/aiter-reflection/SKILL.md +++ /dev/null @@ -1,72 +0,0 @@ ---- -name: aiter-reflection -description: This skill should be used when optimizing AMD GPU kernels on MI300 using the aiter project, including running op tests, benchmarking, iterating on kernel changes, and recording results in the kernel experiment database. ---- - -# Aiter Reflection - -## Overview - -Optimize AMD MI300 GPU kernels for correctness and performance using the aiter workflow, then record each iteration to the kernel experiment database. - -## Workflow - -### 1) Locate targets and understand tests - -- Use the provided context to identify target kernel files, kernels, and their op tests. -- Run the op tests once to understand output format and verify correctness expectations. (Attention: Stucked background op test processes and lock files under jit folder may cause the op tests running failed; Op tests require JIT compiling, please be prepared to wait for a long time) - -### 2) Build a benchmark shell script -- Come up with a new name for this iteration and create a folder logs/. Put the shell script under this folder -- Reuse the existing op_test python script -- Covers common shapes: 128, 256, 512, 1024, 2048, 4096 if applies -- Repeats each op test multiple times and reports the correctness and the average time consuming. - - Use at least 100 iterations per configuration for reliable results - - Include 10-20 warmup iterations to handle JIT compilation overhead - - Add torch.cuda.synchronize() after each kernel call - - Use fixed random seed for reproducibility - - Use high-precision timing (time.perf_counter()) -- Implements a robust timeout to avoid hangs. -- Outputs structured timing per shape. - -### 3) Establish a baseline - -- **Before testing**: Check for background GPU processes that may interfere - - Use `rocm-smi` or `ps aux | grep python` to identify GPU tasks - - Stop any unrelated GPU workloads -- Clear JIT compilation cache to ensure clean state -- Run the benchmark script using the `.venv` Python environment -- Save results under logs/ folder with timestamp - - -### 4) Iterate on kernel optimization (one iteration) - -- Read the kernel source, identify bottlenecks, and call `rocprof-compute` at least once to deepen bottleneck analysis. -- Use `kernel-exp-history` to review related optimization history and extract ideas. -- Modify the kernel file to improve performance for multiple shapes allowed. -- Save the changes: (git diff > logs//iter_diff.patch) -- Reinstall aiter and clear cache: - - `python -m pip install -e . --no-build-isolation --no-deps --force-reinstall` - - `rm -f aiter/jit/*.so && rm -rf aiter/jit/build ~/.aiter` -- Re-run the benchmark to measure the new performance. -- **If results seem suspicious** (unexpected regressions): - - Verify no background processes are running - - Re-test baseline with same methodology - - Check if JIT compilation overhead affected measurements - - -### 5) Record the iteration - -- **Document the results**: - - Save detailed analysis in logs//iter_analysis.md - - Include performance comparison table - - Document any issues encountered (false regressions, test methodology problems) - -- Use `kernel-exp-history` to store in database -- **Verify result quality**: If showing unexpected regression, investigate before recording -- Restore the repo code to the `main` branch state after finishing the iteration - - -### 6) Repeat iterations - -- Repeat step 4 for ten iterations (no stop), each time measuring and recording results. diff --git a/skills/gpu-architecture-fundamentals/.federated.json b/skills/gpu-architecture-fundamentals/.federated.json deleted file mode 100644 index 4bab37c..0000000 --- a/skills/gpu-architecture-fundamentals/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/gpu-architecture-fundamentals", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/gpu-architecture-fundamentals/SKILL.md b/skills/gpu-architecture-fundamentals/SKILL.md deleted file mode 100644 index 4f439f6..0000000 --- a/skills/gpu-architecture-fundamentals/SKILL.md +++ /dev/null @@ -1,36 +0,0 @@ ---- -name: gpu-architecture-fundamentals -description: This skill should be used when reasoning about GPU architecture fundamentals to guide kernel optimization choices such as memory hierarchy usage, execution model mapping, block sizing, and latency-aware tuning across HIP, Triton, and PyTorch. ---- - -# GPU Architecture Fundamentals - -## Purpose -- Reference core GPU concepts (memory hierarchy, execution model) and typical bandwidth/latency numbers to ground optimization choices. -- Provide block size heuristics and ready-to-use checklists before writing or tuning kernels. -- Map common optimization patterns across HIP, Triton, and PyTorch to pick framework-specific tactics quickly. - -## When to Use -- Planning or reviewing kernel designs where occupancy, memory bandwidth, or latency hiding are concerns. -- Selecting grid/block shapes, deciding on shared memory usage, or checking for coalesced accesses. -- Comparing optimization levers across frameworks when porting kernels. - -## How to Use -- Recall memory hierarchy: prefer registers > shared/L1 > L2 > HBM; treat HBM as ~400–800 cycle latency, registers ~0, shared ~20–30 cycles. -- Anchor bandwidth sense-checks with table values (e.g., MI300X HBM3 ~5.3 TB/s, A100 HBM2e ~2.0 TB/s). -- Choose block sizes by operation: element-wise 256–1024 threads, reduction 256–512, matmul tiles 128x128 or 256x128, conv 32x32 or 64x64. -- Apply execution model mapping: thread ↔ element/partial tile, warp/wavefront ↔ contiguous data segments, block/workgroup ↔ tiles sharing shared memory, grid ↔ full problem coverage. -- Run the optimization checklist before finalizing kernels: - - Ensure coalesced and vectorized memory access; avoid shared memory bank conflicts. - - Target occupancy >50%; watch register pressure and shared memory usage to avoid spilling. - - Fuse operations where possible; leverage mixed precision when valid. - - Overlap transfers with compute; tune block/grid dimensions; unroll small loops. -- Use pattern summaries to pick tactics per framework: - - Memory: HIP manual strides/shared, Triton `tl.arange`/implicit tiling, PyTorch `.contiguous()`/compiler. - - Compute: HIP manual fusion/unroll, Triton `@triton.jit` + `tl.constexpr`, PyTorch `torch.compile`/FlashAttention. - - Parallelism: HIP block/grid + occupancy APIs, Triton autotune + constexpr block sizes, PyTorch compiler/automatic launch config. - -## Quick Checks -- If performance regresses, compare achieved block size and occupancy to table heuristics. -- If L2/HBM traffic is high, add tiling or fusion; if shared memory stalls, check bank conflicts and tile padding. -- When switching hardware, re-evaluate bandwidth and latency assumptions and retune block sizes accordingly. diff --git a/skills/hip-kernel-optimization/.federated.json b/skills/hip-kernel-optimization/.federated.json deleted file mode 100644 index ef7eec1..0000000 --- a/skills/hip-kernel-optimization/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/hip-kernel-optimization", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/hip-kernel-optimization/SKILL.md b/skills/hip-kernel-optimization/SKILL.md deleted file mode 100644 index 353ea26..0000000 --- a/skills/hip-kernel-optimization/SKILL.md +++ /dev/null @@ -1,255 +0,0 @@ ---- -name: hip-kernel-optimization -description: This skill should be used when writing or tuning HIP kernels on AMD/NVIDIA GPUs, covering memory coalescing, shared-memory tiling, bank conflict avoidance, warp primitives, occupancy, vectorization, async ops, loop unrolling, and profiling. ---- - -# HIP Kernel Optimization - -## Purpose -Provide ready patterns for efficient HIP kernels and guide diagnosis of memory throughput, occupancy, and synchronization bottlenecks. - -## When to Use -- Implementing or reviewing HIP kernels for AMD MI/CDNA architectures or CUDA-portable code -- Porting CUDA code to HIP while retaining performance -- Preparing profiling runs with `rocprof` - -## Optimization Priority - -**Phase 1: Low-hanging fruit** (try first, low risk) -1. `#pragma unroll` on hot loops with small, fixed trip counts -2. Enable `-ffast-math` compiler flag for floating-point kernels -3. Use 32B vectorized loads/stores instead of 16B -4. Add `__launch_bounds__(maxThreads, minBlocks)` to guarantee occupancy -5. Add `const` qualifiers on read-only pointers -6. Verify memory coalescing (consecutive threads → consecutive addresses) - -**Phase 2: Targeted improvements** (profile first) -7. Profile with `rocprof` to confirm bottleneck -8. If memory-bound: CK-Tile buffer views with vectorization -9. If compute-bound: Shared memory tiling -10. Dynamically calculate block size based on problem dimensions -11. Replace large 2D shared arrays with atomicAdd for sparse patterns -12. Provide multiple block size configurations to avoid register spill -13. Add explicit rounding mode control for numerical correctness -14. Pre-compute workspace size to avoid dynamic allocation -15. Implement CSV-based tuning cache for repeated GEMM shapes - -**Phase 3: Complex transformations** (high effort) -16. Algorithm changes (e.g., Top-K-only softmax) -17. gfx950: Use 16x16x32 MFMA instead of 2x 16x16x16 -18. Kernel fusion (multi-op in single kernel) -19. Persistent kernels for repeatedly executed operations -20. Shape-based heuristic dispatching - -**Anti-patterns**: -- Optimizing everything at once -- Manual loop unrolling (use `#pragma unroll` instead) -- Over-unrolling (factor > 8) -- Premature vectorization without alignment check -- Unnecessary buffer coherence flags (e.g., `glc`) - -## Core Optimization Patterns - -### 1. Memory Access -- **Coalescing**: Map consecutive threads to consecutive addresses; prefer SoA over AoS -- **Vectorization**: Use CK-Tile buffer views for efficient I/O; prefer 32B loads over 16B -- **Boundary handling**: Separate fast vectorized path from slow boundary path - ```cpp - if(idx + VEC_SIZE <= d) { - vec_o out_vec; - #pragma unroll - for(size_t j = 0; j < VEC_SIZE; j++) { - out_vec[j] = compute(x[j], y[j]); - } - buffer_out.template set(idx, 0, true, out_vec); // Fast path - } else { - for(size_t j = 0; j < VEC_SIZE; j++) { // Boundary path - if(idx + j < d) ptr_out[idx + j] = compute(...); - } - } - ``` - -### 2. Shared Memory -- **Tiling**: Load tiles once, reuse; balance TILE_SIZE vs occupancy -- **Bank conflicts**: Pad shared arrays (e.g., `[32][33]`) or rotate access -- **Sparse patterns**: Use atomicAdd to 1D counters (O(N)) instead of 2D arrays (O(N²)) - ```cpp - // Three-pass pattern for sparse bucketing - // Pass 1: Count items per category - for(int i = start_idx; i < end_idx; ++i) { - int32_t category_id = input_ids[i]; - atomicAdd(&category_counts[category_id], 1); - } - __syncthreads(); - - // Pass 2: Compute prefix sum for offsets - if(threadIdx.x == 0) { - for(int i = 0; i < num_categories; ++i) - cumsum[i+1] = cumsum[i] + category_counts[i]; - } - __syncthreads(); - - // Pass 3: Assign items using atomic write positions - for(int i = start_idx; i < end_idx; ++i) { - int32_t position = atomicAdd(&write_positions[input_ids[i]], 1); - sorted_output[position] = i; - } - ``` - -### 3. Warp/Wavefront Primitives -- Use `__shfl_*`, ballots, and warp reductions to reduce shared memory -- Pattern for warp-level argmax: - ```cpp - auto arg_max = [](const kvp& a, const kvp& b) { - return (a.value > b.value || (a.value == b.value && a.key < b.key)) ? a : b; - }; - kvp thread_kvp = {item_id, max_val}; - thread_kvp = warp_reduce(thread_kvp, arg_max, WARP_SIZE); - ``` - -### 4. Occupancy Tuning -- **Dynamic block sizing**: Calculate based on problem dimensions - ```cpp - int vec_size = nextPow2(d / 64); - vec_size = min(vec_size, max_vec_size); - int num_wave = min(nextPow2(d / 64 / vec_size), max_wave_num); - dim3 block(max(num_wave, 1) * 64); - ``` - -- **Guaranteed occupancy**: Use `__launch_bounds__` for predictable performance - ```cpp - __launch_bounds__(256, 8) __global__ // 256 threads, min 8 blocks per CU - void kernel(scalar_t* __restrict__ output, ...) { } - ``` - -- **Register spill prevention**: Provide multiple block size options - ```cpp - if (MPerBlock == 64) - gemm_kernel<..., 64, ...>(...); - else if (MPerBlock == 128) - gemm_kernel<..., 128, ...>(...); - else if (MPerBlock == 256) - gemm_kernel<..., 256, ...>(...); - ``` - -- **Adaptive grid sizing**: Don't use fixed grid for variable problem sizes; adapt to small dimensions - -### 5. Loop Unrolling -- Apply `#pragma unroll` for small, fixed trip counts -- Unroll vector processing: `#pragma unroll` before `for(size_t j = 0; j < VEC_SIZE; j++)` - -### 6. Async Memory Operations -- Overlap H2D/D2H with compute using multiple streams and `hipMemcpyAsync` - -## AMD-Specific Optimizations - -### 7. MFMA Instructions (gfx940/942/950) -- **gfx950**: Use single 16x16x32 MFMA instead of 2x 16x16x16 - ```cpp - #if defined(__gfx950__) - dout = gcn_mfma16x16x32_instr(K, Q, dout); - #else - for(int i = 0; i < 2; i++) { - dout = gcn_mfma16x16x16_instr(K.xy[i], Q.xy[i], dout); - } - #endif - ``` -- Use `__builtin_shufflevector` to reorganize data for larger MFMA variants - -### 8. Inline Assembly for Packed Operations -- **v_pk_mul_f32**: Process two floats at once - ```cpp - float2 result; - asm volatile("v_pk_mul_f32 %0, %1, %2\n\t" - "v_pk_mul_f32 %0, %0, %3" - : "=v"(result) : "v"(act_vals), "v"(y_vals), "v"(scale_vals)); - ``` - -### 9. Compiler Flags -- **-ffast-math**: Enables aggressive FP optimizations -- **Avoid unnecessary coherence**: Don't use `ck_tile::amd_buffer_coherence_enum::glc` unless required - -## Advanced Optimization Strategies - -### 10. Algorithm-Level Optimizations -- **Top-K-only softmax**: Only compute exp on top-K values, not entire row - ```cpp - float thread_max = find_max_in_row(); - for(int k_idx = 0; k_idx < k; ++k_idx) { - kvp top = find_argmax_in_remaining(); - output[k_idx] = expf(top.value - thread_max); - renorm_value += output[k_idx]; - row_chunk[top.index] = -INFINITY; - } - float row_sum_rest = compute_sum_of_remaining_exp(thread_max); - normalize_top_k(renorm_value + row_sum_rest); - ``` - -- **Kernel fusion**: Combine operations to reduce launches (e.g., norm+RoPE+cache+quant) -- **Persistent kernels**: Keep kernels resident on GPU for repeated operations - -### 11. Numerical Precision -- **Explicit rounding**: Add rounding mode parameters for attention kernels -- **FP8 descale**: Apply descaling during computation to avoid separate kernel - -### 12. Kernel Selection and Dispatching -- **CSV tuning cache**: Cache optimal configs to eliminate repeated tuning - ```cpp - int get_algoIdx_from_csv(const std::string filename, ...) { - // Parse CSV and match (trans_a, trans_b, m, n, k, dtypes) - for each line: - if (all_params_match) return algo_index; - return -1; // Not found - } - ``` - -- **Shape-based dispatch**: Use heuristics for kernel selection - ```cpp - Kernel select_kernel(int M, int N, int K) { - if (M < 128) return gemm_small_m<...>; - else return gemm_large_m<...>; - } - ``` - -- **Workspace pre-calculation**: Compute exact size before allocation - ```cpp - int64_t ws_size = topkValue * (sizeof(T) + sizeof(IdxT)) * numRows; - auto workspace = allocate_device_memory(ws_size); - ``` - -## Quick Reference -- Kernel launch: `hipLaunchKernelGGL(kernel, dim3(grid), dim3(block), sharedMem, stream, args...)` -- Memory: `hipMalloc`, `hipMemcpy`, `hipFree` -- Sync: `__syncthreads()`, `hipDeviceSynchronize()` -- Atomics: `atomicAdd`, `atomicCAS` -- CK-Tile: `ck_tile::make_buffer_view(ptr, oob)` - -## Profiling -- Summary: `rocprof --stats program` -- Detailed: `rocprof --hip-trace --hsa-trace program` -- Metrics: `rocprof -i metrics.txt program` - -## Validation Checklist -- [ ] Coalesced loads/stores; bank conflicts minimized -- [ ] Vectorized I/O aligned and beneficial -- [ ] Occupancy >50%; no register spilling -- [ ] Shared memory: atomicAdd for sparse patterns (O(N) not O(N²)) -- [ ] Loops unrolled for small fixed trips -- [ ] `-ffast-math` enabled for FP kernels -- [ ] No unnecessary coherence flags -- [ ] gfx950: Using 16x16x32 MFMA - -## Performance Impact (Production-Validated) - -| Optimization | Use Case | Typical Impact | -|-------------|----------|----------------| -| `#pragma unroll` | Memory kernels | +3-5% | -| AtomicAdd sparse | MOE, sorting | +15-20%, O(N²)→O(N) | -| 32B vectors | Memory-bound | Better throughput | -| `-ffast-math` | Math-heavy | +5-10% FP | -| Top-K softmax | Gating | Reduce exp by 50-90% | -| 16x16x32 MFMA | Attention | 2x→1x calls | -| `__launch_bounds__` | Position encoding | Guaranteed occupancy | -| Multiple MPerBlock | GEMM stages | Fix register spill | -| Persistent kernels | Paged attention | -50-80% launch overhead | -| CSV cache | GEMM tuning | Eliminate repeat tuning | diff --git a/skills/kernel-exp-history/.federated.json b/skills/kernel-exp-history/.federated.json deleted file mode 100644 index ce47b3b..0000000 --- a/skills/kernel-exp-history/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/kernel-exp-history", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/kernel-exp-history/SKILL.md b/skills/kernel-exp-history/SKILL.md deleted file mode 100644 index 6145858..0000000 --- a/skills/kernel-exp-history/SKILL.md +++ /dev/null @@ -1,163 +0,0 @@ ---- -name: kernel-exp-history -description: This skill should be used when optimizing kernels in this repo and needing to consult past optimization experiments, or when recording the current optimization iteration back into the kernel experiment database. ---- - -# Kernel Experiment History - -## Overview - -Use the local kernel experiment database to look up prior optimization attempts and record new results after an optimization iteration completes. - -## Workflow - -### 1) Find prior experiments for inspiration - -- Read `references/kernel_exp_dataclass.py` to understand the database helpers and schema. -- Start with `top_experiments(max_results=20)` to get a score-sorted list of high-impact experiments. -- If more context is needed, load full entries using `get_experiment(exp_id)` or `list_experiments()` and filter by `operator_sig`, `dtype_sig`, `env`, or `base_commit`. -- Summarize the most relevant patterns (block sizes, memory changes, profiling signals, etc.) before proposing new optimizations. - -#### Query Examples - -**Example 1: Find similar kernel optimizations** -```python -# Search for cache kernel optimizations -from kernel_exp_dataclass import list_experiments - -experiments = list_experiments() -cache_exps = [e for e in experiments if 'cache' in e.operator_sig.lower()] - -# Sort by score -cache_exps_sorted = sorted(cache_exps, key=lambda x: x.score, reverse=True) - -print("Top cache kernel optimizations:") -for exp in cache_exps_sorted[:5]: - print(f" {exp.score:.4f}x - {exp.change_summary}") -``` - -**Example 2: Find best unroll factor** -```python -# Compare different unroll factors -unroll_exps = [e for e in experiments if 'unroll' in e.change_summary.lower()] - -for exp in unroll_exps: - factor = 'unknown' - if 'unroll 4' in exp.detailed_description.lower(): - factor = '4' - elif 'unroll 8' in exp.detailed_description.lower(): - factor = '8' - print(f"Unroll {factor}: {exp.score:.4f}x - {exp.operator_sig[:50]}") -``` - -**Example 3: Learn from failures** -```python -# Find what NOT to do -failures = [e for e in experiments if e.score < 0.98 or e.is_buggy] - -print("Failed optimizations (learn from these!):") -for exp in failures: - print(f" ❌ {exp.change_summary}") - print(f" Why: {exp.detailed_description[:100]}...") -``` - -### 2) Record the current optimization iteration - -- After finishing the optimization iteration, write a concise summary of the changes and results. -- Populate all required fields on `KernelExperiment`, including: - - `change_summary`, `detailed_description`, `raw_result`, `score` - - `operator_sig`, `dtype_sig`, `env`, `base_commit`, `profiling_info` - - `is_buggy`, `error_message`, `status` - - `pid` if this iteration builds on a parent experiment (set manually) -- Call `create_experiment()` to append the entry to the database. - -#### Field-by-Field Best Practices - -**change_summary** (1 line, <80 chars): -- ✅ Good: "Applied #pragma unroll 4 to flash kernel - best result at +1.90%" -- ❌ Bad: "Made some changes to the kernel" -- Format: ` - ` or ` - ` - -**detailed_description** (multiple paragraphs): -Structure: -``` -**Approach**: [What you tried] -- Specific technical details -- Why you thought it would work - -**Result**: [What happened] -- Quantitative results -- Qualitative observations - -**Why it worked/failed**: [Root cause analysis] -- Technical explanation -- Compare to similar attempts - -**Key insight**: [Takeaway for future] -- What this taught you -- How to apply the lesson -``` - -**raw_result** (structured text): -``` -Iteration N Results - [SUCCESS/REGRESSION/CRASH]: - -**Overall**: X.XXXXx speedup = Y.YY% [IMPROVEMENT/REGRESSION] - -**Per-kernel breakdown**: -- kernel_1: X.XXXXx (+Y.YY%) -- kernel_2: X.XXXXx (+Y.YY%) -... - -**Summary**: X improvements, Y neutral, Z regressions - -**Key finding**: [One-line takeaway] -``` - -**profiling_info** (even if not profiled): -- If profiled: Include key metrics (occupancy, bandwidth, bottleneck type) -- If NOT profiled: Explain why not, and what benchmarks showed - -### 3) Update existing experiments (if needed) - -- If you discover errors in previous recordings (e.g., false regression due to testing issues): - - Use `update_experiment(exp_id, raw_result=..., score=..., detailed_description=...)` - - Update the score to reflect corrected performance - - Document the correction reason in detailed_description -- Common update scenarios: - - Test methodology errors discovered - - Performance re-measurement with better methodology - - Bug fixes affecting correctness - - -## Score Guidelines - -- Score = speedup ratio (e.g., 1.18 for 18% improvement) -- For regressions: score < 1.0 (e.g., 0.70 for 30% slower) -- Average across all tested configurations if performance varies - -## The Value of Recording Failures - -**Critical**: Record ALL iterations, especially failures! - -**Why record failures?** -1. 🚫 **Prevent repetition**: Future you won't try the same failed approach -2. 📚 **Build institutional knowledge**: Team learns what doesn't work -3. 🔍 **Pattern recognition**: Multiple failures reveal deeper issues -4. 💡 **Negative results are results**: "X doesn't work" is valuable information - -**Failure categories to track**: -- **Buggy** (`is_buggy=True`): Crashes, correctness errors -- **Regressive** (score < 1.0): Made things slower -- **Marginal** (0.99 < score < 1.01): No meaningful impact -- **Interference** (combined optimization worse than separate): Resource conflicts - -**Example from cache kernel optimization**: -- Iteration 1 (-1.81%): Disrupted coalescing → Learned: preserve memory patterns -- Iteration 5 (CRASH): Manual unrolling bug → Learned: use pragmas, not manual -- Iteration 7 (-0.63%): Combined optimizations → Learned: resource interference real - -## Notes - -- Use `top_experiments()` first; fall back to full queries only when additional details are needed. -- Keep summaries short but specific enough to guide future optimization decisions. diff --git a/skills/kernel-exp-history/references/kernel_exp_dataclass.py b/skills/kernel-exp-history/references/kernel_exp_dataclass.py deleted file mode 100644 index 2ee6973..0000000 --- a/skills/kernel-exp-history/references/kernel_exp_dataclass.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2025 Advanced Micro Devices, Inc. -# SPDX-License-Identifier: MIT -from __future__ import annotations - -import json -import os -import tempfile -import time -import fcntl -from contextlib import contextmanager -from dataclasses import dataclass, field, asdict -from datetime import datetime, timezone -from pathlib import Path -from typing import Dict, List, Literal, Optional -import uuid - -# Global JSON "database" location -DB_PATH = Path("kernel_experiments_db.json") - -Status = Literal["new", "running", "done", "failed", "timeout"] - - -def _uuid64() -> str: - """Generate a 64-bit hex uuid string.""" - return f"{uuid.uuid4().int & ((1 << 64) - 1):016x}" - - -@dataclass -class KernelExperiment: - score: float # avg speedup (1.0 = no speedup, 2.0 = 2x faster) - raw_result: str # per-shape speedups or notes - dtype_sig: str # fp16, bf16, fp32, bf8, etc. - env: str # GPU model, ROCm version, etc. - is_buggy: bool - error_message: str # error type + message when is_buggy is True - change_summary: str - detailed_description: str - code_change: str # diff patch string - base_commit: str # upstream commit id (not local) - operator_sig: str # which files/kernels are affected - profiling_info: str - status: Status - id: str = field(default_factory=_uuid64) - pid: str = field( - default="", - metadata={"comment": "Parent experiment id; set manually, do not auto-generate."}, - ) # Parent experiment id; set manually when linking lineage. - created_at: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") - ) - - def to_dict(self) -> Dict: - return asdict(self) - - @staticmethod - def from_dict(data: Dict) -> "KernelExperiment": - return KernelExperiment(**data) - - -def _ensure_db_exists() -> None: - DB_PATH.parent.mkdir(parents=True, exist_ok=True) - if not DB_PATH.exists(): - DB_PATH.write_text("{}", encoding="utf-8") - - -@contextmanager -def _locked_db(exclusive: bool): - mode = "a+" # ensure file exists and is open for locking - with DB_PATH.open(mode, encoding="utf-8") as f: - fcntl.flock(f.fileno(), fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH) - try: - yield - finally: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) - - -def _atomic_write_json(path: Path, content: Dict[str, Dict]) -> None: - with tempfile.NamedTemporaryFile( - "w", dir=path.parent, delete=False, encoding="utf-8" - ) as tmp: - json.dump(content, tmp, indent=2, sort_keys=True) - tmp.flush() - os.fsync(tmp.fileno()) - tmp_path = Path(tmp.name) - os.replace(tmp_path, path) - - -def _load_db() -> Dict[str, Dict]: - _ensure_db_exists() - with _locked_db(exclusive=False): - with DB_PATH.open("r", encoding="utf-8") as f: - return json.load(f) - - -def _save_db(db: Dict[str, Dict]) -> None: - with _locked_db(exclusive=True): - _atomic_write_json(DB_PATH, db) - - -def create_experiment(exp: KernelExperiment) -> None: - db = _load_db() - if exp.id in db: - raise ValueError(f"Experiment with id '{exp.id}' already exists") - db[exp.id] = exp.to_dict() - _save_db(db) - - -def get_experiment(exp_id: str) -> Optional[KernelExperiment]: - db = _load_db() - if exp_id not in db: - return None - return KernelExperiment.from_dict(db[exp_id]) - - -def list_experiments() -> List[KernelExperiment]: - db = _load_db() - return [KernelExperiment.from_dict(v) for v in db.values()] - - -def top_experiments(max_results: int = 20) -> List[Dict[str, object]]: - """ - Return experiments sorted by score desc, containing only key fields. - """ - experiments = list_experiments() - filtered = [exp for exp in experiments] - filtered.sort(key=lambda e: e.score, reverse=True) - top_n = filtered[: max(0, max_results)] - keys = [ - "base_commit", - "change_summary", - "detailed_description", - "dtype_sig", - "env", - "id", - "operator_sig", - "profiling_info", - "raw_result", - "score", - ] - return [{k: getattr(exp, k) for k in keys} for exp in top_n] - - -def update_experiment(exp_id: str, **changes) -> KernelExperiment: - db = _load_db() - if exp_id not in db: - raise KeyError(f"Experiment with Id '{exp_id}' not found") - current = db[exp_id] - current.update(changes) - db[exp_id] = current - _save_db(db) - return KernelExperiment.from_dict(current) - - -def delete_experiment(exp_id: str) -> None: - db = _load_db() - if exp_id not in db: - raise KeyError(f"Experiment with Id '{exp_id}' not found") - del db[exp_id] - _save_db(db) - - -def test_insert_example() -> KernelExperiment: - """Insert a sample experiment entry for quick sanity checks.""" - sample = KernelExperiment( - pid="(Parent experiment id)", - score=1.25, - raw_result="shape=128x128 speedup=1.3; shape=256x256 speedup=1.2", - dtype_sig="fp16", - env="MI300X, ROCm 7.0.0", - is_buggy=False, - error_message="", - change_summary="tuned block size and vectorized loads", - detailed_description="Adjusted kernel launch for better wave occupancy on MI300X.", - code_change="(diff patch here)", - base_commit="abcdef1234567890", - operator_sig="attention_ragged.cu: paged_attention_ll4mi", - profiling_info="SQ busy 75%, TCP 65%, TCC 55%", - status="new", - ) - create_experiment(sample) - return sample - - -if __name__ == "__main__": - exp = test_insert_example() - print(f"Inserted sample experiment with id: {exp.id}") diff --git a/skills/mi300-hip-programming-insights/.federated.json b/skills/mi300-hip-programming-insights/.federated.json deleted file mode 100644 index 2d0e469..0000000 --- a/skills/mi300-hip-programming-insights/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/mi300-hip-programming-insights", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/mi300-hip-programming-insights/SKILL.md b/skills/mi300-hip-programming-insights/SKILL.md deleted file mode 100644 index 6d7e1e6..0000000 --- a/skills/mi300-hip-programming-insights/SKILL.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -name: mi300-hip-programming-insights -description: CDNA3/MI300 HIP programming insights—chiplet/cache model, Infinity Cache, memory coherency, matrix cores, sparsity, and best practices. ---- - -# MI300 HIP Programming Insights - -Use when tuning HIP kernels with CDNA3 architectural context (chiplets, caches, matrix cores). - -Highlights: -- Memory hierarchy: 128B cache lines; leverage 256MB Infinity Cache (temporal locality); explicit sync across XCDs (relaxed coherency). -- Workgroups: size for 4 ACEs per XCD; balance across 38 CUs; exploit shared I-cache locality; LDS 64KB per CU. -- Matrix cores: align data; overlap matrix + vector + memory; choose FP8/TF32 for throughput vs precision; schedule for concurrency. -- Sparsity: 2:4 structured sparsity (INT8/FP8/FP16/BF16); weigh reordering overhead vs gains; good for attention/conv. -- Cross-platform: HIP differences vs CUDA—explicit fences, data-type fallbacks, platform-specific tuning. -- Debug/profiling: use ROCm tools to analyze cache misses, bandwidth, sync overhead; focus on memory-side cache behavior. - -References: -- `references/AMD MI300 HIP Kernel Programming Guide_ CDNA3 Architecture Insights.md` diff --git a/skills/mi300-hip-programming-insights/references/AMD MI300 HIP Kernel Programming Guide_ CDNA3 Architecture Insights.md b/skills/mi300-hip-programming-insights/references/AMD MI300 HIP Kernel Programming Guide_ CDNA3 Architecture Insights.md deleted file mode 100644 index 35ed08a..0000000 --- a/skills/mi300-hip-programming-insights/references/AMD MI300 HIP Kernel Programming Guide_ CDNA3 Architecture Insights.md +++ /dev/null @@ -1,331 +0,0 @@ -# AMD MI300 HIP Kernel Programming Guide: CDNA3 Architecture Insights - - -## Executive Summary - -The AMD CDNA3 architecture, embodied in the MI300 series accelerators, represents a paradigmatic shift in GPU design philosophy that fundamentally impacts how high-performance HIP kernels should be written and optimized. Unlike traditional monolithic GPU designs, CDNA3 embraces a heterogeneous chiplet architecture that introduces unique programming considerations, memory hierarchy optimizations, and performance characteristics that differ significantly from NVIDIA's AI accelerators. - -This guide synthesizes critical architectural insights from the AMD CDNA3 white paper to provide large language models and developers with the specialized knowledge necessary to generate high-quality HIP kernels optimized for MI300 hardware. The focus is on architectural features that are either unique to AMD or implemented differently from NVIDIA solutions, as general GPU programming concepts are assumed to be well-understood. - -The MI300 series introduces revolutionary concepts including memory-side caching through AMD Infinity Cache, 2:4 structured sparsity support, novel data types like TF32 and OCP-compliant FP8, and a relaxed memory coherency model that requires explicit synchronization. These features, combined with the chiplet-based design and enhanced matrix processing capabilities, create both opportunities and challenges for kernel optimization that are distinct from CUDA programming paradigms. - - - - -## 1. CDNA3 Architecture Overview: Chiplet-Based Design Implications - -The AMD CDNA3 architecture fundamentally departs from traditional monolithic GPU designs by implementing a heterogeneous chiplet approach that has profound implications for kernel programming and optimization strategies. Understanding this architectural foundation is crucial for writing efficient HIP kernels that can fully exploit the hardware capabilities. - -### 1.1 Heterogeneous Chiplet Organization - -The MI300 series processors are constructed using up to 8 Accelerator Complex Dies (XCDs) and 4 I/O Dies (IODs), each fabricated on different process nodes and optimized for specific functions. The XCDs, manufactured on TSMC's 5nm process, contain the computational elements and lower-level cache hierarchy, while the IODs, built on TSMC's 6nm process, house the memory controllers, AMD Infinity Cache, and system interconnects. This separation allows for specialized optimization of each component while enabling vertical 3D stacking through advanced packaging technologies. - -Each XCD contains exactly 40 Compute Units (CUs), with 38 active units and 2 disabled for yield management purposes. This yields a total of 304 active CUs across the full MI300X configuration, representing approximately 40% more computational resources than the previous generation MI250X. The consistent 38-CU configuration per XCD creates predictable resource allocation patterns that kernel developers can exploit for load balancing and work distribution strategies. - -The chiplet design introduces unique considerations for memory access patterns and inter-CU communication. Unlike monolithic designs where all CUs share uniform access to memory controllers, the CDNA3 architecture creates a hierarchical access pattern where CUs within the same XCD have lower latency access to the local L2 cache, while cross-XCD communication must traverse the AMD Infinity Fabric network. This architectural characteristic suggests that kernel designs should prioritize data locality within XCD boundaries when possible, and carefully consider the cost of cross-XCD data sharing. - -### 1.2 Asynchronous Compute Engine Architecture - -Each XCD incorporates 4 Asynchronous Compute Engines (ACEs) that serve as the primary work distribution mechanism for compute shader workgroups. Each ACE is nominally associated with 40 CUs, though the actual active count is 38 due to yield management. This 4-ACE configuration provides fine-grained control over work distribution and enables sophisticated load balancing strategies that can adapt to varying computational workloads. - -The ACE architecture differs significantly from NVIDIA's GigaThread Engine approach by providing multiple independent scheduling domains within each XCD. This design enables better isolation between concurrent kernels and can reduce scheduling overhead for workloads that can be effectively partitioned across the available ACEs. Kernel developers should consider designing workgroup distributions that align with the 4-ACE structure to minimize scheduling conflicts and maximize throughput. - -The hardware scheduler (HWS) coordinates work distribution across all ACEs and manages the hardware queues (HQDO-7) that feed work to the compute accelerators. Understanding this scheduling hierarchy is important for optimizing kernel launch patterns and minimizing dispatch overhead, particularly for workloads that involve frequent kernel launches or complex dependency chains. - -### 1.3 Compute Unit Internal Architecture - -The CDNA3 Compute Units represent a comprehensive redesign that doubles or quadruples performance per CU for vector and matrix workloads compared to the previous generation. Each CU functions as a complete, highly threaded parallel processor core that includes instruction fetching and scheduling, execution units for scalar, vector, and matrix operations, and load/store pipelines with integrated L1 cache and Local Data Share (LDS). - -A critical architectural innovation is the shared 64KB instruction cache between pairs of CUs, which doubles the capacity from the previous generation while maintaining nearly constant die area. This design exploits the common pattern where adjacent CUs execute identical instruction streams, effectively increasing the cacheable instruction window and improving hit rates. Kernel developers should be aware that instruction cache efficiency is maximized when neighboring CUs execute similar code paths, suggesting that workgroup assignment strategies should consider instruction locality alongside data locality. - -The enhanced source caching mechanism provides improved register reuse and bandwidth amplification, allowing each vector register read to support multiple downstream vector or matrix operations. This architectural feature rewards kernel designs that maximize register reuse and minimize redundant memory accesses, particularly for computationally intensive operations where the same data elements are used across multiple computational stages. - - -## 2. Memory Hierarchy and Caching Strategy: The Infinity Cache Revolution - -The CDNA3 memory hierarchy represents one of the most significant departures from conventional GPU memory systems and introduces programming considerations that are fundamentally different from NVIDIA architectures. Understanding these differences is crucial for optimizing memory access patterns and achieving peak performance in HIP kernels. - -### 2.1 Three-Tier Cache Hierarchy with Memory-Side Caching - -The CDNA3 architecture implements a unique three-tier cache hierarchy consisting of L1 vector data cache, L2 cache, and the revolutionary AMD Infinity Cache. This design differs markedly from traditional two-tier GPU cache hierarchies and introduces novel optimization opportunities that kernel developers must understand to achieve optimal performance. - -The L1 vector data cache has been substantially enhanced with a doubled cache line size of 128 bytes and doubled capacity to 32KB per CU. This larger cache line size is particularly beneficial for streaming workloads and vectorized operations that access contiguous memory regions. The increased line size also doubles the bandwidth between the L1 cache and the core, providing improved data delivery rates for bandwidth-intensive kernels. However, the larger cache lines also mean that memory access patterns with poor spatial locality may suffer from increased cache pollution, making careful attention to data layout and access patterns even more critical. - -The L2 cache serves as a 4MB, 16-way set-associative cache shared by all 38 CUs within an XCD. The L2 is organized into 16 parallel channels of 256KB each, enabling massive parallelism with the ability to sustain four requests from different CUs per cycle. This design provides a combined throughput of 2KB per clock per XCD, with aggregate read bandwidth across all XCDs reaching up to 34.4 TB/s. The L2 cache plays a critical role as the lowest level where hardware coherency is automatically maintained, making it the boundary between coherent and non-coherent memory operations. - -### 2.2 AMD Infinity Cache: Memory-Side Cache Innovation - -The AMD Infinity Cache represents a paradigm shift in GPU cache design, implementing a memory-side cache architecture that fundamentally differs from traditional cache hierarchies. Unlike conventional caches that can hold dirty data evicted from lower levels, the Infinity Cache is designed as a shared memory-side cache that exclusively caches the contents of memory and cannot hold dirty data. - -This design choice provides two significant advantages that impact kernel programming strategies. First, the Infinity Cache does not participate in coherency protocols and does not need to handle snoop traffic, which significantly improves efficiency and reduces latency for coherency operations from lower-level caches. Second, the cache can hold nominally uncacheable memory such as I/O buffers, providing performance benefits for kernels that work with mixed data types or perform I/O operations alongside computation. - -The Infinity Cache is organized around 128 parallel channels across 8 HBM stacks, with each channel being 64 bytes wide and connected to 2MB of data arrays. The total capacity of 256MB provides substantial caching capability, while the peak bandwidth of 17.2 TB/s approaches the aggregate bandwidth of previous generation L2 caches. This massive bandwidth makes the Infinity Cache particularly effective for workloads with good temporal locality but poor spatial locality, as it can efficiently serve repeated accesses to scattered memory locations. - -### 2.3 Relaxed Coherency Model and Synchronization Requirements - -A critical difference from NVIDIA architectures is the CDNA3's relaxed coherency model, which requires explicit synchronization to provide strong coherency and ordering guarantees. The L1 vector data cache operates with very relaxed coherency semantics, meaning that kernel developers must explicitly manage cache coherency through appropriate synchronization primitives and memory fence operations. - -This relaxed coherency model provides performance benefits by eliminating the overhead of automatic coherency maintenance, but it places additional responsibility on kernel developers to ensure correct memory ordering. Kernels that share data between workgroups or that require specific memory ordering semantics must use explicit synchronization operations such as memory fences, atomic operations, or barrier synchronization to ensure correctness. - -The coherency boundary at the L2 cache level means that operations within a single XCD can rely on hardware-maintained coherency, while operations that span multiple XCDs require explicit synchronization. This architectural characteristic suggests that kernel designs should minimize cross-XCD data sharing when possible, or carefully structure such sharing to use appropriate synchronization mechanisms. - -### 2.4 HBM3/HBM3E Memory Interface Optimization - -The CDNA3 architecture upgrades to HBM3 for MI300X and MI300A products, and HBM3E for MI325X, providing substantial memory capacity and bandwidth improvements. The MI300X provides 192GB of HBM3 memory with 5.3 TB/s peak bandwidth, while the MI325X offers 256GB of HBM3E with 6.0 TB/s peak bandwidth. These specifications represent significant improvements over previous generations and enable new classes of memory-intensive applications. - -The memory controllers are distributed across the IODs and operate at 5.2 Gbps for HBM3 and 6.0 Gbps for HBM3E. Each IOD manages two HBM stacks, creating a distributed memory architecture that can provide excellent bandwidth utilization when memory accesses are properly distributed across all stacks. Kernel developers should consider memory access patterns that can effectively utilize all available memory controllers to achieve peak bandwidth utilization. - -The channel-based organization extends from the L2 cache through the Infinity Cache to the HBM interface, with each HBM stack associated with 16 parallel channels. This consistent channel organization provides predictable performance characteristics and enables sophisticated memory access optimization strategies that can align data placement with the underlying hardware organization. - - -## 3. Matrix Core Technology and Advanced Data Type Support - -The CDNA3 Matrix Cores represent a substantial evolution in specialized compute capabilities, introducing new data types and computational paradigms that are specifically optimized for modern AI and machine learning workloads. Understanding these capabilities and their optimal usage patterns is essential for developing high-performance HIP kernels for AI applications. - -### 3.1 Enhanced Matrix Core Architecture - -The Matrix Cores in CDNA3 have been comprehensively redesigned to provide dramatic performance improvements across all supported data types. The architecture delivers generational improvements ranging from 1.7x for FP64 operations to 6.8x for INT8 operations compared to the previous CDNA2 generation. These improvements are achieved through a combination of increased parallelism, enhanced data path widths, and optimized instruction scheduling. - -Each Compute Unit contains integrated Matrix Core functionality that can execute matrix operations in parallel with vector operations, enabling sophisticated kernel designs that can overlap different types of computation. The Matrix Cores support a wide range of data types with varying throughput characteristics, allowing kernel developers to choose the optimal precision for their specific workload requirements while maximizing computational throughput. - -The peak theoretical performance for matrix operations reaches impressive levels: 163.4 TFLOP/s for FP32 matrix operations, 1,307.4 TFLOP/s for FP16/BF16 operations, and an extraordinary 2,614.9 TFLOP/s for FP8 operations on the MI300X. These performance levels represent substantial improvements over previous generations and enable new classes of computationally intensive applications that were previously impractical. - -### 3.2 Novel Data Type Support: TF32 and FP8 - -The CDNA3 architecture introduces support for two critical new data types that are becoming increasingly important in modern AI workloads: TF32 and FP8. These data types provide different trade-offs between precision, performance, and memory efficiency, enabling kernel developers to optimize for specific application requirements. - -TF32 is a 19-bit hybrid data format that combines the 10-bit mantissa precision of FP16 with the 8-bit exponent range of BF16, plus a sign bit. Despite its name suggesting a 32-bit format, TF32 is actually more compact while providing a precision and range combination that can effectively replace FP32 in most machine learning applications without accuracy degradation. The Matrix Cores provide full-rate support for TF32 operations at 1,024 FLOPS per clock per CU, offering a compelling balance between performance and precision for training workloads that require higher precision than FP16 but don't need full FP32 precision. - -FP8 support follows the OCP 8-bit Floating Point Specification, providing two variants optimized for different use cases. The E5M2 variant, with a 5-bit exponent and 2-bit mantissa, is optimized for training workloads where the extended range is more important than mantissa precision. The E4M3 variant, with a 4-bit exponent and 3-bit mantissa, is optimized for inference workloads where mantissa precision is more critical than extended range. The Matrix Cores can achieve 4,096 operations per clock per CU for FP8 operations, representing 16x the throughput of FP32 operations while using only 1/4 the memory bandwidth. - -### 3.3 Structured Sparsity Support and 2:4 Sparse Operations - -One of the most innovative features of the CDNA3 Matrix Cores is native support for structured sparsity, specifically the 2:4 sparse pattern where at least two values within every group of four input values are zero. This sparsity support is available for matrix operations using INT8, FP8, FP16, and BF16 data types, enabling up to double the computational throughput for workloads that can exploit this sparsity pattern. - -The sparse matrix support is implemented through a compact representation where non-zero data is stored in dense form with additional metadata tracking the locations of zero values. This approach allows the dense representation to fit directly into the Matrix Core pipeline while enabling the hardware to skip computations involving zero values. When the sparsity requirements are met, the Matrix Cores can achieve up to 8,000 operations per clock per CU, representing a substantial performance improvement for compatible workloads. - -The 2:4 sparsity pattern is particularly well-suited to many neural network architectures, especially attention mechanisms in transformer-based models and convolution-based networks. Kernel developers working with these types of models should consider whether their data can be structured to exploit this sparsity support, as the performance benefits can be substantial. However, it's important to note that the sparsity must be structured in the specific 2:4 pattern to be exploitable by the hardware. - -### 3.4 Matrix Core Programming Considerations - -Effective utilization of the Matrix Cores requires careful attention to data layout, operation scheduling, and memory access patterns. The Matrix Cores are designed to work most efficiently with data that is properly aligned and organized to match the hardware's internal data paths. Kernel developers should ensure that matrix data is laid out in memory with appropriate alignment and that matrix dimensions are chosen to maximize hardware utilization. - -The integration of Matrix Cores within the Compute Units enables sophisticated kernel designs that can overlap matrix operations with vector operations and memory accesses. This capability allows for the development of fused kernels that can perform complex operations without intermediate memory round-trips, potentially providing significant performance improvements for workloads that can exploit this parallelism. - -Memory bandwidth considerations are particularly important when working with the Matrix Cores, as the high computational throughput can quickly become memory-bound if data access patterns are not optimized. The enhanced cache hierarchy, including the Infinity Cache, can help mitigate memory bandwidth limitations for workloads with good temporal locality, but kernel developers must still carefully consider data reuse patterns and memory access optimization. - -### 3.5 Performance Optimization Strategies - -Achieving optimal performance with the Matrix Cores requires a holistic approach that considers data types, sparsity patterns, memory access patterns, and operation scheduling. Kernel developers should start by selecting the most appropriate data type for their precision requirements, considering the substantial performance benefits available with lower-precision formats when accuracy requirements permit. - -For workloads that can exploit sparsity, restructuring data to match the 2:4 sparse pattern can provide dramatic performance improvements. This may require preprocessing steps to identify and reorganize sparse data, but the computational benefits can justify this overhead for many applications. The sparse support is particularly valuable for inference workloads where the sparsity patterns can be determined offline and optimized for the specific hardware capabilities. - -Memory access optimization becomes even more critical when working with the high-throughput Matrix Cores. Kernel designs should prioritize data reuse, minimize memory round-trips, and structure memory accesses to take advantage of the cache hierarchy. The large cache line sizes and substantial cache capacities in CDNA3 can provide significant benefits for workloads that can maintain good spatial and temporal locality. - - -## 4. Key Differences from NVIDIA AI Accelerators - -Understanding the fundamental differences between AMD CDNA3 and NVIDIA AI accelerators is crucial for developers transitioning between platforms or optimizing kernels for cross-platform compatibility. These differences span architectural philosophy, memory systems, programming models, and performance characteristics. - -### 4.1 Architectural Philosophy: Chiplets vs. Monolithic Design - -The most fundamental difference between CDNA3 and NVIDIA architectures lies in the basic design philosophy. NVIDIA's H100 and A100 accelerators follow a monolithic die approach where all computational and memory control functions are integrated onto a single large die. This design provides uniform access patterns and simplified programming models but is limited by the maximum practical die size and manufacturing yield considerations. - -In contrast, CDNA3 embraces a heterogeneous chiplet architecture that separates computational functions (XCDs) from memory and I/O functions (IODs). This approach enables specialized optimization of each chiplet type and allows for more flexible scaling through the addition of more chiplets. However, it also introduces hierarchical access patterns and requires more sophisticated programming strategies to achieve optimal performance. - -The chiplet approach provides several advantages that impact kernel programming. The ability to disable individual CUs for yield management (2 per XCD) provides more predictable performance characteristics compared to monolithic designs where yield issues might affect larger functional blocks. The separation of compute and memory functions also enables independent optimization of each subsystem, potentially providing better performance for specific workload types. - -### 4.2 Memory Hierarchy Differences: Memory-Side Cache vs. Traditional Caching - -The memory hierarchy represents one of the most significant differences between CDNA3 and NVIDIA architectures. NVIDIA accelerators typically implement a traditional two-level cache hierarchy (L1 and L2) with write-through L1 caches and hardware-managed coherency. This approach provides predictable behavior and simplified programming models but may not be optimal for all workload types. - -CDNA3's three-tier hierarchy with the memory-side Infinity Cache introduces novel optimization opportunities that don't exist in NVIDIA architectures. The memory-side cache design means that the Infinity Cache can hold data that would be uncacheable in traditional architectures, such as I/O buffers or streaming data. This capability can provide significant performance benefits for kernels that work with mixed data types or perform complex memory access patterns. - -The relaxed coherency model in CDNA3 contrasts sharply with NVIDIA's hardware-managed coherency. While NVIDIA's approach simplifies programming by automatically maintaining cache coherency, it also introduces overhead that may not be necessary for all workloads. CDNA3's explicit synchronization requirements provide more control over coherency operations but require more sophisticated programming to ensure correctness. - -### 4.3 Compute Unit Organization and Scheduling Differences - -The organization of computational resources differs significantly between the two architectures. NVIDIA's Streaming Multiprocessors (SMs) typically contain 64-128 CUDA cores along with specialized Tensor Cores, with a single GigaThread Engine managing work distribution across all SMs. This centralized scheduling approach provides good load balancing but may introduce bottlenecks for certain workload types. - -CDNA3's approach with 4 Asynchronous Compute Engines per XCD provides more distributed scheduling and can offer better isolation between concurrent workloads. Each ACE manages a subset of the available CUs, enabling more fine-grained control over work distribution and potentially reducing scheduling overhead for workloads that can be effectively partitioned. - -The shared instruction cache between pairs of CUs in CDNA3 is another unique feature that doesn't have a direct equivalent in NVIDIA architectures. This design can provide significant benefits for workloads where adjacent CUs execute similar instruction streams, but it also requires careful consideration of workgroup assignment strategies to maximize cache efficiency. - -### 4.4 Data Type and Precision Support Variations - -While both architectures support a range of data types for AI workloads, there are important differences in implementation and performance characteristics. NVIDIA's Tensor Cores have evolved through multiple generations with different capabilities, and the specific data types and operations supported can vary significantly between different GPU models. - -CDNA3's support for TF32 as a native data type represents a unique approach to balancing precision and performance. While NVIDIA accelerators can perform TF32 operations, the implementation details and performance characteristics may differ. The OCP-compliant FP8 support in CDNA3 also follows industry standards that may not be directly compatible with NVIDIA's FP8 implementations. - -The structured sparsity support in CDNA3 follows the 2:4 pattern that is also supported by NVIDIA architectures, but the implementation details and performance characteristics can differ significantly. Kernel developers need to understand these differences to optimize sparsity exploitation for each platform. - -### 4.5 Programming Model and Software Stack Differences - -The programming model differences between HIP and CUDA represent both opportunities and challenges for kernel developers. HIP is designed to provide CUDA-like syntax while enabling cross-platform compatibility, but there are subtle differences in semantics and capabilities that can impact kernel performance and correctness. - -The ROCm software stack's open-source nature provides greater visibility into the underlying implementation compared to NVIDIA's closed-source approach. This transparency can enable more sophisticated optimization strategies but also requires developers to have a deeper understanding of the software stack internals. - -Memory management approaches also differ between the platforms. NVIDIA's Unified Memory system provides automatic data migration between CPU and GPU memory spaces, while AMD's approach typically requires more explicit memory management. The MI300A APU variant provides true unified memory that eliminates the need for data copies, but this capability is unique to the APU configuration. - -### 4.6 Virtualization and Multi-Tenancy Approaches - -The virtualization capabilities of CDNA3 and NVIDIA architectures follow different philosophies that impact how kernels can be deployed in multi-tenant environments. NVIDIA's Multi-Instance GPU (MIG) technology provides fixed partition sizes with strong isolation guarantees, but limited flexibility in partition configuration. - -CDNA3's spatial partitioning approach based on XCDs provides more flexible partition sizes and can be combined with NUMA memory partitioning for sophisticated resource allocation strategies. The SR-IOV support also provides hardware-level isolation that can be valuable for certain deployment scenarios. - -These virtualization differences can impact kernel design strategies, particularly for applications that need to run in multi-tenant environments or that require specific resource allocation patterns. Understanding the capabilities and limitations of each approach is important for developing kernels that can effectively utilize the available hardware resources. - -### 4.7 Interconnect and Scaling Characteristics - -The interconnect technologies used for multi-GPU scaling also differ between the platforms. NVIDIA's NVLink technology has evolved through multiple generations with varying bandwidth and topology capabilities, while AMD's Infinity Fabric provides a different approach to inter-GPU communication. - -The fully connected 8-GPU topologies enabled by CDNA3's Infinity Fabric can provide advantages for certain communication patterns, particularly all-reduce and all-gather operations that are common in distributed machine learning workloads. However, the specific performance characteristics and optimal usage patterns can differ from NVIDIA's NVLink-based solutions. - -Understanding these interconnect differences is crucial for developing kernels that will be used in multi-GPU configurations, as the optimal communication strategies and data distribution patterns can vary significantly between platforms. - - -## 5. HIP Kernel Programming Best Practices for CDNA3 - -Developing high-performance HIP kernels for CDNA3 requires understanding the unique architectural characteristics and optimizing for the specific capabilities and constraints of the platform. This section provides concrete guidance for kernel developers to achieve optimal performance on MI300 hardware. - -### 5.1 Memory Access Pattern Optimization - -The CDNA3 memory hierarchy with its three-tier cache system and relaxed coherency model requires careful attention to memory access patterns. The doubled cache line size of 128 bytes means that kernels should be designed to maximize spatial locality within these larger cache lines. Sequential memory accesses that can fill entire cache lines will achieve better bandwidth utilization than scattered access patterns. - -The memory-side Infinity Cache provides unique optimization opportunities that don't exist in traditional GPU architectures. Kernels that can maintain good temporal locality across large working sets can benefit significantly from the 256MB cache capacity and 17.2 TB/s bandwidth. This is particularly valuable for iterative algorithms or kernels that process the same data multiple times with different operations. - -The relaxed coherency model requires explicit synchronization for cross-workgroup communication or when specific memory ordering is required. Kernel developers should use appropriate memory fence operations, atomic operations, or barrier synchronization to ensure correctness. The coherency boundary at the L2 cache level means that operations within a single XCD can rely on hardware coherency, while cross-XCD operations require explicit synchronization. - -### 5.2 Workgroup and Thread Block Organization - -The 4-ACE architecture within each XCD suggests that workgroup organization should consider the scheduling hierarchy to minimize conflicts and maximize throughput. Workgroups should be sized and distributed to enable effective utilization of all available ACEs while maintaining good load balance across the 38 active CUs per XCD. - -The shared instruction cache between pairs of CUs rewards kernel designs where adjacent CUs execute similar instruction streams. This suggests that workgroup assignment strategies should consider instruction locality alongside data locality. Kernels with divergent control flow should be structured to minimize the impact on instruction cache efficiency. - -The Local Data Share (LDS) remains at 64KB per CU, consistent with previous generations. Effective utilization of LDS for data sharing between threads within a workgroup can reduce memory traffic and improve performance. The enhanced L1 cache capacity and bandwidth can also reduce the pressure on LDS for certain access patterns. - -### 5.3 Matrix Core Utilization Strategies - -Achieving optimal performance with the Matrix Cores requires careful attention to data layout, operation scheduling, and precision selection. Matrix data should be organized in memory with appropriate alignment to match the hardware's internal data paths. The specific alignment requirements may vary depending on the data type and operation being performed. - -The integration of Matrix Cores within the Compute Units enables sophisticated kernel designs that can overlap matrix operations with vector operations and memory accesses. Kernels should be structured to take advantage of this parallelism by organizing computations to minimize dependencies and enable concurrent execution of different operation types. - -Data type selection can have dramatic performance implications. FP8 operations can achieve 16x the throughput of FP32 operations while using only 1/4 the memory bandwidth. TF32 provides a good balance between precision and performance for many applications. Kernel developers should carefully evaluate their precision requirements and select the most appropriate data type to maximize performance. - -### 5.4 Sparsity Exploitation Techniques - -The 2:4 structured sparsity support in the Matrix Cores can provide up to 2x performance improvements for compatible workloads. However, exploiting this capability requires that data be structured in the specific 2:4 pattern where at least two values in every group of four are zero. This may require preprocessing steps to identify and reorganize sparse data. - -Kernels that work with naturally sparse data, such as attention mechanisms in transformer models or certain types of convolution operations, should be evaluated for sparsity exploitation potential. The performance benefits can be substantial, but the overhead of data reorganization must be considered in the overall performance analysis. - -The sparse support is available for INT8, FP8, FP16, and BF16 data types, providing flexibility in precision selection while maintaining sparsity benefits. Kernel developers should consider whether lower precision formats can be used to enable both sparsity and precision optimizations simultaneously. - -### 5.5 Cross-Platform Compatibility Considerations - -When developing kernels that need to run on both AMD and NVIDIA platforms, careful attention to programming model differences is essential. While HIP provides CUDA-like syntax, there are semantic differences that can impact performance and correctness. Memory management approaches, synchronization semantics, and performance characteristics can all differ between platforms. - -The relaxed coherency model in CDNA3 may require additional synchronization compared to NVIDIA platforms with hardware-managed coherency. Kernels should be designed with explicit synchronization that ensures correctness on both platforms, even if some synchronization operations may be redundant on certain platforms. - -Data type support and performance characteristics can vary significantly between platforms. Kernels should be designed with fallback strategies for data types or features that may not be available on all target platforms. Performance tuning may need to be platform-specific to achieve optimal results on each architecture. - -### 5.6 Debugging and Profiling Strategies - -The ROCm software stack provides comprehensive debugging and profiling tools that can help identify performance bottlenecks and correctness issues. The open-source nature of the stack provides greater visibility into the underlying implementation compared to closed-source alternatives, enabling more sophisticated debugging strategies. - -Memory access pattern analysis is particularly important for CDNA3 kernels due to the complex cache hierarchy and relaxed coherency model. Profiling tools can help identify cache miss patterns, memory bandwidth utilization, and synchronization overhead that may not be apparent from source code analysis alone. - -The chiplet architecture can introduce performance variations that may not be present in monolithic designs. Profiling should consider the distribution of work across XCDs and the impact of cross-XCD communication on overall performance. Load balancing strategies may need to be adjusted based on profiling results to achieve optimal performance. - -### 5.7 Performance Tuning and Optimization Workflow - -Developing high-performance CDNA3 kernels requires an iterative optimization workflow that considers the unique architectural characteristics. Initial kernel development should focus on correctness and basic functionality, followed by systematic optimization of memory access patterns, compute utilization, and synchronization overhead. - -Memory hierarchy optimization should be prioritized early in the development process, as the three-tier cache system can have significant impact on performance. Cache-friendly data layouts and access patterns should be established before focusing on computational optimizations. - -Matrix Core utilization should be evaluated for any kernels that perform matrix or tensor operations. The substantial performance benefits available through optimal Matrix Core usage can justify significant restructuring of computational algorithms to take advantage of these capabilities. - -The iterative nature of performance optimization means that profiling and measurement should be integrated throughout the development process. Performance characteristics can change significantly as kernels are optimized, and continuous measurement ensures that optimizations are providing the expected benefits. - - -## 6. Technical Specifications and Performance Characteristics - -### 6.1 MI300 Series Specifications Comparison - -| Specification | MI300A APU | MI300X GPU | MI325X GPU | -|---------------|------------|------------|------------| -| **Architecture** | AMD CDNA 3 | AMD CDNA 3 | AMD CDNA 3 | -| **Accelerator Complex Dies (XCD)** | 6 | 8 | 8 | -| **Active Compute Units** | 228 | 304 | 304 | -| **Stream Processors** | 14,592 | 19,456 | 19,456 | -| **Matrix Cores** | 912 | 1,216 | 1,216 | -| **Max Engine Clock** | 2,100 MHz | 2,100 MHz | 2,100 MHz | -| **CPU Cores (Zen 4)** | 24 | N/A | N/A | -| **Memory Capacity** | 128GB HBM3 | 192GB HBM3 | 256GB HBM3E | -| **Memory Bandwidth** | 5.3 TB/s | 5.3 TB/s | 6.0 TB/s | -| **Memory Interface** | 1024-bit x 8 | 1024-bit x 8 | 1024-bit x 8 | -| **L1 Cache per CU** | 32KB | 32KB | 32KB | -| **L2 Cache per XCD** | 4MB | 4MB | 4MB | -| **Infinity Cache Total** | 256MB | 256MB | 256MB | - -### 6.2 Matrix Core Performance Characteristics - -| Data Type | Operations per Clock per CU | MI300X Peak Performance | MI325X Peak Performance | Generational Improvement | -|-----------|----------------------------|------------------------|------------------------|-------------------------| -| **FP64 Matrix** | 256 | 163.4 TFLOP/s | 163.4 TFLOP/s | 1.7x | -| **FP32 Matrix** | 256 | 163.4 TFLOP/s | 163.4 TFLOP/s | 1.7x | -| **TF32 Matrix** | 1,024 | 653.7 TFLOP/s | 653.7 TFLOP/s | New | -| **FP16 Matrix** | 2,048 | 1,307.4 TFLOP/s | 1,307.4 TFLOP/s | 3.4x | -| **BF16 Matrix** | 2,048 | 1,307.4 TFLOP/s | 1,307.4 TFLOP/s | 3.4x | -| **FP8 Matrix** | 4,096 | 2,614.9 TFLOP/s | 2,614.9 TFLOP/s | New | -| **INT8 Matrix** | 4,096 | 2,614.9 TOPs | 2,614.9 TOPs | 6.8x | -| **Sparse (2:4) Performance** | Up to 8,192 | Up to 5,229.8 TFLOP/s | Up to 5,229.8 TFLOP/s | 2x with sparsity | - -### 6.3 Memory Hierarchy Performance Characteristics - -| Memory Level | Capacity | Bandwidth | Latency Characteristics | Key Features | -|--------------|----------|-----------|------------------------|--------------| -| **L1 Vector Cache** | 32KB per CU | 2KB/clock per CU | Lowest latency | 128-byte cache lines, relaxed coherency | -| **L2 Cache** | 4MB per XCD | 2KB/clock per XCD | Low latency | 16-way associative, coherency boundary | -| **Infinity Cache** | 256MB total | 17.2 TB/s aggregate | Medium latency | Memory-side cache, no dirty data | -| **HBM3/HBM3E** | 192-256GB | 5.3-6.0 TB/s | Highest latency | 8 stacks, 128 channels total | - -## 7. Conclusion and Future Considerations - -The AMD CDNA3 architecture represents a fundamental shift in GPU design philosophy that introduces both opportunities and challenges for HIP kernel developers. The heterogeneous chiplet approach, revolutionary memory hierarchy with Infinity Cache, and advanced Matrix Core capabilities provide substantial performance potential for applications that can effectively exploit these architectural innovations. - -### 7.1 Key Takeaways for Kernel Developers - -The most critical insight for kernel developers is that CDNA3 requires a different optimization mindset compared to traditional GPU architectures. The memory-side Infinity Cache, relaxed coherency model, and chiplet-based organization create optimization opportunities that don't exist in monolithic designs, but they also require more sophisticated programming strategies to achieve optimal performance. - -The Matrix Core enhancements, particularly the support for TF32 and FP8 data types along with structured sparsity, provide dramatic performance improvements for AI workloads. However, achieving these benefits requires careful attention to data layout, precision selection, and sparsity structuring that may require significant algorithmic modifications. - -The three-tier cache hierarchy with its unique characteristics demands careful consideration of memory access patterns and explicit synchronization strategies. Kernel developers must understand the coherency boundaries and design their algorithms to work effectively within the relaxed coherency model while taking advantage of the substantial cache bandwidth and capacity. - -### 7.2 Architectural Advantages and Unique Capabilities - -The CDNA3 architecture provides several unique advantages that distinguish it from competing solutions. The memory-side Infinity Cache design enables caching of data types that would be uncacheable in traditional architectures, potentially providing performance benefits for complex workloads with mixed data types. The chiplet approach enables more flexible scaling and specialized optimization of different functional units. - -The unified memory capability in the MI300A APU represents a particularly compelling advantage for certain workload types, eliminating the overhead of host-device data transfers and enabling new programming paradigms that can exploit true CPU-GPU memory sharing. This capability is unique in the current market and provides opportunities for innovative algorithm designs. - -The open-source ROCm software stack provides transparency and customization opportunities that are not available with closed-source alternatives. This openness enables more sophisticated optimization strategies and provides developers with greater control over the software stack behavior. - -### 7.3 Challenges and Considerations - -The complexity of the CDNA3 architecture also introduces challenges that kernel developers must navigate. The relaxed coherency model requires more explicit synchronization management, which can increase development complexity and the potential for subtle correctness issues. The chiplet-based design creates hierarchical access patterns that must be understood and optimized for optimal performance. - -Cross-platform compatibility considerations become more complex when targeting both AMD and NVIDIA platforms, as the architectural differences require platform-specific optimization strategies. Kernel developers must balance the benefits of platform-specific optimizations against the complexity of maintaining multiple code paths. - -### 7.4 Future Evolution and Ecosystem Development - -The CDNA3 architecture represents a significant step forward in GPU design, but it also establishes a foundation for future evolution. The chiplet approach provides a scalable framework for adding new capabilities and increasing computational resources in future generations. The software ecosystem around ROCm and HIP continues to mature, providing increasingly sophisticated tools and libraries for kernel development. - -The industry trend toward lower precision data types and structured sparsity is well-supported by CDNA3's capabilities, positioning it well for future AI workload evolution. The architectural innovations in memory hierarchy and compute organization provide a foundation for continued performance improvements as manufacturing processes and packaging technologies advance. - -Understanding and effectively utilizing the CDNA3 architecture requires a comprehensive approach that considers the unique architectural characteristics, programming model differences, and optimization opportunities. Kernel developers who invest in understanding these aspects will be well-positioned to achieve exceptional performance on MI300 hardware and contribute to the continued evolution of the AMD GPU computing ecosystem. - -The architectural innovations in CDNA3 represent more than incremental improvements; they constitute a new paradigm for GPU design that will likely influence future developments across the industry. Kernel developers who master these concepts will be prepared not only for current MI300 optimization but also for the continued evolution of heterogeneous computing architectures. - ---- - -*This guide represents a comprehensive analysis of the AMD CDNA3 architecture based on official documentation and technical specifications. Kernel developers should consult the latest ROCm documentation and AMD developer resources for the most current programming guidelines and optimization recommendations.* - diff --git a/skills/pytorch-kernel-optimization/.federated.json b/skills/pytorch-kernel-optimization/.federated.json deleted file mode 100644 index 37cbf6f..0000000 --- a/skills/pytorch-kernel-optimization/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/pytorch-kernel-optimization", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/pytorch-kernel-optimization/SKILL.md b/skills/pytorch-kernel-optimization/SKILL.md deleted file mode 100644 index 980d5d9..0000000 --- a/skills/pytorch-kernel-optimization/SKILL.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -name: pytorch-kernel-optimization -description: This skill should be used when optimizing PyTorch models and kernels, including efficient tensor operations, torch.compile, custom autograd/CUDA/Triton extensions, mixed precision, memory and data pipeline tuning, model optimization techniques, CUDA graphs, and profiling. ---- - -# PyTorch Kernel Optimization - -## Purpose -- Equip PyTorch workflows with concrete optimization patterns from high-level APIs to custom kernels. -- Provide practical snippets for compilation, extensions, mixed precision, memory efficiency, and profiling. - -## When to Use -- Tuning PyTorch models for throughput/latency on GPU. -- Deciding between compiler-level optimizations and custom kernels (C++/CUDA/Triton). -- Profiling and addressing bottlenecks in compute or input pipelines. - -## How to Use -- **Efficient tensor ops**: favor contiguous layouts (`.contiguous()` when needed); use `channels_last` for convs; replace Python loops with vectorized ops; prefer in-place ops (`add_`, `mul_`, `out=`) when autograd-safe. -- **torch.compile**: wrap functions or models with `@torch.compile`; choose modes: - - `"default"` balanced, `"reduce-overhead"` for small batches/CUDA graphs, `"max-autotune"` for peak perf, `"max-autotune-no-cudagraphs"` when graphs undesirable. - - Use `fullgraph=True` for whole-graph capture; set `dynamic=False` when shapes are static. -- **Custom autograd**: implement `torch.autograd.Function` saving minimal tensors; recompute in backward when memory-bound (e.g., checkpointed attention); use custom backward formulas for fused ops (e.g., SiLU). -- **CUDA extensions**: build with `CUDAExtension` (`-O3`, `--use_fast_math`, `-arch=sm_80`); enforce input checks in C++ bindings; expose kernels via `PYBIND11_MODULE`. -- **Mixed precision**: train with `torch.cuda.amp` + `GradScaler`; mix dtypes per op if needed; leverage `bfloat16` when supported. -- **Memory optimization**: apply gradient checkpointing (`checkpoint`, `checkpoint_sequential`); use memory-efficient attention via `scaled_dot_product_attention`; consider activation offloading (CPU swap) when memory-bound. -- **Data loading**: configure `DataLoader` with `num_workers`, `pin_memory`, `prefetch_factor`, `persistent_workers`, `drop_last`; implement fast collate; prefetch to GPU with custom loader using streams and non-blocking copies. -- **Model optimization**: fuse Conv+BN (`fuse_conv_bn`), apply quantization (`quant.fuse_modules`, `prepare`, `convert`), prune weights via `torch.nn.utils.prune`; ensure evaluation mode during quantization calibration. -- **CUDA graphs**: capture steady workloads via `torch.cuda.CUDAGraph`; warm up then capture forward/backward; reuse static input/output buffers; note `torch.compile(mode=\"reduce-overhead\")` can leverage graphs automatically. -- **Profiling**: - - Use `torch.profiler.profile` with CPU/CUDA activities, schedules, and `tensorboard_trace_handler`; enable `record_shapes`, `profile_memory`, `with_stack`. - - Review `prof.key_averages().table(sort_by=\"cuda_time_total\")`; iterate on hotspots. - -## Validation Checklist -- Tensor layouts contiguous/channels_last as appropriate; Python loops eliminated; in-place ops safe for autograd. -- `torch.compile` mode chosen for workload; warmup complete; performance measured post-compilation. -- Custom ops (autograd or CUDA) validate device/contiguity; register usage and block sizes tuned for kernels. -- AMP scaling stable (no inf/nan); dtype choices align with numerical sensitivity. -- Data loader keeps GPU fed (no data starvation); streams overlap transfers where applicable. -- Profiling reviewed after each major change; bottlenecks addressed or noted. diff --git a/skills/rocprof-compute/.federated.json b/skills/rocprof-compute/.federated.json deleted file mode 100644 index 135a980..0000000 --- a/skills/rocprof-compute/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/rocprof-compute", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/rocprof-compute/SKILL.md b/skills/rocprof-compute/SKILL.md deleted file mode 100644 index 489a283..0000000 --- a/skills/rocprof-compute/SKILL.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -name: rocprof-compute -description: This skill should be used when profiling AMD GPU kernels with rocprof-compute to collect metrics, roofline data, and analyze bottlenecks for HIP kernels. ---- - -# rocprof-compute Profiling - -## Purpose -- Capture AMD GPU kernel metrics, roofline data, and traces with rocprof-compute. -- Analyze collected workloads to identify bottlenecks (compute vs memory, cache/TCP/TCC/SQ utilization). - -## When to Use -- Need kernel-level performance diagnostics on AMD GPUs (MI200/MI300 family). -- Comparing different kernel implementations or launch configs. -- Triaging stalls/low occupancy indicated by runtime benchmarks. - -## How to Use -- Activate project venv and install rocprof-compute Python deps (once per environment): - - `source .venv/bin/activate` - - `python -m pip install -r /opt/rocm-7.0.0/libexec/rocprofiler-compute/requirements.txt` -- Profile a workload: - - `source .venv/bin/activate && rocprof-compute profile -n --path --join-type kernel -b SQ -b TCP -b TCC -- ` - - Example (paged attention ragged test): - `rocprof-compute profile -n kernelgen --path rocprof_compute_profile --no-roof --join-type kernel -b SQ -b TCP -b TCC -- .venv/bin/python -O op_tests/test_pa_ragged.py -p Shomy -q none -c 128` - - Prefer `--join-type kernel` for comparing same kernel across grids; switch to `grid` if grid-sensitive. - - Add/adjust `-b` blocks to target specific hardware units; use `--list-metrics ` if unsure. - - Use `--no-roof` to skip roofline if only counters are needed; remove it to gather roofline data. -- Analyze collected data: - - `rocprof-compute analyze --path -b ` or `--list-stats` / `--list-metrics ` to discover ids. - - Example: `rocprof-compute analyze --path rocprof_compute_profile -b 2` - - For interactive review, use `--gui` (default port 8050 or `--random-port`) or `--tui`. -- Typical workflow checklist: - - Pick a short, reproducible workload and seed; pin `--name` + `--path` per experiment. - - Collect counters (SQ/TCP/TCC) and optionally roofline in one run; avoid mixing many kernels in a single profile when isolating a hotspot. - - After analyze, inspect top stats, occupancy, LDS/HBM bandwidth, and hotspot kernels; rerun with filtered `--kernel` or `--dispatch` if needed. - -## References -- Load `references/rocprof_compute_profile_help.txt` for full `rocprof-compute profile --help`. -- Load `references/rocprof_compute_analyze_help.txt` for full `rocprof-compute analyze --help`. diff --git a/skills/rocprof-compute/references/rocprof_compute_analyze_help.txt b/skills/rocprof-compute/references/rocprof_compute_analyze_help.txt deleted file mode 100644 index c748c29..0000000 --- a/skills/rocprof-compute/references/rocprof_compute_analyze_help.txt +++ /dev/null @@ -1,79 +0,0 @@ -rocprof-compute analyze --help -usage: -rocprof-compute analyze --path [analyze options] - ------------------------------------------------------------------------------------ -Examples: - rocprof-compute analyze -p workloads/vcopy/mi200/ --list-metrics gfx90a - rocprof-compute analyze -p workloads/mixbench/mi200/ --dispatch 12 34 --decimal 3 - rocprof-compute analyze -p workloads/mixbench/mi200/ --gui ------------------------------------------------------------------------------------ - - -Help: - -h, --help show this help message and exit - -General Options: - -v, --version show program's version number and exit - -V, --verbose Increase output verbosity (use multiple times for higher levels) - -q, --quiet Reduce output and run quietly. - -s, --specs Print system specs and exit. - -Analyze Options: - -p [ ...], --path [ ...] Specify the raw data root dirs or desired results directory. - --list-stats List all detected kernels and kernel dispatches. - --list-metrics List all available metrics for analysis on specified arch: - gfx908 - gfx90a - gfx940 - gfx941 - gfx942 - gfx950 - -k [ ...], --kernel [ ...] Specify kernel id(s) from --list-stats for filtering. - -d [ ...], --dispatch [ ...] Specify dispatch id(s) for filtering. - -b [ ...], --block [ ...] Specify metric id(s) from --list-metrics for filtering. - --gpu-id [ ...] Specify GPU id(s) for filtering. - --spatial-multiplexing Mode of spatial multiplexing. - -o , --output Specify an output file to save analysis results. - --gui [GUI] Activate a GUI to interate with rocprofiler-compute metrics. - Optionally, specify port to launch application (DEFAULT: 8050) - --tui Activate a Textual User Interface (TUI) to interact with rocprofiler-compute metrics. - -R [ ...], --roofline-data-type [ ...] - Choose datatypes to view roofline PDFs for: (DEFAULT: FP32) - FP4 - FP6 - FP8 - FP16 - BF16 - FP32 - FP64 - I8 - I32 - I64 - - --pc-sampling-sorting-type Set the sorting type of pc sampling: offset or count (DEFAULT: offset). - -Advanced Options: - --random-port Randomly generate a port to launch GUI application. - Registered Ports range inclusive (1024-49151). - --max-stat-num Specify the maximum number of stats shown in "Top Stats" tables (DEFAULT: 10) - -n , --normal-unit Specify the normalization unit: (DEFAULT: per_kernel) - per_wave - per_cycle - per_second - per_kernel - -t , --time-unit Specify display time unit in kernel top stats: (DEFAULT: ns) - s - ms - us - ns - --decimal Specify desired decimal precision of analysis results. (DEFAULT: 2) - --config-dir Specify the directory of customized configs. - --save-dfs Specify the dirctory to save analysis dataframe csv files. - --cols [ ...] Specify column indices to display. - -g Debug single metric. - --dependency List the installation dependency. - --kernel-verbose Specify Kernel Name verbose level 1-5. Lower the level, shorter the kernel name. (DEFAULT: 5) (DISABLE: 5) - --specs-correction Specify the specs to correct. e.g. --specs-correction='specname1:specvalue1,specname2:specvalue2' - --list-nodes Multi-node option: list all node names. - --nodes [ ...] Multi-node option: filter with node names. Enable it without node names means ALL. diff --git a/skills/rocprof-compute/references/rocprof_compute_profile_help.txt b/skills/rocprof-compute/references/rocprof_compute_profile_help.txt deleted file mode 100644 index f7e55b5..0000000 --- a/skills/rocprof-compute/references/rocprof_compute_profile_help.txt +++ /dev/null @@ -1,93 +0,0 @@ -rocprof-compute profile --help -usage: - -rocprof-compute profile --name [profile options] [roofline options] -- - ---------------------------------------------------------------------------------- -Examples: - rocprof-compute profile -n vcopy_all -- ./vcopy -n 1048576 -b 256 - rocprof-compute profile -n vcopy_SPI_TCC -b SQ TCC -- ./vcopy -n 1048576 -b 256 - rocprof-compute profile -n vcopy_kernel -k vecCopy -- ./vcopy -n 1048576 -b 256 - rocprof-compute profile -n vcopy_disp -d 0 -- ./vcopy -n 1048576 -b 256 - rocprof-compute profile -n vcopy_roof --roof-only -- ./vcopy -n 1048576 -b 256 ---------------------------------------------------------------------------------- - - -Help: - -h, --help show this help message and exit - -General Options: - -v, --version show program's version number and exit - -V, --verbose Increase output verbosity (use multiple times for higher levels) - -q, --quiet Reduce output and run quietly. - -s, --specs Print system specs and exit. - -Profile Options: - -n , --name Assign a name to workload. - -p , --path Specify path to save workload. - (DEFAULT: /root/aiter/workloads/) - --subpath Specify the type of subpath to save workload: node_name, gpu_model. - --hip-trace HIP trace, execturion trace for the entire application at the HIP level. - -k [ ...], --kernel [ ...] Kernel filtering. - -d [ ...], --dispatch [ ...] Dispatch ID filtering. - -b [ ...], --block [ ...] Specify metric id(s) from --list-metrics for filtering (e.g. 10, 4, 4.3). - Can provide multiple space separated arguments. - Can also accept Hardware blocks. - Hardware block filtering (to be deprecated soon): - SQ - SQC - TA - TD - TCP - TCC - SPI - CPC - CPF - --list-metrics [] List all available metrics for analysis on specified arch: - gfx908 - gfx90a - gfx940 - gfx941 - gfx942 - gfx950 - --config-dir Specify the directory of customized report section configs. - --join-type Choose how to join rocprof runs: (DEFAULT: grid) - kernel (i.e. By unique kernel name dispatches) - grid (i.e. By unique kernel name + grid size dispatches) - --no-roof Profile without collecting roofline data. - -- [ ...] Provide command for profiling after double dash. - --spatial-multiplexing [ ...] Provide Node ID and GPU number per node. - --format-rocprof-output Set the format of output file of rocprof. - --pc-sampling-method Set the method of pc sampling, stochastic or host_trap. Support stochastic only >= MI300 - --pc-sampling-interval Set the interval of pc sampling. - For stochastic sampling, the interval is in cycles. - For host_trap sampling, the interval is in microsecond (DEFAULT: 1048576). - --rocprofiler-sdk-library-path ROCPROFILER_SDK_LIBRARY_PATH - Set the path to rocprofiler SDK library. - -Standalone Roofline Options: - --roof-only Profile roofline data only. - --sort Overlay top kernels or top dispatches: (DEFAULT: kernels) - kernels - dispatches - -m [ ...], --mem-level [ ...] Filter by memory level: (DEFAULT: ALL) - HBM - L2 - vL1D - LDS - --device Target GPU device ID. (DEFAULT: ALL) - --kernel-names Include kernel names in roofline plot. - -R [ ...], --roofline-data-type [ ...] - Choose datatypes to view roofline PDFs for: (DEFAULT: FP32) - FP4 - FP6 - FP8 - FP16 - BF16 - FP32 - FP64 - I8 - I32 - I64 - - diff --git a/skills/triton-hip-reference-kernel-search/.federated.json b/skills/triton-hip-reference-kernel-search/.federated.json deleted file mode 100644 index f80b618..0000000 --- a/skills/triton-hip-reference-kernel-search/.federated.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "source": "amd-agi-apex", - "repo": "AMD-AGI/Apex", - "ref": "main", - "commit": "6ee40e7e6f2d03902cce503cddf873f2ab75f05c", - "path": "tools/skills/triton-hip-reference-kernel-search", - "license": "MIT", - "imported_at": "2026-05-28T21:37:40Z" -} diff --git a/skills/triton-hip-reference-kernel-search/SKILL.md b/skills/triton-hip-reference-kernel-search/SKILL.md deleted file mode 100644 index 2a4337b..0000000 --- a/skills/triton-hip-reference-kernel-search/SKILL.md +++ /dev/null @@ -1,17 +0,0 @@ ---- -name: triton-hip-reference-kernel-search -description: Search and adapt Triton/HIP kernel patterns from a corpus to optimize AMD GPUs; use to find similar ops and reuse tiling/occupancy strategies. ---- - -# AMD Kernel Patterns - -- Use when you need real kernel templates (attention, layernorm, matmul, activations) to adapt for AMD/ROCm. -- Do not load the entire corpus; grep targeted snippets instead. - -## How to use -- Search `references/train_crawl.json` with ripgrep for relevant ops; keep context tight. -- Extract only needed code and descriptions; rewrite for wave64 occupancy, LDS tiling, vectorized/coalesced access, and bank-conflict avoidance. -- Cite source file and lines; pair with reflection prompts to validate correctness and performance. - -## References -- `references/SEARCH.md`: Grep commands and tips for slicing snippets efficiently. diff --git a/skills/triton-hip-reference-kernel-search/references/SEARCH.md b/skills/triton-hip-reference-kernel-search/references/SEARCH.md deleted file mode 100644 index 1105da6..0000000 --- a/skills/triton-hip-reference-kernel-search/references/SEARCH.md +++ /dev/null @@ -1,12 +0,0 @@ -Search the kernel corpus for reusable Triton/HIP patterns without loading the full file. - -- Corpus file: `skills/amd-kernel-patterns/references/train_crawl.json` (~24k lines, copied locally). -- Quick grep examples: - - `rg -n "attention|flash" skills/amd-kernel-patterns/references/train_crawl.json` - - `rg -n "layer[_-]?norm" ...` - - `rg -n "activation" ...` - - `rg -n "triton" ...` - - `rg -n "hip" ...` -- After finding a hit, slice a small window with `sed -n 'start,endp'` to extract code + descriptions. -- Adapt to AMD: wave64 occupancy, LDS tiling, vectorized loads/stores, avoid bank conflicts, coalesced global access. -- Cite file and line numbers when reusing snippets; trim to only what you need. diff --git a/skills/triton-hip-reference-kernel-search/references/train_crawl.json b/skills/triton-hip-reference-kernel-search/references/train_crawl.json deleted file mode 100644 index 087c653..0000000 --- a/skills/triton-hip-reference-kernel-search/references/train_crawl.json +++ /dev/null @@ -1,24146 +0,0 @@ -[ - { - "code": "import triton\nimport triton.language as tl\n\n# triton kernel\n@triton.jit\ndef kernel(X, stride_xm,\n Z, stride_zn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n off_m = tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, BLOCK_N)\n Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1\n Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn\n tl.store(Zs, tl.load(Xs))\n\n\nret = triton.compile(kernel, signature=\"*fp32,i32,*fp32,i32\", constants={\"BLOCK_M\": 64, \"BLOCK_N\": 64}, output=\"ttgir\")\n\nprint(ret)\n", - "description_1": "Use triton language to define a kernel that copies a 2D block of data from one location to another. The kernel takes in four parameters: X (the source tensor), stride_xm (the stride for the X tensor), Z (the destination tensor), and stride_zn (the stride for the Z tensor). It also utilizes two constexpr parameters BLOCK_M and BLOCK_N to determine the size of the 2D block to copy. The kernel computes offsets using tl.arange and performs element-wise loading from the source tensor and storing into the destination tensor.", - "description_2": "Use triton language to define and compile a kernel that performs a 2D block data copy between tensors, using specific strides and block size parameters.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n UNROLL: tl.constexpr = 4\n for i in range(0, M, BLOCK_SIZE_M * UNROLL):\n for j in range(UNROLL):\n rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n # allocate output\n out = torch.empty_like(a)\n # reshape input data into 2D tensor\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = weight.shape[0]\n # allocate output\n da = torch.empty_like(dout)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n if N > 10240:\n BLOCK_SIZE_N = 128\n BLOCK_SIZE_M = 32\n num_warps = 4\n else:\n # maximize occupancy for small N\n BLOCK_SIZE_N = 16\n BLOCK_SIZE_M = 16\n num_warps = 8\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=BLOCK_SIZE_M,\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n num_warps=num_warps\n )\n return (da, None, dweight, dbias, None)\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n", - "description_1": "Use triton language to implement a layer normalization operation with three kernels: one for the forward pass, one for the backward pass computing gradients with respect to the input, and one for computing gradients with respect to the weights and biases. The forward kernel takes 9 parameters: output tensor, input tensor, weight, bias, mean, rstd, stride, number of elements, and epsilon. The backward kernel for input gradients takes 10 parameters: gradient of input, gradient of output, input tensor, weight, mean, rstd, stride, number of rows, number of columns, and epsilon. The backward kernel for weight and bias gradients takes 9 parameters: input tensor, gradient of output, mean, variance, gradient of weight, gradient of bias, number of rows, number of columns, and block sizes for rows and columns.", - "description_2": "Use triton language to create a layer normalization function with forward and backward passes, optimizing for GPU execution by using block sizes and warps.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n BLOCK_SIZE: tl.constexpr # Number of elements each program should process\n # NOTE: `constexpr` so it can be used as a shape value\n):\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extra elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to implement a kernel function 'add_kernel' that takes five parameters: x_ptr (pointer to the first input vector), y_ptr (pointer to the second input vector), output_ptr (pointer to the output vector), n_elements (size of the vector), and BLOCK_SIZE (number of elements each program should process). The kernel computes the element-wise sum of two input vectors and stores the result in the output vector, using a 1D launch grid and masking to handle out-of-bounds accesses.", - "description_2": "Use triton language to create a kernel that adds two vectors element-wise, handling out-of-bounds with masking.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# triton kernel\n@triton.jit\ndef kernel(X, stride_xm,\n Z, stride_zn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n off_m = tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, BLOCK_N)\n Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1\n Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn\n tl.store(Zs, tl.load(Xs))\n\n\nret = triton.compile(kernel, signature=\"*fp32,i32,*fp32,i32\", constants={\"BLOCK_M\": 64, \"BLOCK_N\": 64}, output=\"ttgir\")\n\nprint(ret)\n", - "description_1": "Use triton language to define a kernel that loads data from a source matrix X and stores it into a destination matrix Z. The kernel has 5 parameters: 1) X (pointer to the source matrix); 2) stride_xm (int, stride for the m dimension of X); 3) Z (pointer to the destination matrix); 4) stride_zn (int, stride for the n dimension of Z); 5) BLOCK_M and BLOCK_N (constexpr int, dimensions of each block). The data is loaded from X using calculated offsets and stored into Z at corresponding offsets, where block size is controlled by BLOCK_M and BLOCK_N. The kernel is compiled with given signatures and constants and the output is printed.", - "description_2": "Use triton language to create a kernel for copying blocks of data from matrix X to matrix Z, using customizable block sizes and strides.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n tl.store(DA + cols, da, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n UNROLL: tl.constexpr = 4\n for i in range(0, M, BLOCK_SIZE_M * UNROLL):\n for j in range(UNROLL):\n rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n out = torch.empty_like(a)\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n N = weight.shape[0]\n da = torch.empty_like(dout)\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n if N > 10240:\n BLOCK_SIZE_N = 128\n BLOCK_SIZE_M = 32\n num_warps = 4\n else:\n BLOCK_SIZE_N = 16\n BLOCK_SIZE_M = 16\n num_warps = 8\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=BLOCK_SIZE_M,\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n num_warps=num_warps\n )\n return (da, None, dweight, dbias, None)\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n", - "description_1": "Use triton language to implement the layer normalization forward and backward passes. The forward kernel '_layer_norm_fwd_fused' computes the output of the layer normalization by normalizing the input tensor, applying the weight and bias, and storing the results. It requires 9 parameters: Out (output tensor), A (input tensor), Weight (weight tensor), Bias (bias tensor), Mean (mean tensor), Rstd (reciprocal standard deviation tensor), stride, N (number of columns), and eps (epsilon for numerical stability). The backward kernels '_layer_norm_bwd_dx_fused' and '_layer_norm_bwd_dwdb' compute the gradients with respect to the input, weights, and biases. '_layer_norm_bwd_dx_fused' requires 10 parameters: _DA (gradient tensor for input), _DOut (gradient tensor for output), _A (input tensor), Weight (weight tensor), Mean (mean tensor), Rstd (reciprocal standard deviation tensor), stride, NumRows, NumCols, and eps. '_layer_norm_bwd_dwdb' requires 8 parameters: A (input tensor), DOut (gradient tensor for output), Mean (mean tensor), Var (variance tensor), DW (gradient tensor for weights), DB (gradient tensor for biases), M (number of rows), and N (number of columns).", - "description_2": "Use triton language to perform layer normalization by implementing forward and backward kernels. The forward kernel normalizes the input tensor and applies weight and bias transformations, while the backward kernels calculate gradients for the input, weights, and biases.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n BLOCK_SIZE: tl.constexpr # Number of elements each program should process\n # NOTE: `constexpr` so it can be used as a shape value\n):\n \"\"\"\n This is a test kernel. Testing some stuff here\n New line\n would this look good?\n \"\"\"\n # There are multiple 'program's processing different data. We identify which program\n # we are here\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n # This program will process inputs that are offset from the initial data.\n # for instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extra elements in case the input is not a\n # multiple of the block size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM\n tl.store(output_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to define a kernel function named `add_kernel` that takes five parameters: x_ptr, y_ptr, output_ptr, n_elements, and BLOCK_SIZE. This kernel performs element-wise addition of two input vectors (pointed by x_ptr and y_ptr) and stores the result in the output vector (pointed by output_ptr). The computation is performed in blocks of size BLOCK_SIZE, and it includes a mask to prevent out-of-bounds memory access if the number of elements (n_elements) is not a multiple of BLOCK_SIZE.", - "description_2": "Use triton language to create a kernel for element-wise addition of vectors using pointers and block processing.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel for element-wise addition\n@triton.jit\ndef _add(x_ptr, y_ptr, output_ptr, n_elements,\n BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Test function for element-wise addition\ndef test_elementwise(N):\n torch.manual_seed(0)\n z = torch.empty((N, ), dtype=torch.float16, device='cuda')\n x = torch.randn_like(z)\n y = torch.randn_like(z)\n grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )\n fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)\n", - "description_1": "Use triton language to create a kernel for element-wise addition of two vectors, where the kernel reads elements from two input pointers, adds them, and writes the results to an output pointer. The kernel is executed using a grid of blocks, with each block processing a subset of elements defined by BLOCK_SIZE. The test function initializes input data, sets up the execution grid, and benchmarks the kernel performance.", - "description_2": "Use triton language to implement and benchmark an element-wise vector addition kernel with configurable block size and input data on GPU.", - "difficulty": 2 - }, - { - "code": "import torch\nfrom torch.testing import assert_close\n\nimport triton\nimport triton.language as tl\n\ndef get_tensor(shape, data_type):\n x = torch.arange(0, shape[0], dtype=torch.float16 if data_type == \"float16\" else torch.int8, device='cuda')\n return x\n\ndef printf(data_type):\n @triton.jit\n def kernel(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.printf(\"\", x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n shape = (128, )\n x = get_tensor(shape, data_type)\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n kernel[(1,)](x, y, BLOCK=shape[0])\n assert_close(y, x)\n\nprintf(\"float16\")\nprintf(\"int8\")\n", - "description_1": "Use triton language to implement a kernel with 3 parameters: X (input tensor), Y (output tensor), BLOCK (block size). The kernel loads data from X, prints it, and stores it in Y. Call this kernel with tensors of shape (128,) and check if the output matches the input.", - "description_2": "Use triton language to create a kernel that loads from, prints, and stores tensor data, with input and output verification.", - "difficulty": 2 - }, - { - "code": "import numpy as np\nimport torch\nimport triton\nimport triton.language as tl\nfrom numpy.random import RandomState\n\n# Description of the triton kernel with @triton.jit decorators\n\n# Kernel 1: empty kernel\n@triton.jit\ndef empty_kernel(X, SIZE: tl.constexpr):\n # Parameters:\n # X: tensor for input data\n # SIZE: compile-time constant for the size of the data\n\n pass\n\n# Kernel 2: Unary operation kernel\n@triton.jit\ndef unary_op_kernel(Z, X, SIZE: tl.constexpr):\n # Parameters:\n # Z: tensor for storing the result\n # X: tensor for input data\n # SIZE: compile-time constant for the size of the data\n\n off = tl.arange(0, SIZE)\n x = tl.load(X + off)\n z = GENERATE_TEST_HERE # Replace this with actual expression\n tl.store(Z + off, z)\n\n# Kernel 3: Binary operation kernel\n@triton.jit\ndef binary_op_kernel(Z, X, Y, SIZE: tl.constexpr):\n # Parameters:\n # Z: tensor for storing the result\n # X: tensor for first input data\n # Y: tensor for second input data\n # SIZE: compile-time constant for the size of the data\n\n off = tl.arange(0, SIZE)\n x = tl.load(X + off)\n y = tl.load(Y + off)\n z = GENERATE_TEST_HERE # Replace this with actual expression\n tl.store(Z + off, z)\n\n# Example test function using binary_op_kernel\ndef test_bin_op(dtype_x, dtype_y, op, device='cuda'):\n expr = f' x {op} y'\n x = numpy_random(128, dtype_str=dtype_x)\n y = numpy_random(128, dtype_str=dtype_y)\n z_ref = eval(expr)\n x_tri = torch.tensor(x, device=device)\n y_tri = torch.tensor(y, device=device)\n z_tri = torch.empty_like(x_tri)\n binary_op_kernel[(1,)](z_tri, x_tri, y_tri, SIZE=128)\n np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01)\n", - "description_1": "Use triton language to define multiple kernels using @triton.jit, including an empty kernel and kernels for unary and binary operations. Each kernel should perform operations on input tensors and store results, requiring compile-time constants for size. Example functions should demonstrate binary operations using these kernels, utilizing numpy for reference results and PyTorch for tensor handling.", - "description_2": "Implement triton kernels for unary and binary tensor operations; ensure kernel execution via test functions that compare triton results with numpy references.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport numpy as np\nimport scipy.stats\n\nBLOCK = 1024\n\n# Kernel for generating random uint32\n@triton.jit\ndef kernel_randint(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.randint(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for generating uniform random numbers\n@triton.jit\ndef kernel_rand(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.rand(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for generating normal random numbers\n@triton.jit\ndef kernel_randn(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.randn(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel to test rand limits\n@triton.jit\ndef kernel_rand_limits(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = tl.random.uint32_to_uniform_float(x)\n tl.store(output + idx, y)\n\n# Function to test random uint32 generation\ndef test_randint(size, seed, device='cuda'):\n size = list(map(int, size.split(',')))\n x = torch.empty(size, dtype=torch.int32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK),)\n kernel_randint[grid](x, N, seed)\n out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()\n gen = CustomPhilox4x(seed, config=PHILOX_32)\n out_ref = [gen.random_raw()[0] for _ in out_tri]\n assert out_tri == out_ref\n\n# Function to test uniform PRNG\ndef test_rand(size, seed, device='cuda'):\n x = torch.empty(size, dtype=torch.float32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK),)\n kernel_rand[grid](x, N, seed)\n assert all((x >= 0) & (x <= 1))\n assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01\n\n# Function to test normal PRNG\ndef test_randn(size, seed, device='cuda'):\n x = torch.empty(size, dtype=torch.float32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK),)\n kernel_randn[grid](x, N, seed)\n assert abs(x.mean()) < 1e-2\n assert abs(x.std() - 1) < 1e-2\n\n# Function to test rand limits\ndef test_rand_limits():\n min_max_int32 = torch.tensor([\n torch.iinfo(torch.int32).min,\n torch.iinfo(torch.int32).max,\n ], dtype=torch.int32, device='cuda')\n output = torch.empty(2, dtype=torch.float32, device='cuda')\n kernel_rand_limits[(1,)](min_max_int32, output, 2)\n assert output[0] == output[1]\n assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0\n", - "description_1": "Use triton language to implement kernels for generating random numbers. The kernel_randint function takes three parameters: X (output tensor), N (number of elements), and seed (random seed). It generates random uint32 numbers. The kernel_rand function also takes three parameters: X (output tensor), N (number of elements), and seed (random seed). It generates uniform random numbers. The kernel_randn function takes the same parameters and generates normal random numbers. The kernel_rand_limits function takes three parameters: input (input tensor), output (output tensor), and n (number of elements). It tests the limits of random number generation.", - "description_2": "Use triton language to create kernels for generating random uint32, uniform, and normal numbers, and to test the limits of random number generation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel function that increments an integer and stores it\n@triton.jit\ndef function_1(i):\n i = i + 1\n i = function_2(i)\n return i\n\n# Triton kernel function that increments an integer\n@triton.jit\ndef function_2(i):\n i = i + 1\n return i\n\n# Triton kernel that uses function_1 and stores the result\n@triton.jit\ndef kernel(X, i, BLOCK: tl.constexpr):\n i = i + 1\n i = function_1(i)\n tl.store(X, i)\n\n# Triton kernel with no specialization\n@triton.jit(do_not_specialize=[\"i\"])\ndef kernel_nospec(X, i, BLOCK: tl.constexpr):\n i = i + 1\n i = function_1(i)\n tl.store(X, i)\n\n# Test function to check kernel reuse\ndef test_reuse():\n counter = 0\n\n def inc_counter(*args, **kwargs):\n nonlocal counter\n counter += 1\n triton.runtime.jit.JITFunction.cache_hook = inc_counter\n reset_tmp_dir()\n x = torch.empty(1, dtype=torch.int32, device='cuda')\n for i in range(10):\n kernel[(1,)](x, 1, BLOCK=1024)\n assert counter == 1\n\n# Test function to check specialization\n@pytest.mark.parametrize('mode', ['enable', 'disable'])\ndef test_specialize(mode):\n counter = 0\n\n def inc_counter(*args, **kwargs):\n nonlocal counter\n counter += 1\n triton.runtime.jit.JITFunction.cache_hook = inc_counter\n reset_tmp_dir()\n x = torch.empty(1, dtype=torch.int32, device='cuda')\n function = {'enable': kernel, 'disable': kernel_nospec}[mode]\n target = {'enable': 3, 'disable': 1}[mode]\n for i in [1, 2, 4, 8, 16, 32]:\n function[(1,)](x, i, BLOCK=512)\n assert counter == target\n", - "description_1": "Use triton language to define a series of kernels: function_1 and function_2 increment an integer; kernel uses function_1 to increment and store a value; kernel_nospec is a non-specialized version of kernel. Test functions ensure kernel reuse and specialization behavior.", - "description_2": "Use triton language to create kernels for integer increment and storage, and test their reuse and specialization.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' with four parameters: X, Y, Z, and BLOCK_SIZE. The kernel performs operations on input tensors X, Y, and Z with a specified block size. A separate function 'call_example_kernel' is used to invoke this kernel with the given parameters.", - "description_2": "Use triton language to define a kernel and a function to invoke it, processing input tensors with a specified block size.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language\n\n# Kernel to compute the absolute value of a tensor\n@triton.jit\ndef abs(x):\n return triton.language.where(x >= 0, x, -x)\n\n# Kernel to compute the ceiling division of two tensors\n@triton.jit\ndef cdiv(x, div):\n return (x + div - 1) // div\n\n# Kernel to compute the element-wise minimum of two tensors\n@triton.jit\ndef minimum(x, y):\n return triton.language.where(x < y, x, y)\n\n# Kernel to compute the element-wise maximum of two tensors\n@triton.jit\ndef maximum(x, y):\n return triton.language.where(x > y, x, y)\n\n# Kernel to compute the sigmoid function of a tensor\n@triton.jit\ndef sigmoid(x):\n return 1 / (1 + triton.language.exp(-x))\n\n# Kernel to compute the softmax function of a tensor\n@triton.jit\ndef softmax(x, ieee_rounding=False):\n z = x - triton.language.max(x, 0)\n num = triton.language.exp(z)\n den = triton.language.sum(num, 0)\n return triton.language.fdiv(num, den, ieee_rounding)\n\n# Kernel to flatten a tensor\n@triton.jit\ndef ravel(x):\n return triton.language.view(x, [x.numel])\n\n# Kernel to transform indices of a matrix\n@triton.jit\ndef swizzle2d(i, j, size_i, size_j, size_g):\n ij = i * size_j + j\n size_gj = size_g * size_j\n group_id = ij // size_gj\n off_i = group_id * size_g\n size_g = minimum(size_i - off_i, size_g)\n new_i = off_i + (ij % size_g)\n new_j = (ij % size_gj) // size_g\n return new_i, new_j\n\n# Kernel to create a tensor filled with zeros\n@triton.jit\ndef zeros(shape, dtype):\n return triton.language.full(shape, 0, dtype)\n\n# Kernel to create a tensor filled with zeros like another tensor\n@triton.jit\ndef zeros_like(input):\n return zeros(input.shape, input.dtype)\n", - "description_1": "Use triton language to define kernels for computing absolute values, ceiling division, element-wise minimum and maximum, sigmoid, softmax, flattening a tensor, transforming matrix indices, and creating zero-filled tensors.", - "description_2": "Use triton language to implement mathematical operations and tensor manipulations such as abs, cdiv, minimum, maximum, sigmoid, softmax, ravel, swizzle2d, zeros, and zeros_like.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\nPHILOX_KEY_A: tl.constexpr = 0x9E3779B9\nPHILOX_KEY_B: tl.constexpr = 0xBB67AE85\nPHILOX_ROUND_A: tl.constexpr = 0xD2511F53\nPHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57\nN_ROUNDS_DEFAULT = 10 # Default number of rounds for philox\n\n@triton.jit\ndef philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).\n \"\"\"\n for _ in range(n_rounds):\n A = PHILOX_ROUND_A\n B = PHILOX_ROUND_B\n _c0, _c2 = c0, c2\n c0 = tl.umulhi(B, _c2) ^ c1 ^ k0\n c2 = tl.umulhi(A, _c0) ^ c3 ^ k1\n c1 = B * _c2\n c3 = A * _c0\n k0 = k0 + PHILOX_KEY_A\n k1 = k1 + PHILOX_KEY_B\n return c0, c1, c2, c3\n\n@triton.jit\ndef philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n seed = seed.to(tl.uint64)\n seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)\n seed_lo = (seed & 0xffffffff).to(tl.uint32)\n return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)\n\n@triton.jit\ndef randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Given a `seed` scalar and an `offset` block, returns a single block of random `int32`.\n \"\"\"\n ret, _, _, _ = randint4x(seed, offset, n_rounds)\n return ret\n\n@triton.jit\ndef randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Given a `seed` scalar and an `offset` block, returns four blocks of random `int32`.\n \"\"\"\n _0 = offset * 0\n return philox(seed, offset, _0, _0, _0, n_rounds)\n\n@triton.jit\ndef uint32_to_uniform_float(x):\n \"\"\"\n Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).\n \"\"\"\n x = x.to(tl.int32, bitcast=True)\n scale = 4.6566127342e-10\n x = tl.where(x < 0, -x - 1, x)\n return x * scale\n\n@triton.jit\ndef rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Given a `seed` scalar and an `offset` block, returns a block of random `float32` in U(0, 1).\n \"\"\"\n offset = offset.to(tl.uint32, bitcast=True)\n source = randint(seed, offset, n_rounds)\n return uint32_to_uniform_float(source)\n\n@triton.jit\ndef rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Given a `seed` scalar and an `offsets` block, returns 4 blocks of random `float32` in U(0, 1).\n \"\"\"\n offsets = offsets.to(tl.uint32, bitcast=True)\n i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)\n u1 = uint32_to_uniform_float(i1)\n u2 = uint32_to_uniform_float(i2)\n u3 = uint32_to_uniform_float(i3)\n u4 = uint32_to_uniform_float(i4)\n return u1, u2, u3, u4\n\n@triton.jit\ndef pair_uniform_to_normal(u1, u2):\n \"\"\"Box-Muller transform\"\"\"\n u1 = tl.maximum(1.0e-7, u1)\n th = 6.283185307179586 * u2\n r = tl.sqrt(-2.0 * tl.log(u1))\n return r * tl.cos(th), r * tl.sin(th)\n\n@triton.jit\ndef randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Given a `seed` scalar and an `offset` block, returns a block of random `float32` in N(0, 1).\n \"\"\"\n i1, i2, _, _ = randint4x(seed, offset, n_rounds)\n u1 = uint32_to_uniform_float(i1)\n u2 = uint32_to_uniform_float(i2)\n n1, _ = pair_uniform_to_normal(u1, u2)\n return n1\n\n@triton.jit\ndef randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):\n \"\"\"\n Given a `seed` scalar and an `offset` block, returns 4 blocks of random `float32` in N(0, 1).\n \"\"\"\n u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)\n n1, n2 = pair_uniform_to_normal(u1, u2)\n n3, n4 = pair_uniform_to_normal(u3, u4)\n return n1, n2, n3, n4\n", - "description_1": "Use triton language to implement several random number generation functions using the Philox algorithm. `philox_impl` and `philox` functions generate pseudo-random numbers given initial states and seeds. `randint` and `randint4x` generate one and four blocks of random integers, respectively. `uint32_to_uniform_float` converts random integers to floats in the range [0, 1). `rand` and `rand4x` return blocks of random floats in U(0, 1), while `randn` and `randn4x` return random floats in N(0, 1) using the Box-Muller transform.", - "description_2": "Use triton language to create random number generators utilizing the Philox algorithm for generating both integer and floating-point random values. Implement conversions to uniform and normal distributions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# ********************************************************\n# --------------------------------------------------------\n# Sparse = Dense x Dense (SDD)\n# --------------------------------------------------------\n# ********************************************************\n\n@triton.jit\ndef _sdd_kernel(\n A, B, C,\n stride_za, stride_ha, stride_ma, stride_ak,\n stride_zb, stride_hb, stride_bk, stride_nb,\n stride_zc, stride_hc, stride_mc, stride_nc,\n K, grid_offset, lut,\n TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,\n BLOCK: tl.constexpr, EVEN_K: tl.constexpr\n):\n block_id = tl.program_id(1) + grid_offset\n lut += block_id * 3\n off_z = tl.program_id(2)\n off_h = tl.load(lut + 0)\n\n start_am = tl.load(lut + 1)\n offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)\n offs_ak = tl.arange(0, TILE_K)\n a_ptrs = A \\\n + off_z * stride_za \\\n + off_h * stride_ha \\\n + offs_am[:, None] * stride_ma \\\n + offs_ak[None, :] * stride_ak\n\n start_bn = tl.load(lut + 2)\n offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)\n offs_bk = tl.arange(0, TILE_K)\n b_ptrs = B \\\n + off_z * stride_zb \\\n + off_h * stride_hb \\\n + offs_bn[None, :] * stride_nb \\\n + offs_bk[:, None] * stride_bk\n\n acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n for k in range(K, 0, -TILE_K):\n if EVEN_K:\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n else:\n a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)\n b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)\n acc += tl.dot(a, b)\n a_ptrs += TILE_K * stride_ak\n b_ptrs += TILE_K * stride_bk\n c = acc.to(C.dtype.element_ty)\n\n offs_cm = tl.arange(0, TILE_M) % BLOCK\n offs_cn = tl.arange(0, TILE_N) % BLOCK\n pc = C \\\n + off_z * stride_zc \\\n + block_id * stride_hc \\\n + offs_cm[:, None] * stride_mc \\\n + offs_cn[None, :] * stride_nc\n tl.store(pc, c, mask=True)\n\ndef sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):\n if a.stride(2) != 1 and a.stride(3) != 1:\n a = a.contiguous()\n if b.stride(2) != 1 and b.stride(3) != 1:\n b = b.contiguous()\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n Ka, Kb = a.shape[a_dim], b.shape[b_dim]\n if Ka != Kb:\n raise ValueError(f\"Inner dimension mismatch (A: {Ka} vs B: {Kb})\")\n if out is None:\n c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)\n else:\n assert out.shape == (a.shape[0], lut.shape[0], block, block)\n c = out\n grid = [1, c.shape[1], c.shape[0]]\n _sdd_kernel[grid](\n a, b, c,\n a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),\n b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),\n c.stride(0), c.stride(1), c.stride(2), c.stride(3),\n Ka, 0, lut,\n TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,\n num_warps=4,\n )\n return c\n\n@triton.jit\ndef _dsd_kernel(\n A, B, C,\n stride_az, stride_ha, stride_am, stride_ak,\n stride_zb, stride_hb, stride_bk, stride_bn,\n stride_zc, stride_hc, stride_cm, stride_cn,\n DS0, DS1, lut,\n TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr\n):\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n num_pid_m = tl.num_programs(0)\n num_pid_n = tl.num_programs(1)\n pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)\n pidz = tl.program_id(2)\n header = lut + pid_n * 4\n offset = tl.load(header + 0)\n K = tl.load(header + 1)\n column = tl.load(header + 2)\n off_h = tl.load(header + 3)\n pinc = lut + offset\n block_id = tl.load(pinc + 1)\n block_id = tl.multiple_of(block_id, 8)\n offs_am = tl.arange(0, TILE_M)\n offs_ak = tl.arange(0, TILE_K)\n pa = A + pidz * stride_az \\\n + block_id * stride_ha \\\n + offs_am[:, None] * stride_am \\\n + offs_ak[None, :] * stride_ak\n offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)\n start_bk = tl.load(pinc)\n start_bk = tl.multiple_of(start_bk, 8)\n offs_bk = start_bk + tl.arange(0, TILE_K)\n pb = B + pidz * stride_zb \\\n + off_h * stride_hb \\\n + offs_bn[None, :] * stride_bn \\\n + offs_bk[:, None] * stride_bk\n acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n pinc += 2\n inc_a = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.load(pinc)\n inc_b = tl.multiple_of(inc_b, 8)\n for k in range(K, 0, -TILE_K):\n a = tl.load(pa, mask=True)\n b = tl.load(pb, mask=offs_bn[None, :] < DS0)\n acc += tl.dot(a, b)\n pa += inc_a\n pb += inc_b * stride_bk\n pinc += 2\n inc_a = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.load(pinc)\n inc_b = tl.multiple_of(inc_b, 8)\n c = acc.to(C.dtype.element_ty)\n offs_cm = column * TILE_M + tl.arange(0, TILE_M)\n offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)\n pc = C \\\n + off_h * stride_hc \\\n + pidz * stride_zc \\\n + offs_cm[:, None] * stride_cm \\\n + offs_cn[None, :] * stride_cn\n tl.store(pc, c, mask=offs_cn[None, :] < DS0)\n\ndef dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):\n if a.stride(2) != 1 and a.stride(3) != 1:\n a = a.contiguous()\n if b.stride(2) != 1 and b.stride(3) != 1:\n b = b.contiguous()\n AS1 = block * spdims[2 if trans_a else 1]\n BS0 = b.size(0)\n BS1 = b.size(1)\n BS3 = b.size(2 if trans_b else 3)\n dtype = a.dtype\n CS0 = BS0\n CS1 = BS1\n CS2 = BS3 if trans_c else AS1\n CS3 = AS1 if trans_c else BS3\n if out is None:\n c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)\n else:\n assert out.shape == (CS0, CS1, CS2, CS3)\n c = out\n TILE_N = 128\n grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]\n _dsd_kernel[grid](\n a, b, c,\n a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),\n b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),\n c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),\n BS3, AS1, lut,\n TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,\n num_warps=4, GROUP_SIZE_M=4,\n )\n return c\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: _sdd_kernel for multiplying sparse matrix A with dense matrices B and C, and _dsd_kernel for multiplying dense matrix A with sparse matrix B. Both functions require several parameters, including tensor data, strides, dimensions, and layout information in the form of a lookup table (lut).", - "description_2": "Use triton language to implement sparse-dense and dense-sparse matrix multiplications with lookup tables for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _blocksparse_softmax_fwd(\n Out, A, stride_xz, LUT,\n R, extent, stride_zr, stride_hr, # relative attention\n scale, is_causal,\n ROW_SIZE: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n IS_DENSE: tl.constexpr,\n):\n h = tl.program_id(0)\n m = tl.program_id(1)\n z = tl.program_id(2)\n # create index ranges\n hm = h * tl.num_programs(1) + m\n lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE\n block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE\n # extract information from LUT\n header = LUT + (hm // BLOCK_SIZE) * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # pointer offset\n off_a = z * stride_xz\n off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx\n off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx\n # do not need to read column indices in the dense case\n if IS_DENSE:\n ns = tl.arange(0, ROW_SIZE)\n else:\n off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE\n start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)\n ns = start_n * BLOCK_SIZE + lane_n\n # load X\n mask = block_n < size\n a = tl.load(A + off_a + lane_n, mask=mask, other=-float(\"inf\"))\n a = a.to(tl.float32)\n # compute\n out = a\n out *= scale\n # apply relative attention\n if R is not None:\n R += z * stride_zr\n R += h * stride_hr\n off_lo = (extent - m - 1) + ns\n mask_lo = (off_lo >= 0) & (off_lo < extent)\n rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)\n out += rel_logits\n out = out.to(tl.float32)\n # apply causal mask\n out = tl.where((ns > m) & is_causal, -float(\"inf\"), out)\n # computation\n out = tl.softmax(out)\n # write-back\n tl.store(Out + off_a + lane_n, out, mask=mask)\n\n@triton.jit\ndef _blocksparse_softmax_bwd(\n DA, stride_zdx,\n DOut, stride_zdout,\n Out, stride_zout,\n scale,\n LUT,\n DR, extent, stride_zr, stride_hr, stride_er,\n is_causal,\n ROW_SIZE: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n IS_DENSE: tl.constexpr,\n):\n h = tl.program_id(0)\n m = tl.program_id(1)\n z = tl.program_id(2)\n # create index ranges\n hm = h * tl.num_programs(1) + m\n lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE\n block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE\n # extract information from LUT\n header = LUT + (hm // BLOCK_SIZE) * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # row-col offset\n off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE\n off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE\n mask = block_n < size\n # pointers\n As = Out + z * stride_zout + off_mn\n DOuts = DOut + z * stride_zdout + off_mn\n # do not need to read column indices in the dense case\n if IS_DENSE:\n ns = tl.arange(0, ROW_SIZE)\n else:\n off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE\n start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)\n ns = start_n * BLOCK_SIZE + lane_n\n # load data\n a = tl.load(As + lane_n, mask=mask, other=0.0)\n a = a.to(tl.float32)\n dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)\n dout = dout.to(tl.float32)\n # compute\n a = tl.where((ns > m) & is_causal & (a == a), 0., a)\n da = a * (dout - tl.sum(a * dout, 0))\n # apply relative attention\n if DR is not None:\n DR += z * stride_zr\n DR += h * stride_hr\n off_lo = (extent - m - 1) + ns\n mask_lo = (off_lo >= 0) & (off_lo < extent) & mask\n tl.store(DR + m * extent + off_lo, da, mask=mask_lo)\n da = da * scale\n # convert da\n # write-back\n DAs = DA + z * stride_zdx + off_mn\n tl.store(DAs + lane_n, da, mask=mask)\n\nclass _softmax(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx, a, scale, rel_logits, is_causal,\n spdims, block, lut, maxlut, is_dense\n ):\n if scale is not None and isinstance(scale, torch.Tensor):\n assert scale.device.type == \"cpu\"\n scale = scale.item()\n M = a.shape[0]\n grid = [spdims[0], spdims[1] * block, M]\n rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape\n rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()\n # enqueue kernel\n out = torch.empty_like(a)\n _blocksparse_softmax_fwd[grid](\n out, a, a.stride(0), lut,\n rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn\n scale,\n is_causal,\n BLOCK_SIZE=block,\n ROW_SIZE=triton.next_power_of_2(maxlut),\n IS_DENSE=is_dense,\n num_warps=num_warps(maxlut)\n )\n # save to context\n ctx.save_for_backward(out, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.rel_shape = rel_shape\n ctx.rel_strides = rel_strides\n ctx.rel_dtype = a.dtype\n ctx.is_dense = is_dense\n ctx.is_causal = is_causal\n return out\n\n @staticmethod\n def backward(ctx, dout):\n # retrieve from context\n out, lut = ctx.saved_tensors\n # relative logits gradients\n dr = None\n if ctx.needs_input_grad[3]:\n dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)\n # run kernel\n M = out.shape[0]\n grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)\n da = torch.empty_like(dout)\n _blocksparse_softmax_bwd[grid](\n da, da.stride(0),\n dout, dout.stride(0),\n out, out.stride(0),\n ctx.scale,\n lut,\n dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],\n ctx.is_causal,\n BLOCK_SIZE=ctx.block,\n ROW_SIZE=triton.next_power_of_2(ctx.maxlut),\n IS_DENSE=ctx.is_dense,\n num_warps=num_warps(ctx.maxlut)\n )\n return (da, None, None, dr, None,\n None, None, None, None, None,\n None,\n None, None, None,\n None,\n None, None, None\n )\n\nclass softmax:\n def __init__(self, layout, block, device, is_dense=False):\n self.spdims = layout.shape\n self.layout = layout\n self.block = block\n self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)\n self.is_dense = is_dense\n\n def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):\n if rel_logits is not None and rel_logits.dtype != a.dtype:\n raise ValueError(\"relative position embedding must be %s\" % a.dtype)\n a = _softmax.apply(\n a, scale, rel_logits, is_causal,\n self.spdims, self.block, self.lut, self.maxlut, self.is_dense,\n )\n return a\n", - "description_1": "Use triton language to implement a block-sparse softmax forward and backward kernel. The forward kernel '_blocksparse_softmax_fwd' takes 12 parameters: Out (output tensor), A (input tensor), stride_xz (stride for input tensor), LUT (lookup table), R (relative attention tensor), extent (extent of relative attention), stride_zr (stride for relative attention), stride_hr (stride for relative attention), scale (scaling factor), is_causal (causal flag), ROW_SIZE (row size), BLOCK_SIZE (block size), and IS_DENSE (density flag). The backward kernel '_blocksparse_softmax_bwd' takes 16 parameters: DA (gradient of input tensor), stride_zdx (stride for DA), DOut (gradient of output tensor), stride_zdout (stride for DOut), Out (output tensor), stride_zout (stride for Out), scale (scaling factor), LUT (lookup table), DR (gradient of relative attention), extent (extent of relative attention), stride_zr (stride for relative attention), stride_hr (stride for relative attention), stride_er (stride for relative attention), is_causal (causal flag), ROW_SIZE (row size), BLOCK_SIZE (block size), and IS_DENSE (density flag). The '_softmax' class is a PyTorch autograd function that uses these kernels for forward and backward passes, and the 'softmax' class is a wrapper for using this functionality.", - "description_2": "Use triton language to create block-sparse softmax operations with forward and backward kernels, handling relative attention and causal masking.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\ndef num_warps(N):\n if N < 2048:\n return 4\n elif N < 8192:\n return 8\n return 16\n\n@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})\n@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})\n@triton.jit\ndef _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK)\n idx = tl.load(IDX + row)\n # pointers to logit and probs\n LOGITS = LOGITS + row * N + cols\n WRIT_PROBS = PROBS + row * N + cols\n READ_PROBS = PROBS + row * N + idx\n # write-back negative log-probs\n logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))\n logits = logits.to(tl.float32)\n logits = logits - tl.max(logits, 0)\n probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits\n tl.store(WRIT_PROBS, probs, mask=cols < N)\n # There is a bug in the compiler, which fails to insert a barrier here.\n # We add it explicitly for now. Will be fixed soon.\n tl.debug_barrier()\n # write-back loss\n probs = tl.load(READ_PROBS)\n tl.store(LOSS + row, probs)\n\n@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})\n@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})\n@triton.jit\ndef _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK)\n idx = tl.load(IDX + row)\n # pointers to probs\n PROBS = PROBS + row * N + cols\n # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]\n # and we have -log(p[k]) stored in PROBS, so this is easy\n probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))\n probs = tl.exp(probs.to(tl.float32))\n delta = cols == idx\n # write result in-place in PROBS\n dout = tl.load(DPROBS + row)\n din = (probs - delta) * dout\n tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)\n\nclass _cross_entropy(torch.autograd.Function):\n @classmethod\n def forward(cls, ctx, logits, indices):\n # make sure we can use triton\n assert (indices.dtype == torch.int64), \"Indices are expected to be of type long.\"\n # make kernel\n device, dtype = logits.device, logits.dtype\n n_cols = logits.shape[-1]\n # run the kernel\n result = torch.empty_like(indices, dtype=dtype, device=device)\n neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)\n grid = lambda opt: (logits.numel() // n_cols, )\n _forward[grid](logits, neg_logprobs, indices, result, n_cols)\n # save for backward\n ctx.save_for_backward(neg_logprobs, indices)\n return result\n\n @classmethod\n def backward(cls, ctx, dneg_logprobs):\n \"\"\"We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]\n so we initialize the gradient as neg_logprobs, so we can just exponentiate\n to get p[k], which is most of what we need... neg_logprobs will be\n modified in place to become the gradient we want\n \"\"\"\n # load saved tensors\n neg_logprobs, indices = ctx.saved_tensors\n # run the kernel\n # neg_logprobs will be modified in place to become our gradient:\n n_cols = neg_logprobs.shape[-1]\n grid = lambda opt: (neg_logprobs.numel() // n_cols, )\n _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)\n return neg_logprobs, None\n\ncross_entropy = _cross_entropy.apply\n", - "description_1": "Use triton language to implement a cross-entropy loss function with two kernels: _forward and _backward. The _forward kernel computes negative log-probabilities and loss, taking 6 parameters: LOGITS (input logits), PROBS (output probabilities), IDX (indices), LOSS (output loss), N (number of columns), and BLOCK (block size). The _backward kernel computes gradients, taking 5 parameters: PROBS (input/output probabilities), IDX (indices), DPROBS (input gradients), N (number of columns), and BLOCK (block size). The _cross_entropy class wraps these kernels for use in PyTorch's autograd system.", - "description_2": "Use triton language to create a cross-entropy loss function with forward and backward kernels, handling logits, probabilities, indices, and gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _kernel(A, B, C, M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr\n ):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(K, 0, -BLOCK_K * SPLIT_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.)\n b = tl.load(B, mask=rk[:, None] < k, other=0.)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\nclass _matmul(torch.autograd.Function):\n kernel = _kernel\n\n _locks = dict()\n\n @staticmethod\n def _call(a, b):\n device = a.device\n # handle non-contiguous inputs if necessary\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c = torch.empty((M, N), device=device, dtype=a.dtype)\n # accumulator types\n ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32\n # launch kernel\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _kernel[grid](a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n GROUP_M=8, ACC_TYPE=ACC_TYPE)\n return c\n\n @staticmethod\n def forward(ctx, a, b):\n return _matmul._call(a, b)\n\n\nmatmul = _matmul.apply\n", - "description_1": "Use triton language to create a matrix multiplication kernel with parameters for input matrices A, B, and output matrix C, dimensions M, N, K, strides for A, B, C, block sizes, group size, split factor, even K heuristic, and accumulation type. A wrapper function handles non-contiguous inputs and launches the kernel, checking constraints and setting the accumulator type based on input data types.", - "description_2": "Use triton language to implement a kernel for matrix multiplication with specific configurations and a wrapper to handle non-contiguous inputs and launch the kernel with necessary constraints.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Define the Triton kernel\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel code would go here\n\n# Define a function to run the kernel\ndef run_kernel(x_ptr, x_size):\n kernel[(1,)](x_ptr, x_size, META={'BLOCK_SIZE': 128})\n\n# Example code to call the kernel\nx_size = torch.tensor(1024)\nx_ptr = torch.empty(x_size)\nrun_kernel(x_ptr, x_size)\n", - "description_1": "Use triton language to define a kernel function 'kernel' that takes two arguments: 'x_ptr' (a pointer) and 'x_size' (an integer). It also accepts additional meta-parameters through '**META'. The function uses a meta-parameter 'BLOCK_SIZE' within its body. Additionally, implement a function 'run_kernel' that executes this Triton kernel with specified values.", - "description_2": "Use triton language to define a kernel with pointer and size inputs and run it with metadata.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef example_kernel(X, Y, Z):\n # Triton kernel that performs element-wise addition of two tensors\n idx = triton.program_id(0)\n X[idx] = Y[idx] + Z[idx]\n\ndef call_example_kernel(X, Y, Z):\n # Function to call the Triton kernel\n grid = (X.numel(),)\n example_kernel[grid](X, Y, Z)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' that takes three parameters: X, Y, and Z. The kernel performs element-wise addition of tensors Y and Z, storing the result in X. The kernel is called using 'call_example_kernel', which sets up the grid size based on the number of elements in X and invokes the kernel.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors and a function to call this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel 'add_kernel' takes five parameters: pointers to the input vectors x and y, a pointer to the output vector, the number of elements in the vectors, and a block size as a compile-time constant. The kernel computes the element-wise sum of x and y, storing the result in the output vector. The 'add' function is a wrapper that prepares the output tensor, sets up the grid for kernel execution, and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to perform element-wise addition of two vectors on the GPU, utilizing a custom kernel and a wrapper function for execution.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n if ACTIVATION:\n accumulator = ACTIVATION(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=None):\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n assert a.is_contiguous(), \"matrix A must be contiguous\"\n assert b.is_contiguous(), \"matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n assert (\n K % 32 == 0\n ), \"We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K\"\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation,\n )\n return c\n", - "description_1": "Use triton language to create a matrix multiplication kernel `matmul_kernel` which multiplies matrix A (shape MxK) with B (shape KxN) using specified block sizes for M, N, K, and an optional activation function. The wrapper function `matmul` checks constraints, prepares data, and invokes the kernel with appropriate grid configuration.", - "description_2": "Use triton language to implement high-performance matrix multiplication with configurable block sizes and optional activation, checking input constraints and launching the computation on the grid.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\noutput = dropout(x, x_keep=x_keep, p=p)\n\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n", - "description_1": "Use triton language to implement two dropout functions. The first function, _dropout, takes six parameters: pointers to input, mask, and output tensors, the number of elements, dropout probability, and block size. It applies dropout using a precomputed mask. The second function, _seeded_dropout, takes six parameters: pointers to input and output tensors, the number of elements, dropout probability, a seed for random number generation, and block size. It applies dropout using a generated random mask based on the seed.", - "description_2": "Use triton language to implement dropout with a precomputed mask and seeded random mask.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n A, Out, Weight, Bias, Mean, Rstd, stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,\n GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n # position of elements processed by this program\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n # offset data pointers to start at the row of interest\n X += row * stride\n DY += row * stride\n DX += row * stride\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n DB = DB + lock_id * N + cols\n # load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(M + row)\n rstd = tl.load(V + row)\n # compute dx\n xhat = (x - mean) * rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n mean1 = tl.sum(xhat * wdy, axis=0) / N\n mean2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DX + cols, dx, mask=mask)\n # accumulate partial sums for dw/db\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = (dy).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n # first store doesn't accumulate\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n # release lock\n tl.atomic_xchg(Lock, 0)\n\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n tl.store(FINAL_DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device='cuda')\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n # allocate output\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,\n x_arg.stride(0), N, ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n # accumulate partial sums in separate kernel\n _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128)\n return dx, None, dw, db, None\n\n\nlayer_norm = LayerNorm.apply\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\ntest_layer_norm(1151, 8192, torch.float16)\n", - "description_1": "Use triton language to implement a fused forward and backward layer normalization. The forward function computes the mean and variance of the input data, normalizes it, and applies a linear transformation. The backward function computes the gradients for the input data and the linear transformation parameters.", - "description_2": "Use triton language to implement a layer normalization forward and backward operation with input normalization and parameter gradient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n L, M,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n q = tl.load(q_ptrs)\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n l_prev *= tl.exp(m_prev - m_curr)\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n l_rcp = 1. / l_curr\n p *= l_rcp\n acc *= (l_prev * l_rcp)[:, None]\n p = p.to(tl.float16)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n l_prev = l_curr\n m_prev = m_curr\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vk\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n q = tl.load(q_ptrs)\n qk = tl.dot(q, tl.trans(k))\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n do = tl.load(do_ptrs)\n dv += tl.dot(tl.trans(p.to(tl.float16)), do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, tl.trans(v))\n ds = p * dp * sm_scale\n dk += tl.dot(tl.trans(ds.to(tl.float16)), q)\n dq = tl.load(dq_ptrs)\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_ptrs, dq)\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk, num_warps=num_warps,\n num_stages=2,\n )\n\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n BLOCK = 128\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement forward and backward kernels for the Flash Attention mechanism. The forward kernel computes the matrix multiplications and softmax scaling using blocks of data, maintaining accumulators for the outputs. The backward kernel computes gradients with respect to inputs using the chain rule and processes data in blocks. Each function requires input tensors and stride parameters for proper memory access, along with several constants defining block sizes for computation.", - "description_2": "Use triton language to implement and apply a fused attention operator with kernels handling both forward and backward passes, utilizing block-based computation for efficient GPU execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# triton kernel\n@triton.jit\ndef kernel(X, stride_xm,\n Z, stride_zn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n off_m = tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, BLOCK_N)\n Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1\n Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn\n tl.store(Zs, tl.load(Xs))\n\nret = triton.compile(kernel, signature=\"*fp32,i32,*fp32,i32\", constants={\"BLOCK_M\": 64, \"BLOCK_N\": 64})\nprint(ret.asm[\"ttgir\"])\n", - "description_1": "Use triton language to define a kernel function 'kernel' with 5 parameters: X (input tensor), stride_xm (stride for X), Z (output tensor), stride_zn (stride for Z), and two constexpr parameters BLOCK_M and BLOCK_N. The kernel computes offsets for a block of size BLOCK_M x BLOCK_N, loads data from X using these offsets, and stores the result in Z.", - "description_2": "Use triton language to define a kernel that loads data from an input tensor and stores it in an output tensor using block-wise offsets.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport importlib\n\n@triton.jit\ndef kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Kernel function to load data from in_ptr0 and store it to out_ptr0\n xnumel = 10\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), xmask)\n tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)\n\ninp = torch.randn(10)\nout = torch.randn(10)\nkernel[(10,)](inp, out, 10, XBLOCK=16)\nspec = importlib.util.spec_from_file_location(\"__triton_launcher\", ExtensionBackend.stub_so_path)\nmod = importlib.util.module_from_spec(spec)\nspec.loader.exec_module(mod)\nlaunch_counter = getattr(mod, \"launch_counter\")\n\nfor _ in range(100):\n kernel[(10,)](inp, out, 10, XBLOCK=16)\n\nassert launch_counter() > 0\n", - "description_1": "Use triton language to define a kernel function that takes four parameters: in_ptr0 (input pointer), out_ptr0 (output pointer), xnumel (number of elements), and XBLOCK (block size). The kernel loads data from in_ptr0 and stores it to out_ptr0 using a mask to ensure indices are within bounds. The kernel is launched with a grid size of 10 and block size of 16.", - "description_2": "Use triton language to create a kernel that transfers data from an input pointer to an output pointer with bounds checking, and execute it with specified grid and block sizes.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit()\ndef kernel(x_ptr, y_ptr, out_ptr):\n # Triton kernel to add two vectors element-wise\n pid = tl.program_id(axis=0)\n x = tl.load(x_ptr + pid)\n y = tl.load(y_ptr + pid)\n out = x + y\n tl.store(out_ptr + pid, out)\n\ndef test_xpu_backend(cmdopt):\n if cmdopt == \"xpu\":\n has_ipex = False\n try:\n # Import IPEX to provide Intel GPU runtime\n import intel_extension_for_pytorch # type: ignore # noqa: F401\n has_ipex = True if hasattr(torch, \"xpu\") else False\n except Exception:\n has_ipex = False\n\n if has_ipex:\n for _ in range(1000):\n x = torch.randn((65536,), device=\"xpu\", dtype=torch.float32)\n y = torch.randn((65536,), device=\"xpu\", dtype=torch.float32)\n z = torch.zeros((65536,), device=\"xpu\", dtype=torch.float32)\n # Kernel call: perform element-wise addition\n kernel[(65536,)](x, y, z, num_warps=32)\n assert torch.all(x + y == z)\n else:\n return\n", - "description_1": "Use triton language to implement a kernel that performs element-wise addition of two vectors. The kernel function 'kernel' has three parameters: 'x_ptr', 'y_ptr', and 'out_ptr', which are pointers to input and output vectors. The function loads elements from 'x_ptr' and 'y_ptr', adds them, and stores the result in 'out_ptr'. The function 'test_xpu_backend' is used to test this kernel on an Intel GPU (if available), creating random input vectors and verifying the output for correctness. It has one parameter 'cmdopt', a string indicating whether the backend is 'xpu'.", - "description_2": "Use triton language to implement and test an element-wise addition kernel for vectors using a conditional Intel GPU backend.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom numpy.random import RandomState\n\n# Kernel and its invocation for the 'chained_matmul_kernel'\n@triton.jit\ndef chained_matmul_kernel(\n A, # shape: (m, k)\n B, # shape: (n, k)\n C, # shape: (n, k)\n out, # shape: (m, k)\n m, n, k: tl.constexpr,\n block_m: tl.constexpr,\n block_n: tl.constexpr,\n block_k: tl.constexpr):\n\n tl.static_assert(block_k == k,\n f\"expected block_k == k but got {block_k} != {k}\")\n\n block_ix = tl.program_id(0)\n a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \\\n + tl.arange(0, block_k)[None, :]\n\n a = tl.load(A + a_tile, mask=a_tile < m * k, other=0.0)\n\n acc = tl.zeros([block_m, block_k], dtype=tl.float32)\n\n for loop_block_start in range(0, n, block_n):\n bc_tile = (loop_block_start + tl.arange(0, block_n))[:, None] * block_k \\\n + tl.arange(0, block_k)[None, :]\n b = tl.load(B + bc_tile, mask=bc_tile < n * k, other=0.0)\n\n intermediate = tl.dot(a, tl.trans(b))\n intermediate_mask = ((loop_block_start + tl.arange(0, block_n)) < n)[None, :] \\\n * (tl.arange(0, block_m) < m)[:, None]\n\n intermediate = tl.where(intermediate_mask, intermediate, 0.0)\n\n c = tl.load(C + bc_tile, mask=bc_tile < n * k)\n\n acc += tl.dot(intermediate.to(A.dtype.element_ty), c)\n\n tl.store(out + a_tile, acc.to(A.dtype.element_ty), mask=a_tile < m * k)\n\nm, n, k = 32, 64, 128\nblock_m, block_n, block_k = 16, 32, k\n\ngrid = (triton.cdiv(m, block_m),)\na = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda')\nb = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda')\nc = torch.randint_like(b, low=0, high=2)\ntriton_result = torch.zeros_like(a)\n\nchained_matmul_kernel[grid](a, b, c, triton_result, m, n, k, block_m=block_m, block_n=block_n, block_k=block_k)\n\n# Kernel and its invocation for 'batched_vecmat'\n@triton.jit\ndef batched_vecmat(\n # inputs\n A, # shape: [dim_m, dim_k]\n B, # shape: [dim_m, dim_n, dim_k]\n # dimensions\n dim_m, dim_n, dim_k,\n # outputs\n output,\n # block information\n block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr\n):\n m_index = tl.program_id(0)\n n_index = tl.program_id(1)\n # Output tile\n output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \\\n + (n_index * block_n + tl.arange(0, block_n))[None, :]\n\n vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)\n k_blocks = dim_k // block_k\n for k_index in range(k_blocks):\n # Load A tile\n a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \\\n + (k_index * block_k + tl.arange(0, block_k))[None, :]\n a = tl.load(A + a_tile)\n\n # Load B tile, transposed to [n, m, k] in order to broadcast A on a\n # leading dimension.\n b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \\\n + (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \\\n + (k_index * block_k + tl.arange(0, block_k))[None, None, :]\n b = tl.load(B + b_tile)\n\n expanded_a, _ = tl.broadcast(a, b)\n vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))\n\n tl.store(output + output_tile, vecmat)\n\nM, N, K = 128, 128, 128\nblock_m, block_n, block_k = 16, 32, 64\n\nrs = RandomState(17)\nA_vec = rs.randint(0, 4, (M, K)).astype('float32')\nB_vec = rs.randint(0, 4, (M, N, K)).astype('float32')\nA = A_vec\nB = B_vec\n\nA_tri = torch.tensor(A, device='cuda')\nB_tri = torch.tensor(B, device='cuda')\nC_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')\n\ngrid = (M // block_m, N // block_n)\n\nbatched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,\n block_m=block_m, block_n=block_n, block_k=block_k,\n num_warps=4, num_stages=1)\n\n# Kernel and its invocation for 'kernel'\n@triton.jit\ndef kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n type: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n pid_m = pid // num_pid_n\n pid_n = pid % num_pid_n\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n a_ptrs = a_ptr\n b_ptrs = b_ptr\n if type == \"post_load_two_iters\":\n a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_three_iters\":\n a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk\n a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak\n b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n if type == \"pre_load\":\n a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak\n b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk\n elif type == \"post_pre_mixed\":\n a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n if type == \"post_load\":\n a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak\n b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_pre_mixed\":\n b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_two_iters\":\n a_ptrs = a_ptrs_next\n b_ptrs = b_ptrs_next\n a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_three_iters\":\n a_ptrs = a_ptrs_next\n b_ptrs = b_ptrs_next\n a_ptrs_next = a_ptrs_next_next\n b_ptrs_next = b_ptrs_next_next\n a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak\n b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\nM = 256\nK = 256\nN = 256\nBLOCK_SIZE_K = 32\nBLOCK_SIZE_N = 32\nBLOCK_SIZE_M = 32\n\na = torch.rand((M, K), device='cuda')\nb = torch.rand((K, N), device='cuda')\n\ntorch_output = torch.mm(a, b)\ntriton_output = torch.empty_like(torch_output, device=torch_output.device)\n\ndef grid(META):\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)\n\nnum_stages = 4\nkernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),\n b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n type=\"post_load_three_iters\", num_stages=num_stages)\n", - "description_1": "Use triton language to implement three different matrix multiplication kernels: 1. 'chained_matmul_kernel' performs a chained matrix multiplication by loading and processing blocks of data from input matrices and storing the result in an output matrix. Parameters: 10 (A, B, C, out, m, n, k, block_m, block_n, block_k). 2. 'batched_vecmat' computes a batch of vector-matrix multiplications with input and output matrices specified. Parameters: 8 (A, B, dim_m, dim_n, dim_k, output, block_m, block_n, block_k). 3. 'kernel' is used for matrix multiplication allowing for various strategies ('pre_load', 'post_load', etc.) by adjusting pointer calculations for A and B matrices. Parameters: 13 (a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, type).", - "description_2": "Use triton language to create matrix multiplication kernels for chained operations, batched vector-matrix calculations, and customizable iteration strategies, each with specific parameters for data layout and execution configuration.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport pytest\n\n@triton.jit\ndef _add(x_ptr, y_ptr, output_ptr, n_elements,\n BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n@pytest.mark.parametrize('N', [1024 * 16, 1024 * 64, 1024 * 256, 1024 * 1024, 1024 * 16384, 1024 * 65536, 1020 * 100, 10003 * 7007])\n@pytest.mark.parametrize(\"dtype_str\", ['float16', 'bfloat16', 'float32'])\ndef test_elementwise(N, dtype_str):\n stream = torch.cuda.Stream()\n torch.cuda.set_stream(stream)\n torch.manual_seed(0)\n if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100':\n pytest.skip('Only test bfloat16 on a100')\n dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]\n z = torch.empty((N, ), dtype=dtype, device='cuda')\n x = torch.randn_like(z)\n y = torch.randn_like(z)\n grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )\n fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)\n ms = triton.testing.do_bench_cudagraph(fn)\n cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6\n cur_gpu_util = cur_gpu_perf / get_dram_gbps()\n print_perf(ms, cur_gpu_util, elementwise_data[DEVICE_NAME][N][dtype_str])\n triton.testing.assert_close(cur_gpu_util, elementwise_data[DEVICE_NAME][N][dtype_str], atol=0.02, rtol=0.01)\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel '_add' takes five parameters: x_ptr, y_ptr, output_ptr, n_elements, and BLOCK_SIZE. It computes the element-wise sum of two input arrays 'x' and 'y' and stores the result in 'output'. The 'test_elementwise' function benchmarks this kernel for different data sizes and types, ensuring performance close to reference values.", - "description_2": "Use triton language to create an element-wise addition kernel and benchmark its performance for various data sizes and types.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel to perform element-wise addition of two vectors\n@triton.jit\ndef add_kernel(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to test the add_kernel\ndef test_addition():\n a = torch.rand((128,), device=\"cuda\")\n b = torch.rand((128,), device=\"cuda\")\n expected = a + b\n output = torch.empty((128,), device=\"cuda\")\n\n def grid(meta):\n return (triton.cdiv(128, meta[\"BLOCK_SIZE\"]),)\n\n add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)\n\n assert torch.allclose(expected, output, atol=1e-2, rtol=0)\n\n# Kernel to perform atomic operations on a vector\n@triton.jit\ndef atomic(\n x_ptr,\n):\n pid = tl.program_id(axis=0)\n tl.atomic_add(x_ptr + pid, 1)\n t = tl.atomic_xchg(x_ptr + pid, 3)\n t += 1 # 2\n tl.atomic_cas(x_ptr + pid, 3, t) # match\n tl.atomic_cas(x_ptr + pid, 40, 9) # no match\n\n# Function to test the atomic kernel\ndef test_atomic():\n nb_dim = 16\n a = torch.zeros((nb_dim, ), dtype=torch.int32, device=\"cuda\")\n\n atomic[(nb_dim, )](a)\n assert torch.allclose(a, torch.full_like(a, 2))\n", - "description_1": "Use triton language to implement two kernels: one for element-wise addition of two vectors and another for performing atomic operations on a vector. The addition kernel takes five parameters: pointers to the input vectors, a pointer to the output vector, the number of elements, and a block size. The atomic kernel takes one parameter: a pointer to the vector on which atomic operations are performed.", - "description_2": "Use triton language to create a kernel for vector addition and another for atomic operations on a vector.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing import assert_close\n\n@triton.jit\ndef kernel_device_assert(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_assert(x == 0, \"x != 0\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_assert(0 == 0, \"x != 0\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit(debug=False)\ndef kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_assert(x == 0, \"x != 0\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_assert(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n assert x == 0, \"x != 0\"\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_static_assert(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.static_assert(BLOCK == 128, \"BLOCK != 128\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\ndef test_assert(func: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if func == \"device_assert\":\n kernel_device_assert[(1,)](x, y, BLOCK=shape[0])\n kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])\n elif func == \"no_debug\":\n kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])\n elif func == \"assert\":\n kernel_assert[(1,)](x, y, BLOCK=shape[0])\n elif func == \"static_assert\":\n kernel_static_assert[(1,)](x, y, BLOCK=shape[0])\n assert_close(y, x)\n\n@triton.jit\ndef jit_device_assert_none(x):\n tl.device_assert(x == 0, \"x != 0\")\n\n@triton.jit(debug=True)\ndef jit_device_assert_true(x):\n tl.device_assert(x == 0, \"x != 0\")\n\n@triton.jit(debug=False)\ndef jit_device_assert_false(x):\n tl.device_assert(x == 0, \"x != 0\")\n\n@triton.jit\ndef kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n if jit_debug == \"true\":\n jit_device_assert_true(x)\n elif jit_debug == \"false\":\n jit_device_assert_false(x)\n else:\n jit_device_assert_none(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit(debug=True)\ndef kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n if jit_debug == \"true\":\n jit_device_assert_true(x)\n elif jit_debug == \"false\":\n jit_device_assert_false(x)\n else:\n jit_device_assert_none(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit(debug=False)\ndef kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n if jit_debug == \"true\":\n jit_device_assert_true(x)\n elif jit_debug == \"false\":\n jit_device_assert_false(x)\n else:\n jit_device_assert_none(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\ndef test_assert_nested(caller: str, callee: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if caller == \"none\":\n kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)\n elif caller == \"true\":\n kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)\n elif caller == \"false\":\n kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)\n assert_close(y, x)\n", - "description_1": "Use triton language to define several kernels: 'kernel_device_assert', 'kernel_device_assert_scalar', 'kernel_device_assert_no_debug', 'kernel_assert', 'kernel_static_assert', and 'kernel_device_assert_nested'. Each kernel loads data from tensor X, asserts certain conditions using 'tl.device_assert', and stores results into tensor Y. The kernels are invoked by the 'test_assert' and 'test_assert_nested' functions, which set up input tensors and validate the results.", - "description_2": "Use triton language to implement kernels for asserting conditions on input tensors and performing data transfer between global memory and registers.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing import assert_close\n\n# Kernel that uses tl.device_print to print values from device\n@triton.jit\ndef kernel_device_print(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_print(\"\", x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n# Kernel that uses Python's print function to print values\n@triton.jit\ndef kernel_print(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n print(\"\", x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n# Kernel that uses tl.static_print to print values\n@triton.jit\ndef kernel_static_print(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.static_print(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n# Function to test the print kernels\ndef test_print(func: str, data_type: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if func == \"device_print\":\n kernel_device_print[(1,)](x, y, BLOCK=shape[0])\n elif func == \"print\":\n kernel_print[(1,)](x, y, BLOCK=shape[0])\n elif func == \"static_print\":\n kernel_static_print[(1,)](x, y, BLOCK=shape[0])\n assert_close(y, x)\n", - "description_1": "Use triton language to define three kernels: kernel_device_print, kernel_print, and kernel_static_print. Each kernel takes three parameters: X (input tensor), Y (output tensor), and BLOCK (a compile-time constant representing the block size). The kernels load a block of data from X, print it using different methods (tl.device_print, Python's print, and tl.static_print), and store the result back to Y. The test_print function calls these kernels based on the provided function name and data type.", - "description_2": "Use triton language to create kernels that load data, print it using different methods, and store the result. Implement a test function to execute these kernels based on input parameters.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel that takes a tensor X, an integer N, and a block size BLOCK_SIZE\n@triton.jit\ndef _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):\n pass\n\n# Function to test the kernel\ndef test_annotations(device):\n x = torch.empty(1, device=device)\n # Launch the kernel with the tensor x, its size, and a block size of 32\n _kernel[(1,)](x, x.shape[0], 32)\n try:\n # Attempt to launch the kernel with incorrect arguments to trigger an exception\n _kernel[(1,)](x.shape[0], x.shape[0], 32)\n except AttributeError:\n pass\n", - "description_1": "Use triton language to define a kernel that takes three parameters: a tensor X, an integer N, and a block size BLOCK_SIZE. The kernel is launched with a tensor and its size, and a block size of 32. The function also includes a test to handle incorrect argument types.", - "description_2": "Use triton language to create a kernel with a tensor, an integer, and a block size, and test it with correct and incorrect arguments.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel to copy blocks of data with padding options\n@triton.jit\ndef block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):\n pid = tl.program_id(0)\n # We only copy half of the data to see if the padding works\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),\n block_shape=(BLOCK_SIZE, ), order=(0, ))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),\n block_shape=(BLOCK_SIZE, ), order=(0, ))\n a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)\n tl.store(b_block_ptr, a, boundary_check=(0, ))\n\n# Function to test block copy kernel\ndef test_block_copy(dtype_str, n, padding_option):\n dtype = getattr(torch, dtype_str)\n if dtype_str in (\"bool\", \"int16\"):\n if padding_option == \"nan\":\n return\n a = torch.randint(0, 2, (n, ), device=\"cuda\", dtype=dtype)\n else:\n a = torch.randn((n, ), device=\"cuda\", dtype=dtype)\n b = torch.zeros((n, ), device=\"cuda\", dtype=dtype)\n\n grid = lambda meta: (triton.cdiv(n, meta[\"BLOCK_SIZE\"]),)\n block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)\n\n# Kernel for matrix multiplication with block pointers and advance API\n@triton.jit\ndef matmul_no_scf_with_advance_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr\n):\n offs_m = tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),\n offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),\n offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))\n # Below two lines are just for testing negative offsets for the `advance` API, which could be removed\n a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))\n a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))\n a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option=\"zero\")\n b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option=\"zero\")\n\n c = tl.dot(a, b)\n c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn\n tl.store(c_ptrs, c)\n\n# Function to test matrix multiplication kernel\ndef test_block_ptr_matmul_no_scf(shape, num_warps):\n m, n, k = shape\n a = torch.randn((m, k), device=\"cuda\", dtype=torch.float16)\n b = torch.randn((k, n), device=\"cuda\", dtype=torch.float16)\n c = torch.empty((m, n), device=\"cuda\", dtype=torch.float32)\n\n grid = lambda META: (1, )\n matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,\n M=m, N=n, K=k,\n stride_am=a.stride(0), stride_ak=a.stride(1),\n stride_bk=b.stride(0), stride_bn=b.stride(1),\n stride_cm=c.stride(0), stride_cn=c.stride(1),\n BLOCK_M=m, BLOCK_N=n, BLOCK_K=k,\n num_warps=num_warps)\n", - "description_1": "Use triton language to implement two kernels: one for copying blocks of data with padding options, and another for matrix multiplication using block pointers and the advance API. The block copy kernel takes 5 parameters: a_ptr (source pointer), b_ptr (destination pointer), N (total elements), BLOCK_SIZE (size of each block), and padding_option (padding strategy). The matrix multiplication kernel takes 13 parameters: a_ptr, b_ptr, c_ptr (pointers to matrices), M, N, K (dimensions), stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn (strides for matrices), and BLOCK_M, BLOCK_N, BLOCK_K (block sizes).", - "description_2": "Use triton language to create a block copy kernel with padding options and a matrix multiplication kernel using block pointers and the advance API.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel with no operations, just a placeholder\n@triton.jit\ndef kernel(X, SIZE: tl.constexpr):\n pass\n\n# Function to test the empty kernel\ndef test_empty_kernel(dtype_x, device):\n SIZE = 128\n check_type_supported(dtype_x, device)\n x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)\n kernel[(1, )](x, SIZE=SIZE, num_warps=4)\n\n# Kernel for unary operations\n@triton.jit\ndef kernel_unary(Z, X, SIZE: tl.constexpr):\n off = tl.arange(0, SIZE)\n x = tl.load(X + off)\n z = GENERATE_TEST_HERE\n tl.store(Z + off, z)\n\n# Function to test unary operations\ndef _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):\n check_type_supported(dtype_x, device)\n SIZE = 128\n kernel = patch_kernel(kernel_unary, {'GENERATE_TEST_HERE': expr})\n x = numpy_random(SIZE, dtype_str=dtype_x)\n if 'log' in expr:\n x = np.abs(x) + 0.01\n z_ref = eval(expr if numpy_expr is None else numpy_expr)\n x_tri = to_triton(x, device=device, dst_type=dtype_x)\n z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)\n kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)\n np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)\n\n# Kernel for binary operations\n@triton.jit\ndef kernel_binary(Z, X, Y, SIZE: tl.constexpr):\n off = tl.arange(0, SIZE)\n x = tl.load(X + off)\n y = tl.load(Y + off)\n z = GENERATE_TEST_HERE\n tl.store(Z + off, z)\n\n# Function to test binary operations\ndef _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):\n check_type_supported(dtype_x, device)\n check_type_supported(dtype_y, device)\n SIZE = 128\n kernel = patch_kernel(kernel_binary, {'GENERATE_TEST_HERE': expr})\n rs = RandomState(17)\n x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)\n y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)\n if mode_x == 'nan':\n x[:] = float('nan')\n if mode_y == 'nan':\n y[:] = float('nan')\n z_ref = eval(expr if numpy_expr is None else numpy_expr)\n dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)\n if dtype_z is not None:\n z_ref = z_ref.astype(dtype_z)\n x_tri = to_triton(x, device=device, dst_type=dtype_x)\n y_tri = to_triton(y, device=device, dst_type=dtype_y)\n z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)\n kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)\n np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)\n\n# Kernel for broadcasting\n@triton.jit\ndef broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):\n offset1 = tl.arange(0, M)\n offset2 = tl.arange(0, N)\n x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])\n y = tl.load(y_ptr + offset2)\n _, y_broadcasted = tl.broadcast(x, y)\n tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)\n\n# Function to test broadcasting\ndef test_broadcast(dtype, device):\n M = 32\n N = 64\n rs = RandomState(17)\n x = numpy_random((M, N), dtype_str=dtype, rs=rs)\n y = numpy_random(N, dtype_str=dtype, rs=rs)\n _, y_broadcasted_np = np.broadcast_arrays(x, y)\n x_tri = to_triton(x, device=device, dst_type=dtype)\n y_tri = to_triton(y, device=device, dst_type=dtype)\n y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype)\n broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)\n assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()\n", - "description_1": "Use triton language to implement kernels for unary and binary operations, broadcasting, and an empty kernel for testing. The kernels should handle data loading, computation, and storing results. The test functions should validate the kernels by comparing Triton results with NumPy results.", - "description_2": "Use triton language to implement and test kernels for unary and binary operations, broadcasting, and an empty kernel. Ensure correctness by comparing with NumPy results.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel that loads data from X and stores it in Y\n@triton.jit\ndef kernel_single(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n# Kernel that calls an inline device function\n@triton.jit\ndef kernel_call(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = device_inline(x)\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n# Inline device function\n@triton.jit\ndef device_inline(x):\n return x + x\n\n# Kernel that calls a noinline device function\n@triton.jit\ndef kernel_call_noinline(X, Y, BLOCK: tl.constexpr):\n device_noinline(X, Y, BLOCK)\n\n# Noinline device function\n@triton.jit(noinline=True)\ndef device_noinline(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = x + x\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n# Kernel that applies softmax to data from X and stores it in Y\n@triton.jit\ndef kernel_multi_files(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.softmax(x)\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n# Test function to execute kernels\ndef test_line_info(func: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.float32, device='cuda')\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if func == \"single\":\n kernel_single[(1,)](x, y, BLOCK=shape[0])\n elif func == \"call\":\n kernel_call[(1,)](x, y, BLOCK=shape[0])\n elif func == \"call_noinline\":\n kernel_call_noinline[(1,)](x, y, BLOCK=shape[0])\n elif func == \"multi_files\":\n kernel_multi_files[(1,)](x, y, BLOCK=shape[0])\n", - "description_1": "Use triton language to define multiple kernels: 'kernel_single' with 3 parameters (X, Y, BLOCK) to load and store data; 'kernel_call' with 3 parameters (X, Y, BLOCK) to load data, process it with an inline function 'device_inline', and store it; 'device_inline' with 1 parameter (x) to double the input; 'kernel_call_noinline' with 3 parameters (X, Y, BLOCK) to call a noinline function 'device_noinline'; 'device_noinline' with 3 parameters (X, Y, BLOCK) to load data, double it, and store it; 'kernel_multi_files' with 3 parameters (X, Y, BLOCK) to apply softmax to loaded data and store it. Test these kernels using 'test_line_info' function with 1 parameter (func) to select the kernel to execute.", - "description_2": "Use triton language to define kernels for data manipulation and processing, including inline and noinline function calls, and test their execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport numpy as np\nimport scipy.stats\n\nBLOCK = 1024\n\n# Kernel for generating random uint32\n@triton.jit\ndef kernel_randint(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.randint(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for generating uniform random numbers\n@triton.jit\ndef kernel_rand(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.rand(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for generating normal random numbers\n@triton.jit\ndef kernel_randn(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.randn(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for testing rand limits\n@triton.jit\ndef kernel_rand_limits(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = tl.random.uint32_to_uniform_float(x)\n tl.store(output + idx, y)\n\n# Function to test random uint32 generation\ndef test_randint(size, seed, device):\n size = list(map(int, size.split(',')))\n x = torch.empty(size, dtype=torch.int32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK),)\n kernel_randint[grid](x, N, seed)\n out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()\n gen = CustomPhilox4x(seed, config=PHILOX_32)\n out_ref = [gen.random_raw()[0] for _ in out_tri]\n assert out_tri == out_ref\n\n# Function to test uniform PRNG\ndef test_rand(size, seed, device):\n x = torch.empty(size, dtype=torch.float32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK),)\n kernel_rand[grid](x, N, seed)\n assert all((x >= 0) & (x <= 1))\n assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01\n\n# Function to test normal PRNG\ndef test_randn(size, seed, device):\n x = torch.empty(size, dtype=torch.float32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK),)\n kernel_randn[grid](x, N, seed)\n assert abs(x.mean()) < 1e-2\n assert abs(x.std() - 1) < 1e-2\n\n# Function to test rand limits\ndef test_rand_limits(device):\n min_max_int32 = torch.tensor([\n torch.iinfo(torch.int32).min,\n torch.iinfo(torch.int32).max,\n ], dtype=torch.int32, device=device)\n output = torch.empty(2, dtype=torch.float32, device=device)\n kernel_rand_limits[(1,)](min_max_int32, output, 2)\n assert output[0] == output[1]\n assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0\n", - "description_1": "Use triton language to implement kernels for generating random numbers. The kernel_randint function takes three parameters: X (output tensor), N (number of elements), and seed (random seed). It generates random uint32 numbers and stores them in X. The kernel_rand function also takes three parameters: X (output tensor), N (number of elements), and seed (random seed). It generates uniform random numbers between 0 and 1 and stores them in X. The kernel_randn function takes the same parameters and generates normal random numbers with mean 0 and standard deviation 1, storing them in X. The kernel_rand_limits function takes three parameters: input (input tensor), output (output tensor), and n (number of elements, constexpr). It converts uint32 to uniform float and stores the result in output.", - "description_2": "Use triton language to create kernels for generating random uint32, uniform, and normal numbers, and to test the limits of uniform random number generation.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel for normalization with rematerialization\n@triton.jit\ndef triton_normalization(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 512\n rnumel = 4096\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x3 = xindex\n x0 = xindex % 64\n tmp1 = tl.load(in_ptr0 + (x0), xmask)\n tmp3 = tl.load(in_ptr1 + (x0), xmask)\n tmp11 = tl.load(in_ptr2 + (x0), xmask)\n tmp13 = tl.load(in_ptr3 + (x0), xmask)\n _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r2 = rindex\n tmp0 = tl.load(in_out_ptr0 + (r2 + (4096 * x3)), rmask & xmask, eviction_policy='evict_last', other=0)\n tmp2 = tmp0 - tmp1\n tmp4 = 1e-05\n tmp5 = tmp3 + tmp4\n tmp6 = tl.sqrt(tmp5)\n tmp7 = 1 / tmp6\n tmp8 = 1.0\n tmp9 = tmp7 * tmp8\n tmp10 = tmp2 * tmp9\n tmp12 = tmp10 * tmp11\n tmp14 = tmp12 + tmp13\n _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17)\n tl.store(in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp14, rmask & xmask)\n tmp17 = tl.sum(_tmp17, 1)[:, None]\n tmp18 = 4096.0\n tmp19 = tmp17 / tmp18\n tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask)\n\n# Kernel for average pooling backward\n@triton.jit\ndef triton_avg_pool_bw(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n x1 = (xindex // 8) % 8\n x0 = xindex % 8\n x2 = (xindex // 64)\n x5 = xindex\n tmp0 = (-1) + x1\n tmp1 = (-1) + x0\n tmp2 = 2 + x1\n tmp3 = 2 + x0\n tmp4 = 0\n tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4))\n tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4))\n tmp7 = 8\n tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7))\n tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7))\n tmp10 = tmp5 + tmp4\n tmp11 = tmp6 + tmp4\n tmp12 = 1\n tmp13 = tmp8 - tmp12\n tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13))\n tmp15 = tmp9 - tmp12\n tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15))\n tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to(tl.float32)\n tmp18 = tmp17 / 9\n tmp19 = tmp10 < tmp8\n tmp20 = tmp11 < tmp9\n tmp21 = tmp19 & tmp20\n tmp22 = 0.0\n tmp23 = tl.where(tmp21, tmp18, tmp22)\n tmp24 = tmp6 + tmp12\n tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15))\n tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to(tl.float32)\n tmp27 = tmp26 / 9\n tmp28 = tmp24 < tmp9\n tmp29 = tmp19 & tmp28\n tmp30 = tmp23 + tmp27\n tmp31 = tl.where(tmp29, tmp30, tmp23)\n tmp32 = 2\n tmp33 = tmp6 + tmp32\n tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15))\n tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to(tl.float32)\n tmp36 = tmp35 / 9\n tmp37 = tmp33 < tmp9\n tmp38 = tmp19 & tmp37\n tmp39 = tmp31 + tmp36\n tmp40 = tl.where(tmp38, tmp39, tmp31)\n tmp41 = tmp5 + tmp12\n tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13))\n tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to(tl.float32)\n tmp44 = tmp43 / 9\n tmp45 = tmp41 < tmp8\n tmp46 = tmp45 & tmp20\n tmp47 = tmp40 + tmp44\n tmp48 = tl.where(tmp46, tmp47, tmp40)\n tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to(tl.float32)\n tmp50 = tmp49 / 9\n tmp51 = tmp45 & tmp28\n tmp52 = tmp48 + tmp50\n tmp53 = tl.where(tmp51, tmp52, tmp48)\n tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to(tl.float32)\n tmp55 = tmp54 / 9\n tmp56 = tmp45 & tmp37\n tmp57 = tmp53 + tmp55\n tmp58 = tl.where(tmp56, tmp57, tmp53)\n tmp59 = tmp5 + tmp32\n tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13))\n tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to(tl.float32)\n tmp62 = tmp61 / 9\n tmp63 = tmp59 < tmp8\n tmp64 = tmp63 & tmp20\n tmp65 = tmp58 + tmp62\n tmp66 = tl.where(tmp64, tmp65, tmp58)\n tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to(tl.float32)\n tmp68 = tmp67 / 9\n tmp69 = tmp63 & tmp28\n tmp70 = tmp66 + tmp68\n tmp71 = tl.where(tmp69, tmp70, tmp66)\n tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to(tl.float32)\n tmp73 = tmp72 / 9\n tmp74 = tmp63 & tmp37\n tmp75 = tmp71 + tmp73\n tmp76 = tl.where(tmp74, tmp75, tmp71)\n tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None)\n\n# Call the normalization kernel\ndef call_triton_normalization():\n torch.manual_seed(123)\n buf14 = torch.rand(8, 64, 64, 64, device=\"cuda\")\n buf16 = torch.rand(8, 1, 64, device=\"cuda\")\n arg114_1 = torch.rand(64, device=\"cuda\")\n arg115_1 = torch.rand(64, device=\"cuda\")\n arg8_1 = torch.rand(64, device=\"cuda\")\n arg9_1 = torch.rand(64, device=\"cuda\")\n triton_normalization[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)\n torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)\n\n# Call the average pooling backward kernel\ndef call_triton_avg_pool_bw():\n inp = torch.ones(8, 2048, 8, 8, device=\"cuda\", dtype=torch.half)\n out = torch.ones_like(inp) * 3\n numel = inp.numel()\n triton_avg_pool_bw[(numel // 1024,)](inp, out, 1024)\n out_ref = torch.ones_like(inp)\n out_ref[:, :, 1:7, 0::7] = 2 / 3\n out_ref[:, :, 0::7, 1:7] = 2 / 3\n out_ref[:, :, 0::7, 0::7] = 4 / 9\n torch.testing.assert_allclose(out, out_ref)\n", - "description_1": "Use triton language to implement two kernels: one for normalization with rematerialization and another for average pooling backward. The normalization kernel takes 10 parameters: two output pointers, four input pointers, two integers for element counts, and two block size constants. It performs element-wise operations and stores results. The average pooling backward kernel takes three parameters: an input pointer, an output pointer, and a block size constant. It computes average pooling gradients and stores results.", - "description_2": "Use triton language to create a normalization kernel with rematerialization and an average pooling backward kernel, each with specific input/output pointers and block size parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for converting float8 to float16\n@triton.jit\ndef kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offs < N\n x = tl.load(X + offs, mask=mask)\n tl.store(Y + offs, x, mask=mask)\n\n# Function to call the Triton kernel\ndef f8_to_f16(x, dtype):\n ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)\n grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)\n dtype = getattr(tl, dtype)\n kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)\n return ret\n", - "description_1": "Use triton language to implement a kernel that converts float8 data to float16. The kernel takes four parameters: Y (output tensor), X (input tensor), N (number of elements), and BLOCK_SIZE (block size for parallel processing). The function f8_to_f16 calls this kernel, preparing the output tensor and setting up the grid for execution.", - "description_2": "Use triton language to create a kernel for float8 to float16 conversion and a function to execute it.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Define the kernel using triton.jit\n@triton.jit\ndef _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):\n # Compute the offsets for the block\n offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Load elements from src with a mask\n x = tl.load(src + offsets, mask=offsets < N)\n # Store the elements into dst with a mask\n tl.store(dst + offsets, x, mask=offsets < N)\n\n# Define a function to call the kernel\ndef call_kernel():\n N = 1024\n src = torch.empty(N, device='cuda')\n dst = torch.empty(N, device='cuda')\n \n # Configuration for the kernel with different BLOCK_SIZE\n configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]\n \n # Define the grid lambda\n grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)\n \n # Call the kernel with two different configurations\n _kernel[grid](dst, src, N)\n _kernel[grid](dst=dst, src=src, N=N)\n", - "description_1": "Use triton language to implement a kernel for copying elements from a source tensor to a destination tensor. The kernel is parameterized by block size, and uses a grid to launch multiple blocks in parallel. The kernel loads elements from the source tensor, applies a mask for bounds checking, and then stores the elements into the destination tensor. It requires 4 parameters: dst (destination tensor), src (source tensor), N (number of elements), and BLOCK_SIZE (block size for processing).", - "description_2": "Use triton language to create a parallel element copying kernel with customizable block size.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel function that increments an integer and stores it\n@triton.jit\ndef function_1(i):\n i = i + 1\n i = function_2(i)\n return i\n\n# Triton kernel function that increments an integer\n@triton.jit\ndef function_2(i):\n i = i + 1\n return i\n\n# Triton kernel that uses function_1 and stores the result\n@triton.jit\ndef kernel(X, i, BLOCK: tl.constexpr):\n i = i + 1\n i = function_1(i)\n tl.store(X, i)\n\n# Triton kernel with no specialization that uses function_1 and stores the result\n@triton.jit(do_not_specialize=[\"i\"])\ndef kernel_nospec(X, i, BLOCK: tl.constexpr):\n i = i + 1\n i = function_1(i)\n tl.store(X, i)\n\n# Test function to check cache reuse\ndef test_reuse():\n counter = 0\n\n def inc_counter(*args, **kwargs):\n nonlocal counter\n counter += 1\n JITFunction.cache_hook = inc_counter\n reset_tmp_dir()\n x = torch.empty(1, dtype=torch.int32, device='cuda')\n for i in range(10):\n kernel[(1,)](x, 1, BLOCK=1024)\n assert counter == 1\n\n# Test function to check specialization\n@pytest.mark.parametrize('mode', ['enable', 'disable'])\ndef test_specialize(mode):\n counter = 0\n\n def inc_counter(*args, **kwargs):\n nonlocal counter\n counter += 1\n JITFunction.cache_hook = inc_counter\n reset_tmp_dir()\n x = torch.empty(1, dtype=torch.int32, device='cuda')\n function = {'enable': kernel, 'disable': kernel_nospec}[mode]\n target = {'enable': 3, 'disable': 1}[mode]\n for i in [1, 2, 4, 8, 16, 32]:\n function[(1,)](x, i, BLOCK=512)\n assert counter == target\n\n# Triton kernel for adding two arrays\n@triton.jit\ndef kernel_add(a, b, o, N: tl.constexpr):\n idx = tl.arange(0, N)\n tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))\n\n# Triton kernel for adding two arrays with device-specific operations\n@triton.jit\ndef kernel_add_device(a, b, o, N: tl.constexpr):\n add_fn(a, b, o, N)\n\n# Triton kernel for memory operations\n@triton.jit\ndef kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 10\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), xmask)\n tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)\n", - "description_1": "Use triton language to define several kernels: function_1 and function_2 increment an integer; kernel uses function_1 to increment and store a value; kernel_nospec is a non-specialized version of kernel; kernel_add adds two arrays; kernel_add_device uses add_fn to add arrays; kernel performs memory operations with masks.", - "description_2": "Use triton language to create kernels for integer increment, array addition, and masked memory operations.", - "difficulty": 2 - }, - { - "code": "import tracemalloc\nimport torch\nimport triton\nimport triton.language as tl\nimport gc\n\ndef test_memory_leak() -> None:\n\n @triton.jit\n def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Define the number of elements to process\n xnumel = 10\n # Calculate the offset for the current program ID\n xoffset = tl.program_id(0) * XBLOCK\n # Calculate the index for each element in the block\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n # Create a mask to ensure we don't go out of bounds\n xmask = xindex < xnumel\n # Load input data with the mask\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), xmask)\n # Store the result back to the output pointer\n tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)\n\n tracemalloc.start()\n try:\n # Initialize input and output tensors\n inp = torch.randn(10, device='cuda')\n out = torch.randn(10, device='cuda')\n # Launch the kernel\n kernel[(10,)](inp, out, 10, XBLOCK=16)\n gc.collect()\n begin, _ = tracemalloc.get_traced_memory()\n # Run the kernel multiple times to check for memory leaks\n for _ in range(100):\n kernel[(10,)](inp, out, 10, XBLOCK=16)\n gc.collect()\n end, _ = tracemalloc.get_traced_memory()\n # Assert that the memory usage has not increased significantly\n assert end - begin < 5000\n finally:\n tracemalloc.stop()\n", - "description_1": "Use triton language to define a kernel function 'kernel' with four parameters: in_ptr0 (input pointer), out_ptr0 (output pointer), xnumel (number of elements), and XBLOCK (block size as a compile-time constant). The kernel calculates an offset and index for each element in a block, applies a mask to ensure indices are within bounds, loads input data using the mask, and stores the result back to the output pointer. The kernel is called in a function 'test_memory_leak' which initializes input and output tensors, launches the kernel, and checks for memory leaks by running the kernel multiple times and comparing memory usage before and after.", - "description_2": "Use triton language to define a kernel that processes elements in blocks, applies bounds checking, and performs masked load/store operations. Implement a function to test the kernel for memory leaks by running it multiple times and monitoring memory usage.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport multiprocessing\nfrom collections import namedtuple\n\ninstance_descriptor = namedtuple(\"instance_descriptor\", [\"divisible_by_16\", \"equal_to_1\"])\n\ndef compile_fn(config, cc):\n @triton.jit\n def kernel_sub(a, b, o, N: tl.constexpr):\n # Kernel to perform element-wise subtraction and multiplication\n idx = tl.arange(0, N)\n tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)\n triton.compile(\n fn=kernel_sub,\n signature={0: \"*fp32\", 1: \"*fp32\", 2: \"*fp32\"},\n device=0,\n constants={3: 32},\n configs=[config],\n warm_cache_only=True,\n cc=cc,\n )\n\ndef test_compile_in_subproc() -> None:\n # Test function to compile kernel_sub in a subprocess\n major, minor = torch.cuda.get_device_capability(0)\n cc = major * 10 + minor\n config = instance_descriptor(tuple(range(4)), ())\n\n multiprocessing.set_start_method('fork')\n proc = multiprocessing.Process(\n target=compile_fn,\n args=(config, cc))\n proc.start()\n proc.join()\n assert proc.exitcode == 0\n\ndef compile_fn_dot(config, cc):\n @triton.jit\n def kernel_dot(Z):\n # Kernel to perform dot product on a 16x16 block\n offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]\n z = tl.load(Z + offs)\n z = tl.dot(z, z)\n tl.store(Z + offs, z)\n\n triton.compile(\n fn=kernel_dot,\n signature={0: \"*fp32\"},\n device=0,\n configs=[config],\n warm_cache_only=True,\n cc=cc,\n )\n\ndef test_compile_in_forked_subproc() -> None:\n # Test function to compile kernel_dot in a subprocess\n reset_tmp_dir()\n major, minor = torch.cuda.get_device_capability(0)\n cc = major * 10 + minor\n config = instance_descriptor(tuple(range(1)), ())\n\n assert multiprocessing.get_start_method() == 'fork'\n proc = multiprocessing.Process(\n target=compile_fn_dot,\n args=(config, cc))\n proc.start()\n proc.join()\n assert proc.exitcode == 0\n", - "description_1": "Use triton language to define two kernels: 'kernel_sub' which performs element-wise subtraction and multiplication on input arrays 'a' and 'b', storing the result in 'o'. It takes 4 parameters: 'a', 'b', 'o' (all pointers to float32 arrays), and 'N' (a constant expression for the range). 'kernel_dot' performs a dot product on a 16x16 block of the input array 'Z'. It takes 1 parameter: 'Z' (a pointer to a float32 array). Both kernels are compiled with specific configurations and device capabilities.", - "description_2": "Use triton language to define and compile kernels for element-wise operations and block-wise dot products on GPU arrays.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(C, A, B,\n stride_cm, stride_cn,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr):\n # Define the range of indices for each block\n ms = tl.arange(0, BLOCK_M)\n ns = tl.arange(0, BLOCK_N)\n ks = tl.arange(0, BLOCK_K)\n \n # Load blocks of A and B matrices\n a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak)\n b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn)\n \n # Compute the dot product\n c = tl.dot(a, b)\n \n # Square the result using a utility function\n c = kernel_utils.mul(c, c)\n \n # Store the result in matrix C\n tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel that computes the dot product of sub-blocks of matrices A and B, squares the result, and stores it in matrix C. The kernel takes 11 parameters: C, A, B (pointers to matrices), stride_cm, stride_cn, stride_am, stride_ak, stride_bk, stride_bn (stride values for accessing matrix elements), and BLOCK_M, BLOCK_N, BLOCK_K (block sizes for the computation).", - "description_2": "Use triton language to create a kernel for matrix multiplication with block-wise computation and result squaring.", - "difficulty": 2 - }, - { - "code": "import triton\n\n@triton.jit\ndef _argmax_combine_tie_break_left(value1, index1, value2, index2):\n tie = value1 == value2 and index1 < index2\n gt = value1 > value2 or tie\n v_ret = triton.language.where(gt, value1, value2)\n i_ret = triton.language.where(gt, index1, index2)\n return v_ret, i_ret\n\n@triton.jit\ndef _argmax_combine_tie_break_fast(value1, index1, value2, index2):\n tie = False\n gt = value1 > value2 or tie\n v_ret = triton.language.where(gt, value1, value2)\n i_ret = triton.language.where(gt, index1, index2)\n return v_ret, i_ret\n\n@triton.jit\ndef _fast_max(x, y):\n return triton.language.math.max(x, y)\n\n@triton.jit\ndef max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):\n input = triton.language._promote_reduction_input(input)\n if return_indices:\n if return_indices_tie_break_left:\n return triton.language._reduce_with_indices(input, axis, _argmax_combine_tie_break_left)\n else:\n return triton.language._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)\n else:\n if triton.language.constexpr(input.dtype.primitive_bitwidth) < 32:\n if triton.language.constexpr(input.dtype.is_floating()):\n input = input.to(triton.language.float32)\n else:\n assert input.dtype.is_integer_type()\n input = input.to(triton.language.int32)\n return triton.language.reduce(input, axis, _fast_max)\n\n@triton.jit\ndef argmax(input, axis, tie_break_left=True):\n (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)\n return ret\n", - "description_1": "Use triton language to define kernels and functions for computing the maximum value and its index along a specified axis. The kernels `_argmax_combine_tie_break_left`, `_argmax_combine_tie_break_fast`, `_fast_max`, `max`, and `argmax` handle the combination of values and indices, computing the maximum value with optional index retrieval and tie-breaking strategies.", - "description_2": "Use triton language to implement reduction kernels for maximum value computation with index tracking and tie-breaking options.", - "difficulty": 3 - }, - { - "code": "import triton\n\n# Example kernel decorated with @triton.jit\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel computation here...\n\n# The function to call the kernel\ndef call_kernel(x_ptr, x_size):\n # Define meta-parameters, e.g., BLOCK_SIZE\n META = {'BLOCK_SIZE': 128}\n # Call the kernel\n kernel[(1,)](x_ptr, x_size, **META)\n", - "description_1": "Use triton language to define a kernel with @triton.jit decorator that takes x_ptr (pointer), x_size (integer), and META (keyword arguments) as parameters. The kernel uses META['BLOCK_SIZE'] to perform computations. Define a function to call this kernel with predefined meta-parameters and execute it.", - "description_2": "Use triton language to create a kernel with pointer and integer inputs, utilizing a meta-parameter for block size, and define a function to execute this kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel function\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to call the Triton kernel\ndef add(x: torch.Tensor, y: torch.Tensor):\n assert x.is_cuda and y.is_cuda\n assert x.shape == y.shape\n output = torch.empty_like(x)\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n", - "description_1": "Use triton language to define a kernel function 'add_kernel' that takes pointers to two input tensors 'x_ptr' and 'y_ptr', a pointer to an output tensor 'output_ptr', the number of elements 'n_elements', and a block size 'BLOCK_SIZE'. The kernel computes the element-wise sum of 'x' and 'y' and stores the result in 'output'. The function 'add' calls this kernel, ensuring the input tensors are on CUDA, have the same shape, and prepares an output tensor. It calculates the grid size based on the number of elements and block size, and launches the kernel.", - "description_2": "Use triton language to create a kernel for element-wise addition of two CUDA tensors and a function to launch this kernel.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel function 'add_kernel' takes five parameters: pointers to the input vectors x and y, a pointer to the output vector, the number of elements in the vectors, and a block size as a compile-time constant. The kernel computes the element-wise sum of x and y, storing the result in the output vector. The 'add' function is a wrapper that prepares the output tensor, sets up the grid for kernel execution, and launches the kernel with the specified block size.", - "description_2": "Use triton language to create a kernel for element-wise vector addition and a wrapper function to execute it on CUDA tensors.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n x = x + 1\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation\n )\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with auto-tuning capabilities. The kernel takes pointers to matrices A, B, and C, their dimensions M, N, K, and stride information for each matrix. It also accepts meta-parameters for block sizes and group size, and an optional activation function. The kernel computes the product of matrices A and B, storing the result in C, with optional activation applied.", - "description_2": "Use triton language to create a matrix multiplication kernel with configurable block sizes and optional activation function, optimized for L2 cache reuse.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\nx = torch.randn(size=(10,)).cuda()\n# Dropout mask\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\noutput = dropout(x, x_keep=x_keep, p=p)\n\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\n", - "description_1": "Use triton language to implement two dropout kernels. The first kernel, _dropout, takes six parameters: pointers to input, mask, and output tensors, the number of elements, dropout probability, and block size. It applies dropout using a precomputed mask. The second kernel, _seeded_dropout, takes six parameters: pointers to input and output tensors, the number of elements, dropout probability, a random seed, and block size. It applies dropout using a generated random mask based on the seed. Both kernels are called by their respective wrapper functions, dropout and seeded_dropout, which handle tensor preparation and kernel invocation.", - "description_2": "Use triton language to create a dropout kernel with a precomputed mask and another with a random seed-based mask.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, Y, W, B, Mean, Rstd, stride, N, eps, BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Write mean / rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n DX, DY, DW, DB, X, W, B, Mean, Rstd, Lock, stride, N, eps, GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n X += row * stride\n DY += row * stride\n DX += row * stride\n # Offset locks and weights/biases gradient pointer for parallel reduction\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n DB = DB + lock_id * N + cols\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n # Write dx\n tl.store(DX + cols, dx, mask=mask)\n # Accumulate partial sums for dw/db\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = (dy).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n # First store doesn't accumulate\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n # Release the lock\n tl.atomic_xchg(Lock, 0)\n\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n DW, DB, FINAL_DW, FINAL_DB, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr\n):\n # Map the program id to the elements of DW and DB it should compute.\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n # Iterate through the rows of DW and DB to sum the partial sums.\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n # Write the final sum to the output.\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n tl.store(FINAL_DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device='cuda')\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DW/DB\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n # allocate output\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,\n x_arg.stride(0), N, ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n # accumulate partial sums in separate kernel\n _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128)\n return dx, None, dw, db, None\n\n\nlayer_norm = LayerNorm.apply\n", - "description_1": "Use triton language to implement a layer normalization function with three kernels: '_layer_norm_fwd_fused', '_layer_norm_bwd_dx_fused', and '_layer_norm_bwd_dwdb'. The '_layer_norm_fwd_fused' kernel normalizes the input tensor and applies a linear transformation using weights and biases, computing the mean and variance for normalization. The '_layer_norm_bwd_dx_fused' kernel computes gradients with respect to the input and accumulates partial gradients for the weights and biases using parallel reduction. The '_layer_norm_bwd_dwdb' kernel sums the partial gradients across different program instances. The primary function 'LayerNorm' encapsulates both forward and backward passes, enabling efficient layer normalization suitable for GPUs.", - "description_2": "Use triton language to create a high-performance layer normalization function with kernels for forward and backward passes, leveraging parallel reduction for gradient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n L,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n):\n # Kernel function code...\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # scale sm_scale by log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.float16)\n # loop over k, v and update accumulator\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc_scale = l_i * 0 + alpha # workaround some compiler bug\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.float16), v)\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n # write back l and m\n acc = acc / l_i[:, None]\n l_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, m_i + tl.math.log2(l_i))\n # write back O\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(O_block_ptr, acc.to(tl.float16))\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO,\n Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n # Kernel function code...\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n # compute\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n):\n # Kernel function code...\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n qk_scale = sm_scale * 1.44269504\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n if CAUSAL:\n lo = start_n * BLOCK_M\n else:\n lo = 0\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n l_ptrs = L + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n if CAUSAL:\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float(\"-inf\"))\n else:\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n qk *= qk_scale\n l_i = tl.load(l_ptrs + offs_m_curr)\n p = tl.math.exp2(qk - l_i[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, tl.trans(v))\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)\n # compute dq\n dq = tl.load(dq_ptrs)\n dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n tl.store(dq_ptrs, dq)\n # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n BLOCK_M = 128\n BLOCK_N = 64\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n L,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n IS_CAUSAL=causal,\n num_warps=num_warps,\n num_stages=4)\n\n ctx.save_for_backward(q, k, v, o, L)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n BLOCK = 128\n q, k, v, o, L = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n delta = torch.empty_like(L)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do,\n delta,\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do,\n dq, dk, dv,\n L, delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n CAUSAL=ctx.causal,\n num_stages=1,\n )\n return dq, dk, dv, None, None\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement Flash Attention algorithm with three kernel functions. `_fwd_kernel` takes 27 parameters and performs forward pass to compute attention outputs. `_bwd_preprocess` uses 4 parameters to perform operations needed before backward pass. `_bwd_kernel` with 29 parameters computes gradients for inputs in backward pass. The `_attention` class utilizes these kernels to execute forward and backward operations.", - "description_2": "Use triton language to implement forward and backward kernels for attention mechanism, optimizing operations like matrix multiplications and softmax computations, while handling gradient calculations efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n # Calculate program ID and offsets for each block\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load elements from x_ptr with masking\n x = tl.load(x_ptr + offsets, mask=mask)\n # Apply the arc sine function using triton's libdevice support\n x = tl.math.asin(x)\n # Store the results back into y_ptr\n tl.store(y_ptr + offsets, x, mask=mask)\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n# Invoke the Triton kernel for arc sine calculation\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)\n\noutput_triton = torch.empty_like(x)\n# Invoke the Triton kernel with custom libdevice path\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)\n", - "description_1": "Use triton language to define and execute a kernel that computes the arc sine of elements from an input tensor and stores the results in an output tensor. The kernel uses a BLOCK_SIZE parameter to determine execution configuration and masks loads and stores to respect tensor boundaries.", - "description_2": "Use triton language to create a kernel that applies the arc sine function on a CUDA tensor using libdevice's asin function and stores the result in another CUDA tensor, with support for custom libdevice paths.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel_with_block_pointers(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),\n offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),\n order=(1, 0))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),\n offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),\n order=(1, 0))\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n a = tl.load(a_block_ptr, boundary_check=(0, 1))\n b = tl.load(b_block_ptr, boundary_check=(0, 1))\n accumulator += tl.dot(a, b)\n a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))\n b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))\n c = accumulator.to(tl.float16)\n\n c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),\n offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),\n block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))\n tl.store(c_block_ptr, c, boundary_check=(0, 1))\n\n\ndef matmul(a, b):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n matmul_kernel_with_block_pointers[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n )\n return c\n\n\ntorch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\ntriton_output = matmul(a, b)\ntorch_output = torch.matmul(a, b)\nprint(f\"triton_output={triton_output}\")\nprint(f\"torch_output={torch_output}\")\nif torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):\n print(\"✅ Triton and Torch match\")\nelse:\n print(\"❌ Triton and Torch differ\")\n", - "description_1": "Use triton language to implement a matrix multiplication kernel using block pointers. The kernel function 'matmul_kernel_with_block_pointers' has 17 parameters: three pointers to matrices (a_ptr, b_ptr, c_ptr), three matrix dimensions (M, N, K), six stride variables representing memory strides of the input matrices, and four meta-parameters that define block sizes and group size. The kernel uses block pointers to load blocks of matrices A and B, computes their dot product, and stores the result in matrix C. The function 'matmul' serves as a wrapper for the kernel, ensuring input matrices' shape constraints, allocating output, and launching the kernel.", - "description_2": "Use triton language to create a block-pointer-based matrix multiplication kernel for enhanced memory access patterns, handling matrix dimensions, and strides as inputs.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward pass of RMS normalization\n@triton.jit\ndef _rms_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n Rstd, # pointer to the 1/std\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Write rstd\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\n# Triton kernel for backward pass computing dx\n@triton.jit\ndef _rms_norm_bwd_dx_fused(\n DX, # pointer to the input gradient\n DY, # pointer to the output gradient\n DW, # pointer to the partial sum of weights gradient\n X, # pointer to the input\n W, # pointer to the weights\n Rstd, # pointer to the 1/std\n Lock, # pointer to the lock\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n GROUP_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n X += row * stride\n DY += row * stride\n DX += row * stride\n # Offset locks and weights/biases gradient pointer for parallel reduction\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = x * rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - (xhat * c1)) * rstd\n # Write dx\n tl.store(DX + cols, dx, mask=mask)\n # Accumulate partial sums for dw/db\n partial_dw = (dy * xhat).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n # First store doesn't accumulate\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n # Release the lock\n tl.atomic_xchg(Lock, 0)\n\n\n# Triton kernel for accumulating partial weight gradients\n@triton.jit\ndef _rms_norm_bwd_dwdb(\n DW, # pointer to the partial sum of weights gradient\n FINAL_DW, # pointer to the weights gradient\n M, # GROUP_SIZE_M\n N, # number of columns\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr):\n # Map the program id to the elements of DW and DB it should compute.\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n # Iterate through the rows of DW and DB to sum the partial sums.\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n # Write the final sum to the output.\n sum_dw = tl.sum(dw, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n\n\nclass RMSNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, weight, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\n \"This rms norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _rms_norm_fwd_fused[(M, )](\n x_arg,\n y,\n weight,\n rstd,\n x_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(x, weight, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, v = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DW/DB\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192:\n GROUP_SIZE_M = 96\n if N <= 4096:\n GROUP_SIZE_M = 128\n if N <= 1024:\n GROUP_SIZE_M = 256\n # allocate output\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]),\n dtype=x.dtype,\n device=w.device)\n dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _rms_norm_bwd_dx_fused[(M, )](\n dx,\n dy,\n _dw,\n x,\n w,\n v,\n locks,\n x_arg.stride(0),\n N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps)\n\n def grid(meta):\n return [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n\n # accumulate partial sums in separate kernel\n _rms_norm_bwd_dwdb[grid](\n _dw,\n dw,\n GROUP_SIZE_M,\n N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128,\n )\n return dx, dw, None\n\n\nrms_norm = RMSNorm.apply\n\n\ndef rms_norm_forward(self, hidden_states):\n if (hidden_states.device == torch.device('cpu')\n or self.weight.device == torch.device('cpu')):\n raise RuntimeError(\n 'Can not use triton kernels on cpu. Please set `USE_TRITON_KERNEL`'\n ' environment variable to 0 before training.')\n return rms_norm(hidden_states, self.weight, self.variance_epsilon)\n", - "description_1": "Use triton language to implement RMS normalization. The first kernel '_rms_norm_fwd_fused' normalizes and scales inputs using a given weight and computes the reciprocal of the standard deviation. The second kernel '_rms_norm_bwd_dx_fused' computes the gradient with respect to the input and accumulates partial weight gradients. The third kernel '_rms_norm_bwd_dwdb' sums the partial weight gradients. The 'RMSNorm' class utilizes these kernels in a custom PyTorch autograd function for both the forward and backward passes.", - "description_2": "Use triton language to perform forward and backward passes of RMS normalization by implementing custom kernels for normalization and gradient computation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Kernel function with @triton.jit decorator\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\n# Function to call the kernel\ndef call_kernel(x_ptr, x_size):\n # Example of calling the kernel\n kernel[(1,)](x_ptr, x_size, BLOCK_SIZE=128)\n\n# Example usage\nx = torch.tensor([1, 2, 3, 4], dtype=torch.float32)\ncall_kernel(x.data_ptr(), x.size(0))\n", - "description_1": "Use triton language to define a kernel function 'kernel' with 2 parameters: x_ptr (pointer to data) and x_size (size of data). The kernel uses a meta-parameter BLOCK_SIZE. A function 'call_kernel' is used to invoke this kernel with specific arguments.", - "description_2": "Use triton language to define a kernel with parameters for data pointer and size, and a meta-parameter for block size. Implement a function to call this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None,\n :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c = accumulator.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None,\n :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c = accumulator.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)\n matmul_248_kernel[grid](input, qweight, output,\n scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0))\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device='cuda', dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)\n trans_matmul_248_kernel[grid](input, qweight, output,\n scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_dim, bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0))\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: 'matmul_248_kernel' and 'trans_matmul_248_kernel'. The first kernel computes C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel computes C = A x B where A is a float16 matrix of shape (M, N), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, K). Both kernels use scales and zeros for quantization and dequantization, and they handle bit-packing of B. The kernels are called by 'matmul248' and 'transpose_matmul248' functions respectively, which prepare the output tensor and grid configuration for the kernel execution.", - "description_2": "Use triton language to implement matrix multiplication kernels with quantization support, handling bit-packing and dequantization of input matrices.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel function with @triton.jit decorator\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\n# Function to call the kernel\ndef call_kernel(x_ptr, x_size):\n # Example of calling the kernel\n kernel[(1,)](x_ptr, x_size, BLOCK_SIZE=128)\n", - "description_1": "Use triton language to define a kernel function 'kernel' with two parameters: 'x_ptr' (pointer to data) and 'x_size' (size of the data). The kernel uses a meta-parameter 'BLOCK_SIZE' to control block size. A separate function 'call_kernel' is used to invoke this kernel with specific arguments and a block size of 128.", - "description_2": "Use triton language to create a kernel that processes data with a specified block size, and provide a function to execute this kernel.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out,\n Lse, TMP,\n softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == 'vector':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == 'matrix':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0)\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != 'none':\n if BIAS_TYPE == 'vector':\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == 'matrix':\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((start_n + offs_n)[None, :] < seqlen_k),\n other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o,\n mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'\n ' or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q, k, v, bias, o,\n lse, tmp,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n o.stride(0), o.stride(2), o.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, softmax_scale\n", - "description_1": "Use triton language to define a flash attention forward kernel (_fwd_kernel) with parameters for Q, K, V, Bias, Out, Lse, TMP, softmax_scale, various stride and size parameters, cache keys, bias type, causality, block head dimension, even flags, and block sizes. The function implements a forward pass for flash attention by computing QK products, applying biases, and accumulating values for the output. The function _flash_attn_forward is a wrapper that prepares the input parameters, asserts conditions, and launches the kernel with specific grid and block configurations for execution.", - "description_2": "Use triton language to implement a flash attention forward function that calculates QK products, applies biases, and accumulates results based on input queries (Q), keys (K), values (V), and optional biases, with parameters for block sizes, causality, and scaling, executed as a Triton kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None,\n :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n # ! Convert to fp16\n b = b.to(tl.float16)\n a = a.to(tl.float16)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c = accumulator.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\n@triton.jit\ndef trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None,\n :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n # ! Convert to fp16\n b = b.to(tl.float16)\n a = a.to(tl.float16)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c = accumulator.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):\n assert input.shape[-1] == qweight.shape[0] * 32 // bits\n outshape = input.shape[:-1] + (qweight.shape[1],)\n input = input.reshape(-1, input.shape[-1])\n output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)\n matmul_248_kernel[grid](input, qweight, output,\n scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0))\n output = output.reshape(outshape)\n return output\n\n\ndef triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):\n assert input.shape[-1] == qweight.shape[1]\n out_dim = qweight.shape[0] * 32 // bits\n outshape = input.shape[:-1] + (out_dim,)\n input = input.reshape(-1, input.shape[-1])\n output_shape_mid = (input.shape[0], out_dim)\n output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),)\n trans_matmul_248_kernel[grid](input, qweight, output,\n scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0))\n output = output.reshape(outshape)\n return output\n", - "description_1": "Use triton language to implement matrix multiplication kernels for two variations: normal and transpose. Each kernel has a specific configuration and implements detailed processes like unpacking 32-bit values, fetching scales and zeros, shifting, scaling, converting data types, and using dot products to compute results. The matrix multiplication involves multiple parameters including input matrices, scales, zeros, configuration constants (like block size), and strides for each dimension. The kernels are called in Python functions which reshape inputs, set up grids for kernel execution, and store outputs.", - "description_2": "Use triton language to implement and invoke kernels for performing efficient matrix multiplication and transposed matrix multiplication with specific configurations and processing steps for precision control.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport torch.distributed as dist\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n L, O,\n MAX, DENOM,\n stride_q_bs, stride_q_head, stride_q_seqlen, stride_q_dim,\n stride_k_bs, stride_k_head, stride_k_seqlen, stride_k_dim,\n stride_v_bs, stride_v_head, stride_v_seqlen, stride_v_dim,\n BS, HEAD, SEQLEN,\n DIM: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_bs_head = tl.program_id(1)\n\n qkv_base_offset = off_bs_head * stride_q_head\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_q_seqlen, stride_q_dim),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, DIM),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qkv_base_offset,\n shape=(DIM, SEQLEN),\n strides=(stride_k_dim, stride_k_seqlen),\n offsets=(0, 0),\n block_shape=(DIM, BLOCK_N),\n order=(0, 1),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_k_seqlen, stride_v_dim),\n offsets=(0, 0),\n block_shape=(BLOCK_N, DIM),\n order=(1, 0),\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n max_ptr = MAX + off_bs_head * SEQLEN + offs_m\n max = tl.load(max_ptr)\n denom_ptr = DENOM + off_bs_head * SEQLEN + offs_m\n denom = tl.load(denom_ptr)\n\n O_block_ptr = tl.make_block_ptr(\n base=O + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_q_seqlen, stride_q_dim),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, DIM),\n order=(1, 0),\n )\n out_buffer = tl.load(O_block_ptr)\n out_buffer = out_buffer.to(tl.float32)\n\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.float16)\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else SEQLEN\n\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n\n max_new = tl.maximum(max, tl.max(qk, 1))\n alpha = tl.math.exp2(max - max_new)\n nume = tl.math.exp2(qk - max_new[:, None])\n out_scale = denom * 0 + alpha\n out_buffer *= out_scale[:, None]\n out_buffer += tl.dot(nume.to(tl.float16), v)\n denom = denom * alpha + tl.sum(nume, 1)\n max = max_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n\n tl.store(max_ptr, max)\n tl.store(denom_ptr, denom)\n tl.store(O_block_ptr, out_buffer.to(tl.float16))\n\n@triton.jit\ndef _rescale(\n L, O,\n DENOM,\n stride_o_bs, stride_o_head, stride_o_seqlen, stride_o_dim,\n BS, HEAD, SEQLEN,\n DIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_bs_head = tl.program_id(1)\n\n qkv_base_offset = off_bs_head * stride_o_head\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n denom_ptr = DENOM + off_bs_head * SEQLEN + offs_m\n denom = tl.load(denom_ptr)\n\n O_block_ptr = tl.make_block_ptr(\n base=O + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_o_seqlen, stride_o_dim),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, DIM),\n order=(1, 0)\n )\n out_buffer = tl.load(O_block_ptr)\n out_buffer = out_buffer.to(tl.float16)\n\n out_buffer = out_buffer / denom[:, None]\n tl.store(O_block_ptr, out_buffer.to(tl.float16))\n\ndef ring_attention(q, k, v, causal=True, sm_scale=1):\n rank = dist.get_rank()\n world_size = dist.get_world_size()\n\n bs, head, seqlen, dim = q.shape\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n buffer_k, buffer_v = prepare_kv_double_buffer(k, v)\n\n max = torch.full((bs, head, seqlen), fill_value=-float(\"inf\"), device=q.device, dtype=torch.float32).contiguous()\n denom = torch.zeros((bs, head, seqlen), device=q.device, dtype=torch.float32).contiguous()\n local_o = torch.empty_like(q)\n\n BLOCK_M = 128\n BLOCK_N = 64\n\n group_size = triton.cdiv(seqlen, BLOCK_M)\n grid = (group_size, q.shape[0] * q.shape[1], 1)\n\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n for time_step in range(world_size):\n buf_id1 = time_step % 2\n buf_id2 = (time_step - 1) % 2\n\n local_q = q\n local_k = buffer_k[buf_id1]\n local_v = buffer_v[buf_id1]\n\n _fwd_kernel[grid](\n local_q, local_k, local_v, sm_scale,\n L, local_o,\n max, denom,\n local_q.stride(0), local_q.stride(1), local_q.stride(2), local_q.stride(3),\n local_k.stride(0), local_k.stride(1), local_k.stride(2), local_k.stride(3),\n local_v.stride(0), local_v.stride(1), local_v.stride(2), local_v.stride(3),\n bs, head, seqlen,\n DIM=Lk,\n IS_CAUSAL=causal,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, \n num_warps=num_warps,\n num_stages=4)\n\n torch.cuda.synchronize()\n step_kv_send_recv(buffer_k[buf_id1], buffer_k[buf_id2], buffer_v[buf_id1], buffer_v[buf_id2])\n\n _rescale[grid](\n L, local_o,\n denom,\n local_o.stride(0), local_o.stride(1), local_o.stride(2), local_o.stride(3),\n bs, head, seqlen,\n DIM=Lk,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)\n torch.cuda.synchronize()\n\n res_o = [torch.empty_like(q, dtype=local_o.dtype) for _ in range(world_size)]\n dist.all_gather(res_o, local_o)\n res_o = torch.cat(res_o, dim=-2)\n\n return res_o\n", - "description_1": "Use triton language to implement a forward kernel (_fwd_kernel) and a rescale kernel (_rescale) for attention mechanism. The forward kernel takes 22 parameters: Q, K, V, sm_scale, L, O, MAX, DENOM, stride_q_bs, stride_q_head, stride_q_seqlen, stride_q_dim, stride_k_bs, stride_k_head, stride_k_seqlen, stride_k_dim, stride_v_bs, stride_v_head, stride_v_seqlen, stride_v_dim, BS, HEAD, SEQLEN, and 5 constexpr parameters: DIM, IS_CAUSAL, BLOCK_M, BLOCK_N. It computes the scaled dot-product attention and updates the output tensor O. The rescale kernel takes 11 parameters: L, O, DENOM, stride_o_bs, stride_o_head, stride_o_seqlen, stride_o_dim, BS, HEAD, SEQLEN, and 3 constexpr parameters: DIM, BLOCK_M, BLOCK_N. It rescales the output tensor O by dividing it by the denominator tensor DENOM.", - "description_2": "Use triton language to implement a ring attention function (ring_attention) that orchestrates the execution of the forward and rescale kernels. The function takes 4 parameters: q, k, v, causal, and sm_scale. It prepares double buffers for k and v, initializes max and denom tensors, and iteratively calls the forward kernel to compute attention scores. After processing all time steps, it calls the rescale kernel to finalize the output tensor. The function returns the gathered output tensor across all distributed ranks.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch._inductor.triton_heuristics import reduction\nfrom torch._inductor import triton_helpers\n\n# Kernel: Reduction operation on a tensor using Triton\n@reduction(size_hints=[512, 64], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: ('pointer', 'float32')}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'triton_sum_kernel'})\n@triton.jit\ndef triton_sum_kernel(out_ptr, numel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * RBLOCK\n offset = block_start + tl.arange(0, RBLOCK)\n mask = offset < numel\n ptrs = out_ptr + offset\n x = tl.load(ptrs, mask=mask, other=0)\n \n acc = tl.zeros([1], dtype=tl.float32)\n acc += x\n \n acc = tl.sum(acc, 0)\n if mask[0]:\n tl.store(out_ptr + block_start, acc)\n\n# Function: Execute the Triton kernel\ndef sum_triton(x):\n assert x.is_cuda, \"Input must be a CUDA tensor\"\n numel = x.numel()\n output = torch.empty_like(x)\n triton_sum_kernel[(numel,)](output, numel, XBLOCK=512, RBLOCK=64)\n return output\n\n# Example Usage\nx = torch.randn(1024, device='cuda', dtype=torch.float32)\nsum_result = sum_triton(x)\n", - "description_1": "Use triton language to define a kernel 'triton_sum_kernel' that performs a reduction operation (sum) on a tensor. The kernel is executed with a function 'sum_triton' which takes a CUDA tensor, calculates the sum of its elements using the Triton kernel, and returns the result.", - "description_2": "Use triton language to define a kernel for reduction, and execute it to compute the sum of a CUDA tensor's elements.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import foreach\n\nclass ForeachKernel:\n # Other class members and functions\n\n def jit_line(self):\n # Return a string that serves as a decorator line for the kernel\n can_use_32bit = all(k.index_dtype == \"tl.int32\" for k in self.sub_kernels)\n index_dtype = \"tl.int32\" if can_use_32bit else \"tl.int64\"\n _, _, signature = self.args.python_argdefs()\n triton_meta = {\n \"signature\": signature_to_meta(signature, size_dtype=can_use_32bit),\n \"device\": V.graph.scheduler.current_device.index,\n \"device_type\": V.graph.scheduler.current_device.type,\n \"constants\": {},\n }\n triton_meta[\"configs\"] = [config_of(signature)]\n return (\n f\"@foreach(num_warps={self.num_warps}, meta={triton_meta!r})\\n\"\n + \"@triton.jit\"\n )\n\n def codegen_kernel(self, name=None):\n # Generate the kernel code\n code = IndentedBuffer()\n code.splice(\n \"\"\"\n import triton\n import triton.language as tl\n from torch._inductor.triton_heuristics import foreach\n from torch._inductor.utils import instance_descriptor\n from torch._inductor import triton_helpers\n \"\"\"\n )\n argdefs, _, _ = self.args.python_argdefs()\n code.writeline(self.jit_line())\n code.writeline(f\"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):\")\n \n with code.indent():\n code.splice(\"xpid = tl.program_id(0)\")\n if self.blocking_2d:\n code.splice(\"ypid = tl.program_id(1)\")\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_2d}\")\n code.splice(f\"YBLOCK: tl.constexpr = {self.block_size_2d}\")\n else:\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_1d}\")\n\n for sub_kernel in self.sub_kernels:\n assert len(sub_kernel.numels) <= 3\n numel_ind = 0 if not self.blocking_2d else 1\n self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))\n with code.indent():\n if self.blocking_2d:\n code.splice(f\"ynumel = {sub_kernel.numels[0]}\")\n code.splice(f\"xnumel = {sub_kernel.numels[1]}\")\n else:\n code.splice(f\"xnumel = {sub_kernel.numels[0]}\")\n\n sub_kernel.codegen_body()\n code.splice(sub_kernel.body)\n\n code.splice(\"else:\")\n with code.indent():\n code.splice(\"pass\")\n\n return code.getvalue()\n\n def call_kernel(self, code, name: str):\n # Call the generated kernel\n _, call_args, _ = self.args.python_argdefs()\n for i in range(len(call_args)):\n if V.graph.is_unspec_arg(call_args[i]):\n call_args[i] = call_args[i] + \".item()\"\n if V.graph.cpp_wrapper:\n V.graph.wrapper_code.generate_kernel_call(\n name, call_args, device_index=V.graph.scheduler.current_device.index\n )\n else:\n call_args_str = \", \".join(call_args)\n stream_name = code.write_get_cuda_stream(\n V.graph.scheduler.current_device.index\n )\n code.writeline(\n f\"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})\"\n )\n", - "description_1": "Use triton language to define a kernel within the ForeachKernel class that supports the generation and execution of a Triton kernel with specific configurations such as block sizes, warp counts, and indexing dtype, allowing dynamic shape and type management.", - "description_2": "Use triton language to create and call a kernel with configurations for block size and device type, supporting both 1D and 2D blocking with dynamic execution contexts.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Dummy code to indicate a Triton kernel with @triton.jit\n@triton.jit\ndef example_kernel(input1, input2, output):\n # Triton kernel code\n pass\n\n# Call to Triton kernel\ndef call_example_kernel():\n input1 = ... # Initialize input\n input2 = ... # Initialize input\n output = ... # Initialize output\n example_kernel[(1,)](input1, input2, output)\n", - "description_1": "Use triton language to define a kernel `example_kernel` with three parameters: input1, input2, and output. The kernel performs computations on these inputs and writes the result to the output. Then, define a function `call_example_kernel` to set up inputs and outputs and call the `example_kernel` with a specified grid.", - "description_2": "Use triton language to define a basic kernel and a function to call it, processing two input tensors and storing the result in an output tensor.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight):\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n return (\n new_mean,\n m2 + delta * (value - new_mean),\n new_weight,\n )\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n", - "description_1": "Use triton language to implement various kernels for tensor operations, including tensor promotion, floating point check, product accumulation, minimum and maximum calculations, reduction with indices, Welford's method for variance calculation, device assertions, random integer generation, logical 'any' reduction, and bucketization using binary search. The kernels use triton's primitives such as reduction, type promotion, and logical operations. Parameters for each function typically involve tensors for input data, dimensions for operations, and auxiliary data for indexing and type information.", - "description_2": "Use triton language to perform tensor operations such as reductions, element-wise calculations, and statistical methods. Implement kernels that leverage triton primitives to conduct efficient data manipulation on tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional, Tuple\n\n# Triton Kernel 1: _sampled_addmm_kernel\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n# Triton Kernel 2: _bsr_strided_dense_rowspace_kernel\n@triton.jit\ndef _bsr_strided_dense_rowspace_kernel(\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n dense_ptr,\n dense_batch_stride,\n dense_tiled_row_stride,\n dense_tiled_col_stride,\n dense_row_block_stride,\n dense_col_block_stride,\n output_ptr,\n output_batch_stride,\n output_tiled_row_stride,\n output_tiled_col_stride,\n output_row_block_stride,\n output_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n GROUP_SIZE_ROW: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=2)\n row_block_pid = tl.program_id(axis=0)\n col_block_pid = tl.program_id(axis=1)\n n_block_rows = tl.num_programs(axis=0)\n n_block_cols = tl.num_programs(axis=1)\n\n row_block_pid, col_block_pid = tl.swizzle2d(\n row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW\n )\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n dense_block_ptrs = (\n dense_ptr\n + dense_batch_stride * batch_pid\n + dense_tiled_col_stride * col_block_pid\n + dense_row_block_stride * col_block_arange[:, None]\n + dense_col_block_stride * row_block_arange[None, :]\n )\n\n output_ptrs = (\n output_ptr\n + output_batch_stride * batch_pid\n + output_tiled_row_stride * row_block_pid\n + output_tiled_col_stride * col_block_pid\n + output_row_block_stride * row_block_arange[:, None]\n + output_col_block_stride * row_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), dtype=acc_dtype)\n for _ in range(row_nnz):\n values_block = tl.load(values_block_ptrs)\n dense_row_idx = tl.load(col_index_nnz_ptr)\n dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)\n\n output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32)\n\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))\n\n# Triton Kernel 3: _bsr_softmax_kernel\n@triton.jit\ndef _bsr_softmax_kernel(\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n values_ptr,\n values_batch_stride,\n values_row_block_stride,\n values_nnz_col_block_stride,\n row_block, col_block,\n MAX_ROW_NNZ: tl.constexpr,\n TILE: tl.constexpr\n):\n batch_pid = tl.program_id(axis=2)\n row_block_offset_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_arange = tl.arange(0, TILE)\n mask = row_arange < row_nnz * col_block\n\n curr_row_values_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_row_block_stride * row_block_offset_pid\n + nnz_offset * col_block\n )\n\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n max_row_value = tl.max(row_tile, axis=0)\n for _ in range(TILE, MAX_ROW_NNZ, TILE):\n row_arange += TILE\n mask = row_arange < row_nnz * col_block\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n curr_max_row_value = tl.max(row_tile, axis=0)\n max_row_value = tl.where(max_row_value > curr_max_row_value, max_row_value, curr_max_row_value)\n\n num = tl.exp(row_tile - max_row_value)\n denom = tl.sum(num, axis=0)\n for _ in range(TILE, MAX_ROW_NNZ, TILE):\n row_arange -= TILE\n mask = row_arange < row_nnz * col_block\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n num = tl.exp(row_tile - max_row_value)\n denom += tl.sum(num, axis=0)\n\n tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask)\n for _ in range(TILE, MAX_ROW_NNZ, TILE):\n row_arange += TILE\n mask = row_arange < row_nnz * col_block\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n num = tl.exp(row_tile - max_row_value)\n tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask)\n\n\n# Function to run the _sampled_addmm_kernel\ndef _run_sampled_addmm_kernel(\n alpha, beta, is_beta_zero,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n):\n n_batches = values.size(0)\n n_block_rows = crow_indices.size(-1) - 1\n\n full_grid = (n_batches, n_block_rows)\n if max_grid is not None:\n grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))\n else:\n grid_blocks = None\n tensor_dims_map = {\n values: (0, None),\n crow_indices: (0, -1),\n col_indices: (0, None),\n mat1: (0, -4),\n mat2: (0, None),\n }\n if values.dtype in (torch.half, torch.bfloat16):\n acc_dtype = tl.float32\n allow_tf32 = True\n else:\n acc_dtype = tl.float64\n allow_tf32 = False\n\n def kernel(grid, *sliced_tensors):\n _sampled_addmm_kernel[grid](\n alpha, beta, is_beta_zero,\n *blocksize, k, tile_k,\n *ptr_stride_extractor(*sliced_tensors),\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n num_stages=1,\n num_warps=4\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n# Function to run the _bsr_strided_dense_rowspace_kernel\ndef _run_dense_rowspace_kernel(\n blocksize, values, crow_indices, col_indices, dense, output, max_grid\n):\n n_batches = dense.size(0)\n n_block_rows = crow_indices.size(-1) - 1\n n_block_cols = dense.size(-3)\n\n full_grid = (n_batches, n_block_cols, n_block_rows)\n if max_grid is not None:\n grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))\n else:\n grid_blocks = None\n tensor_dims_map = {\n values: (0, None, None),\n crow_indices: (0, None, -1),\n col_indices: (0, None, None),\n dense: (0, -3, None),\n output: (0, -3, -4)\n }\n if values.dtype in (torch.half, torch.bfloat16):\n acc_dtype = tl.float32\n allow_tf32 = True\n else:\n acc_dtype = tl.float64\n allow_tf32 = False\n\n def kernel(grid, *sliced_tensors):\n _bsr_strided_dense_rowspace_kernel[grid](\n *blocksize,\n *ptr_stride_extractor(*sliced_tensors),\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n GROUP_SIZE_ROW=4,\n num_stages=1,\n num_warps=4\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n# Function to run the _bsr_softmax_kernel\ndef bsr_softmax(input, max_row_nnz=None):\n f_name = \"bsr_softmax\"\n\n check_bsr_layout(f_name, input)\n check_dtype(f_name, input, input.dtype)\n\n if input._nnz() == 0 or input.numel() == 0:\n return input.clone()\n\n m, n = input.shape[-2:]\n nnz = input._nnz()\n row_block, col_block = input.values().shape[-2:]\n\n if max_row_nnz is None:\n max_row_nnz = triton.next_power_of_2(n)\n else:\n max_row_nnz = triton.next_power_of_2(max_row_nnz)\n\n crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2)\n if input.values().transpose(-3, -2).is_contiguous():\n values = input.values().clone()\n else:\n values = input.values()\n values = values.transpose(-3, -2).contiguous().unsqueeze(0).flatten(0, -4).reshape(-1, row_block, nnz * col_block)\n full_grid = (values.shape[0], row_block, m // row_block)\n grid_blocks = None\n tensor_dims_map = {\n crow_indices[..., :-1]: (0, None, -1),\n values: (0, None, None),\n }\n\n def kernel(grid, *sliced_tensors):\n _bsr_softmax_kernel[grid](\n *ptr_stride_extractor(*sliced_tensors),\n row_block, col_block,\n max_row_nnz,\n min(2 ** 17, max_row_nnz)\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n values = values.reshape(-1, row_block, nnz, col_block).transpose(-3, -2).reshape(*input.values().shape)\n\n return torch.sparse_compressed_tensor(\n input.crow_indices().clone(),\n input.col_indices().clone(),\n values,\n size=input.shape,\n layout=input.layout\n )\n\n# Wrapper function: sampled_addmm\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\n# Wrapper function: bsr_dense_mm\ndef bsr_dense_mm(\n bsr: torch.Tensor,\n dense: torch.Tensor,\n *,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"bsr_dense_mm\"\n if not skip_checks:\n check_bsr_layout(f_name, bsr)\n check_device(f_name, bsr, dense.device)\n check_dtype(f_name, bsr, dense.dtype)\n check_mm_compatible_shapes(f_name, bsr, dense)\n\n m = bsr.size(-2)\n n = dense.size(-1)\n row_block, col_block = bsr.values().shape[-2:]\n check(\n not n % row_block,\n f\"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by \"\n f\"blocksize[0] == {row_block}.\",\n )\n check_blocksize(f_name, (row_block, col_block))\n else:\n m, kl = bsr.shape[-2:]\n kr, n = dense.shape[-2:]\n\n original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)\n\n if out is not None and not skip_checks:\n expected_out_shape = original_batch_dims_broadcasted + (m, n)\n check(\n out.shape == expected_out_shape,\n \"bsr_dense_mm(): `out` argument has wrong shape, \"\n f\"expected {expected_out_shape}, but got {out.shape}.\",\n )\n check(\n out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),\n \"bsr_dense_mm(): only row-major/col-major `out` arguments are supported, \"\n \"i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) \"\n \"should be True.\",\n )\n\n if out is None:\n out = dense.new_empty(original_batch_dims_broadcasted + (m, n))\n\n if bsr._nnz() == 0:\n return out.zero_()\n\n blocksize = bsr.values().shape[-2:]\n out_backup = out\n\n crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out)\n\n dense = tile_to_blocksize(dense, blocksize[::-1])\n out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))\n\n _run_dense_rowspace_kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)\n\n return out_backup\n\n# Wrapper function: _scaled_dot_product_attention\ndef _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None\n):\n f_name = \"_scaled_dot_product_attention\"\n check(\n not is_causal,\n f\"{f_name}(): is_causal == True is not supported.\"\n )\n check(\n attn_mask is not None,\n f\"{f_name}(): attn_mask == None is not supported.\"\n )\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\"\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\"\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\n", - "description_1": "Use triton language to define kernels for matrix operations such as sampled matrix multiplication, dense matrix multiplication, and sparse softmax. These kernels handle operations on sparse block-sparse row (BSR) matrices, allowing efficient execution on GPUs. The main operations include sampled_addmm_kernel for sampled add-matrix multiplication, bsr_strided_dense_rowspace_kernel for multiplying BSR matrices with dense matrices, and bsr_softmax_kernel for computing softmax over BSR matrices.", - "description_2": "Use triton language to implement kernels for matrix operations with block-sparse matrices, supporting operations like sampled add-matrix multiplication, matrix multiplication, and softmax computation for optimized GPU performance.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom fla.utils import contiguous\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_fwd_kernel(\n x,\n y,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_y = y + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_m = tl.minimum(0., b_x)\n b_z = 1. + tl.exp(-tl.abs(b_x))\n b_y = b_m - tl.log(b_z)\n tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_bwd_kernel(\n x,\n dx,\n dy,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_dx = dx + o_i\n p_dy = dy + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)\n b_dx = b_dy * (1. - tl.sigmoid(b_x))\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n\nclass LogSigmoidFunction(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n def forward(ctx, x):\n T, D = x.numel(), x.shape[-1]\n y = torch.empty_like(x)\n logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)\n ctx.save_for_backward(x,)\n return y\n\n @staticmethod\n @contiguous\n def backward(ctx, dy):\n x, = ctx.saved_tensors\n T, D = x.numel(), x.shape[-1]\n dx = torch.empty_like(x)\n logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)\n return dx\n\n\nlogsigmoid = LogSigmoidFunction.apply\n", - "description_1": "Use triton language to implement a forward and backward kernel for the logsigmoid function. The forward kernel takes 5 parameters: x (input tensor), y (output tensor), T (total number of elements), D (dimension size), and BT (block size). It computes the logsigmoid of the input tensor and stores the result in the output tensor. The backward kernel takes 6 parameters: x (input tensor), dx (gradient of input), dy (gradient of output), T (total number of elements), D (dimension size), and BT (block size). It computes the gradient of the logsigmoid function with respect to the input tensor.", - "description_2": "Use triton language to create a logsigmoid function with forward and backward passes, where the forward pass computes the logsigmoid of an input tensor and the backward pass computes the gradient with respect to the input.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, O, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row,\n N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n O += row * stride_x_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)\n y = y * o * tl.sigmoid(o)\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, o, y, weight, bias, residual, residual_out, mean, rstd,\n x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N, eps, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None,\n weight is not None, bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a forward pass kernel for layer normalization with optional residuals, weights, and biases. The kernel computes the mean and variance for normalization, applies a linear transformation, and includes a Swish activation function. The function _layer_norm_fwd is a wrapper that prepares inputs and calls the kernel.", - "description_2": "Use triton language to create a layer normalization kernel with Swish activation, supporting optional residuals, weights, and biases.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to define a kernel _layer_norm_fwd_1pass_kernel that performs a layer normalization operation on input matrix X with shape (M, N) where M is the number of rows and N is the number of columns. The kernel takes 20 arguments including pointers to input (X), output (Y), weights (W), biases (B), residual (RESIDUAL), residual output (RESIDUAL_OUT), mean (Mean), and reciprocal of standard deviation (Rstd). Other arguments include strides for each of these pointers and several boolean constexpr arguments to specify the behavior of the kernel. The kernel computes the mean and variance of each row, normalizes the row values, applies optional weights and biases, and stores the result in Y. If a residual is provided, it's added to the input before normalization. Additionally, an optional second function _layer_norm_fwd is provided which handles setting up the inputs for the kernel, ensuring stride conditions and allocating memory for outputs.", - "description_2": "Use triton language to define a layer normalization kernel, including mean and variance computation, normalization, and application of weights and biases with potential residuals. Handle input validation and output allocation outside the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k, v, z, h, h0, ht,\n s_k_h, s_k_t, s_k_d,\n s_v_h, s_v_t, s_v_d,\n s_h_h, s_h_t, s_h_d,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n NT: tl.constexpr, NORMK: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n if NORMK:\n p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,))\n else:\n p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_z0).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if NORMK:\n p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[:, None]\n b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype)\n else:\n p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[None, :]\n b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_K(\n q, k, z, h, o, A,\n s_k_h, s_k_t, s_k_d,\n s_v_h, s_v_t, s_v_d,\n s_h_h, s_h_t, s_h_d,\n scale, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_p = tl.maximum(i_t * BT - 1, 0)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_A += tl.dot(b_q, b_k, allow_tf32=False)\n p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_z = tl.load(p_z, boundary_check=(0, 1))\n p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_zp, boundary_check=(0,))\n b_o = b_o * tl.exp(b_zp[None, :] - b_z)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.where(m_s, b_A, 0.)\n if i_v == 0:\n tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_intra_K(\n v, z, o, A,\n s_v_h, s_v_t, s_v_d,\n T: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BC: tl.constexpr, BV: tl.constexpr, NC: tl.constexpr\n):\n i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i = i_c // NC, i_c % NC\n p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))\n p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))\n b_zn = tl.load(p_zn, boundary_check=(0,))\n b_o = tl.zeros([BC, BV], dtype=tl.float32)\n for i_j in range(0, i_i):\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, tl.exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)\n b_z = tl.load(p_z, boundary_check=(0, 1))\n b_o *= tl.exp(b_zn[None, :] - b_z)\n o_i = tl.arange(0, BC)\n o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n for j in range(0, BC):\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))\n b_A = tl.load(A + o_A + j, mask=m_A, other=0)\n b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)\n m_i = o_i[:, None] >= j\n b_o += tl.where(m_i, b_A[:, None] * tl.exp(b_v[None, :] - b_z), 0)\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_V(\n q, v, z, h, o, A,\n s_k_h, s_k_t, s_k_d,\n s_v_h, s_v_t, s_v_d,\n s_h_h, s_h_t, s_h_d,\n scale, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_p = tl.maximum(i_t * BT - 1, 0)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_z = tl.load(p_z, boundary_check=(0, 1))\n b_zp = tl.load(p_zp, boundary_check=(0,))\n b_q = (b_q * tl.exp(b_zp[None, :] - b_z)).to(b_q.dtype)\n b_h = tl.load(p_h, boundary_check=(0, 1))\n if i_k >= 0:\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)\n return ov, final_state\n", - "description_1": "Use triton language to implement chunked attention forward pass kernels with input tensors q, k, v, s, and optional initial and final states. The kernels handle key-value interactions, intra-block interactions, and value processing in a tensor-based neural network model.", - "description_2": "Use triton language to execute chunked attention mechanism on GPU, handling key-query interactions, intra-block processing, and managing state tensors for neural network forward passes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr,\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n GATEK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if GATEK:\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n # [BK,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BK, BV]\n b_h *= tl.exp(b_gn)[:, None]\n # [BK, BT]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n else:\n p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n # [BV,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BK, BV]\n b_h *= tl.exp(b_gn)[None, :]\n # [BT, BV]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_v = (b_v * tl.exp(b_gn[None, :] - b_g)).to(b_v.dtype)\n # [BK, BV]\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_V(\n q,\n v,\n g,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n # [BT, BK]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n # [BT, BK]\n b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # works but dkw, owing to divine benevolence\n # [BT, BV]\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BT]\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_pre(g, B, H, T, S, BT):\n NT = triton.cdiv(T, BT)\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_gated_abc_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=S, BT=BT\n )\n return g\n\n\ndef fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, gatek=False, h0=None, ht=None):\n NT = triton.cdiv(T, BT)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gated_abc_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n GATEK=gatek,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n\ndef fwd_v(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.):\n NT = triton.cdiv(T, BT)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n NC = triton.cdiv(BT, BC)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n gatek=True,\n h0=h0,\n ht=ht\n )\n A = q.new_zeros(NK, B, H, T, BT)\n o = torch.empty_like(v)\n grid = (NV, NT, B * H)\n chunk_gated_abc_fwd_kernel_V[grid](\n q, v, g, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return o, h, A\n", - "description_1": "Use triton language to define kernels and functions for chunked gated attention using multiple kernels. Implement forward cumulative sum, intermediate hidden states calculation, and final output aggregation using queries, keys, values, and gate tensors. Ensure memory alignment and efficient execution with block pointers and grid specifications.", - "description_2": "Use triton language to perform chunked gated attention computations with kernels for cumulative sum, intermediate processing, and output aggregation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_chunk_based_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h_0o = tl.zeros([BV], dtype=tl.float32)\n b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_0o = 0\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_k_2o = b_k[:, None, :] * b_k[None, :, :]\n b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_z = tl.zeros([BT], dtype=tl.float32)\n\n b_o += b_h_0o\n b_z += k_0o\n b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)\n b_z += tl.sum(b_q * k_1o, axis=1)\n b_q_2o = b_q[:, :, None] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)\n b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5\n b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5\n\n k_1o += tl.sum(b_k, axis=1)[None, :]\n k_2o += tl.sum(b_k_2o, axis=1)[None, :]\n k_0o += BT\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)\n\n b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)\n b_h_0o = b_h_0o + tl.sum(b_v, axis=0)\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_z += BT\n\n@triton.jit\ndef fused_chunk_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BT: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)\n b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n\n b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)\n if i_v == 0:\n b_dq += b_dz[:, None] * k_1o\n b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5\n if i_v == 0:\n b_dq_2o += (b_dz[:, None] * k_2o) * 0.5\n b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])\n b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)\n b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)\n b_dq *= scale\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)\n\n if i_v == 0:\n k_1o += tl.sum(b_k, axis=0)[None, :]\n k_2o += tl.sum(b_k_2o, axis=0)[None, :]\n\n tl.debug_barrier()\n b_h_1o = None\n b_h_2o = None\n\n b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n b_dh_0o = tl.zeros([BV], dtype=tl.float32)\n m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]\n\n dq_1o = tl.zeros([1, BK], dtype=tl.float32)\n dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)\n\n for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i\n\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_dv = tl.zeros([BT, BV], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n b_ds *= (1+b_s)\n\n b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n\n b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)\n b_dv += b_dh_0o\n\n b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)\n\n if i_v == 0:\n b_dk += dq_1o\n\n b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),\n tl.trans(b_v), allow_tf32=False)\n if i_v == 0:\n b_dk_2o += dq_2o\n b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])\n b_k_fp32 = tl.trans(b_k.to(tl.float32))\n b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)\n b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)\n b_dk += tl.trans(b_dk2)\n\n b_dh_0o += tl.sum(b_do, axis=0)\n b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)\n b_q_2o = b_q[None, :, :] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)\n b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5\n\n if i_v == 0:\n dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]\n dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkBasedFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, scale=1):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n )\n o = o.sum(0)\n z = z.sum(0)\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.to(q.dtype), z.to(z.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None\n\n\ntriton_fused_chunk_based = FusedChunkBasedFunction.apply\n\n\ndef fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):\n assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_fused_chunk_based(q, k, v, scale)\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement forward and backward kernels for fused chunk-based attention. The forward kernel computes attention outputs and normalization factors for each query, key, and value tensor in chunks, utilizing Taylor expansion for optimization. It takes query, key, value tensors, and various parameters, returning computed outputs and normalizers. The backward kernel calculates gradients for query, key, and value tensors, also using Taylor expansion for efficiency. This kernel is optimized for specific block sizes along different dimensions.", - "description_2": "Use triton language to implement fused attention mechanism with efficient forward and backward kernels utilizing Taylor expansion and block processing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_based_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef parallel_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n i_h = i_bh % H\n _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n o = torch.empty(NK, batch_size, n_heads, seq_len, d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len, device=q.device)\n parallel_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement forward and backward kernels for a parallel-based sequence mixer, taking query, key, value tensors, and a scale as inputs and producing output and normalization tensors. The forward kernel processes data blocks and handles overlaps in the sequence, while the backward kernel computes gradients with respect to the query, key, and value tensors.", - "description_2": "Use triton language to design forward and backward kernels for a sequence mixer, processing query, key, value, and scale inputs to produce result and normalization outputs, and calculate gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_dv_kernel(\n q,\n k,\n do,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n b_A += tl.dot(b_k, b_q, allow_tf32=False)\n\n b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.dot(b_A, b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_dv(q, k, do, BT):\n dv = torch.empty_like(do)\n B, H, T, K, V = *k.shape, do.shape[-1]\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_prepare_dv_kernel[(NT, B*H)](\n q, k, do, dv,\n k.stride(1), k.stride(2), k.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n T, K, V, K**-0.5, BT, BK, BV\n )\n return dv\n\n", - "description_1": "Use triton language to implement a kernel fwd_prepare_dv_kernel which computes the dot product between key and query matrices in a block-wise fashion to update an intermediate buffer for a subsequent computation. The fwd_prepare_dv_kernel takes 16 arguments: 4 tensors (q, k, do, dv) which represent the query, key, gradient, and output respectively, 6 integers representing strides (s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d), 3 integers (T, K, V) for dimensions, and a float scale. Three integer constants (BT, BK, BV) define block sizes for processing.", - "description_2": "Use triton language to compute the forward preparation of dv using a block-wise dot product computation between tensors q, k, and do with specific strides and scaling, and store the result in dv.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8)\n ],\n key=[\"BT\", \"BK\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_fwd_kernel(\n q, k, v, v_new, d, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_d = tl.load(p_d, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)\n b_v = b_v - b_v_prime\n tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))\n b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_v_new = tl.advance(p_v_new, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_d = tl.advance(p_d, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_bwd_kernel(\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n b_d = tl.load(p_d, boundary_check=(0, 1))\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n\n tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n if i < (NT - 1):\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.load(p_dv, boundary_check=(0, 1))\n b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)\n p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n ((i+1) * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BT = BT\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, 'NK should be 1'\n o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n v_new = torch.empty_like(v)\n fused_chunk_delta_rule_fwd_kernel[grid](\n q, k, v, v_new, d, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n )\n return o, v_new, CHECK, final_state\n\n\ndef fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_delta_rule_bwd_kernel[grid](\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=CHECK,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dd = dd.sum(0)\n dd[:, :, 0:BT] = 0\n return dq, dk, dv, dd\n", - "description_1": "Use triton language to implement two kernels: fused_chunk_delta_rule_fwd_kernel and fused_chunk_delta_rule_bwd_kernel. The forward kernel takes 24 parameters: q, k, v, v_new, d, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, STORE_FINAL_STATE, CHECK. It computes the forward pass of a fused chunk delta rule operation. The backward kernel takes 23 parameters: q, k, v, d, do, dq, dk, dv, dd, initial_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, CHECK. It computes the backward pass of the same operation.", - "description_2": "Use triton language to create a forward kernel for a fused chunk delta rule operation with 24 parameters and a backward kernel with 23 parameters, both utilizing triton's block pointers and dot products for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for fused recurrent forward pass\n@triton.jit\ndef fused_recurrent_fwd_kernel(\n q, k, v, beta, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_beta = beta + i_bh * T\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _v_minus = tl.sum(h * _k[None, :], axis=1)\n _v -= _v_minus\n _beta = tl.load(p_beta).to(tl.float32)\n # in-place overwrite\n tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)\n _v *= _beta\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n p_beta += 1\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n# Triton kernel for fused recurrent backward pass\n@triton.jit\ndef fused_recurrent_bwd_kernel(\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_beta = beta + i_bh * T + T - 1\n p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1\n\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + (T - 1) * DK\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + (T - 1) * DV\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _beta = tl.load(p_beta).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n\n d_beta = tl.sum(d_v * _v)\n d_v = d_v * _beta\n\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))\n\n d_h -= _k[:, None] * d_v[None, :]\n\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n p_dbeta -= 1\n p_beta -= 1\n\n tl.debug_barrier()\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_beta = beta + i_bh * T\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK\n\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _beta = tl.load(p_beta).to(tl.float32)\n _v *= _beta\n\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n if i < T - 1:\n d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)\n d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)\n d_k -= tl.sum(d_v[None, :] * h, axis=1)\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n\n p_k += DK\n p_do += DV\n p_v += DV\n p_dk += DK\n p_dv += DV\n p_dq += DK\n p_beta += 1\n\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n @staticmethod\n @torch.utils._python_dispatch.tracing_only(contiguous)\n def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n assert NK == 1, \"NK > 1 is not supported yet\"\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_fwd_kernel[grid](\n q, k, v, beta, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, beta, initial_state)\n return o, final_state\n\n @staticmethod\n @torch.utils._python_dispatch.tracing_only(contiguous)\n def backward(ctx, do, d_final_state=None):\n q, k, v, beta, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, \"NK > 1 is not supported yet\"\n num_stages = 1\n num_warps = 2\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)\n\n fused_recurrent_bwd_kernel[grid](\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dbeta = dbeta.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None\n\n\ndef fused_recurrent_linear_attn_delta_rule(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n beta: torch.Tensor = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if beta is None:\n beta = torch.ones_like(q[..., 0])\n o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused recurrent kernel function for both forward and backward passes of a linear attention mechanism. The forward function requires 20 arguments, including queries, keys, values, beta scaling factors, output tensors, and various strides and dimensions, along with several constexpr parameters. The backward kernel involves 25 arguments, additionally requiring gradient tensors. A helper class FusedRecurrentFunction wraps these kernels for automatic differentiation, with a method implementing the forward pass and another for the backward pass. The fused_recurrent_linear_attn_delta_rule function applies these kernels, managing initial states and normalization.", - "description_2": "Use triton language to design a recurrent fused attention mechanism with forward and backward triton kernel functions. Implement a PyTorch autograd-compatible interface using a custom Function class, which efficiently executes these kernels and computes gradients automatically.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n o,\n o2,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = tl.arange(0, BK) < K\n mask_bv = tl.arange(0, BV) < V\n mask_bk = mask_bk[None, :] & mask_bt[:, None]\n mask_bv = mask_bv[None, :] & mask_bt[:, None]\n # [BT, BK]\n b_k = tl.load(p_k, mask=mask_bk, other=0)\n # [BT,]\n b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)\n # [BT, BV]\n b_v = tl.load(p_v, mask=mask_bv, other=0)\n b_v = (b_v * b_beta[:, None]).to(b_v.dtype)\n # [BT, BK]\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n # [BT, BT]\n b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n b_A = b_A.to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n b_u = tl.dot(b_A, b_v, allow_tf32=False)\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta,\n o, o2, do, do2,\n dk, dv, dbeta,\n NT, K, V, T,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]\n mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]\n b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)\n\n b_beta = b_beta.to(tl.float32)\n A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]\n A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)\n b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)\n b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)\n dA = tl.zeros([BT, BT], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n for i in range(BT-1, -1, -1):\n mask = tl.arange(0, BT) == i\n attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)\n do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)\n dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)\n b_do = b_do - attn[:, None] * do_[None, :]\n b_dv = b_dv - attn[:, None] * dv_[None, :]\n tl.debug_barrier()\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_v = tl.load(p_v, mask=mask_bv)\n b_dk += b_do * b_beta[:, None]\n b_dbeta = tl.sum(b_do * b_k, axis=1)\n b_dbeta += tl.sum(b_dv * b_v, axis=1)\n b_v = None\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_o = tl.load(p_o, mask=mask_bk)\n b_o2 = tl.load(p_o2, mask=mask_bv)\n\n dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)\n dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),\n allow_tf32=False)\n dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)\n b_dv *= b_beta[:, None]\n p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)\n dA = dA * b_beta[:, None]\n b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)\n b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)\n p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)\n\n\ndef fwd_prepare_wy_repr(k, v, beta, chunk_size):\n B, H, T, K, V = *k.shape, v.shape[-1]\n v_new = torch.empty_like(v)\n o_cumdecay = torch.empty_like(k)\n BT = chunk_size\n NT = triton.cdiv(T, BT)\n BK = triton.next_power_of_2(K)\n BV = triton.next_power_of_2(V)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, o_cumdecay, v_new,\n T, K, V, BT, BK, BV\n )\n return o_cumdecay, v_new\n\n\ndef bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):\n b, h, l, d_k = do.shape\n d_v = v.shape[-1]\n BK = triton.next_power_of_2(d_k)\n BV = triton.next_power_of_2(d_v)\n c = chunk_size\n BK = d_k\n NT = triton.cdiv(l, c)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n dbeta = torch.zeros_like(beta)\n bwd_prepare_wy_repr_kernel[(NT, b*h)](\n k, v, beta,\n o_cumdecay, v_new, do, do2,\n dk, dv, dbeta,\n NT, d_k, d_v, l, chunk_size, BK, BV\n )\n return dk, dv, dbeta\n\n", - "description_1": "Use triton language to implement two kernels for forward and backward passes in a custom operator for preparing WY representation. The forward kernel `fwd_prepare_wy_repr_kernel` has 9 parameters: k, v, beta, o, o2, T, K, V, and BT (a compile-time constant). It calculates transformations using dot products and stores results back to memory. The backward kernel `bwd_prepare_wy_repr_kernel` has 16 parameters: k, v, beta, o, o2, do, do2, dk, dv, dbeta, NT, K, V, T, and three compile-time constants BT, BK, BV. It computes gradients of the inputs given the gradients of the outputs, employing triton's matrix operations. Both kernels leverage Triton's parallelization features by using program ids to distribute computations over blocks.", - "description_2": "Use triton language to write a forward kernel for computing WY representation transformations and a backward kernel for computing gradients of WY representation transformations, leveraging program ids for parallel computation distribution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k, v, beta, w, u, A, \n s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n T, K, V, \n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(1, BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))\n b_A = b_A.to(k.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\ndef fwd_prepare_wy_repr(k, v, beta, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u, A\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_recompute_w_u_kernel(\n k, v, beta, w, u, A, \n s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n T, K, V, \n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\ndef fwd_recompute_w_u(k, v, beta, A, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_recompute_w_u_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta, A, dw, du, dk, dv, dbeta, \n s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n T, K, V, \n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n\n b_dbeta = tl.zeros([BT], dtype=tl.float32)\n b_dA = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_du = tl.load(p_du, boundary_check=(0, 1))\n b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)\n b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)\n b_dv = b_dv_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dv_beta * b_v, 1)\n # store\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n tl.debug_barrier() \n b_A2 = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1)) \n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_dw = tl.load(p_dw, boundary_check=(0, 1))\n b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) \n b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)\n b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)\n b_dk = b_dk_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n # store \n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])\n b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)\n tl.debug_barrier()\n\n for i in range(BT-1, 0, -1):\n mask = tl.arange(0, BT) == i\n b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0) \n b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) \n b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1) \n b_dA = tl.where(mask[:, None], b_da2, b_dA)\n b_dA += b_da[None, :] * b_a[:, None]\n\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)\n tl.debug_barrier()\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1)) \n b_dk = tl.load(p_dk, boundary_check=(0, 1))\n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n\n b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) \n b_dk += b_dk_beta * b_beta[:, None] \n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n \n p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0, ))\n\ndef bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT = triton.cdiv(T, BT)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v).contiguous()\n dbeta = torch.zeros_like(beta)\n\n bwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, A,\n dw, du, \n dk, dv, dbeta,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return dk, dv, dbeta\n", - "description_1": "Use triton language to implement forward and backward operations for preparing the WY representation of Householder matrices. The code includes three kernels: fwd_prepare_wy_repr_kernel, fwd_recompute_w_u_kernel, and bwd_prepare_wy_repr_kernel. Each kernel computes or uses block matrix operations, dot products, and transformations over block pointers with boundary checks. The kernels are decorated with @triton.jit, allowing JIT compilation, and @triton.autotune to optimize performance across different configurations. The forward functions (fwd_prepare_wy_repr and fwd_recompute_w_u) and backward function (bwd_prepare_wy_repr) encapsulate the kernel calls for performing tensor operations on the GPU with parameters controlling the chunk/block sizes. ", - "description_2": "Use triton language to implement kernels for WY representation preparation with forward and backward computation, utilizing block matrix operations and dot products.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_gla_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n if i_t < NT - 1:\n b_gn = tl.load(p_gn, boundary_check=(0,))\n else:\n b_gn = tl.min(b_g, axis=1)\n b_h *= tl.exp(b_gn)[:, None]\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_intra(\n q,\n k,\n g,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n NC: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC\n n_bh = tl.num_programs(2)\n if i_i > i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))\n p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n b_gn = tl.load(p_gn, boundary_check=(0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_gk = tl.load(p_gk, boundary_check=(0, 1))\n b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)\n b_A = tl.dot(b_qg, b_kg, allow_tf32=False)\n tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n elif i_i == i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n o_i = tl.arange(0, BC)\n o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n for j in range(0, BC):\n b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)\n b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)\n b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)\n b_A = tl.where(o_i >= j, b_A, 0.)\n tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)\n p_k = tl.advance(p_k, (K,))\n p_gk = tl.advance(p_gk, (K,))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_inter(\n q,\n v,\n g,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)\n b_h = tl.load(p_h, boundary_check=(0, 1))\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\ndef chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 2\n) -> Tuple[torch.Tensor, torch.Tensor]:\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gla_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n chunk_gla_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_gla_fwd_kernel_intra[grid](\n q, k, g, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n grid = (NV, NT, B * H)\n chunk_gla_fwd_kernel_inter[grid](\n q, v, g, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n if checkpoint_level >= 1:\n del g\n g = g_org\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n\n return o, final_state\n", - "description_1": "Use triton language to implement a series of kernels and functions for processing tensors q, k, v, and g, with functions including cumulative sum, hidden state forward pass, and intra/inter processing, each with multiple parameters for tensor strides, shapes, scales, and conditional flags for storing states.", - "description_2": "Use triton language to define kernels and orchestrate their execution for optimized tensor operations with specific grid configurations and parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n if CHECK and i == 0:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n else:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_db += BT * DK\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK \n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n if CHECK and i == 1:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n else:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkGLAFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):\n ctx.g_dtype = g.dtype\n g_original = g\n g = torch.empty_like(g, dtype=torch.float32)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n\n BT = 16\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n num_stages = 1\n num_warps = 2\n\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n\n fwd_decay_cumsum[grid](\n g_original,\n g,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, DK=d_head_qk, num_warps=1\n )\n prepare_qg_kg[grid](\n q, k, g, q_g, k_g,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, DK=d_head_qk, num_warps=1\n )\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_gla_fwd_kernel[grid](\n q_g, k_g, v, g, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n\n chunk_size = 16\n num_chunk = seq_len // chunk_size\n v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)\n BK = min(d_head_qk, 64)\n NK = triton.cdiv(d_head_qk, BK)\n A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n fwd_inner_chunk[grid](\n q, k, g, A,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3,\n num_warps=4\n )\n A = A.sum(0)\n o2 = A @ v2\n o2 = rearrange(o2, 'b h n c d -> b h (n c) d')\n o.add_(o2)\n ctx.save_for_backward(q, k, v, g_original, A, initial_state)\n ctx.CHECK = CHECK\n return o.to(v), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, g_origin, A, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n g = torch.empty_like(g_origin, dtype=torch.float32)\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n fwd_decay_cumsum[grid](\n g_origin,\n g,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, DK=d_head_qk, num_warps=1\n )\n prepare_qg_kg[grid](\n q, k, g, q_g, k_g,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, DK=d_head_qk, num_warps=1\n )\n\n BT = 16\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 2\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_gla_bwd_kernel[grid](\n q_g, k_g, v, g, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n\n num_chunk = seq_len // BT\n v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)\n do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)\n dA2 = (do2 @ v2.transpose(-2, -1)) * scale\n dv2 = A.transpose(-1, -2) @ do2\n dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)\n\n BK = min(triton.next_power_of_2(d_head_qk), 16)\n NK = triton.cdiv(d_head_qk, BK)\n dk2 = torch.empty_like(k)\n dq2 = torch.empty_like(q)\n\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n bwd_inner_chunk[grid](\n q, k, g,\n dA2, dq2, dk2,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, BK=BK,\n num_warps=1,\n num_stages=3\n )\n\n BK = min(triton.next_power_of_2(d_head_qk), 32)\n NK = triton.cdiv(d_head_qk, BK)\n dg = torch.empty_like(g, dtype=torch.float32)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n bwd_decay_global_cumsum[grid](\n dq2, dq, dk2, dk, q, k, g, dg,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, BK=BK,\n num_warps=1,\n num_stages=1\n )\n dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)\n\n def rev_cumsum_exclusive(x):\n cumsum_x = x.cumsum(-2)\n rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x\n return rev_cumsum_x\n\n rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])\n dg.add_(rev_cumsum_dg.unsqueeze(-2))\n dv.add_(dv2)\n dg = rearrange(dg, 'b h n c d -> b h (n c) d')\n\n return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None\n\ndef fused_chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = q.shape[-2]\n q, k, v, g = map(lambda x: pad(x), [q, k, v, g])\n o, final_state = FusedChunkGLAFunction.apply(\n q, k, v, g, scale, initial_state, output_final_state)\n o = o[..., :seq_len, :]\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunked generalized linear attention kernel and its corresponding backward pass. The forward kernel processes tensors q, k, v, and g with dimensions corresponding to batch size, number of heads, sequence length, and head dimensions. It takes into account initial and final states and uses a block-wise approach along the sequence, key, and value dimensions. The backward kernel computes the gradients of q, k, v, and the cumulative sum of g, using the same block-wise approach.", - "description_2": "Use triton language to develop a fused attention mechanism with backward support, using block pointers and handling batch size, head counts, and sequence lengths in a kernel function.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n@triton.jit\ndef fwd_decay_cumsum(\n g,\n g_o, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n # Triton kernel for forward decay cumulative sum\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n@triton.jit\ndef prepare_qg_kg(\n q,\n k,\n g,\n qg,\n kg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n # Triton kernel for preparing qg and kg\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner,\n dq_inter,\n dk_inner,\n dk_inter,\n q, k, g, dg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n # Triton kernel for backward decay global cumulative sum\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n", - "description_1": "Use triton language to define three kernel functions: 'fwd_decay_cumsum', 'prepare_qg_kg', and 'bwd_decay_global_cumsum'. Each function performs specific matrix operations based on triton's block and thread structure. Parameters involve pointers to input/output data and tiling constants.", - "description_2": "Use triton language to create kernels for forward decay cumulative sum, prepare transformations on Q and K tensors, and backward decay cumulative gradient calculations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.utils import contiguous\nfrom typing import Tuple\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, gk, gv, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, \n REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n h = tl.zeros([BV, BK], dtype=tl.float32)\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[None, :]) * DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[None, :]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[:, None]\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -DK if REVERSE else DK\n p_k += -DK if REVERSE else DK\n p_o += -DV if REVERSE else DV\n p_v += -DV if REVERSE else DV\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[None, :]) * DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[:, None]) * DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[None, :]\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -DK if REVERSE else DK\n p_v += -DV if REVERSE else DV\n p_q += -DK if REVERSE else DK\n p_do += -DV if REVERSE else DV\n p_dq += -DK if REVERSE else DK\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n d_h *= _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n d_h *= _gv[None, :]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do += DV if REVERSE else -DV\n p_q += DK if REVERSE else -DK\n p_k += DK if REVERSE else -DK\n p_v += DV if REVERSE else -DV\n p_dk += DK if REVERSE else -DK\n p_dv += DV if REVERSE else -DV\n if USE_GK:\n p_gk += DK if REVERSE else -DK\n if USE_GV:\n p_gv += DV if REVERSE else -DV\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n if scale is None:\n scale = d_head_qk ** -0.5\n if gk is not None:\n gk = gk.float().exp()\n if gv is not None:\n gv = gv.float().exp()\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=torch.float32)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_gla_fwd_kernel[grid](\n q, k, v, gk, gv, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n @contiguous\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, gk, gv, initial_state, o = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=torch.float32)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=torch.float32)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=torch.float32)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n if gk is not None:\n _dgk = dq * q.float() - dk * k.float()\n if ctx.reverse:\n dgk = _dgk.cumsum(-2)\n else:\n _dgk_cumsum = _dgk.cumsum(-2)\n dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum\n else:\n dgk = None\n\n if gv is not None:\n _dgv = do.float() * o.float() - dv * v.float()\n if ctx.reverse:\n dgv = _dgv.cumsum(-2)\n else:\n _dgv_cumsum = _dgv.cumsum(-2)\n dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum\n else:\n dgv = None\n\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None\n\ndef fused_recurrent_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n gk: torch.Tensor = None,\n gv: torch.Tensor = None,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n if causal:\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)\n return o, final_state\n else:\n assert initial_state is None\n assert output_final_state is False\n o, final_state = FusedRecurrentGLAFunction.apply(\n q, k, v, gk, gv, scale, initial_state, output_final_state, False)\n o_reversed, final_state = FusedRecurrentGLAFunction.apply(\n q, k, v, gk, gv, scale, initial_state, output_final_state, True)\n return [o, o_reversed]\n", - "description_1": "Use triton language to implement two kernel functions for fused recurrent Gated Linear Attention (GLA) in forward and backward passes. The forward kernel takes 21 parameters including query, key, value, and various configuration constants. It computes the GLA operation with optional initial and final states and direction control. The backward kernel takes 22 parameters including gradients and computes the gradient of the inputs for backpropagation. Both use optional gate tensors to scale the input and output.", - "description_2": "Use triton language to create a forward and backward fused recurrent GLA kernel, with support for initial/final states and gating.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef chunk_hgrn_fwd_kernel_h(\n x,\n g,\n gc,\n o,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + i_t * BT * D + o_d\n p_g = g + i_bh * T * D + i_t * BT * D + o_d\n p_gc = gc + i_bh * T * D + i_t * BT * D + o_d\n p_o = o + i_bh * T * D + i_t * BT * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n b_gc = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n if i_t == 0:\n b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n for i in range(0, BT):\n mask_t = mask & ((i_t * BT + i) < T)\n b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n b_gc = b_gc + b_g\n tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)\n\n p_x += D\n p_g += D\n p_gc += D\n p_o += D\n\n\n@triton.jit\ndef chunk_hgrn_fwd_kernel_o(\n gc,\n o,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(1, tl.cdiv(T, BT)):\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_o = b_o + tl.exp(b_gc) * b_h0[None, :]\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_hgrn_bwd_kernel_h(\n g,\n gc,\n dx,\n do,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n BC = min(BT, T - i_t * BT)\n NT = tl.num_programs(1)\n\n p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n\n if i_t == NT - 1:\n b_gc = tl.zeros([BD], dtype=tl.float32)\n else:\n b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for _ in range(BC - 1, -1, -1):\n tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)\n\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n\n b_gc = b_gc + b_g\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_gc -= D\n p_dx -= D\n p_do -= D\n\n\n@triton.jit\ndef chunk_hgrn_bwd_kernel_o(\n g,\n gc,\n o,\n dx,\n dg,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))\n p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n mask_t = mask & ((i_t + 1) * BT < T)\n b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)\n b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)\n b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]\n b_dg = b_o * b_dx * tl.exp(b_g)\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n o = torch.empty_like(x, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_fwd_kernel_h[grid](\n x, g, gc, o, initial_state,\n T, D,\n BT=BT,\n USE_INITIAL_STATE=initial_state is not None\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_fwd_kernel_o[grid](\n gc, o,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n final_state = None\n if output_final_state:\n final_state = o[:, :, -1].clone()\n o = o.to(x.dtype)\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_bwd_kernel_h[grid](\n g, gc, dx, do,\n T, D,\n BT=BT\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_bwd_kernel_o[grid](\n g, gc, o, dx, dg,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n if initial_state is not None:\n dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()\n\n return dx, dg, None, None\n\n\ndef chunk_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)\n return o, final_state\n\n", - "description_1": "Use triton language to implement chunk-wise HGRN forward and backward kernels. The kernel `chunk_hgrn_fwd_kernel_h` requires 9 parameters: x, g, gc, o, h0, T, D, BT, and BD. It processes the input x and g, computes the chunk-wise forward pass, stores intermediate results in gc and o, and uses an initial state h0 if provided. The kernel `chunk_hgrn_fwd_kernel_o` also requires 9 parameters: gc, o, s_h, s_t, s_d, T, D, BT, and BD. It updates the output tensor o based on previously computed gc values. For backward propagation, `chunk_hgrn_bwd_kernel_h` requires 7 parameters: g, gc, dx, do, T, D, BT, and BD to compute the gradient of x, while `chunk_hgrn_bwd_kernel_o` takes 9 parameters: g, gc, o, dx, dg, s_h, s_t, s_d, T, D, BT, and BD to compute the gradient of g.", - "description_2": "Use triton language to implement a neural network operator involving forward and backward pass kernels for a specific recurrent structure, known as HGRN, which operates in a chunk-wise manner to improve efficiency.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_hgrn_fwd_kernel(\n x,\n g,\n o,\n h0,\n ht,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + o_d\n p_g = g + i_bh * T * D + o_d\n p_o = o + i_bh * T * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * D + o_d\n b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)\n for _ in range(0, T):\n b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)\n\n p_x += D\n p_g += D\n p_o += D\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * D + o_d\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)\n\n@triton.jit\ndef fused_recurrent_hgrn_bwd_kernel(\n g,\n o,\n dx,\n dg,\n do,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_g = g + (i_bh * T + T - 1) * D + o_d\n p_o = o + (i_bh * T + T - 2) * D + o_d\n p_dx = dx + (i_bh * T + T - 1) * D + o_d\n p_dg = dg + (i_bh * T + T - 1) * D + o_d\n p_do = do + (i_bh * T + T - 1) * D + o_d\n\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for i in range(T - 1, -1, -1):\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n if i > 0:\n b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)\n elif USE_INITIAL_STATE:\n b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n else:\n b_o = tl.zeros([BD], dtype=tl.float32)\n\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n b_dg = b_dh * b_o\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_o -= D\n p_dx -= D\n p_dg -= D\n p_do -= D\n\nclass FusedRecurrentHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n\n final_state = None\n if output_final_state:\n final_state = x.new_empty(B, H, D)\n\n o = torch.empty_like(x)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_fwd_kernel[grid](\n x, g, o, initial_state, final_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_bwd_kernel[grid](\n g, o, dx, dg, do, initial_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n )\n\n return dx, dg, None, None\n\ndef fused_recurrent_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused recurrent neural network forward and backward kernel. The forward kernel takes 10 parameters: x (input tensor), g (gate tensor), o (output tensor), h0 (initial hidden state), ht (final hidden state), T (time steps), D (dimension), BD (block dimension), USE_INITIAL_STATE (flag for initial state usage), and STORE_FINAL_STATE (flag for storing final state). The backward kernel takes 9 parameters: g (gate tensor), o (output tensor), dx (gradient of x), dg (gradient of g), do (gradient of output), h0 (initial hidden state), T (time steps), D (dimension), BD (block dimension), and USE_INITIAL_STATE (flag for initial state usage).", - "description_2": "Use triton language to create a fused recurrent neural network function with forward and backward operations, handling input, gate, and state tensors, and supporting optional initial and final state management.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_linear_attn_fwd_kernel_h(\n k, v, h, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n # kernel code omitted for brevity\n\n@triton.jit\ndef chunk_linear_attn_fwd_kernel_o(\n q, k, v, h, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n # kernel code omitted for brevity\n\n@triton.jit\ndef chunk_linear_attn_bwd_kernel_dh(\n q, do, dh, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n # kernel code omitted for brevity\n\n@triton.jit\ndef chunk_linear_attn_bwd_kernel_dqkv(\n q, k, v, h, do, dh, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n # kernel code omitted for brevity\n\nclass ChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n # function code omitted for brevity\n\n @staticmethod\n def backward(ctx, do, d_ht=None):\n # function code omitted for brevity\n\ndef chunk_linear_attn(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,\n scale: float = -1, initial_state: torch.Tensor = None,\n output_final_state: bool = False, normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n # function code omitted for brevity\n", - "description_1": "Use triton language to define multiple kernels and a function for performing chunk-based linear attention computations in both forward and backward passes. The kernels require parameters like query, key, value matrices (q, k, v), scaling factor, initial and final states, strides for each tensor, block sizes, tensor dimensions, and other constexpr parameters. The forward kernel functions `chunk_linear_attn_fwd_kernel_h` and `chunk_linear_attn_fwd_kernel_o` compute intermediate and final attention outputs. The backward kernel functions `chunk_linear_attn_bwd_kernel_dh` and `chunk_linear_attn_bwd_kernel_dqkv` compute gradients. The main Python function `chunk_linear_attn` invokes these kernels using grid definitions based on input sizes and configurations to obtain the final output and optional final state.", - "description_2": "Use triton language to implement kernels for chunk-based linear attention with forward and backward passes, handling input queries, keys, values, scaling, and tensor strides.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef fused_chunk_linear_attn_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, \n STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_linear_attn_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n\n tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\n\ndef fused_chunk_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement two kernel functions for a fused chunk linear attention mechanism. The first function, `fused_chunk_linear_attn_fwd_kernel`, is the forward pass of the attention mechanism, and takes 18 inputs including q, k, v tensors (query, key, and value), output tensors, initial and final state, stride sizes, batch size, number of heads, sequence length, and scaling factor among other constants. The second function, `fused_chunk_linear_attn_bwd_kernel`, handles the backward pass, taking the same number of inputs but working with gradients of query, key, and value tensors instead. Each function makes use of triton's block pointers, boundary checking, and allows control flow via constants.", - "description_2": "Use triton language to create a fused chunk linear attention mechanism consisting of a forward and backward pass, each implemented as kernel functions. These kernels manage data with specific triton operations for efficient processing and parallel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_q.dtype)\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)\n b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n else:\n b_ds = b_ds\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n\n b_dq *= scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_d, s_qk_t), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n else:\n b_ds = b_ds\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),\n b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(\n p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(\n [BTL, BV], dtype=tl.float32)\n\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale\n if i_v == 0:\n b_ds += b_dz[None, :] * scale\n else:\n b_ds = b_ds\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n m_s = o_k[:, None] <= o_q[None, :]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n else:\n b_ds = b_ds\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n o_q += BTS\n\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,\n (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel rebased forward kernel and backward kernel for a custom linear transformer function. The forward kernel handles inputs q, k, v, computes intermediate outputs o and z, and takes multiple strides and constants like B, H, T, scale as inputs. The backward kernel computes gradients dq, dk, dv using the saved tensors q, k, v from the forward pass. Both kernels require block sizes (BTL, BTS, BK, BV) and constants (DK, DV).", - "description_2": "Use triton language to create forward and backward kernels for parallel rebased transformer operations that efficiently compute outputs and gradients for input tensors q, k, v, along with necessary strides and constants.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n # decay rate given the head index\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n # d_b: overall decay for the entire chunk\n # d_o: cumulative decay from the start of the chunk\n # d_h: cumulative decay from the end of the chunk\n d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n\n # [BT, BT]\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n # make block pointers\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n # [BT, BT]\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n # [BT, BV]\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n if i == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)\n d_b = tl.math.exp2(BT * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n # [BV, BK]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [DV, BT]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, DV]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n # [BT, BT]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n # [BT, DK]\n b_dq = tl.dot(b_ds, b_k, allow_tf32=False)\n # [DV, DK]\n if CHECK and i == 0:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n # sync threads\n b_h = None\n tl.debug_barrier()\n d_s = tl.trans(d_s)\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n # [DK, BT]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, DV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n # [BT, BT]\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n\n # [BT, BT]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n # [BT, DK]\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n # [BT, DV]\n b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\n\ndef fused_chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunk retention kernel with forward and backward operations. The forward kernel computes retention over queries, keys, and values, applies decay, and optionally uses initial and final states. It requires 21 parameters: query, key, value, output, initial state, final state, stride sizes for query/key, and value/output, batch size, number of heads, sequence length, scaling factor, block sizes along sequence, key, and value dimensions, dimension sizes for query/key and value, and flags for using initial states, storing final states, and enabling checks. The backward kernel computes gradients with respect to queries, keys, and values using similar parameters, requiring 22 in total.", - "description_2": "Use triton language to create a fused chunk retention kernel that processes sequences with optional initial and final states, supporting both forward and backward propagation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef parallel_retention_fwd_kernel(\n q, k, v, o,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # Triton forward kernel implementation for parallel retention\n\n@triton.jit\ndef _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # Triton backward kernel implementation for dq in parallel retention\n\n@triton.jit\ndef _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # Triton backward kernel implementation for dk and dv in parallel retention\n\n@triton.jit\ndef parallel_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\nclass ParallelRetentionFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_retention_fwd_kernel[grid](\n q, k, v, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n return o.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)\n\nparallel_retention = ParallelRetentionFunction.apply\n", - "description_1": "Use triton language to implement a parallel retention mechanism with forward and backward passes. The forward pass involves calculating the output using queries, keys, and values with given strides, batch size, number of heads, sequence length, and scale. The backward pass computes gradients for queries, keys, and values using similar input parameters. The forward function has 3 input parameters: q (query), k (key), and v (value). The backward function has 1 input parameter: do (gradient of the output).", - "description_2": "Use triton language to implement a parallel retention forward kernel that computes the output from query, key, and value tensors. Use triton language to implement a backward kernel to compute gradients for input tensors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef fused_recurrent_rwkv4_forward_kernel(\n w_ptr, w_s_c, u_ptr, u_s_c, k_ptr, k_s_b, k_s_t, k_s_c, v_ptr, v_s_b, v_s_t, v_s_c,\n state_ptr, state_s_b, state_s_abe, state_s_c, wkv_ptr, wkv_s_b, wkv_s_t, wkv_s_c,\n state_out_ptr, state_out_s_b, state_out_s_abe, state_out_s_t, state_out_s_c,\n chans, tsz, BLOCK_SIZE_C: tl.constexpr,\n):\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n\n wkv_ptr = wkv_ptr + b_idx * wkv_s_b\n alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b\n beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe\n eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe\n\n alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)\n\n ukt = u + kt\n tau = tl.maximum(ukt, eps)\n e1a = tl.exp(eps - tau)\n e2a = tl.exp(ukt - tau)\n wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)\n tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)\n\n w_eps = w + eps\n eps = tl.maximum(w_eps, kt)\n e1b = tl.exp(w_eps - eps)\n e2b = tl.exp(kt - eps)\n alpha = e1b * alpha + e2b * vt\n beta = e1b * beta + e2b\n tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)\n tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)\n tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)\n\ndef fused_recurrent_rwkv4_forward(\n w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor\n) -> tuple[Tensor, Tensor]:\n (bsz, tsz, chans) = k.shape\n\n wkvs = k.new_empty(bsz, tsz, chans)\n state_out = k.new_empty(bsz, 3, tsz, chans)\n\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_forward_kernel[grid](\n w, w.stride(0), u, u.stride(0), k, k.stride(0), k.stride(1), k.stride(2),\n v, v.stride(0), v.stride(1), v.stride(2), state, state.stride(0), state.stride(1),\n state.stride(3), wkvs, wkvs.stride(0), wkvs.stride(1), wkvs.stride(2),\n state_out, state_out.stride(0), state_out.stride(1), state_out.stride(2),\n state_out.stride(3), chans, tsz, BLOCK_SIZE_C=block_size_c,\n )\n\n state_out = torch.cat((state, state_out), dim=2)\n\n return wkvs, state_out\n\n@triton.jit\ndef fused_recurrent_rwkv4_backward_kernel(\n w_ptr, w_s_c, u_ptr, u_s_c, k_ptr, k_s_b, k_s_t, k_s_c, v_ptr, v_s_b, v_s_t, v_s_c,\n state_ptr, state_s_b, state_s_abe, state_s_t, state_s_c, gwkv_ptr, gwkv_s_b, gwkv_s_t,\n gwkv_s_c, gstate_out_ptr, gstate_out_s_b, gstate_out_s_abe, gstate_out_s_c, gw_ptr,\n gw_s_c, gu_ptr, gu_s_c, gk_ptr, gk_s_b, gk_s_t, gk_s_c, gv_ptr, gv_s_b, gv_s_t, gv_s_c,\n gstate_ptr, gstate_s_b, gstate_s_abe, gstate_s_c, tsz, chans, BLOCK_SIZE_C: tl.constexpr,\n):\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n\n gk_ptr = gk_ptr + b_idx * gk_s_b\n gv_ptr = gv_ptr + b_idx * gv_s_b\n\n gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b\n galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b\n gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe\n geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe\n\n galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)\n\n gw = tl.zeros_like(w)\n gu = tl.zeros_like(u)\n\n alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n tc = tsz - t - 1\n\n kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)\n\n alpha_curr = alpha_prev\n beta_curr = beta_prev\n eps_curr = eps_prev\n\n alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n ukt = u + kt\n tau = tl.maximum(ukt, eps_prev)\n e1 = tl.exp(eps_prev - tau)\n e2 = tl.exp(ukt - tau)\n\n euke = tl.exp(ukt + eps_prev - 2 * tau)\n\n denom = e1 * beta_prev + e2\n denom_sq = denom * denom\n\n gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)\n\n guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq\n gu += guk\n gk = guk\n gv = gwkvt * e2 / denom\n\n galpha_wkv = gwkvt * e1 / denom\n gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq\n geps_wkv_denom = e1 * beta_prev + e2\n geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)\n\n e1 = tl.exp(w + eps_prev - eps_curr)\n e2 = tl.exp(kt - eps_curr)\n\n galpha_we = galpha * e1 * alpha_prev\n gw += galpha_we\n gk += galpha * e2 * vt\n gv += galpha * e2\n geps += galpha * -alpha_curr\n\n gbeta_we = gbeta * e1 * beta_prev\n gw += gbeta_we\n gk += gbeta * e2\n geps += gbeta * -beta_curr\n\n geps_mask = w + eps_prev > kt\n geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))\n gw += geps_we\n gk += tl.where(geps_mask, tl.zeros_like(geps), geps)\n\n tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)\n tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)\n\n galpha = galpha * e1 + galpha_wkv\n gbeta = gbeta * e1 + gbeta_wkv\n geps = galpha_we + gbeta_we + geps_we + geps_wkv\n\n galpha_ptr = gstate_ptr + b_idx * gstate_s_b\n gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe\n geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe\n tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)\n tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)\n tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)\n\n gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)\n gw_temp += gw\n tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)\n gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)\n gu_temp += gu\n tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)\n\ndef fused_recurrent_rwkv4_backward(\n w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, grad_wkv: Tensor, grad_state: Tensor\n) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n bsz, tsz, chans = k.shape\n\n gw = torch.zeros_like(w)\n gu = torch.zeros_like(u)\n gk = torch.empty_like(k)\n gv = torch.empty_like(v)\n gstate = k.new_empty(bsz, 3, 1, chans)\n\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_backward_kernel[grid](\n w, w.stride(0), u, u.stride(0), k, k.stride(0), k.stride(1), k.stride(2),\n v, v.stride(0), v.stride(1), v.stride(2), state, state.stride(0), state.stride(1),\n state.stride(2), state.stride(3), grad_wkv, grad_wkv.stride(0), grad_wkv.stride(1),\n grad_wkv.stride(2), grad_state, grad_state.stride(0), grad_state.stride(1),\n grad_state.stride(3), gw, gw.stride(0), gu, gu.stride(0), gk, gk.stride(0),\n gk.stride(1), gk.stride(2), gv, gv.stride(0), gv.stride(1), gv.stride(2),\n gstate, gstate.stride(0), gstate.stride(1), gstate.stride(3), tsz, chans,\n BLOCK_SIZE_C=block_size_c,\n )\n\n return gw, gu, gk, gv, gstate\n", - "description_1": "Use triton language to implement a fused recurrent RWKV forward and backward kernel. The forward kernel takes 26 parameters: pointers to input tensors (w, u, k, v, state), strides for these tensors, pointers to output tensors (wkv, state_out), strides for output tensors, and constants (chans, tsz, BLOCK_SIZE_C). It computes the RWKV forward pass by iterating over the time dimension and updating the state. The backward kernel takes 41 parameters: pointers to input tensors, strides, pointers to gradient tensors, strides, and constants. It computes the gradients for the RWKV backward pass by iterating in reverse over the time dimension.", - "description_2": "Use triton language to create a fused recurrent RWKV kernel for forward and backward passes, handling input and output tensor pointers, strides, and constants for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom fla.ops.utils import chunk_reversed_cumsum_fwd\nfrom fla.utils import contiguous\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_cum(\n s,\n o,\n o_minus_s,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef post_process_grad(\n q,\n k,\n v,\n u,\n do,\n dk,\n dq,\n du,\n scale,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n H,\n T: tl.constexpr,\n BT: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n i_h = i_bh % H\n\n # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_u = tl.load(p_u, boundary_check=(0,))\n\n b_vdo = tl.sum(b_v * b_do, axis=1)\n b_du = b_vdo[:, None] * b_k * b_q * scale\n b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale\n b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale\n\n b_dq += tl.load(p_dq, boundary_check=(0, 1))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_dk += tl.load(p_dk, boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):\n q = r # alias\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_rwkv6_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g_org = g_org.view(B, H, NT, BT, -1)\n # g = g_org.cumsum(-2).view(B, H, T, -1)\n # gs = g - g_org\n chunk_rwkv6_fwd_kernel_cum[grid](\n g_org, g, gs,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_rwkv6_fwd_kernel_intra[grid](\n q, k, g, gs, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n\n grid = (NV, NT, B * H)\n chunk_rwkv6_fwd_kernel_inter[grid](\n q, v, gs, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n del g, gs\n ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A)\n ctx.BT = BT\n ctx.scale = scale\n ctx.checkpoint_level = checkpoint_level\n return o, final_state\n\n @staticmethod\n @contiguous\n def backward(ctx, do, dht=None):\n q, k, v, g, u, h, initial_state, A = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = ctx.BT, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_rwkv6_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n dh = q.new_empty(B, H, NT * K, V)\n dh0 = torch.empty_like(h0) if h0 is not None else None\n grid = (NK, NV, B * H)\n chunk_rwkv6_bwd_kernel_dh[grid](\n q, g, gs, do, dh, dh0,\n q.stride(1), q.stride(2), q.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n dh.stride(1), dh.stride(2), dh.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return dh, dh0\n\n # recompute cumulative log decays.\n g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_rwkv6_fwd_kernel_cum[grid](\n g_org, g, gs,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n\n # rerun the forward pass to get h if checkpoint_level >= 1\n if ctx.checkpoint_level == 1:\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=None\n )\n\n scale = ctx.scale\n dh, dh0 = bwd_inner(\n q, g, gs, initial_state, do,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n scale=scale\n )\n dq = torch.empty_like(q, dtype=torch.float)\n dk = torch.empty_like(k, dtype=torch.float)\n dv = v.new_empty(NK, *v.shape)\n dA = q.new_zeros(B, H, T, BT)\n grid = (NK, NT, B * H)\n chunk_rwkv6_bwd_kernel_inter[grid](\n k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0, dtype=dv.dtype)\n grid = (NK, NT * NC, B * H)\n chunk_rwkv6_bwd_kernel_intra[grid](\n q, k, g, gs, dA, dq, dk,\n k.stride(1), k.stride(2), k.stride(3),\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n # TODO: fuse?\n dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]\n dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)\n dg = chunk_reversed_cumsum_fwd(dg).to(g)\n # equivalent to the following pytorch code.\n # du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)\n # dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])\n # dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])\n BT = 64\n grid = (triton.cdiv(T, BT), B * H)\n du = torch.empty_like(g, dtype=torch.float)\n post_process_grad[grid](\n q, k, v, u, do, dk, dq, du, scale,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), H=H,\n T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),\n num_warps=4\n )\n du = du.sum([0, 2])\n return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None\n\n\ndef chunk_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n u: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 0\n) -> Tuple[torch.Tensor, torch.Tensor]:\n r\"\"\"\n Args:\n r (torch.Tensor):\n reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.\n k (torch.Tensor):\n keys of shape `(B, H, T, K)`\n v (torch.Tensor):\n values of shape `(B, H, T, V)`\n w (torch.Tensor):\n data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.\n u (torch.Tensor):\n bonus of shape `(H, K)`\n scale (Optional[int]):\n Scale factor for the RWKV6 attention scores.\n If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n initial_state (Optional[torch.Tensor]):\n Initial state of shape `(B, H, K, V)`. Default: `None`.\n output_final_state (Optional[bool]):\n Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.\n checkpoint_level (Optional[int]):\n Checkpointing level; higher values will save more memories and do more recomputations during backward.\n Default: `0`:\n - Level `0`: store forward hidden states for backprop.\n - Level `1`: recompute the forward hidden states during backward.\n \"\"\"\n assert checkpoint_level in [0, 1]\n if scale is None:\n scale = r.shape[-1] ** -0.5\n o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)\n return o, final_state\n", - "description_1": "Use triton language to implement a series of kernels for efficient forward and backward computation in a RWKV model. The kernels are tailored for specific tensor shapes and leverage the GPU's parallelism. They handle cumulative summation, gradient post-processing, and intra/inter-block operations. The function `chunk_rwkv6` acts as a wrapper for these kernels, allowing for both forward and backward passes with optional checkpointing for memory efficiency.", - "description_2": "Use triton language to create efficient forward and backward kernels for a RWKV model. Implement cumulative summation, gradient post-processing, and intra/inter-block operations for specific tensor shapes, optimizing GPU parallelism.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.ops.utils import chunk_reversed_cumsum_fwd\nfrom fla.utils import contiguous\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel(\n q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr\n):\n # Triton kernel logic...\n pass\n\n@triton.jit\ndef fused_recurrent_rwkv6_bwd_kernel_dq(\n k, v, w, u, do, dq, dq_aux, h0, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr\n):\n # Triton kernel logic...\n pass\n\n@triton.jit\ndef fused_recurrent_rwkv6_bwd_kernel_dkv(\n q, k, v, w, u, do, dk, dk_aux, dv, dh0, s_k_h, s_v_h, scale,\n B, H, T, BK: tl.constexpr, BV: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr\n):\n # Triton kernel logic...\n pass\n\nclass FusedRecurrentRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):\n # alias\n q = r\n B, H, T, K, V = *q.shape, v.shape[-1]\n\n BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n\n o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_fwd_kernel[grid](\n q, k, v, w, u, o, initial_state, final_state,\n k.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, w, u, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n # we do not need the gradient of the final state from the next chunk\n # similiar to Trunctated BPTT\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n @contiguous\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, w, u, initial_state, o = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32)\n dq_aux = torch.empty_like(dq)\n grid = (NV, NK, B * H)\n\n fused_recurrent_rwkv6_bwd_kernel_dq[grid](\n k, v, w, u, do, dq, dq_aux, initial_state,\n q.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n )\n dq = dq.sum(0).to(q)\n dq_aux = dq_aux.sum(0)\n\n BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n\n dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32)\n dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32)\n dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_bwd_kernel_dkv[grid](\n q, k, v, w, u, do, dk, dk_aux, dv, dh0,\n q.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n )\n dk = dk.sum(0).to(k)\n dv = dv.sum(0).to(v)\n dk_aux = dk_aux.sum(0)\n\n dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1]\n dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)\n dw = chunk_reversed_cumsum_fwd(dw).to(w)\n\n du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u)\n return dq, dk, dv, dw, du, None, dh0, None, None\n\ndef fused_recurrent_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n w: torch.Tensor,\n u: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n r\"\"\"\n Args:\n r (torch.Tensor): reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.\n k (torch.Tensor): keys of shape `(B, H, T, K)`\n v (torch.Tensor): values of shape `(B, H, T, V)`\n w (torch.Tensor): data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.\n u (torch.Tensor): bonus of shape `(H, K)`\n scale (Optional[int]): Scale factor for the RWKV6 attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n initial_state (Optional[torch.Tensor]): Initial state of shape `(B, H, K, V)`. Default: `None`.\n output_final_state (Optional[bool]): Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.\n \"\"\"\n if scale == -1:\n scale = r.shape[-1] ** -0.5\n o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to create forward and backward kernels for a custom recurrent attention mechanism with Triton, implementing both forward and backward passes. The forward kernel calculates output tensors by performing operations on input tensors like query, key, value, and others based on specified dimensions and control flags. The backward kernels compute gradients for the inputs based on gradients of the outputs, using stored states and control flags. The code includes an autograd function for PyTorch, wrapping the kernels for gradient computation.", - "description_2": "Use triton language to implement a custom recurrent attention mechanism, providing forward and backward operations for efficient computation on GPUs using Triton kernels, interfaced through PyTorch autograd.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_h(\n k, v, h, g, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n \n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n b_h *= tl.math.exp2(b_g_last)\n b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))\n b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q, k, v, h, g, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_o = b_o * tl.math.exp2(b_g)[:, None]\n b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dh(\n q, g, do, dh, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkv(\n q, k, v, h, g, do, dh, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n mask = tl.math.exp2(b_g[None, :] - b_g[:, None])\n mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)\n b_s = b_s * mask\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n \n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n \n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_dq = b_dq * tl.math.exp2(b_g)[:, None]\n b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]\n b_ds = b_ds * tl.trans(mask)\n b_ds = b_ds.to(b_k.dtype)\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\nclass SimpleGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, g, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n assert T % BT == 0, 'sequence length must be divisible by BT'\n g = g.reshape(B, H, -1, BT)\n g = g.cumsum(-1) * 1.44269504\n g = g.reshape(B, H, -1)\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_fwd_kernel_h[grid](\n k, v, h, g, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h, g)\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_ht=None):\n q, k, v, h, g = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_bwd_kernel_dh[grid](\n q, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_simple_gla_bwd_kernel_dqkv[grid](\n q, k, v, h, g, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n dg = (dq * q - dk * k).sum(-1)\n\n def rev_cumsum(x):\n cumsum_x = x.cumsum(-1)\n rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x\n return rev_cumsum_x + x\n dg = rev_cumsum(dg)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None\n\n\ndef chunk_simple_gla(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, \n initial_state: torch.Tensor = None, output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n g = g.float()\n o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement kernels for forward and backward passes of a chunk-based generalized linear attention mechanism, which computes attention scores and updates states over chunks of input tensors q, k, v, g, utilizing specific block sizes and strides to efficiently process multi-dimensional data in parallel.", - "description_2": "Use triton language to write kernels that handle the forward and backward calculations for a chunk-based attention mechanism, efficiently managing tensor data with specified constraints on block sizes and memory strides.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_fwd_kernel(\n s,\n z,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_bwd_kernel(\n ds,\n dz,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)\n\n b_ds = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)\n tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_ds += tl.sum(b_dz, 0)\n\ndef chunk_cumsum_fwd(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_cumsum_fwd_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n\ndef chunk_cumsum_bwd(\n dz: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = dz.shape\n BS = 32\n\n dtype = dtype or dz.dtype\n grid = (triton.cdiv(S, BS), B * H)\n ds = torch.empty_like(dz, dtype=dtype)\n chunk_cumsum_bwd_kernel[grid](\n ds, dz,\n ds.stride(1), ds.stride(2), ds.stride(3),\n T=T, S=S, BS=BS\n )\n return ds\n", - "description_1": "Use triton language to implement a forward and backward cumulative sum operation on a 4D tensor. The forward kernel 'chunk_cumsum_fwd_kernel' takes 8 parameters: input tensor 's', output tensor 'z', strides 's_s_h', 's_s_t', 's_s_d', and constants 'T', 'S', 'BT', 'BS'. It computes the cumulative sum along the last dimension in chunks. The backward kernel 'chunk_cumsum_bwd_kernel' takes the same parameters but computes the gradient of the cumulative sum. The functions 'chunk_cumsum_fwd' and 'chunk_cumsum_bwd' are Python wrappers that set up the grid and call the respective kernels.", - "description_2": "Use triton language to create a forward and backward cumulative sum operation on a 4D tensor with chunk processing, utilizing kernels for efficient computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward attention computation.\n@triton.jit\ndef _attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, \n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,\n):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n else:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_block_ptr)\n acc += tl.dot(p.to(tl.float16), v)\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n# Triton kernel for forward attention computation.\n@triton.jit\ndef _attn_fwd(\n Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, \n stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX: tl.constexpr, \n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale\n qk_scale *= 1.44269504\n q = tl.load(Q_block_ptr)\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, \n BLOCK_M, BLOCK_DMODEL, BLOCK_N, 1, offs_m, offs_n,\n )\n tl.debug_barrier()\n if STAGE & 2:\n acc, l_i, m_i = _attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, \n BLOCK_M, BLOCK_DMODEL, BLOCK_N, 2, offs_m, offs_n,\n )\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\n# PyTorch function encapsulating Triton kernels for attention.\nclass _attention(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n BLOCK_M = 128\n BLOCK_N = 64 if Lk <= 64 else 32\n num_stages = 4 if Lk <= 64 else 3\n num_warps = 4\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n M = torch.empty(\n (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32\n )\n _attn_fwd[grid](\n q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), \n k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), \n v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), \n q.shape[0], q.shape[1], N_CTX=q.shape[2], BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, \n BLOCK_DMODEL=Lk, STAGE=3, num_warps=num_warps, num_stages=num_stages,\n )\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, M = ctx.saved_tensors\n assert do.is_contiguous()\n assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n BATCH, N_HEAD, N_CTX = q.shape[:3]\n PRE_BLOCK = 128\n NUM_WARPS, NUM_STAGES = 4, 1\n BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32\n BLK_SLICE_FACTOR = 2\n RCP_LN2 = 1.4426950408889634\n arg_k = k\n arg_k = arg_k * (ctx.sm_scale * RCP_LN2)\n PRE_BLOCK = 128\n assert N_CTX % PRE_BLOCK == 0\n pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)\n delta = torch.empty_like(M)\n _attn_bwd_preprocess[pre_grid](\n o, do, delta, BATCH, N_HEAD, N_CTX, BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD)\n _attn_bwd[grid](\n q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, M, delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), N_HEAD, N_CTX,\n BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2,\n BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, BLOCK_DMODEL=ctx.BLOCK_DMODEL,\n num_warps=NUM_WARPS, num_stages=NUM_STAGES,\n )\n return dq, dk, dv, None, None\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement attention forward and backward computations for a neural network. The Triton kernels involve computing attention scores, masking, and performing matrix multiplication between queries, keys, and values. The forward computation kernel (_attn_fwd) takes inputs for queries (Q), keys (K), and values (V), along with various stride and block size parameters. It then computes the attention scores and outputs the resulting context vectors. The backward computation (_attn_bwd) is responsible for calculating the gradients for queries, keys, and values using the saved context from the forward pass.", - "description_2": "Use triton language to implement a neural network attention mechanism, with separate kernels for forward and backward computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, # Pointers to matrices\n X,\n COS,\n SIN,\n CU_SEQLENS,\n SEQLEN_OFFSETS, # this could be int or a pointer\n # Matrix dimensions\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n CACHE_KEY_SEQLEN,\n # strides\n stride_out_batch,\n stride_out_seqlen,\n stride_out_nheads,\n stride_out_headdim,\n stride_x_batch,\n stride_x_seqlen,\n stride_x_nheads,\n stride_x_headdim,\n # Meta-parameters\n BLOCK_K: tl.constexpr,\n IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr,\n INTERLEAVED: tl.constexpr,\n CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT\n X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(\n COS,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half),\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(\n X,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x1 = tl.load(\n X + rotary_dim_half * stride_x_headdim,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n # write back result\n OUT = OUT + (\n rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim\n )\n tl.store(\n OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)\n )\n tl.store(\n OUT + rotary_dim_half * stride_out_headdim,\n o1,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n )\n else:\n # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.\n # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].\n # Loading x0 will be fast but x1 will be slow.\n # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].\n # Then we do the calculation and use tl.where to pick put the right outputs for the even\n # and for the odd indices.\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(\n COS,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(\n X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0\n ).to(tl.float32)\n x1 = tl.load(\n X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))\n\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved=False,\n inplace=False,\n conjugate=False,\n) -> torch.Tensor:\n \"\"\"\n Arguments:\n x: (batch, seqlen, nheads, headdim) if cu_seqlens is None\n else (total_seqlen, nheads, headdim).\n cos: (seqlen_ro, rotary_dim / 2)\n sin: (seqlen_ro, rotary_dim / 2)\n seqlen_offsets: integer or integer tensor of size (batch,)\n cu_seqlens: (batch + 1,) or None\n max_seqlen: int\n Returns:\n y: (batch, seqlen, nheads, headdim)\n \"\"\"\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert (\n max_seqlen is not None\n ), \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch_p_1 = cu_seqlens.shape[0]\n batch = batch_p_1 - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n rotary_dim *= 2\n assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n assert (\n cos.dtype == sin.dtype\n ), f\"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}\"\n assert (\n x.dtype == cos.dtype\n ), f\"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}\"\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = (\n 32\n if rotary_dim <= 32\n else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n )\n grid = lambda META: (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads) # noqa\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n # Need this, otherwise Triton tries to launch from cuda:0 and we get\n # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, # data ptrs\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen, # shapes\n nheads,\n rotary_dim,\n seqlen_ro,\n seqlen // 128, # key for triton cache (limit number of compilations)\n output.stride(0)\n if not is_varlen\n else 0, # batch_strides if not varlen else 0\n output.stride(-3), # seqlen_stride or total_seqlen_stride\n output.stride(-2), # nheads_stride\n output.stride(-1), # headdim_stride\n x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0\n x.stride(-3), # seqlen stride or total_seqlen_stride\n x.stride(-2), # nheads stride\n x.stride(-1), # headdim stride\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n", - "description_1": "Use triton language to define a kernel called 'rotary_kernel' which computes a rotation transformation on input matrices. It takes 29 parameters: 10 tensor pointers, 4 integers for dimensions, 4 integers for strides, and 10 compile-time constants. Another function 'apply_rotary' calls this kernel. It accepts 9 arguments: 3 tensors (x, cos, sin), an integer or tensor for sequence length offsets, an optional tensor for cumulative sequence lengths, two optional integers (max_seqlen, CACHE_KEY_SEQLEN), and 3 boolean flags. The rotary_kernel applies cosine and sine transformations to input matrices based on parameters and outputs transformed data.", - "description_2": "Use triton language to implement a kernel for rotating matrix transformations based on cosine and sine inputs, called by a function that handles batch and sequence details.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Write mean / rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device='cuda')\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n \nlayer_norm = LayerNorm.apply\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n\n # compare\n assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)\n\ntest_layer_norm(1151, 8192, torch.float16)\n", - "description_1": "Use triton language to implement a fused forward layer normalization kernel. The kernel (_layer_norm_fwd_fused) takes in 10 parameters: pointers to the input tensor X, output tensor Y, weights W, biases B, mean, and reciprocal of standard deviation Rstd. It also takes stride, the number of columns in X (N), a small epsilon to prevent division by zero, and a block size constant. The kernel computes mean and variance for each row, normalizes the input, and applies linear transformations with the provided weights and biases. The kernel is then called in the LayerNorm forward method with additional context saving.", - "description_2": "Use triton language to implement a layer normalization kernel that computes mean and variance per row, normalizes data, and applies linear transformation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Write mean / rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n # Write output\n tl.store(Y + cols, y, mask=mask)\n \nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device='cuda')\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n \nlayer_norm = LayerNorm.apply\n", - "description_1": "Use triton language to implement a layer normalization operation for 2D tensors, with inputs for the tensor itself, weights, biases, mean, rstd, stride, number of columns, and an epsilon value for numerical stability. The operation is optimized for memory and performance constraints, utilizing block size heuristics, and aims to produce an output tensor with normalized values.", - "description_2": "Use triton language to optimize layer normalization with 2D tensor input and block size heuristics for performance.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rms_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = tl.math.rsqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\nclass RmsNormFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _rms_norm_fwd_fused[(M,)](\n x_arg,\n y,\n weight,\n x_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n num_ctas=1,\n )\n ctx.save_for_backward(x, weight, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n\nclass RMSNorm(torch.nn.Module):\n def __init__(self, dim: int, eps: float = 1e-6):\n super().__init__()\n self.eps = eps\n self.dim = dim\n self.weight = nn.Parameter(torch.ones(dim))\n self.rms_norm = RmsNormFunction.apply\n\n def forward(self, x):\n return self.rms_norm(x, self.dim, self.weight, self.eps)\n\n", - "description_1": "Use Triton language to implement a fused RMS normalization forward pass. The kernel computes the row-wise variance, applies normalization, and multiplies the result by a weight for each row. The kernel uses the program ID to map the computation to rows in the input tensor. It supports efficient parallelization using Triton's block-based execution model. The input tensor is normalized row-wise, and the output tensor is computed with normalization and scaling by the weights.", - "description_2": "Use Triton language to perform row-wise RMS normalization and scaling by weight in a parallelized kernel, supporting efficient memory and computation management.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef split_pad_kernel(\n input_ptr,\n output_ptr,\n start_ptr,\n len_ptr,\n hidden_dim,\n stride_i0,\n stride_o0,\n stride_o1,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n bid = tl.program_id(axis=1)\n\n i_start = tl.load(start_ptr + bid)\n len = tl.load(len_ptr + bid)\n\n off = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = off < len * hidden_dim\n vec = tl.load(input_ptr + i_start * stride_i0 + off, mask=mask)\n\n off1 = off // hidden_dim\n off2 = off % hidden_dim\n\n # mask = off1 < len\n tl.store(output_ptr + bid * stride_o0 + off1 * stride_o1 + off2, vec, mask=mask)\n\n\ndef split_and_pad(input: torch.Tensor, batch_info_set) -> torch.Tensor:\n if type(batch_info_set) == int:\n return input\n assert input.ndim == 2\n assert input.is_contiguous()\n\n batch_info, batch_size, hidden_dim, max_len, start, output = batch_info_set\n split_pad_kernel[\n lambda meta: (triton.cdiv(hidden_dim * max_len, meta[\"BLOCK_SIZE\"]), batch_size)\n ](\n input_ptr=input,\n output_ptr=output,\n start_ptr=start,\n len_ptr=batch_info,\n hidden_dim=hidden_dim,\n stride_i0=input.stride(0),\n stride_o0=output.stride(0),\n stride_o1=output.stride(1),\n BLOCK_SIZE=2048,\n )\n\n return output\n", - "description_1": "Use triton language to define a kernel function 'split_pad_kernel' that splits and pads input tensors based on given batch information. The kernel takes 9 arguments: input_ptr (pointer to the input tensor), output_ptr (pointer to the output tensor), start_ptr (pointer to the start indices for each batch), len_ptr (pointer to the lengths of each batch), hidden_dim (dimension of the hidden layer), stride_i0 (stride of the input tensor along the 0th dimension), stride_o0 (stride of the output tensor along the 0th dimension), stride_o1 (stride of the output tensor along the 1st dimension), and BLOCK_SIZE (block size for triton). A higher-level function 'split_and_pad' is used to set up the kernel execution with the appropriate arguments.", - "description_2": "Use triton language to implement a custom kernel that efficiently performs split and pad operations on tensors by utilizing triton.jit for just-in-time compilation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom fla.utils import contiguous\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_fwd_kernel(\n x,\n y,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_y = y + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_m = tl.minimum(0., b_x)\n b_z = 1. + tl.exp(-tl.abs(b_x))\n b_y = b_m - tl.log(b_z)\n tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_bwd_kernel(\n x,\n dx,\n dy,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_dx = dx + o_i\n p_dy = dy + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)\n b_dx = b_dy * (1. - tl.sigmoid(b_x))\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n\nclass LogSigmoidFunction(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n def forward(ctx, x):\n T, D = x.numel(), x.shape[-1]\n y = torch.empty_like(x)\n logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)\n ctx.save_for_backward(x,)\n return y\n\n @staticmethod\n @contiguous\n def backward(ctx, dy):\n x, = ctx.saved_tensors\n T, D = x.numel(), x.shape[-1]\n dx = torch.empty_like(x)\n logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)\n return dx\n\n\nlogsigmoid = LogSigmoidFunction.apply\n", - "description_1": "Use triton language to implement a logarithmic sigmoid forward kernel and its backward pass. The forward kernel computes the logarithmic sigmoid of an input tensor 'x' and stores the result in 'y'. It utilizes the triton language capabilities for parallel execution on GPUs. The backward kernel computes the gradient of the input tensor 'x' based on the output gradient 'dy'. Both kernels use parameters T (total elements in x), D (dimension of x), and BT (block size for tiling). A LogSigmoidFunction class encapsulates the use of these kernels for forward and backward operations in an autograd-friendly manner.", - "description_2": "Use triton language to create logarithmic sigmoid forward and backward GPU kernels.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n O, # pointer to the gate\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual out\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n O += row * stride_x_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols <\n N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n\n # Swish output gate\n o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)\n y = y * o * tl.sigmoid(o)\n\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32,\n device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n o,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n O, # pointer to the gate\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DO, # pointer to the gate gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n O += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n DO += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n\n y = xhat * w if HAS_WEIGHT else xhat\n if HAS_BIAS:\n y = y + b\n if RECOMPUTE_OUTPUT:\n tl.store(Y + cols, y, mask=mask)\n\n sigmoid_o = tl.sigmoid(o)\n do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))\n dy = dy * o * sigmoid_o\n wdy = dy\n if HAS_WEIGHT:\n wdy = dy * w\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n # Write dx\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n tl.store(DO + cols, do, mask=mask)\n\n X += stride_x_row\n O += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n DO += stride_dx_row\n if HAS_WEIGHT:\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n o,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n do = (\n torch.empty_like(o)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = (\n torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n if weight is not None\n else None\n )\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n o,\n weight,\n bias,\n y,\n dy,\n dx,\n do,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n weight is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype) if weight is not None else None\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n # Don't need to compute dresidual_in separately in this case\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)\n\n\nclass LayerNormSwishGateFn(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n def forward(\n ctx,\n x,\n o,\n weight,\n bias,\n residual=None,\n eps=1e-6,\n prenorm=False,\n residual_in_fp32=False,\n is_rms_norm=False,\n ):\n x_shape_og = x.shape\n o_shape_og = o.shape\n # reshape input data into 2D tensor\n x = x.reshape(-1, x.shape[-1])\n o = o.reshape(-1, o.shape[-1])\n if residual is not None:\n assert residual.shape == x_shape_og\n residual = residual.reshape(-1, residual.shape[-1])\n residual_dtype = (\n residual.dtype\n if residual is not None\n else (torch.float32 if residual_in_fp32 else None)\n )\n y, mean, rstd, residual_out = _layer_norm_fwd(\n x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm\n )\n ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)\n ctx.x_shape_og = x_shape_og\n ctx.o_shape_og = o_shape_og\n ctx.eps = eps\n ctx.is_rms_norm = is_rms_norm\n ctx.has_residual = residual is not None\n ctx.prenorm = prenorm\n ctx.x_dtype = x.dtype\n y = y.reshape(x_shape_og)\n return y if not prenorm else (y, residual_out.reshape(x_shape_og))\n\n @staticmethod\n @contiguous\n def backward(ctx, dy, *args):\n x, o, weight, bias, mean, rstd = ctx.saved_tensors\n dy = dy.reshape(-1, dy.shape[-1])\n assert dy.shape == x.shape\n if ctx.prenorm:\n dresidual = args[0]\n dresidual = dresidual.reshape(-1, dresidual.shape[-1])\n assert dresidual.shape == x.shape\n else:\n dresidual = None\n dx, do, dw, db, dresidual_in = _layer_norm_bwd(\n dy,\n x,\n o,\n weight,\n bias,\n ctx.eps,\n mean,\n rstd,\n dresidual,\n ctx.has_residual,\n ctx.is_rms_norm,\n x_dtype=ctx.x_dtype,\n )\n return (\n dx.reshape(ctx.x_shape_og),\n do.reshape(ctx.o_shape_og),\n dw,\n db,\n dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,\n None,\n None,\n None,\n None,\n )\n\ndef layer_norm_swish_gate_fn(\n x,\n o,\n weight,\n bias,\n residual=None,\n prenorm=False,\n residual_in_fp32=False,\n eps=1e-6\n):\n return LayerNormSwishGateFn.apply(\n x,\n o,\n weight,\n bias,\n residual,\n eps,\n prenorm,\n residual_in_fp32,\n False\n )\n", - "description_1": "Use triton language to implement a layer normalization forward and backward kernel with Swish gate function. The forward pass kernel (_layer_norm_fwd_1pass_kernel) takes 19 inputs including pointers to input, gate, output, weights, biases, residuals, mean and rstd, strides, feature size, epsilon, and compile-time constants for conditions. The forward operation computes mean and variance, normalizes the input, applies linear transformations, Swish gating, and stores the output. The backward kernel (_layer_norm_bwd_kernel) has 28 inputs similar to the forward pass, with additional pointers and computations for gradients and applying Swish gating during backpropagation. The forward function for the PyTorch autograd (LayerNormSwishGateFn) utilizes these kernels, reshaping inputs and saving necessary variables for backward computations.", - "description_2": "Use triton language to create a forward kernel for layer normalization with Swish gate, and implement its backward pass for autograd. Use these kernels in a PyTorch custom autograd function.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols <\n N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32,\n device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w if HAS_WEIGHT else xhat\n if HAS_BIAS:\n y = y + b\n tl.store(Y + cols, y, mask=mask)\n wdy = dy\n if HAS_WEIGHT:\n wdy = dy * w\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n if HAS_WEIGHT:\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(\n x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype,\n device=dy.device) if recompute_output else None\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = (\n torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n if weight is not None\n else None\n )\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n weight is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype) if weight is not None else None\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement a fused forward and backward kernel for layer normalization, with support for optional residual connections and the choice between standard and RMS normalization. The forward kernel takes inputs, weights, biases, and computes the normalized output and intermediate statistics (mean, variance or inverse std). The backward kernel computes gradients for the inputs, weights, and biases using the output gradients and stored statistics, with optional recomputation of output to save memory.", - "description_2": "Use triton language to implement fused kernels for efficient layer normalization operations, with optional residuals and RMS norm support.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k,\n v,\n z,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n NORMK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n if NORMK:\n p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,))\n else:\n p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_z0).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if NORMK:\n p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n # [BK,]\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n # [BK, BV]\n b_h = b_h * b_r[:, None]\n b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype)\n else:\n p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n # [BV,]\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n # [BK, BV]\n b_h = b_h * b_r[None, :]\n b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)\n # [BK, BV]\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_K(\n q,\n k,\n z,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_p = tl.maximum(i_t * BT - 1, 0)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # [BT, BV]\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n # [BT, BT]\n b_A += tl.dot(b_q, b_k, allow_tf32=False)\n p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n # [BT, BV]\n b_z = tl.load(p_z, boundary_check=(0, 1))\n # [BT, BV]\n p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_zp, boundary_check=(0,))\n b_o = b_o * tl.exp(b_zp[None, :] - b_z)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n # [BT, BT]\n b_A = tl.where(m_s, b_A, 0.)\n if i_v == 0:\n tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)\n return ov, final_state\n", - "description_1": "Use triton language to define several kernel functions for a custom operation involving tensors q, k, v, and s. Implement forward kernel chunk_abc_fwd_kernel_h with 22 parameters, handling operations like loading and storing tensor blocks, handling initial and final state, calculating norms, etc. Implement kernel chunk_abc_fwd_kernel_K with 22 parameters for calculating matrix product and transformation involving q, k, z, h, and other tensors.", - "description_2": "Use triton language to create kernels for performing tensor computations related to chunked attention, handling initial and final states, and applying softmax operations within the kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_cum(\n s, o, s_s_h, s_s_t, s_s_d,\n T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_pre(g, B, H, T, S, BT):\n NT = triton.cdiv(T, BT)\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)\n chunk_gated_abc_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=S, BT=BT\n )\n return g\n", - "description_1": "Use triton language to implement a forward kernel for chunk gated operations, where the kernel handles cumulative operations. The kernel takes 8 parameters: s (input tensor), o (output tensor), s_s_h (stride in the first dimension), s_s_t (stride in the second dimension), s_s_d (stride in the third dimension), and three compile-time constants: T, S, and BT (block sizes). The kernel uses triton's make_block_ptr, load, and store functions to manage memory blocks and perform matrix multiplication operations using triton's dot product function.", - "description_2": "Use triton language to create a kernel that computes cumulative operations for chunk gated operations in a tensor, utilizing matrix multiplication and block memory pointers for efficient computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_recurrent_gated_abc_fwd_kernel(\n q, k, v, gk, gv, o, h0, ht, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr,\n V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr, USE_GK: tl.constexpr,\n USE_GV: tl.constexpr\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * b_gk[None, :]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * b_gv[:, None]\n h += b_k[None, :] * b_v[:, None]\n b_o = h * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_gated_abc_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, h0, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr,\n V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * b_gk[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * b_gv[None, :]\n h += b_k[:, None] * b_v[None, :]\n b_dq = tl.sum(h * b_do[None, :], axis=1) * scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -K if REVERSE else K\n p_v += -V if REVERSE else V\n p_q += -K if REVERSE else K\n p_do += -V if REVERSE else V\n p_dq += -K if REVERSE else K\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n # sync threads\n tl.debug_barrier()\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for _ in range(T):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n b_dh += b_q[:, None] * b_do[None, :]\n b_dk = tl.sum(b_dh * b_v[None, :], axis=1)\n b_dv = tl.sum(b_dh * b_k[:, None], axis=0)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n b_dh *= b_gk[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n b_dh *= b_gv[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_q += K if REVERSE else -K\n p_k += K if REVERSE else -K\n p_v += V if REVERSE else -V\n p_do += V if REVERSE else -V\n p_dk += K if REVERSE else -K\n p_dv += V if REVERSE else -V\n if USE_GK:\n p_gk += K if REVERSE else -K\n if USE_GV:\n p_gv += V if REVERSE else -V\n\n\nclass FusedRecurrentGatedABCFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, s, g, scale=None, initial_state=None, output_final_state=False, reverse=False):\n B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]\n if scale is None:\n scale = K ** -0.5\n\n BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n num_stages = 1\n num_warps = 1\n\n g = g.float().exp()\n\n final_state = (None, None)\n if output_final_state:\n final_state = (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))\n\n ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)\n gk, gv = None, g\n grid = (NM, NK, B * H)\n fused_recurrent_gated_abc_fwd_kernel[grid](\n q, k, s, gk, gv, ok, initial_state[0], final_state[0],\n k.stride(1),\n s.stride(1),\n scale=scale,\n B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,\n USE_INITIAL_STATE=initial_state[0] is not None,\n STORE_FINAL_STATE=final_state[0] is not None,\n USE_GK=False,\n USE_GV=True,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ok = ok.sum(0)\n\n qv = ok.softmax(-1, dtype=torch.float)\n ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)\n gk, gv = g, None\n grid = (NV, NM, B * H)\n fused_recurrent_gated_abc_fwd_kernel[grid](\n qv, s, v, gk, gv, ov, initial_state[1], final_state[1],\n s.stride(1),\n v.stride(1),\n scale=1.,\n B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,\n USE_INITIAL_STATE=initial_state[0] is not None,\n STORE_FINAL_STATE=final_state[0] is not None,\n USE_GK=True,\n USE_GV=False,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ov = ov.sum(0)\n\n ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = tuple(i.detach() for i in final_state)\n return ov.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dht=None):\n q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors\n B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]\n V = v.shape[-1]\n scale = ctx.scale\n\n BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n num_stages = 1\n num_warps = 1\n\n dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)\n dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)\n dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)\n gk, gv = g, None\n grid = (NV, NM, B * H)\n fused_recurrent_gated_abc_bwd_kernel[grid](\n qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1],\n s.stride(1),\n v.stride(1),\n scale=1.,\n B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state[1] is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dqv = dqv.sum(0)\n dsv = dsv.sum(0)\n dv = dv.sum(0)\n dgk = dqv * qv.float() - dsv * s.float()\n dgk_cumsum = dgk.cumsum(-2)\n dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum\n\n dok = qv * (dqv - (qv * dqv).sum(-1, True))\n dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)\n dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)\n dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)\n gk, gv = None, g\n grid = (NM, NK, B * H)\n fused_recurrent_gated_abc_bwd_kernel[grid](\n q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0],\n q.stride(1),\n s.stride(1),\n scale=scale,\n B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state[0] is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dsk = dsk.sum(0)\n\n dgv = dok.float() * ok.float() - dsk * s.float()\n dgv_cumsum = dgv.cumsum(-2)\n dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum\n\n ds = dsk.add_(dsv)\n dg = dgk.add_(dgv)\n\n return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None\n\n\ndef fused_recurrent_gated_abc(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor,\n g: Optional[torch.Tensor] = None, scale: Optional[int] = None,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n if g is None:\n z = s.float().logcumsumexp(2)\n g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z\n s = torch.exp(s - z).to(k.dtype)\n if scale is None:\n scale = q.shape[-1] ** -0.5\n ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state)\n return ov, final_state\n", - "description_1": "Use triton language to implement two kernels, one for forward and one for backward pass of a fused recurrent gated operation on multi-dimensional tensors with a specified set of parameters including queries, keys, values, forget gates, scale, initial states, etc.", - "description_2": "Use triton language to perform forward and backward passes of a gated recurrent operation on tensors with parameters for queries, keys, values, forget gates, and other configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_chunk_based_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h_0o = tl.zeros([BV], dtype=tl.float32)\n b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_0o = 0\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_k_2o = b_k[:, None, :] * b_k[None, :, :]\n b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_z = tl.zeros([BT], dtype=tl.float32)\n\n b_o += b_h_0o\n b_z += k_0o\n b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)\n b_z += tl.sum(b_q * k_1o, axis=1)\n b_q_2o = b_q[:, :, None] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)\n b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5\n b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5\n\n k_1o += tl.sum(b_k, axis=1)[None, :]\n k_2o += tl.sum(b_k_2o, axis=1)[None, :]\n k_0o += BT\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)\n\n b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)\n b_h_0o = b_h_0o + tl.sum(b_v, axis=0)\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_z += BT\n\n@triton.jit\ndef fused_chunk_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)\n b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)\n\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n\n b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)\n if i_v == 0:\n b_dq += b_dz[:, None] * k_1o\n b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5\n if i_v == 0:\n b_dq_2o += (b_dz[:, None] * k_2o) * 0.5\n b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])\n b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)\n b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)\n b_dq *= scale\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)\n\n if i_v == 0:\n k_1o += tl.sum(b_k, axis=0)[None, :]\n k_2o += tl.sum(b_k_2o, axis=0)[None, :]\n\n tl.debug_barrier()\n b_h_1o = None\n b_h_2o = None\n\n b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n b_dh_0o = tl.zeros([BV], dtype=tl.float32)\n m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]\n\n dq_1o = tl.zeros([1, BK], dtype=tl.float32)\n dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)\n\n for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i\n\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_dv = tl.zeros([BT, BV], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n b_ds *= (1+b_s)\n\n b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n\n b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)\n b_dv += b_dh_0o\n\n b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)\n\n if i_v == 0:\n b_dk += dq_1o\n\n b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)\n if i_v == 0:\n b_dk_2o += dq_2o\n b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])\n b_k_fp32 = tl.trans(b_k.to(tl.float32))\n b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)\n b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)\n b_dk += tl.trans(b_dk2)\n\n b_dh_0o += tl.sum(b_do, axis=0)\n b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)\n b_q_2o = b_q[None, :, :] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)\n b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5\n\n if i_v == 0:\n dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]\n dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkBasedFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, scale=1):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=torch.float32)\n z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n )\n o = o.sum(0)\n z = z.sum(0)\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.to(q.dtype), z.to(z.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None\n\ntriton_fused_chunk_based = FusedChunkBasedFunction.apply\n\n\ndef fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):\n assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_fused_chunk_based(q, k, v, scale)\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n\n return o.to(q.dtype)\n", - "description_1": "Use triton language to define a fused forward kernel (fused_chunk_based_fwd_kernel) with 20 parameters for computing matrix operations related to query, key, and value tensors (q, k, v) with specific strides and block sizes. Implement a corresponding backward kernel (fused_chunk_based_bwd_kernel) with 24 parameters to compute gradients of these tensors. Create an autograd function FusedChunkBasedFunction with 4 parameters that applies these kernels in its forward and backward methods. Finally, encapsulate the function in a callable triton_fused_chunk_based and use it in fused_chunk_based to execute with scaling and normalization options.", - "description_2": "Use triton language to create a forward kernel and a backward kernel for efficient computation of attention-like operations over given query, key, and value tensors. Implement the kernels in an autograd function to leverage PyTorch's automatic differentiation for training deep learning models.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.utils import contiguous\n\n\n@triton.jit\ndef parallel_based_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef parallel_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n\n _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n @contiguous\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel-based forward and backward kernel for a sequence mixer. The forward kernel takes in 18 parameters: q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, and DV. It computes output tensors 'o' and 'z'. The backward kernel is structured similarly but computes gradients for q, k, and v, using additional internal subfunctions.", - "description_2": "Use triton language to create a sequence mixer with forward and backward operations that handle tensor 'q', 'k', and 'v' with specific striding and scale configurations. Forward computes results stored in 'o' and 'z', while backward computes gradients for optimization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_prepare_dv_kernel(\n q, # query tensor\n k, # key tensor\n do, # delta output tensor\n dv, # delta value tensor\n s_qk_h, # stride for qk height\n s_qk_t, # stride for qk time\n s_qk_d, # stride for qk depth\n s_vo_h, # stride for value output height\n s_vo_t, # stride for value output time\n s_vo_d, # stride for value output depth\n T, # total time\n K, # total key\n V, # total value\n scale, # scaling factor\n BT: tl.constexpr, # block time size\n BK: tl.constexpr, # block key size\n BV: tl.constexpr # block value size\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1)) \n b_q = (b_q * scale).to(b_k.dtype)\n b_A += tl.dot(b_k, b_q, allow_tf32=False)\n\n b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.dot(b_A, b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_dv(q, k, do, BT):\n dv = torch.empty_like(do)\n B, H, T, K, V = *k.shape, do.shape[-1]\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_prepare_dv_kernel[(NT, B*H)](\n q, k, do, dv,\n k.stride(1), k.stride(2), k.stride(3), \n do.stride(1), do.stride(2), do.stride(3),\n T, K, V, K**-0.5, BT, BK, BV\n )\n return dv\n\n", - "description_1": "Use triton language to implement a forward kernel 'fwd_prepare_dv_kernel' that calculates the delta value tensor 'dv' from input tensors 'q' (query), 'k' (key), and 'do' (delta output). It uses a block matrix multiplication approach. The function 'fwd_prepare_dv' prepares and calls this kernel with block size parameters BT, BK, and BV.", - "description_2": "Use triton language to implement a forward kernel that computes delta values from query, key, and delta output using block matrix multiplications.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8)\n ],\n key=[\"BT\", \"BK\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_K]\n v, # value [B, H, L, D_head_V]\n v_new,\n d, # decay [B, H, L, D_head_K]\n o, # output [B, H, L, D_head_V]\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n final_state, # final state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n\n # [BT, BT]\n m_s = o_i[:, None] >= o_i[None, :]\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_d = tl.load(p_d, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)\n b_v = b_v - b_v_prime\n tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))\n\n b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_v_new = tl.advance(p_v_new, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_d = tl.advance(p_d, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_bwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n d, # decay [B, H, L, D_head_K]\n do, # gradient of output [B, H, L, D_head_V]\n dq, # gradient of query [NV, B, H, L, D_head_K]\n dk, # gradient of key [NV, B, H, L, D_head_K]\n dv, # gradient of value [NK, B, H, L, D_head_V]\n dd, # gradient of decay [NV, B, H, L, D_head_K]\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch_size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n\n # first reverse\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n b_d = tl.load(p_d, boundary_check=(0, 1))\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n\n tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n if i < (NT - 1):\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.load(p_dv, boundary_check=(0, 1))\n b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)\n p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n ((i+1) * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BT = BT\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, 'NK should be 1'\n o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n v_new = torch.empty_like(v)\n fused_chunk_delta_rule_fwd_kernel[grid](\n q, k, v, v_new, d, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n )\n return o, v_new, CHECK, final_state\n\n\ndef fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_delta_rule_bwd_kernel[grid](\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=CHECK,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dd = dd.sum(0)\n dd[:, :, 0:BT] = 0\n return dq, dk, dv, dd\n", - "description_1": "Use triton language to implement the 'fused_chunk_delta_rule_fwd_kernel' for forward pass and 'fused_chunk_delta_rule_bwd_kernel' for backward pass of a fused delta rule operation. The forward kernel takes 22 parameters including query, key, value tensors with respective strides and dimensions, and returns the final states and value updates. The backward kernel uses 24 parameters including gradients and initial states to compute the backpropagation of deltas.", - "description_2": "Use triton language to implement fused delta rule operations with forward and backward kernels handling query, key, value tensor manipulations and state updates.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V].\n beta, # beta [B, H, L]\n o, # output [B, H, L, D_head_V]\n initial_state,\n final_state, # final hidden state [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n STORE_FINAL_STATE: tl.constexpr, # whether to store final state\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_beta = beta + i_bh * T\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[None, :]) * DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _v_minus = tl.sum(h * _k[None, :], axis=1)\n _v -= _v_minus\n _beta = tl.load(p_beta).to(tl.float32)\n tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)\n _v *= _beta\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n p_beta += 1\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[None, :]) * DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_bwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n beta, # beta [B, H, L]\n do, # gradient of output [B, H, L, D_head_V]\n dq, # gradient of query [NV, B, H, L, D_head_K]\n dk, # gradient of key [NV, B, H, L, D_head_K]\n dv, # gradient of value [NK, B, H, L, D_head_V]\n dbeta, # gradient of beta [B, H, L]\n initial_state, # initial hidden state initialization [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch_size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_beta = beta + i_bh * T + T - 1\n p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _beta = tl.load(p_beta).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n d_beta = tl.sum(d_v * _v)\n d_v = d_v * _beta\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))\n d_h -= _k[:, None] * d_v[None, :]\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n p_dbeta -= 1\n p_beta -= 1\n tl.debug_barrier()\n h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_beta = beta + i_bh * T\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_init_s = initial_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[:, None]) * DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _beta = tl.load(p_beta).to(tl.float32)\n _v *= _beta\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n if i < T - 1:\n d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)\n d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)\n d_k -= tl.sum(d_v[None, :] * h, axis=1)\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n p_k += DK\n p_do += DV\n p_v += DV\n p_dk += DK\n p_dv += DV\n p_dq += DK\n p_beta += 1\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n assert NK == 1, \"NK > 1 is not supported yet\"\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_fwd_kernel[grid](\n q, k, v, beta, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, beta, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, beta, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, \"NK > 1 is not supported yet\"\n num_stages = 1\n num_warps = 2\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)\n fused_recurrent_bwd_kernel[grid](\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dbeta = dbeta.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None\n\ndef fused_recurrent_linear_attn_delta_rule(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n beta: torch.Tensor = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if beta is None:\n beta = torch.ones_like(q[..., 0])\n o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to define and execute forward and backward kernels for a fused recurrent neural network operation that handles attention-like computations in a sequence. It involves inputs for queries, keys, values, beta, with options for initial and final states, while scaling and managing data dimensions and strides for optimal GPU execution. The kernels handle sequence length, block dimensions, and enable gradient computations for backpropagation.", - "description_2": "Use triton language to implement fused recurrent attention operations for neural networks, managing sequence and gradient calculations efficiently on the GPU.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n o,\n o2,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = tl.arange(0, BK) < K\n mask_bv = tl.arange(0, BV) < V\n mask_bk = mask_bk[None, :] & mask_bt[:, None]\n mask_bv = mask_bv[None, :] & mask_bt[:, None]\n # [BT, BK]\n b_k = tl.load(p_k, mask=mask_bk, other=0)\n # [BT,]\n b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)\n # [BT, BV]\n b_v = tl.load(p_v, mask=mask_bv, other=0)\n b_v = (b_v * b_beta[:, None]).to(b_v.dtype)\n # [BT, BK]\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n # [BT, BT]\n b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n b_A = b_A.to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n b_u = tl.dot(b_A, b_v, allow_tf32=False)\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta,\n o, o2, do, do2,\n dk, dv, dbeta,\n NT, K, V, T,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]\n mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]\n b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)\n\n b_beta = b_beta.to(tl.float32)\n A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]\n A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)\n b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)\n b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)\n dA = tl.zeros([BT, BT], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n for i in range(BT-1, -1, -1):\n mask = tl.arange(0, BT) == i\n attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)\n do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)\n dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)\n b_do = b_do - attn[:, None] * do_[None, :]\n b_dv = b_dv - attn[:, None] * dv_[None, :]\n tl.debug_barrier()\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_v = tl.load(p_v, mask=mask_bv)\n b_dk += b_do * b_beta[:, None]\n b_dbeta = tl.sum(b_do * b_k, axis=1)\n b_dbeta += tl.sum(b_dv * b_v, axis=1)\n b_v = None\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_o = tl.load(p_o, mask=mask_bk)\n b_o2 = tl.load(p_o2, mask=mask_bv)\n\n dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)\n dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),\n allow_tf32=False)\n dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)\n b_dv *= b_beta[:, None]\n p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)\n dA = dA * b_beta[:, None]\n b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)\n b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)\n p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)\n\n\ndef fwd_prepare_wy_repr(k, v, beta, chunk_size):\n B, H, T, K, V = *k.shape, v.shape[-1]\n v_new = torch.empty_like(v)\n o_cumdecay = torch.empty_like(k)\n BT = chunk_size\n NT = triton.cdiv(T, BT)\n BK = triton.next_power_of_2(K)\n BV = triton.next_power_of_2(V)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, o_cumdecay, v_new,\n T, K, V, BT, BK, BV\n )\n return o_cumdecay, v_new\n\n\ndef bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):\n b, h, l, d_k = do.shape\n d_v = v.shape[-1]\n BK = triton.next_power_of_2(d_k)\n BV = triton.next_power_of_2(d_v)\n c = chunk_size\n BK = d_k\n NT = triton.cdiv(l, c)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n dbeta = torch.zeros_like(beta)\n bwd_prepare_wy_repr_kernel[(NT, b*h)](\n k, v, beta,\n o_cumdecay, v_new, do, do2,\n dk, dv, dbeta,\n NT, d_k, d_v, l, chunk_size, BK, BV\n )\n return dk, dv, dbeta\n", - "description_1": "Use triton language to implement forward and backward kernels for preparing WY representation. The forward kernel takes 10 parameters: k, v, beta, o, o2, T, K, V, BT, BK, BV. It computes the WY representation using input matrices k and v, scaling by beta, and stores results in o and o2. The backward kernel takes 16 parameters: k, v, beta, o, o2, do, do2, dk, dv, dbeta, NT, K, V, T, BT, BK, BV. It computes gradients for k, v, and beta based on the forward pass outputs and their gradients.", - "description_2": "Use triton language to create kernels for forward and backward passes of WY representation preparation, handling input matrices and their gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n w, \n u,\n A, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(1, BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))\n b_A = b_A.to(k.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_wy_repr(k, v, beta, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u, A\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_recompute_w_u_kernel(\n k,\n v,\n beta,\n w, \n u,\n A, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_recompute_w_u(k, v, beta, A, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_recompute_w_u_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta, A, \n dw, du,\n dk, dv, dbeta,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n\n b_dbeta = tl.zeros([BT], dtype=tl.float32)\n b_dA = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_du = tl.load(p_du, boundary_check=(0, 1))\n b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)\n b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)\n b_dv = b_dv_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dv_beta * b_v, 1)\n # store\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n tl.debug_barrier() \n b_A2 = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1)) \n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_dw = tl.load(p_dw, boundary_check=(0, 1))\n b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) \n b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)\n b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)\n b_dk = b_dk_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n # store \n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])\n b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)\n tl.debug_barrier()\n\n for i in range(BT-1, 0, -1):\n mask = tl.arange(0, BT) == i\n b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0) \n b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) \n b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1) \n b_dA = tl.where(mask[:, None], b_da2, b_dA)\n b_dA += b_da[None, :] * b_a[:, None]\n\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)\n tl.debug_barrier()\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1)) \n b_dk = tl.load(p_dk, boundary_check=(0, 1))\n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n\n b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) \n b_dk += b_dk_beta * b_beta[:, None] \n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n \n p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0, 1))\n\n\ndef bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT = triton.cdiv(T, BT)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v).contiguous()\n dbeta = torch.zeros_like(beta)\n\n bwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, A,\n dw, du, \n dk, dv, dbeta,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return dk, dv, dbeta\n\n", - "description_1": "Use triton language to implement three kernels: fwd_prepare_wy_repr_kernel, fwd_recompute_w_u_kernel, and bwd_prepare_wy_repr_kernel. Each kernel is responsible for different stages of forward and backward computations involving block-wise matrix operations, dot products, and custom transformations to efficiently handle large matrices and tensors with variable dimensions. These kernels take a varying number of parameters including input matrices (k, v), beta, and various strides and size parameters to compute the required operations efficiently across multiple blocks.", - "description_2": "Use triton language to implement kernels for matrix transformations and dot products. These kernels are used in forward and backward pass calculations for efficient tensor operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom packaging import version\n\ninv_ln2 = 1.44269504\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n # Triton kernel for forward pass of the fused chunk GLA operation\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n if CHECK and i == 0:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n else:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_db += BT * DK\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n # Triton kernel for backward pass of the fused chunk GLA operation\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK \n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n if CHECK and i == 1:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n else:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\ndef fused_chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = q.shape[-2]\n q, k, v, g = map(lambda x: pad(x), [q, k, v, g])\n o, final_state = FusedChunkGLAFunction.apply(\n q, k, v, g, scale, initial_state, output_final_state)\n o = o[..., :seq_len, :]\n return o, final_state\n", - "description_1": "Use triton language to implement and perform fused_chunk_gla operation, which includes two kernels for forward and backward pass. The forward kernel processes query, key, and value tensors to produce an output tensor, optionally utilizing initial state and producing a final state. The backward kernel calculates gradients for the query, key, and value tensors, and optionally uses an initial state. The implementation also contains auxiliary functions and management of tensor strides for efficient computation.", - "description_2": "Use triton language to implement and execute fused chunk GLA operation, ensuring efficient tensor operations for forward and backward passes with optional state management.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n@triton.jit\ndef fwd_decay_cumsum(\n g,\n g_o, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n@triton.jit\ndef prepare_qg_kg(\n q,\n k,\n g,\n qg,\n kg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner,\n dq_inter,\n dk_inner,\n dk_inter,\n q, k, g, dg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n\n", - "description_1": "Use triton language to define three kernels: fwd_decay_cumsum, prepare_qg_kg, and bwd_decay_global_cumsum. The fwd_decay_cumsum kernel calculates the cumulative decay for an input tensor and stores the results, using 13 parameters for configuration and addressing. The prepare_qg_kg kernel prepares query and key gradients using 14 parameters including input and output tensors and configuration constants. The bwd_decay_global_cumsum kernel computes the backward cumulative sum with 18 parameters, accounting for gradients and input tensors, for tensor updates in neural networks.", - "description_2": "Use triton language to implement forward cumulative sum decay kernel, query-key gradient preparation kernel, and backward cumulative sum kernel for neural network tensor computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.utils import contiguous\nfrom typing import Tuple\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, gk, gv, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n USE_GK: tl.constexpr, USE_GV: tl.constexpr):\n \n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[None, :]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[:, None]\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -DK if REVERSE else DK\n p_k += -DK if REVERSE else DK\n p_o += -DV if REVERSE else DV\n p_v += -DV if REVERSE else DV\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr):\n \n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[None, :]\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -DK if REVERSE else DK\n p_v += -DV if REVERSE else DV\n p_q += -DK if REVERSE else DK\n p_do += -DV if REVERSE else DV\n p_dq += -DK if REVERSE else DK\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n d_h *= _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n d_h *= _gv[None, :]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do += DV if REVERSE else -DV\n p_q += DK if REVERSE else -DK\n p_k += DK if REVERSE else -DK\n p_v += DV if REVERSE else -DV\n p_dk += DK if REVERSE else -DK\n p_dv += DV if REVERSE else -DV\n if USE_GK:\n p_gk += DK if REVERSE else -DK\n if USE_GV:\n p_gv += DV if REVERSE else -DV\n\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n if scale is None:\n scale = d_head_qk ** -0.5\n if gk is not None:\n gk = gk.float().exp()\n if gv is not None:\n gv = gv.float().exp()\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_gla_fwd_kernel[grid](\n q, k, v, gk, gv, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n @contiguous\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, gk, gv, initial_state, o = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=torch.float32)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=torch.float32)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n if gk is not None:\n _dgk = dq * q.float() - dk * k.float()\n if ctx.reverse:\n dgk = _dgk.cumsum(-2)\n else:\n _dgk_cumsum = _dgk.cumsum(-2)\n dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum\n else:\n dgk = None\n\n if gv is not None:\n _dgv = do.float() * o.float() - dv * v.float()\n if ctx.reverse:\n dgv = _dgv.cumsum(-2)\n else:\n _dgv_cumsum = _dgv.cumsum(-2)\n dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum\n else:\n dgv = None\n\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None\n\n\ndef fused_recurrent_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n gk: torch.Tensor = None,\n gv: torch.Tensor = None,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n if causal:\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)\n return o, final_state\n else:\n assert initial_state is None\n assert output_final_state is False\n o, final_state = FusedRecurrentGLAFunction.apply(\n q, k, v, gk, gv, scale, initial_state, output_final_state, False)\n o_reversed, final_state = FusedRecurrentGLAFunction.apply(\n q, k, v, gk, gv, scale, initial_state, output_final_state, True)\n return [o, o_reversed]\n", - "description_1": "Use triton language to implement fused recurrent gated linear attention forward and backward kernel functions. The forward kernel computes the attention mechanism with queries (q), keys (k), values (v), gates for keys (gk), and gates for values (gv), while taking into account optional initial state and final state for sequences. The backward kernel computes gradients with respect to these inputs during backpropagation. Both kernels handle configurations such as block sizes, head dimensions, and use of initial/final states. Additional Python functions (forward and backward) manage the execution of these kernels within PyTorch's autograd framework.", - "description_2": "Use triton language to implement a fused recurrent gated linear attention mechanism with forward and backward passes, considering optional state and direction (causal vs non-causal) handling.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_fwd_kernel_h(\n x,\n g,\n gc,\n o,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + i_t * BT * D + o_d\n p_g = g + i_bh * T * D + i_t * BT * D + o_d\n p_gc = gc + i_bh * T * D + i_t * BT * D + o_d\n p_o = o + i_bh * T * D + i_t * BT * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n b_gc = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n if i_t == 0:\n b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n for i in range(0, BT):\n mask_t = mask & ((i_t * BT + i) < T)\n b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n b_gc = b_gc + b_g\n tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)\n\n p_x += D\n p_g += D\n p_gc += D\n p_o += D\n\n\n@triton.jit\ndef chunk_hgrn_fwd_kernel_o(\n gc,\n o,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(1, tl.cdiv(T, BT)):\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_o = b_o + tl.exp(b_gc) * b_h0[None, :]\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_bwd_kernel_h(\n g,\n gc,\n dx,\n do,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n BC = min(BT, T - i_t * BT)\n NT = tl.num_programs(1)\n\n p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n\n if i_t == NT - 1:\n b_gc = tl.zeros([BD], dtype=tl.float32)\n else:\n b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for _ in range(BC - 1, -1, -1):\n tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)\n\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n\n b_gc = b_gc + b_g\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_gc -= D\n p_dx -= D\n p_do -= D\n\n\n@triton.jit\ndef chunk_hgrn_bwd_kernel_o(\n g,\n gc,\n o,\n dx,\n dg,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))\n p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n mask_t = mask & ((i_t + 1) * BT < T)\n b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)\n b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)\n b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]\n b_dg = b_o * b_dx * tl.exp(b_g)\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n o = torch.empty_like(x, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_fwd_kernel_h[grid](\n x, g, gc, o, initial_state,\n T, D,\n BT=BT,\n USE_INITIAL_STATE=initial_state is not None\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_fwd_kernel_o[grid](\n gc, o,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n final_state = None\n if output_final_state:\n final_state = o[:, :, -1].clone()\n o = o.to(x.dtype)\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_bwd_kernel_h[grid](\n g, gc, dx, do,\n T, D,\n BT=BT\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_bwd_kernel_o[grid](\n g, gc, o, dx, dg,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n if initial_state is not None:\n dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()\n\n return dx, dg, None, None\n\n\ndef chunk_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement chunk-based forward and backward kernels for HGRN (Hierarchical Gated Recurrent Network). There are four Triton kernels defined: 'chunk_hgrn_fwd_kernel_h' with 8 parameters for the forward pass, 'chunk_hgrn_fwd_kernel_o' with 7 parameters for the forward pass optimization, 'chunk_hgrn_bwd_kernel_h' with 6 parameters for the backward pass, and 'chunk_hgrn_bwd_kernel_o' with 8 parameters for the backward pass optimization. The parameters consist of inputs/outputs (like x, g, gc, o) and constants defining dimensions or options (like T, D, BT, BD, USE_INITIAL_STATE). The 'ChunkHGRNFunction' class uses these kernels in its 'forward' and 'backward' static methods to compute outputs and gradients. An external Python function 'chunk_hgrn' calls this class to perform operations using the specified parameters.", - "description_2": "Use triton language to define HGRN forward and backward kernels, applying these in a custom autograd function to compute outputs and gradients efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_linear_attn_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_linear_attn_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n\n tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\ndef fused_chunk_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunk linear attention mechanism consisting of forward and backward kernels. The forward kernel takes 19 parameters including query (q), key (k), value (v) tensors, output (o) tensor, initial state, final state, stride sizes, batch size (B), number of heads (H), sequence length (T), scaling factor, and constant expressions for block sizes, dimensions and flags for using initial state, storing final state, and a check. The backward kernel also takes 24 parameters including gradients of output (do), query (dq), key (dk), value (dv) tensors along with similar parameters as the forward kernel. A wrapper function applies these kernels to compute output and optionally final state given input tensors q, k, v, and other parameters.", - "description_2": "Use triton language to create a linear attention operator with fused kernels for efficient forward and backward pass, supporting optional initial and final state handling.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n@triton.jit\ndef _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h, q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_q.dtype)\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)\n b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n else:\n b_ds = b_ds\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n b_dq *= scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n else:\n b_ds = b_ds\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),\n b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n@triton.jit\ndef _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h, q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(\n p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(\n [BTL, BV], dtype=tl.float32)\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale\n if i_v == 0:\n b_ds += b_dz[None, :] * scale\n else:\n b_ds = b_ds\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n m_s = o_k[:, None] <= o_q[None, :]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n else:\n b_ds = b_ds\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n o_q += BTS\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,\n (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n assert NK == 1, \"will encounter some synchronization issue if not\"\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement parallel forward and backward kernels for the rebased linear attention mechanism. The forward kernel function takes 20 arguments: q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV, where q, k, v are the input tensors, o and z are the output tensors, the strides and dimensions are defined by s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T are batch size, number of heads, and sequence length respectively, and scale is a scaling factor. BTL, BTS, BK, BV, DK, DV are compile-time constants specifying block and dimension sizes. The backward kernel function uses similar arguments but includes tensors for the derivatives (do, dz, dq, dk, dv) and also involves two helper functions _parallel_rebased_bwd_dq and _parallel_rebased_bwd_dkv for calculating derivatives of q, k, and v separately. The triton_parallel_based function is then defined as the apply method of a torch.autograd.Function class, providing a functional interface for forward and backward passes.", - "description_2": "Use triton language to implement parallel computation for rebased linear attention mechanism's forward and backward passes, optimizing operations by specifying grid and block dimensions to enhance computational efficiency.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n if i == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)\n d_b = tl.math.exp2(BT * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n b_dq = tl.dot(b_ds, b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n d_s = tl.trans(d_s)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\n\ndef fused_chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement two kernels: a forward kernel `fused_chunk_retention_fwd_kernel` and a backward kernel `fused_chunk_retention_bwd_kernel`. The forward kernel computes the result of a block retention mechanism used in neural networks, taking 19 parameters including batch_size, n_heads, seq_len, and others for strides, dimensions, and control flags. It processes these parameters with constant expressions and logs using Triton operations like `make_block_ptr`, `load`, `store`, and `math` functions. The backward kernel takes 23 parameters, including additional parameters for gradients, and similarly processes using Triton operations. These kernels are wrapped in an autograd function `FusedChunkRetentionFunction` which manages forward and backward computations in the PyTorch framework. The kernel computation grids are defined by the parameter grid, based on dimensions of input tensors and the number of warps and stages.", - "description_2": "Use triton language to create forward and backward kernels for fused chunk retention in neural networks with specified grid, block sizes, and strides, embedded in a PyTorch autograd function.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef parallel_retention_fwd_kernel(\n q, k, v, o,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n o_k = tl.arange(0, BTS)\n d_h = tl.math.exp2((BTS - o_k) * b_b)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :]\n b_o = b_o * tl.math.exp2(b_b * BTS)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n d_q = tl.math.exp2(tl.arange(0, BTL) * b_b)\n b_o *= d_q[:, None]\n\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (o_q[:, None] - o_k[None, :]) * b_b), 0)\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_b = tl.math.exp2(b_b * BTS)\n d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b)\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :]\n b_dq *= d_b\n b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (o_q[:, None] - o_k[None, :]) * b_b), 0)\n b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale\n b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n@triton.jit\ndef _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_b = tl.math.exp2(b_b * BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(\n p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(\n [BTL, BV], dtype=tl.float32)\n d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b)\n b_kd = (b_k * d_h[:, None]).to(b_k.dtype)\n d_q = tl.math.exp2(tl.arange(0, BTS) * b_b)\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_do = (b_do * d_q[None, :]).to(b_do.dtype)\n\n b_dv *= d_b\n b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False)\n b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n\n b_dk *= d_b\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)\n b_dk *= d_h[:, None] * scale\n b_dv *= scale\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n m_s = o_k[:, None] <= o_q[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s\n b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n o_q += BTS\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,\n (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n@triton.jit\ndef parallel_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\nclass ParallelRetentionFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_retention_fwd_kernel[grid](\n q, k, v, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n return o.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)\n\nparallel_retention = ParallelRetentionFunction.apply\n", - "description_1": "Use triton language to implement parallel retention forward and backward kernels for batch processing of queries, keys, and values in a Transformer-like architecture. The kernels perform operations with 17-20 input arguments including matrices q, k, v, their strides, batch size, number of heads, sequence length, scaling factor, and constant block sizes. The forward kernel processes the input to produce output o, while the backward kernels compute gradients dq, dk, and dv.", - "description_2": "Use triton language to implement a parallel processing mechanism for queries, keys, and values in deep learning architectures with efficient gradient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, \n seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, \n stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim,\n stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim,\n BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, \n IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, \n CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr,\n):\n # Triton kernel code to compute rotary embedding\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n X = X + (rm[:, None] * stride_x_seqlen +\n rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(\n COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0\n ).to(tl.float32)\n sin = tl.load(\n SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x0 = tl.load(\n X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x1 = tl.load(\n X + rotary_dim_half * stride_x_headdim,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n OUT = OUT + (rm[:, None] * stride_out_seqlen +\n rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen)\n & (rk_half[None, :] < rotary_dim_half))\n tl.store(\n OUT + rotary_dim_half * stride_out_headdim,\n o1,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n )\n else:\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen +\n rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen +\n rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(\n COS,\n mask=(rm_cs[:, None] < seqlen_ro) & (\n rk_repeat[None, :] < rotary_dim_half),\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN,\n mask=(rm_cs[:, None] < seqlen_ro) & (\n rk_repeat[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(\n tl.float32\n )\n x1 = tl.load(\n X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen +\n rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen)\n & (rk[None, :] < rotary_dim))\n\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved=False,\n inplace=False,\n conjugate=False,\n) -> torch.Tensor:\n \"\"\"\n Apply rotary embeddings to the input tensor `x`.\n \"\"\"\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch_p_1 = cu_seqlens.shape[0]\n batch = batch_p_1 - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n rotary_dim *= 2\n assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n assert (\n cos.dtype == sin.dtype\n ), f\"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}\"\n assert (\n x.dtype == cos.dtype\n ), f\"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}\"\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = (\n 32\n if rotary_dim <= 32\n else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n )\n def grid(META): return (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads) # noqa\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, # data ptrs\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen, # shapes\n nheads,\n rotary_dim,\n seqlen_ro,\n seqlen // 128,\n output.stride(0) if not is_varlen else 0,\n output.stride(-3), # seqlen_stride or total_seqlen_stride\n output.stride(-2), # nheads_stride\n output.stride(-1), # headdim_stride\n x.stride(0) if not is_varlen else 0,\n x.stride(-3), # seqlen stride or total_seqlen_stride\n x.stride(-2), # nheads stride\n x.stride(-1), # headdim stride\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n", - "description_1": "Use triton language to implement a rotary embedding kernel (`rotary_kernel`) with 28 parameters, processing matrices and handling sequence lengths, and a calling function (`apply_rotary`) with 9 parameters to manage tensor data, embedding parameters, and configuration flags to apply the kernel.", - "description_2": "Use triton language to create a rotary embedding and apply it efficiently using a kernel for GPU computation with configurable parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef fused_recurrent_rwkv4_forward_kernel(\n w_ptr, w_s_c, u_ptr, u_s_c, k_ptr, k_s_b, k_s_t, k_s_c, v_ptr, v_s_b, v_s_t, v_s_c,\n state_ptr, state_s_b, state_s_abe, state_s_c, wkv_ptr, wkv_s_b, wkv_s_t, wkv_s_c,\n state_out_ptr, state_out_s_b, state_out_s_abe, state_out_s_t, state_out_s_c,\n chans, tsz, BLOCK_SIZE_C: tl.constexpr,\n):\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n wkv_ptr = wkv_ptr + b_idx * wkv_s_b\n alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b\n beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe\n eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe\n alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)\n ukt = u + kt\n tau = tl.maximum(ukt, eps)\n e1a = tl.exp(eps - tau)\n e2a = tl.exp(ukt - tau)\n wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)\n tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)\n w_eps = w + eps\n eps = tl.maximum(w_eps, kt)\n e1b = tl.exp(w_eps - eps)\n e2b = tl.exp(kt - eps)\n alpha = e1b * alpha + e2b * vt\n beta = e1b * beta + e2b\n tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)\n tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)\n tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)\n\ndef fused_recurrent_rwkv4_forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:\n (bsz, tsz, chans) = k.shape\n wkvs = k.new_empty(bsz, tsz, chans)\n state_out = k.new_empty(bsz, 3, tsz, chans)\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_forward_kernel[grid](\n w, w.stride(0), u, u.stride(0), k, k.stride(0), k.stride(1), k.stride(2),\n v, v.stride(0), v.stride(1), v.stride(2), state, state.stride(0), state.stride(1), state.stride(3),\n wkvs, wkvs.stride(0), wkvs.stride(1), wkvs.stride(2), state_out, state_out.stride(0),\n state_out.stride(1), state_out.stride(2), state_out.stride(3), chans, tsz, BLOCK_SIZE_C=block_size_c,\n )\n\n state_out = torch.cat((state, state_out), dim=2)\n return wkvs, state_out\n\n@triton.jit\ndef fused_recurrent_rwkv4_backward_kernel(\n w_ptr, w_s_c, u_ptr, u_s_c, k_ptr, k_s_b, k_s_t, k_s_c, v_ptr, v_s_b, v_s_t, v_s_c,\n state_ptr, state_s_b, state_s_abe, state_s_t, state_s_c, gwkv_ptr, gwkv_s_b, gwkv_s_t, gwkv_s_c,\n gstate_out_ptr, gstate_out_s_b, gstate_out_s_abe, gstate_out_s_c, gw_ptr, gw_s_c, gu_ptr, gu_s_c,\n gk_ptr, gk_s_b, gk_s_t, gk_s_c, gv_ptr, gv_s_b, gv_s_t, gv_s_c, gstate_ptr, gstate_s_b, gstate_s_abe, gstate_s_c,\n tsz, chans, BLOCK_SIZE_C: tl.constexpr,\n):\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n gk_ptr = gk_ptr + b_idx * gk_s_b\n gv_ptr = gv_ptr + b_idx * gv_s_b\n gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b\n galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b\n gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe\n geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe\n galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)\n gw = tl.zeros_like(w)\n gu = tl.zeros_like(u)\n alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n tc = tsz - t - 1\n kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)\n alpha_curr = alpha_prev\n beta_curr = beta_prev\n eps_curr = eps_prev\n alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n ukt = u + kt\n tau = tl.maximum(ukt, eps_prev)\n e1 = tl.exp(eps_prev - tau)\n e2 = tl.exp(ukt - tau)\n euke = tl.exp(ukt + eps_prev - 2 * tau)\n denom = e1 * beta_prev + e2\n denom_sq = denom * denom\n gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)\n guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq\n gu += guk\n gk = guk\n gv = gwkvt * e2 / denom\n galpha_wkv = gwkvt * e1 / denom\n gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq\n geps_wkv_denom = e1 * beta_prev + e2\n geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)\n e1 = tl.exp(w + eps_prev - eps_curr)\n e2 = tl.exp(kt - eps_curr)\n galpha_we = galpha * e1 * alpha_prev\n gw += galpha_we\n gk += galpha * e2 * vt\n gv += galpha * e2\n geps += galpha * -alpha_curr\n gbeta_we = gbeta * e1 * beta_prev\n gw += gbeta_we\n gk += gbeta * e2\n geps += gbeta * -beta_curr\n geps_mask = w + eps_prev > kt\n geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))\n gw += geps_we\n gk += tl.where(geps_mask, tl.zeros_like(geps), geps)\n tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)\n tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)\n galpha = galpha * e1 + galpha_wkv\n gbeta = gbeta * e1 + gbeta_wkv\n geps = galpha_we + gbeta_we + geps_we + geps_wkv\n\n galpha_ptr = gstate_ptr + b_idx * gstate_s_b\n gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe\n geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe\n tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)\n tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)\n tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)\n gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)\n gw_temp += gw\n tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)\n gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)\n gu_temp += gu\n tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)\n\ndef fused_recurrent_rwkv4_backward(\n w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, grad_wkv: Tensor, grad_state: Tensor\n) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n bsz, tsz, chans = k.shape\n gw = torch.zeros_like(w)\n gu = torch.zeros_like(u)\n gk = torch.empty_like(k)\n gv = torch.empty_like(v)\n gstate = k.new_empty(bsz, 3, 1, chans)\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_backward_kernel[grid](\n w, w.stride(0), u, u.stride(0), k, k.stride(0), k.stride(1), k.stride(2),\n v, v.stride(0), v.stride(1), v.stride(2), state, state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n grad_wkv, grad_wkv.stride(0), grad_wkv.stride(1), grad_wkv.stride(2), grad_state, grad_state.stride(0),\n grad_state.stride(1), grad_state.stride(3), gw, gw.stride(0), gu, gu.stride(0), gk, gk.stride(0), gk.stride(1),\n gk.stride(2), gv, gv.stride(0), gv.stride(1), gv.stride(2), gstate, gstate.stride(0), gstate.stride(1),\n gstate.stride(3), tsz, chans, BLOCK_SIZE_C=block_size_c,\n )\n\n return gw, gu, gk, gv, gstate\n", - "description_1": "Use triton language to implement a forward and backward kernel for a fused recurrent RWKV model. The forward kernel takes 27 parameters: pointers to input tensors (w, u, k, v, state), strides for these tensors, pointers for output tensors (wkv, state_out), and parameters for channels, time size, and block size. It computes the RWKV forward pass by iterating over time steps and updating the state. The backward kernel takes 40 parameters: pointers to input tensors, strides, pointers for gradients, and parameters for time size, channels, and block size. It computes the gradients for the RWKV model by iterating backward over time steps.", - "description_2": "Use triton language to create kernels for the forward and backward passes of a fused recurrent RWKV model, handling input and output tensor pointers, strides, and model parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_cum(\n s,\n o,\n o_minus_s,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n if i_t < NT - 1:\n b_gn = tl.load(p_gn, boundary_check=(0,))\n else:\n b_gn = tl.min(b_g, axis=1)\n b_h *= tl.exp(b_gn)[:, None]\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_intra(\n q,\n k,\n g,\n gs,\n u,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n scale,\n H,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n NC: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC\n i_h = i_bh % H\n n_bh = tl.num_programs(2)\n\n o_k = i_k * BK + tl.arange(0, BK)\n o_q = i_t * BT + i_i * BC\n m_k = o_k < K\n\n if i_i > i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_gs = tl.load(p_gs, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_gk = tl.load(p_gk, boundary_check=(0, 1))\n b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)\n b_A = tl.dot(b_qg, b_kg, allow_tf32=False)\n tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n elif i_i == i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n p_q_self = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_gs = tl.load(p_gs, boundary_check=(0, 1))\n o_i = tl.arange(0, BC)\n o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k\n o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,))\n b_u = tl.load(p_u, boundary_check=(0,))\n for j in range(0, BC):\n b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)\n b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32)\n b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1)\n b_A = tl.where(o_i > j, b_A, 0.)\n b_q_self = tl.load(p_q_self, boundary_check=(0,)).to(tl.float32)\n A_self = tl.sum(b_q_self * b_k * b_u * scale, axis=0)\n m_self = tl.arange(0, BC) == j\n b_A = tl.where(m_self, A_self[None], b_A)\n tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A)\n p_k = tl.advance(p_k, (K,))\n p_q_self = tl.advance(p_q_self, (K,))\n\n\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_inter(\n q,\n v,\n gs,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_gs = tl.load(p_gs, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype)\n b_h = tl.load(p_h, boundary_check=(0, 1))\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n u: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 0\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert checkpoint_level in [0, 1]\n if scale is None:\n scale = r.shape[-1] ** -0.5\n o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)\n return o, final_state\n", - "description_1": "Use triton language to implement a forward kernel for RWKV attention, handling cumulative sum, intra-segment attention, and inter-segment attention with boundary checks and efficient memory access.", - "description_2": "Use triton language to implement multiple forward kernels for RWKV attention, ensuring correct cumulative sum and efficient memory access for intra and inter-segment operations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel(\n q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef fused_recurrent_rwkv6_bwd_kernel_dq(\n k, v, w, u, do, dq, dq_aux, h0, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n K: tl.constexpr, V: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef fused_recurrent_rwkv6_bwd_kernel_dkv(\n q, k, v, w, u, do, dk, dk_aux, dv, dh0, s_k_h, s_v_h, scale,\n B, H, T, BK: tl.constexpr, BV: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr\n):\n # Kernel implementation...\n\nclass FusedRecurrentRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):\n # Forward implementation...\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n # Backward implementation...\n\ndef fused_recurrent_rwkv6(\n r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, u: torch.Tensor,\n scale: int = -1, initial_state: torch.Tensor = None, output_final_state: bool = False, causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n # Function implementation...\n", - "description_1": "Use triton language to implement a fused recurrent kernel for the RWKV6 model. The forward kernel 'fused_recurrent_rwkv6_fwd_kernel' takes tensors q (query), k (key), v (value), w (log gate), u (bonus), and other parameters like strides and constants to compute the output tensor o and optionally updates initial/final states. The backward kernel 'fused_recurrent_rwkv6_bwd_kernel_dq' computes gradients with respect to the query, and 'fused_recurrent_rwkv6_bwd_kernel_dkv' computes gradients with respect to the key and value. The torch.autograd.Function class encapsulates these kernels to define a custom autograd function with optional parameters for initial states and directions.", - "description_2": "Use triton language to create a recurrent computation kernel optimized for RWKV6 model forward and backward passes, efficiently computing outputs and gradients with optional state management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_h(\n k,\n v,\n h,\n g,\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n final_state, # final state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,\n (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BK, BV]\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n b_h *= tl.math.exp2(b_g_last)\n b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))\n b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(\n final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q,\n k,\n v,\n h,\n g,\n o,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT]\n\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_o = b_o * tl.math.exp2(b_g)[:, None]\n b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dh(\n q,\n g,\n do,\n dh,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +\n i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)\n # [BT, V]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BK, BV]\n b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkv(\n q,\n k,\n v,\n h,\n g,\n do,\n dh,\n dq,\n dk,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n mask = tl.math.exp2(b_g[None, :] - b_g[:, None])\n mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)\n b_s = b_s * mask\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),\n (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),\n (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BV, BK]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # [BK, BV]\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n # [BT, BT]\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n # [BT, BK]\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n # [BT, BV]\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \\\n tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_dq = b_dq * tl.math.exp2(b_g)[:, None]\n b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]\n b_ds = b_ds * tl.trans(mask)\n b_ds = b_ds.to(b_k.dtype)\n # [BT, BK]\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass SimpleGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, g, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(\n 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n BT = 64\n assert T % BT == 0, 'sequence length must be divisible by BT'\n g = g.reshape(B, H, -1, BT)\n g = g.cumsum(-1) * 1.44269504\n g = g.reshape(B, H, -1)\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_fwd_kernel_h[grid](\n k, v, h, g, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h, g)\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_ht=None):\n q, k, v, h, g = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(\n 32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_bwd_kernel_dh[grid](\n q, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_simple_gla_bwd_kernel_dqkv[grid](\n q, k, v, h, g, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n dg = (dq * q - dk * k).sum(-1)\n\n def rev_cumsum(x):\n cumsum_x = x.cumsum(-1)\n rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x\n return rev_cumsum_x + x\n dg = rev_cumsum(dg)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None\n\n\ndef chunk_simple_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor, # log decay\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n g = g.float()\n o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement forward and backward kernels for chunk-wise generalized linear attention (GLA) using four main kernel functions: 'chunk_simple_gla_fwd_kernel_h', 'chunk_simple_gla_fwd_kernel_o', 'chunk_simple_gla_bwd_kernel_dh', and 'chunk_simple_gla_bwd_kernel_dqkv'. These kernels utilize parameters such as input tensors (e.g., q, k, v, g), strides (e.g., s_qk_h, s_vo_t), and block sizes (e.g., BK, BV) to perform tensor loading, computation, and storing with conditional execution based on constants (e.g., USE_INITIAL_STATE, STORE_FINAL_STATE) to achieve efficient memory and computational performance.", - "description_2": "Use triton language to create custom autograd function 'SimpleGLAFunction' incorporating the forward and backward pass kernels, facilitating end-to-end computation and gradient flow for chunk-wise GLA in PyTorch, with input handling and result aggregation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_fwd_kernel(\n s,\n z,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_bwd_kernel(\n ds,\n dz,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)\n\n b_ds = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)\n tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_ds += tl.sum(b_dz, 0)\n\ndef chunk_cumsum_fwd(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_cumsum_fwd_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n\ndef chunk_cumsum_bwd(\n dz: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = dz.shape\n BS = 32\n\n dtype = dtype or dz.dtype\n grid = (triton.cdiv(S, BS), B * H)\n ds = torch.empty_like(dz, dtype=dtype)\n chunk_cumsum_bwd_kernel[grid](\n ds, dz,\n ds.stride(1), ds.stride(2), ds.stride(3),\n T=T, S=S, BS=BS\n )\n return ds\n", - "description_1": "Use triton language to implement a forward and backward cumulative sum kernel on chunks of data in a tensor. The forward kernel 'chunk_cumsum_fwd_kernel' takes 8 parameters: input tensor 's', output tensor 'z', strides 's_s_h', 's_s_t', and 's_s_d', and three constant expressions 'T', 'S', 'BT', and 'BS'. It iterates over chunks of data, computes the cumulative sum, and stores the result in 'z'. The backward kernel 'chunk_cumsum_bwd_kernel' takes similar parameters and computes the gradient (cumulative sum of gradients) with respect to the input, storing it in 'ds'.", - "description_2": "Use triton language to implement kernels for performing forward and backward cumulative sum operations on chunked tensor data, suitable for parallel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton import compile\n\n# Triton kernel function that performs matrix multiplication and stores results in an output tensor.\n@triton.jit\ndef test(at, bt, ct, k):\n # midx, kidx, nidx are the indices for the matrix multiplication.\n midx = tl.arange(0, 32)\n kidx = tl.arange(0, 32)\n nidx = tl.arange(0, 32)\n\n # Calculate indices for accessing the input matrices a and b, and the output matrix c.\n aidx = midx[:, None] * 32 + kidx[None, :]\n bidx = kidx[:, None] * 32 + nidx[None, :]\n cidx = midx[:, None] * 32 + nidx[None, :]\n\n # Pointer arithmetic to get the correct elements in each tensor.\n a_ptrs = at + aidx\n b_ptrs = bt + bidx\n c_ptrs = ct + cidx\n\n # Loop through the k-dimension to perform the matrix multiplication.\n for i in range(k):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n x = tl.dot(a, b)\n tl.atomic_add(c_ptrs, x)\n a_ptrs += 32\n b_ptrs += 32\n c_ptrs += 32\n\n# Compile the kernel and get the assembly code for the generated kernel.\nkernel = compile(test, signature='*fp32,*fp32,*fp32,i32')\nprint(kernel.asm['amdgcn'])\n", - "description_1": "Use triton language to create a kernel that performs matrix multiplication with triton's `tl.dot` and stores the result using `tl.atomic_add` in the output matrix, for matrices stored in float32 format. The function accepts pointers to the input matrices `at` and `bt`, a pointer to the output matrix `ct`, and an integer `k` representing the size of the matrices along the k-dimension.", - "description_2": "Use triton language to create a matrix multiplication kernel utilizing `tl.dot` and `tl.atomic_add` for float32 matrices, operating with 4 parameters: two input matrix pointers, one output matrix pointer, and an integer k for k-dimension.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, TMP, L, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # [Kernel implementation details omitted for brevity]\n pass\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L, NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n # [Kernel implementation details omitted for brevity]\n pass\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO, DQ, DK, DV, L, M, D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX, num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # [Kernel implementation details omitted for brevity]\n pass\n\nclass _attention(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, tmp, L, m, o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk,\n num_warps=num_warps, num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](\n o, do, l, do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n\n num_warps = 8\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale, o, do_scaled, dq, dk, dv, l, m, delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2], ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL,\n num_warps=num_warps, num_stages=1,\n )\n return dq.to(q.dtype), dk, dv, None\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement fused attention mechanisms with forward and backward passes, utilizing three kernels. Each kernel performs specific operations for forward pass computation, backward pass pre-processing, and backward pass computation respectively. The fused attention function accepts 4 parameters: q, k, v, and sm_scale, representing query, key, value tensors, and the scale for softmax, and executes operations using triton kernels for efficient attention computation on GPUs.", - "description_2": "Use triton language to create a fused attention operator leveraging multiple triton kernels for both forward and backward computations, optimized for GPU execution, with input tensors q, k, v, and a scaling factor.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n n_rows,\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE\n SPLIT: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n max_logits = tl.max(logits, 0)\n if HAS_SMOOTHING:\n sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)\n lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits\n tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)\n if label_idx == ignored_index:\n loss = 0.0\n else:\n label_idx -= class_start_idx\n if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(\n n_cols, (col_block_idx + 1) * BLOCK_SIZE\n ):\n logits_label = tl.load(logits_ptr + label_idx) * logit_scale\n if HAS_SMOOTHING:\n loss = (\n (lse if not SPLIT else 0.0)\n - smoothing * sum_logits / total_classes\n - (1 - smoothing) * logits_label\n )\n else:\n loss = (lse if not SPLIT else 0.0) - logits_label\n else:\n # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss\n if HAS_SMOOTHING:\n loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n else:\n loss = 0.0\n if not SPLIT:\n loss += lse_square_scale * lse * lse\n tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)\n\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx != ignored_index:\n dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n else:\n dloss = 0.0\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n lse = tl.load(lse_ptr + row_idx)\n probs = tl.exp(logits - lse)\n probs += 2.0 * lse_square_scale * lse * probs\n label_idx -= class_start_idx\n if HAS_SMOOTHING:\n smooth_positive = 1.0 - smoothing\n smooth_negative = smoothing / total_classes\n probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative\n else:\n probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)\n\n\nclass CrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n logits,\n labels,\n smoothing=0.0,\n logit_scale=1.0,\n lse_square_scale=0.0,\n ignored_index=-100,\n inplace_backward=False,\n process_group=None,\n ):\n n_rows, n_cols = logits.shape\n assert labels.shape == (n_rows,)\n world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)\n total_classes = world_size * n_cols\n rank = 0 if process_group is None else torch.distributed.get_rank(process_group)\n class_start_idx = rank * n_cols\n\n if logits.stride(-1) != 1:\n logits = logits.contiguous()\n # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py\n MAX_BLOCK_SIZE = 64 * 1024\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)\n num_warps = (\n 4\n if BLOCK_SIZE < 2048\n else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))\n )\n # We may split the lse computation across multiple blocks, then do a reduction\n # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)\n # where having just one thread block processing more than 64k elements is slow.\n split = world_size > 1 or n_cols > MAX_BLOCK_SIZE\n n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE\n loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)\n losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n # Need this, otherwise Triton tries to launch from cuda:0 and we get\n # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n with torch.cuda.device(logits.device.index):\n cross_entropy_fwd_kernel[(n_rows, n_splits)](\n losses, # data ptrs\n lse,\n logits,\n labels,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx,\n n_cols, # shapes\n n_rows,\n logits.stride(0), # strides\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n SPLIT=split,\n )\n\n if split:\n # If there's no smoothing, if labels are in the vocab of this partition, losses contains\n # - predicted logit, and 0 otherwise.\n # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains\n # -0.9 * predicted logit - 0.1 * sum logit / total_classes.\n # For labels not in the vocab of this partition, losses contains\n # -0.1 * sum logit / total_classes.\n if n_splits > 1:\n lse = torch.logsumexp(lse, dim=0)\n losses = losses.sum(dim=0)\n if world_size > 1:\n lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)\n torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)\n handle_losses = torch.distributed.all_reduce(\n losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True\n )\n lse = torch.logsumexp(lse_allgather, dim=0)\n handle_losses.wait()\n # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,\n # we just have to add the (global) lse.\n # If there's smoothing=0.1, the total losses are\n # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.\n # Again, we just have to add the (global) lse.\n losses += lse\n if lse_square_scale != 0.0:\n losses += lse_square_scale * lse.square()\n losses.masked_fill_(labels == ignored_index, 0.0)\n\n ctx.save_for_backward(logits, lse, labels)\n ctx.smoothing = smoothing\n ctx.logit_scale = logit_scale\n ctx.lse_square_scale = lse_square_scale\n ctx.ignored_index = ignored_index\n ctx.total_classes = total_classes\n ctx.class_start_idx = class_start_idx\n ctx.inplace_backward = inplace_backward\n return losses\n\n @staticmethod\n def backward(ctx, grad_losses):\n logits, lse, labels = ctx.saved_tensors\n dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)\n n_rows, n_cols = logits.shape\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)\n num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)\n grid = lambda META: (n_rows, triton.cdiv(n_cols, META[\"BLOCK_SIZE\"])) # noqa\n # Need this, otherwise Triton tries to launch from cuda:0 and we get\n # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n with torch.cuda.device(logits.device.index):\n cross_entropy_bwd_kernel[grid](\n dlogits, # data ptrs\n grad_losses,\n logits,\n lse,\n labels,\n ctx.smoothing,\n ctx.logit_scale,\n ctx.lse_square_scale,\n ctx.ignored_index,\n ctx.total_classes,\n ctx.class_start_idx,\n n_cols, # shapes\n logits.stride(0), # strides\n dlogits.stride(0),\n grad_losses.stride(0),\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n )\n return dlogits, None, None, None, None, None, None, None\n\n\ndef cross_entropy_loss(\n logits: torch.Tensor,\n labels: torch.Tensor,\n label_smoothing: float = 0.0,\n logit_scale: float = 1.0,\n lse_square_scale: float = 0.0,\n ignored_index=-100,\n inplace_backward: bool = False,\n process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Arguments:\n logits: (batch, vocab_size)\n labels: (batch,)\n label_smoothing: float\n logit_scale: float. Multiply logits by this scale before calculating the loss.\n lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.\n This is also referred to as \"z-loss\".\n ignored_index: int. If labels == ignored_index, the loss is set to 0.0.\n inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.\n This saves memory.\n process_group: if not None, we're doing Tensor Parallel: each process is responsible for\n one part of the vocab. The loss will be aggregated across processes.\n Returns:\n losses: (batch,), float\n \"\"\"\n return CrossEntropyLoss.apply(\n logits,\n labels,\n label_smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n inplace_backward,\n process_group,\n )\n", - "description_1": "Use triton language to implement a cross-entropy loss function with forward and backward kernels. The forward kernel computes the loss and log-sum-exp (LSE) for each row of logits, considering label smoothing and tensor parallelism. The backward kernel computes the gradient of the loss with respect to the logits. The function supports optional label smoothing, logit scaling, and LSE square scaling, and can handle ignored indices and tensor parallelism.", - "description_2": "Use triton language to create a cross-entropy loss function with forward and backward passes, supporting label smoothing and tensor parallelism.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\nfrom enum import Enum\nfrom typing import Optional\n\n_sqrt2pi = math.sqrt(2.0 / math.pi)\n_sqrt1_2 = math.sqrt(1.0 / 2)\n_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)\n\n\nclass Activation(str, Enum):\n SquaredReLU = \"squared_relu\"\n GeLU = \"gelu\"\n GeLUApprox = \"gelu_approx\"\n LeakyReLU = \"leaky_relu\"\n ReLU = \"relu\"\n\n\ndef get_triton_activation_kernel(activation: Optional[Activation]):\n return (\n {\n Activation.ReLU: relu,\n Activation.LeakyReLU: leaky_relu,\n Activation.GeLU: gelu,\n Activation.GeLUApprox: gelu_approx,\n Activation.SquaredReLU: squared_relu,\n }[activation]\n if activation\n else None\n )\n\n\ndef get_triton_activation_bwd_kernel(activation: Optional[Activation]):\n return (\n {\n Activation.ReLU: relu_grad,\n Activation.LeakyReLU: leaky_relu_grad,\n Activation.GeLU: gelu_grad,\n Activation.GeLUApprox: gelu_approx_grad,\n Activation.SquaredReLU: squared_relu_grad,\n }[activation]\n if activation\n else None\n )\n\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef cosh(x):\n exp_x = tl.exp(x)\n return (exp_x + 1.0 / exp_x) * 0.5\n\n\n# ReLU\n@triton.jit\ndef relu(x):\n \"\"\"\n ReLU_ activation function\n\n .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html\n \"\"\"\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n\n@triton.jit\ndef relu_grad(x):\n # ReLU is different from other activations\n # in that it does not require the input to retrospectively compute its gradient\n # here the input is the downstream gradient, and we return the upstream gradient directly\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n\n@triton.jit\ndef squared_relu(x):\n \"\"\"\n Squared ReLU activation, as proposed in the Primer_ paper.\n\n .. _Primer: https://arxiv.org/abs/2109.08668\n \"\"\"\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n\n# Leaky ReLU\n@triton.jit\ndef leaky_relu(x):\n \"\"\"\n LeakyReLU_ activation\n\n .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html\n \"\"\"\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n\n return tl.where(x >= 0, max_grad, min_grad)\n\n\n@triton.jit\ndef gelu(x):\n \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n\n\n@triton.jit\ndef gelu_grad(x):\n cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization\n return cdf + x * pdf\n\n\n@triton.jit\ndef gelu_approx(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit, with tanh approximation\n\n .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf\n \"\"\"\n return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))\n\n\n@triton.jit\ndef gelu_approx_grad(x):\n # CREDITS: Fast implementation proposed in\n # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (\n 1 + tanh_out\n )\n", - "description_1": "Use triton language to implement various activation functions and their gradients. Each function takes a single parameter 'x', which is a tensor. The functions include: 'tanh' (computes the hyperbolic tangent), 'cosh' (computes the hyperbolic cosine), 'relu' (computes the Rectified Linear Unit activation), 'relu_grad' (computes the gradient of ReLU), 'squared_relu' (computes the squared ReLU activation), 'squared_relu_grad' (computes the gradient of squared ReLU), 'leaky_relu' (computes the Leaky ReLU activation), 'leaky_relu_grad' (computes the gradient of Leaky ReLU), 'gelu' (computes the Gaussian Error Linear Unit activation), 'gelu_grad' (computes the gradient of GELU), 'gelu_approx' (computes an approximate GELU using tanh), and 'gelu_approx_grad' (computes the gradient of the approximate GELU).", - "description_2": "Use triton language to implement activation functions like ReLU, GELU, and their gradients. Each function processes a tensor 'x'.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, SEEDS, DROPOUT_MASK, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, N, eps, dropout_p,\n IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr,\n STORE_DROPOUT_MASK: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_DROPOUT:\n keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, dropout_p=0.0, out_dtype=None, residual_dtype=None,\n is_rms_norm=False, return_dropout_mask=False,\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if (\n residual is not None\n or (residual_dtype is not None and residual_dtype != x.dtype)\n or dropout_p > 0.0\n ):\n residual_out = torch.empty(\n M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype\n )\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n if dropout_p > 0.0:\n seeds = torch.randint(2**32, (M,), device=x.device, dtype=torch.int64)\n else:\n seeds = None\n if return_dropout_mask and dropout_p > 0.0:\n dropout_mask = torch.empty_like(x, dtype=torch.bool)\n else:\n dropout_mask = None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, residual_out, seeds, dropout_mask, mean, rstd,\n x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0, N, eps, dropout_p,\n is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None,\n dropout_p > 0.0, dropout_mask is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask\n", - "description_1": "Use triton language to implement a layer normalization forward pass kernel with support for dropout and residual connections. The kernel takes pointers to input, output, weights, biases, residuals, dropout seeds, and other parameters, and computes the normalized output with optional dropout and residual addition. The forward function sets up the necessary data structures and calls the kernel with appropriate configurations.", - "description_2": "Use triton language to implement a layer normalization forward pass kernel with dropout and residual support, and a function to configure and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom flash_attn.ops.triton.k_activations import (\n gelu,\n gelu_approx,\n squared_relu,\n)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n # good for int8\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_fwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n bias,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n stride_cm,\n stride_am,\n stride_ak,\n stride_bn,\n stride_bk,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n A_ROWMAJOR: tl.constexpr,\n B_COLMAJOR: tl.constexpr,\n BIAS: tl.constexpr,\n SAVE_ACT_INPUT: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Bias has shape (N,)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n This kernel will consolidate over K\n \"\"\"\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n # now compute the block that each program will go through\n # rm (resp. rn) denotes a range of indices\n # for rows (resp. col) of C\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n # trick to avoid masking on M and N axis\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n if A_ROWMAJOR:\n A = A + (ram[:, None] * stride_am + rk[None, :])\n else:\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n if B_COLMAJOR:\n B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n else:\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n if A_ROWMAJOR:\n A += BLOCK_K\n else:\n A += BLOCK_K * stride_ak\n if B_COLMAJOR:\n B += BLOCK_K\n else:\n B += BLOCK_K * stride_bk\n\n # Putting bias after the matmul (instead of before) is faster, idk why\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n # optional: save the activation inputs\n if SAVE_ACT_INPUT:\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n tl.store(act_in_ptrs, acc)\n\n # optional: fused activation (while the data is in shared memory)\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n elif ACTIVATION == \"gelu_approx\":\n acc = gelu_approx(acc)\n elif ACTIVATION == \"squared_relu\":\n acc = squared_relu(acc)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # write back result\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc)\n\n\ndef triton_linear_act(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor] = None,\n activation: str = \"id\",\n save_act_input: bool = False,\n) -> torch.Tensor:\n \"\"\"\n Compute e = activation(x @ weight.T + bias).\n This wrapper kicks the `kernel_fwd` Triton kernel\n :param x: input tensor\n :param weight: weight matrix\n :param bias: an optional bias tensor\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_input: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = x.shape[:-1], x.shape[-1]\n batch_dim = batch_shape.numel()\n x_reshaped = x.reshape(batch_dim, n)\n\n if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n x_reshaped = x_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n bias = bias.contiguous() if bias is not None else None\n\n assert (\n x.dtype == weight.dtype\n ), f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert (\n x.dtype == bias.dtype\n ), f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert (\n x_reshaped.shape[1] == weight.shape[1]\n ), f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n\n assert (\n bias is None or bias.shape[0] == weight.shape[0]\n ), \"Incompatible dimensions in between weight and bias\"\n\n M, K = x_reshaped.shape\n N, K = weight.shape\n\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_input = torch.empty_like(output) if save_act_input else None\n\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_fwd[grid](\n output,\n act_input,\n x_reshaped,\n weight, # data ptrs\n bias if bias is not None else x, # auto skip bias if not present\n M, # shapes\n N,\n K,\n M // 32, # key for triton cache (limit number of compilations)\n N // 32,\n K // 32,\n stride_cm=output.stride(0), # strides\n stride_am=x_reshaped.stride(0),\n stride_ak=x_reshaped.stride(1),\n stride_bk=weight.stride(1),\n stride_bn=weight.stride(0),\n BIAS=bias is not None, # optional fused bias\n SAVE_ACT_INPUT=save_act_input, # optional save activation inputs\n ACTIVATION=activation, # optional fused activation\n A_ROWMAJOR=x_reshaped.stride(1) == 1,\n B_COLMAJOR=weight.stride(1) == 1,\n GROUP_M=8, # speed optimization: group the programs\n )\n\n if not save_act_input:\n return output.reshape(*batch_shape, output.shape[-1])\n else:\n return (\n output.reshape(*batch_shape, output.shape[-1]),\n act_input.reshape(*batch_shape, act_input.shape[-1]),\n )\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n # good for int8\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_bwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n stride_cm,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n This kernel will consolidate over K\n \"\"\"\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n # now compute the block that each program will go through\n # rm (resp. rn) denotes a range of indices\n # for rows (resp. col) of C\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n # trick to avoid masking on M and N axis\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n # optional: fused activation (while the data is in shared memory)\n if ACTIVATION != \"id\":\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n act_input = tl.load(act_in_ptrs).to(acc.dtype)\n if ACTIVATION == \"gelu\":\n acc *= gelu_grad(act_input)\n elif ACTIVATION == \"gelu_approx\":\n acc *= gelu_approx_grad(act_input)\n elif ACTIVATION == \"squared_relu\":\n acc *= squared_relu_grad(act_input)\n\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # write back result\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc, mask=mask)\n\n\ndef triton_dgrad_act(\n grad_output: torch.Tensor,\n weight: torch.Tensor,\n activation: str = \"id\",\n act_input: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n \"\"\"\n Compute e = activation(grad_output @ weight + bias).\n This wrapper kicks the `kernel_bwd` Triton kernel\n :param grad_output: input tensor\n :param weight: weight matrix\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_input: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]\n batch_dim = batch_shape.numel()\n grad_output_reshaped = grad_output.reshape(batch_dim, n)\n\n if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:\n grad_output_reshaped = grad_output_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n\n assert (\n grad_output.dtype == weight.dtype\n ), f\"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}\"\n assert (\n grad_output_reshaped.shape[1] == weight.shape[0]\n ), f\"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}\"\n if activation != \"id\":\n assert act_input is not None, f\"act_input is required for activation {activation}\"\n\n # M, N, K in bwd are different from M, N, K in fwd\n M, K = grad_output_reshaped.shape\n K, N = weight.shape\n\n grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)\n\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_bwd[grid](\n grad_input,\n act_input,\n grad_output_reshaped,\n weight, # data ptrs\n M, # shapes\n N,\n K,\n M // 32, # key for triton cache (limit number of compilations)\n N // 32,\n K // 32,\n stride_cm=grad_input.stride(0), # strides\n stride_am=grad_output_reshaped.stride(0),\n stride_ak=grad_output_reshaped.stride(1),\n stride_bk=weight.stride(0),\n stride_bn=weight.stride(1),\n ACTIVATION=activation, # optional fused activation\n GROUP_M=8, # speed optimization: group the programs\n )\n\n return grad_input.reshape(*batch_shape, grad_input.shape[-1])\n", - "description_1": "Use triton language to implement a forward pass kernel and a backward pass kernel for linear layers with optional activation functions. The forward kernel, kernel_fwd, takes 24 parameters including tensors for inputs, weights, biases, matrix dimensions, and activation configuration. It computes the matrix multiplication of input and weight, adds the bias, applies activation, and stores the result. The backward kernel, kernel_bwd, takes 20 parameters including tensors for gradient input, weights, matrix dimensions, and activation configuration. It computes the gradient of the input by performing matrix multiplication with the weight and applies the gradient of the activation function.", - "description_2": "Use triton language to create kernels for matrix multiplication with optional activation, supporting both forward and backward passes, handling tensors and dimensions efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, # Pointers to matrices\n X,\n COS,\n SIN,\n CU_SEQLENS,\n SEQLEN_OFFSETS, # this could be int or a pointer\n # Matrix dimensions\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n CACHE_KEY_SEQLEN,\n # strides\n stride_out_batch,\n stride_out_seqlen,\n stride_out_nheads,\n stride_out_headdim,\n stride_x_batch,\n stride_x_seqlen,\n stride_x_nheads,\n stride_x_headdim,\n # Meta-parameters\n BLOCK_K: tl.constexpr,\n IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr,\n INTERLEAVED: tl.constexpr,\n CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT\n X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(\n COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0\n ).to(tl.float32)\n sin = tl.load(\n SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x0 = tl.load(\n X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x1 = tl.load(\n X + rotary_dim_half * stride_x_headdim,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n # write back result\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n tl.store(\n OUT + rotary_dim_half * stride_out_headdim,\n o1,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n )\n else:\n # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.\n # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].\n # Loading x0 will be fast but x1 will be slow.\n # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].\n # Then we do the calculation and use tl.where to pick put the right outputs for the even\n # and for the odd indices.\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(\n COS,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(\n tl.float32\n )\n x1 = tl.load(\n X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved=False,\n inplace=False,\n conjugate=False,\n) -> torch.Tensor:\n \"\"\"\n Arguments:\n x: (batch, seqlen, nheads, headdim) if cu_seqlens is None\n else (total_seqlen, nheads, headdim).\n cos: (seqlen_ro, rotary_dim / 2)\n sin: (seqlen_ro, rotary_dim / 2)\n seqlen_offsets: integer or integer tensor of size (batch,)\n cu_seqlens: (batch + 1,) or None\n max_seqlen: int\n Returns:\n y: (batch, seqlen, nheads, headdim)\n \"\"\"\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch_p_1 = cu_seqlens.shape[0]\n batch = batch_p_1 - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n rotary_dim *= 2\n assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n assert (\n cos.dtype == sin.dtype\n ), f\"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}\"\n assert (\n x.dtype == cos.dtype\n ), f\"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}\"\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = (\n 32\n if rotary_dim <= 32\n else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n )\n grid = lambda META: (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads) # noqa\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n # Need this, otherwise Triton tries to launch from cuda:0 and we get\n # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, # data ptrs\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen, # shapes\n nheads,\n rotary_dim,\n seqlen_ro,\n seqlen // 128, # key for triton cache (limit number of compilations)\n output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0\n output.stride(-3), # seqlen_stride or total_seqlen_stride\n output.stride(-2), # nheads_stride\n output.stride(-1), # headdim_stride\n x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0\n x.stride(-3), # seqlen stride or total_seqlen_stride\n x.stride(-2), # nheads stride\n x.stride(-1), # headdim stride\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n", - "description_1": "Use triton language to implement a rotary kernel for matrix computations on GPUs, accepting 24 parameters: OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro, CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, BLOCK_K, IS_SEQLEN_OFFSETS_TENSOR, IS_VARLEN, INTERLEAVED, CONJUGATE, BLOCK_M. The kernel performs computations based on these inputs and stores results in OUT.", - "description_2": "Use triton language to call the rotary kernel with input tensors (X, COS, SIN) and additional parameters for shape, memory strides, and meta-parameters to compute rotary transformations on GPU.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32 \n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,\n stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32 \n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )\n matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))\n return output\n\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )\n transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))\n return output\n", - "description_1": "Use triton language to create two kernels: `matmul_248_kernel` and `transpose_matmul_248_kernel`. `matmul_248_kernel` performs matrix multiplication where matrix A is of shape (M, K) with float16 data type, matrix B is of shape (K//8, N) with int32 data type, and matrix C is the resulting matrix of shape (M, N) with float16 data type. The function takes several additional parameters including pointers to scales, zeros, a group index, matrix dimensions M, N, K, number of bits, maximum quantization value, and various stride values. Similarly, `transpose_matmul_248_kernel` performs matrix multiplication where A is of shape (M, N), B of shape (K//8, N) and C of shape (M, K), under the same data type conditions, with a similar set of parameters. Both kernels involve bit manipulations and dot products to perform the operations efficiently on a GPU.", - "description_2": "Use triton language to implement `matmul248` function that utilizes `matmul_248_kernel` to execute optimized matrix multiplication on GPU. Another function, `transpose_matmul248`, employs `transpose_matmul_248_kernel` to perform transposed matrix multiplication, supporting different configurations through parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef triton_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # Compute the program ID\n pid = tl.program_id(0)\n # Compute the start index for this program\n start = pid * BLOCK_SIZE\n # Create a range of indices for this program\n offsets = start + tl.arange(0, BLOCK_SIZE)\n # Load input data\n input_data = tl.load(input_ptr + offsets, mask=offsets < n_elements, other=0.0)\n # Perform computation (e.g., element-wise addition)\n output_data = input_data + 1.0\n # Store the result\n tl.store(output_ptr + offsets, output_data, mask=offsets < n_elements)\n\ndef call_triton_kernel(input_tensor, output_tensor):\n # Define the block size\n BLOCK_SIZE = 1024\n # Get the number of elements\n n_elements = input_tensor.numel()\n # Launch the Triton kernel\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n triton_kernel[grid](input_tensor, output_tensor, n_elements, BLOCK_SIZE)\n\n# Example usage\ninput_tensor = torch.randn(10240, device='cuda')\noutput_tensor = torch.empty_like(input_tensor)\ncall_triton_kernel(input_tensor, output_tensor)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition on an input tensor. The kernel is launched with a specified block size and computes the result for each block of data. The kernel takes three parameters: input_ptr (pointer to input data), output_ptr (pointer to output data), and n_elements (number of elements to process). The block size is defined as a constexpr parameter.", - "description_2": "Use triton language to define a kernel that performs element-wise addition on an input tensor. Launch the kernel with a specified block size and compute the result for each block of data.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint\nfrom torch._inductor.triton_heuristics import reduction\nfrom torch._inductor import triton_helpers\n\n\n@triton.jit\ndef triton_reduce(x_ptr, y_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):\n xindex = tl.arange(0, XBLOCK).to(tl.int32)\n xoffset = tl.program_id(0) * XBLOCK\n x = xoffset + xindex\n rbase = tl.arange(0, RBLOCK).to(tl.int32)\n for roffset in range(0, rnumel, RBLOCK):\n r = rbase + roffset\n rmask = r < rnumel\n xmask = x < xnumel\n\n acc = tl.load(x_ptr + (x * rnumel + r), xmask & rmask, eviction_policy='evict_last')\n\n for r_ in range(1, rnumel // RBLOCK):\n r = rbase + (roffset + r_ * RBLOCK)\n rmask = r < rnumel\n xmask = x < xnumel\n a = tl.load(x_ptr + (x * rnumel + r), xmask & rmask, eviction_policy='evict_last')\n acc += a\n\n y = roffset // RBLOCK\n ymask = (y < rnumel)\n y_ptr[x] = tl.where(ymask, acc, 0.0)\n\n\ndef call_reduce(x_ptr, y_ptr, xnumel, rnumel):\n XBLOCK = 128\n RBLOCK = 32\n grid = (xnumel + XBLOCK - 1) // XBLOCK\n triton_reduce[(grid,)](x_ptr, y_ptr, xnumel, rnumel, XBLOCK=XBLOCK, RBLOCK=RBLOCK)\n", - "description_1": "Use triton language to create a kernel function `triton_reduce` that performs a reduction operation. It accepts four input arguments: `x_ptr`, `y_ptr`, `xnumel`, and `rnumel`. The kernel divides the work into blocks using `XBLOCK` and `RBLOCK` as block sizes. The data from the `x_ptr` is iteratively reduced across a specified dimension and the results are stored into `y_ptr`. Additionally, a call function `call_reduce` is provided to execute the kernel with appropriate grid and block settings.", - "description_2": "Use triton language to perform a reduction operation on a multi-dimensional tensor, iterating over a specified dimension and storing the results.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef dequant_kernel_248(\n g_idx_ptr,\n scales_ptr,\n qweight_ptr,\n qzeros_ptr,\n out_ptr,\n numels,\n maxq: tl.constexpr,\n bits: tl.constexpr,\n outfeatures: tl.constexpr,\n num_groups: tl.constexpr,\n X_BLOCK: tl.constexpr,\n):\n # Block indexing\n xoffset = tl.program_id(0) * X_BLOCK\n x_index = xoffset + tl.arange(0, X_BLOCK)\n xmask = x_index < numels\n row_idx = x_index // outfeatures\n col_idx = x_index % outfeatures\n\n elements_per_feature: tl.constexpr = 32 // bits\n\n # Load parameters\n g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy=\"evict_last\")\n qweights = tl.load(\n qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),\n None,\n )\n\n wf_weights = (row_idx % elements_per_feature) * bits\n\n wf_zeros = (col_idx % elements_per_feature) * bits\n\n tmp1 = g_idx + num_groups\n tmp2 = g_idx < 0\n tl.device_assert(g_idx >= 0, \"index out of bounds: 0 <= tmp0 < 0\")\n groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx\n\n scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(\n tl.float32\n )\n\n # Unpack weights\n weights = qweights >> wf_weights # bit shift qweight\n\n weights = weights & maxq\n\n # Unpack zeros\n qzero_ncols: tl.constexpr = outfeatures // elements_per_feature\n qzeros = tl.load(\n qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),\n None,\n eviction_policy=\"evict_last\",\n )\n zeros = qzeros >> wf_zeros\n zeros = zeros & maxq\n\n # Dequantize\n zeros = zeros + 1\n weights = weights - zeros\n weights = weights.to(tl.float32)\n weights = scales * weights\n\n tl.store(out_ptr + (x_index), weights, mask=xmask)\n\n\ndef dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):\n \"\"\"\n Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8\n \"\"\"\n\n num_groups = scales.shape[0]\n outfeatures = scales.shape[1]\n infeatures = g_idx.shape[0]\n\n out = torch.empty((infeatures, outfeatures), device=\"cuda\", dtype=torch.float16)\n numels = out.numel()\n maxq = 2**bits - 1 if maxq is None else maxq\n grid = lambda meta: (triton.cdiv(numels, meta[\"X_BLOCK\"]),) # noqa: E731\n\n dequant_kernel_248[grid](\n g_idx,\n scales,\n qweight,\n qzeros,\n out,\n numels,\n maxq=maxq,\n bits=bits,\n outfeatures=outfeatures,\n num_groups=num_groups,\n )\n return out\n", - "description_1": "Use triton language to implement a dequantization kernel that processes quantized weights, scales, and zero points to produce dequantized weights. The kernel takes 11 parameters: pointers to group indices, scales, quantized weights, zero points, and output; the number of elements; maximum quantization value; bit width; number of output features; number of groups; and block size. The dequant248 function launches this kernel with 7 parameters: quantized weights, scales, zero points, group indices, bit width, maximum quantization value, and returns the dequantized output.", - "description_2": "Use triton language to create a kernel for dequantizing weights with given scales and zero points, and a function to launch this kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef gemm_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n mid = tl.program_id(0)\n nid = tl.program_id(1)\n # Starting row + BLOCK_SIZE_M more rows\n\n a_rows = mid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n # Starting col + BLOCK_SIZE_N more columns\n b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_K)[None, :]\n b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_K)[:, None] * N + b_cols[None, :]\n\n c = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)\n for k in range(K//BLOCK_SIZE_K):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n c += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * N\n\n c = c.to(tl.float16)\n\n # C's block's offsets\n c_ptrs = a_rows[:, None] * N + b_cols[None, :]\n tl.store(c_ptr+ c_ptrs, c)\n\ndef gemm(a, b):\n c = torch.empty([M, N], device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n gemm_kernel[grid](a, b, c, M, N, K)\n return c\n\n@triton.jit\ndef _zp_dequant_kernel(\n Q, Out,\n scales_ptr, zeros_ptr,\n stride_qk, stride_qn,\n stride_ok, stride_on,\n stride_scales_g, stride_scales_n,\n stride_zeros_g, stride_zeros_n,\n groupsize,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"\n Dequant qweight to output matrix.\n Q is of shape (K//8, N) int32\n Out is of shape (K, N) float16\n scales is of shape (G, N) float16, where G is K // groupsize\n zeros is of shape (G, N//8) int32\n \"\"\"\n pid_k = tl.program_id(axis=0)\n pid_n = tl.program_id(axis=1)\n gid = pid_k // groupsize\n\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n # pointers\n offs_q = (pid_k // 8) * stride_qk + offs_n * stride_qn\n offs_scales = gid * stride_scales_g + offs_n * stride_scales_n\n offs_zeros = gid * stride_zeros_g + (offs_n // 8) * stride_zeros_n\n\n # shifter\n shifter = (pid_k % 8) * 4\n zeros_shifter = (offs_n % 8) * 4\n\n # load\n weight = tl.load(Q + offs_q)\n scales = tl.load(scales_ptr + offs_scales)\n zeros = tl.load(zeros_ptr + offs_zeros).to(dtype=tl.int32)\n\n # unpack weight and zeros\n weight = (weight >> shifter) & 0xF\n zeros = (zeros >> zeros_shifter) & 0xF\n zeros = (zeros + 1)\n\n # dequant weight\n weight = (weight - zeros) * scales\n\n # store the result\n offs_o = pid_k * stride_ok + offs_n * stride_on\n tl.store(Out + offs_o, weight)\n\ndef w4a16_matmul(x, w, qweight, scales, qzeros, group_size):\n block_size_n=128\n K = x.shape[1]\n N = qweight.shape[1]\n\n # shape constraints\n assert x.shape[-1] == (qweight.shape[0] * 8), \"Incompatible dimensions\"\n assert x.shape[-1] == w.shape[0], \"Incompatible dimensions\"\n assert w.shape[-1] == qweight.shape[-1], \"Incompatible dimensions\"\n assert K % group_size == 0, \"K must be a multiple of group size\"\n assert N % block_size_n == 0, \"N must be a multiple of block_size_n\"\n\n grid = (K, N // block_size_n)\n\n # dequant qweight to w\n\n _zp_dequant_kernel[grid](\n qweight, w,\n scales, qzeros,\n qweight.stride(0), qweight.stride(1),\n w.stride(0), w.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n block_size_n,\n num_warps=2, num_stages=4,\n )\n c = torch.matmul(x, w)\n return c\n\ndef triton_matmul(a,ref_weight, qweight, scales, qzeros, group_size, stream1, stream2):\n block_size_n = 128\n K = a.shape[1]\n N = qweight.shape[1]\n\n # shape constraints\n assert a.shape[-1] == (qweight.shape[0] * 8), \"Incompatible dimensions\"\n assert a.shape[-1] == ref_weight.shape[0], \"Incompatible dimensions\"\n assert ref_weight.shape[-1] == qweight.shape[-1], \"Incompatible dimensions\"\n assert K % group_size == 0, \"K must be a multiple of group size\"\n assert N % block_size_n == 0, \"N must be a multiple of block_size_n\"\n\n grid = (K, N // block_size_n)\n\n with torch.cuda.stream(stream1):\n _zp_dequant_kernel[grid](\n qweight, ref_weight,\n scales, qzeros,\n qweight.stride(0), qweight.stride(1),\n ref_weight.stride(0), ref_weight.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n block_size_n,\n num_warps=2, num_stages=4\n )\n\n with torch.cuda.stream(stream2):\n torch.matmul(a, ref_weight)\n stream2.wait_stream(stream1)\n", - "description_1": "Use triton language to implement two main functions: `gemm_kernel` for matrix multiplication, handling inputs a, b, and outputs c with triton grid and size management; `_zp_dequant_kernel` for zero-point dequantization, processing quantized input matrices and scales, zeros for dequantized output. These kernels are wrapped in Python functions for matrix operations, including stream-based computations.", - "description_2": "Use triton language to create kernels for matrix multiplication and zero-point dequantization, interfaced with Python functions for matrix computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nDEFAULT_DEQUANT_CONFIGS = [\n triton.Config({\"X_BLOCK\": bs}, num_warps=ws)\n for bs, ws in itertools.product([128, 256, 512, 1024], [4, 8])\n]\n\n@triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=[\"numels\"])\n@triton.jit\ndef dequant_kernel_248(\n g_idx_ptr,\n scales_ptr,\n qweight_ptr,\n qzeros_ptr,\n out_ptr,\n numels,\n maxq: tl.constexpr,\n bits: tl.constexpr,\n outfeatures: tl.constexpr,\n num_groups: tl.constexpr,\n X_BLOCK: tl.constexpr,\n):\n # Triton kernel for dequantization\n xoffset = tl.program_id(0) * X_BLOCK\n x_index = xoffset + tl.arange(0, X_BLOCK)\n xmask = x_index < numels\n row_idx = x_index // outfeatures\n col_idx = x_index % outfeatures\n\n elements_per_feature: tl.constexpr = 32 // bits\n\n # Load parameters\n g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy=\"evict_last\")\n qweights = tl.load(\n qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),\n None,\n )\n\n wf_weights = (row_idx % elements_per_feature) * bits\n\n wf_zeros = (col_idx % elements_per_feature) * bits\n\n tmp1 = g_idx + num_groups\n tmp2 = g_idx < 0\n tl.device_assert(g_idx >= 0, \"index out of bounds: 0 <= tmp0 < 0\")\n groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx\n\n scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(\n tl.float32\n )\n\n # Unpack weights\n weights = qweights >> wf_weights # bit shift qweight\n\n weights = weights & maxq\n\n # Unpack zeros\n qzero_ncols: tl.constexpr = outfeatures // elements_per_feature\n qzeros = tl.load(\n qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),\n None,\n eviction_policy=\"evict_last\",\n )\n zeros = qzeros >> wf_zeros\n zeros = zeros & maxq\n\n # Dequantize\n zeros = zeros + 1\n weights = weights - zeros\n weights = weights.to(tl.float32)\n weights = scales * weights\n\n tl.store(out_ptr + (x_index), weights, mask=xmask)\n\n\ndef dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):\n \"\"\"\n Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8\n \"\"\"\n num_groups = scales.shape[0]\n outfeatures = scales.shape[1]\n infeatures = g_idx.shape[0]\n\n out = torch.empty((infeatures, outfeatures), device=\"cuda\", dtype=torch.float16)\n numels = out.numel()\n maxq = 2**bits - 1 if maxq is None else maxq\n grid = lambda meta: (triton.cdiv(numels, meta[\"X_BLOCK\"]),)\n\n dequant_kernel_248[grid](\n g_idx,\n scales,\n qweight,\n qzeros,\n out,\n numels,\n maxq=maxq,\n bits=bits,\n outfeatures=outfeatures,\n num_groups=num_groups,\n )\n return out\n", - "description_1": "Use triton language to create a dequantization kernel, dequant_kernel_248, which takes 10 parameters: g_idx_ptr (global memory pointer to group indices), scales_ptr (global memory pointer to scale factors), qweight_ptr (global memory pointer to quantized weights), qzeros_ptr (global memory pointer to zero points), out_ptr (global memory pointer to output), numels (number of elements to process), maxq (constant, maximum quantized value), bits (constant, number of bits used in quantization), outfeatures (constant, number of output features), num_groups (constant, number of groups), and X_BLOCK (constant, size of each block of elements to process). The function dequant248 is used to launch this kernel and performs dequantization on given quantized data using these parameters. It computes the grid configuration and calls the Triton kernel with necessary parameters for dequantization.", - "description_2": "Use triton language to implement a kernel for dequantizing quantized weights with group indices and scales, and provide a launcher function to configure and run this kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _zp_dequant_kernel(\n Q, Out,\n scales_ptr, zeros_ptr,\n stride_qk, stride_qn,\n stride_ok, stride_on,\n stride_scales_g, stride_scales_n,\n stride_zeros_g, stride_zeros_n,\n groupsize,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"\n Dequant qweight to output matrix.\n Q is of shape (K//8, N) int32\n Out is of shape (K, N) float16\n scales is of shape (G, N) float16, where G is K // groupsize\n zeros is of shape (G, N//8) int32\n \"\"\"\n pid_k = tl.program_id(axis=0)\n pid_n = tl.program_id(axis=1)\n gid = pid_k // groupsize\n\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n # pointers\n offs_q = (pid_k // 8) * stride_qk + offs_n * stride_qn\n offs_scales = gid * stride_scales_g + offs_n * stride_scales_n\n offs_zeros = gid * stride_zeros_g + (offs_n // 8) * stride_zeros_n\n\n # shifter\n shifter = (pid_k % 8) * 4\n zeros_shifter = (offs_n % 8) * 4\n\n # load\n weight = tl.load(Q + offs_q)\n scales = tl.load(scales_ptr + offs_scales)\n zeros = tl.load(zeros_ptr + offs_zeros)\n\n # unpack weight and zeros\n weight = (weight >> shifter) & 0xF\n zeros = (zeros >> zeros_shifter) & 0xF\n zeros = (zeros + 1)\n\n # dequant weight\n weight = (weight - zeros) * scales\n\n # store the result\n offs_o = pid_k * stride_ok + offs_n * stride_on\n tl.store(Out + offs_o, weight)\n\n@triton.jit\ndef _sym_dequant_kernel(\n Q, Out,\n scales_ptr,\n ZERO,\n stride_qk, stride_qn,\n stride_ok, stride_on,\n stride_scales_g, stride_scales_n,\n groupsize,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"\n Dequant qweight to output matrix.\n Q is of shape (K//8, N) int32\n Out is of shape (K, N) float16\n scales is of shape (G, N) float16, where G is K // groupsize\n ZERO is 8, where 2 ** (bits-1) = 8\n \"\"\"\n pid_k = tl.program_id(axis=0)\n pid_n = tl.program_id(axis=1)\n gid = pid_k // groupsize\n\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n # pointers\n offs_q = (pid_k // 8) * stride_qk + offs_n * stride_qn\n offs_scales = gid * stride_scales_g + offs_n * stride_scales_n\n\n # shifter\n shifter = (pid_k % 8) * 4\n\n # load\n weight = tl.load(Q + offs_q)\n scales = tl.load(scales_ptr + offs_scales)\n\n # unpack weight and zeros\n weight = (weight >> shifter) & 0xF\n\n # dequant weight\n weight = (weight - ZERO) * scales\n\n # store the result\n offs_o = pid_k * stride_ok + offs_n * stride_on\n tl.store(Out + offs_o, weight)\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 84,\n }),\n triton.Config({\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 128,\n }),\n triton.Config({\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 84,\n }),\n triton.Config({\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 128,\n }),\n ],\n key=['group_size'],\n)\n@triton.jit\ndef grouped_matmul_kernel(\n # device tensor of matrices pointers\n group_a_ptrs,\n group_b_ptrs,\n group_c_ptrs,\n # device tensor of gemm sizes. its shape is [group_size, 3]\n # dim 0 is group_size, dim 1 is the values of of each gemm\n group_gemm_sizes,\n # device tensor of leading dimension sizes. its shape is [group_size, 3]\n # dim 0 is group_size, dim 1 is the values of of each gemm\n g_lds,\n # number of gemms\n group_size,\n # number of virtual SM\n NUM_SM: tl.constexpr,\n # tile sizes\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n tile_idx = tl.program_id(0)\n last_problem_end = 0\n for g in range(group_size):\n # get the gemm size of the current problem\n gm = tl.load(group_gemm_sizes + g * 3)\n gn = tl.load(group_gemm_sizes + g * 3 + 1)\n gk = tl.load(group_gemm_sizes + g * 3 + 2)\n num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n num_tiles = num_m_tiles * num_n_tiles\n # iterate through the tiles in the current gemm problem\n while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):\n # pick up a tile from the current gemm problem\n k = gk\n lda = tl.load(g_lds + g * 3)\n ldb = tl.load(g_lds + g * 3 + 1)\n ldc = tl.load(g_lds + g * 3 + 2)\n a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))\n b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))\n c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))\n # figure out tile coordinates\n tile_idx_in_gemm = tile_idx - last_problem_end\n tile_m_idx = tile_idx_in_gemm // num_n_tiles\n tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n # do regular gemm here\n offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]\n b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n # hint to Triton compiler to do proper loop pipelining\n tl.multiple_of(a_ptrs, [16, 16])\n tl.multiple_of(b_ptrs, [16, 16])\n # assume full tile for now\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * ldb\n c = accumulator.to(tl.float16)\n\n offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]\n\n # assumes full tile for now\n tl.store(c_ptrs, c)\n\n # go to the next tile by advancing NUM_SM\n tile_idx += NUM_SM\n\n # get ready to go to the next gemm problem\n last_problem_end = last_problem_end + num_tiles\n\n\ndef group_gemm_fn(group_A, group_B):\n device = torch.device('cuda')\n assert len(group_A) == len(group_B)\n group_size = len(group_A)\n\n A_addrs = []\n B_addrs = []\n C_addrs = []\n g_sizes = []\n g_lds = []\n group_C = []\n for i in range(group_size):\n A = group_A[i]\n B = group_B[i]\n assert A.shape[1] == B.shape[0]\n M, K = A.shape\n K, N = B.shape\n C = torch.empty((M, N), device=device, dtype=A.dtype)\n group_C.append(C)\n A_addrs.append(A.data_ptr())\n B_addrs.append(B.data_ptr())\n C_addrs.append(C.data_ptr())\n g_sizes += [M, N, K]\n g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n\n # note these are device tensors\n d_a_ptrs = torch.tensor(A_addrs, device=device)\n d_b_ptrs = torch.tensor(B_addrs, device=device)\n d_c_ptrs = torch.tensor(C_addrs, device=device)\n d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)\n d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)\n # we use a fixed number of CTA, and it's auto-tunable\n grid = lambda META: (META['NUM_SM'], )\n grouped_matmul_kernel[grid](\n d_a_ptrs,\n d_b_ptrs,\n d_c_ptrs,\n d_g_sizes,\n d_g_lds,\n group_size,\n )\n\n return group_C\n", - "description_1": "Use triton language to implement three kernels: _zp_dequant_kernel, _sym_dequant_kernel, and grouped_matmul_kernel. The _zp_dequant_kernel takes 13 parameters including Q, Out, scales_ptr, zeros_ptr, and others to dequantize a quantized weight matrix to an output matrix. The _sym_dequant_kernel takes 12 parameters including Q, Out, scales_ptr, ZERO, and others to perform symmetric dequantization of a quantized weight matrix. The grouped_matmul_kernel takes 10 parameters including group_a_ptrs, group_b_ptrs, group_c_ptrs, group_gemm_sizes, g_lds, group_size, and others to perform grouped matrix multiplication with autotuning capabilities.", - "description_2": "Use triton language to create kernels for dequantizing quantized matrices and performing grouped matrix multiplication with autotuning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef gemm_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n mid = tl.program_id(0)\n nid = tl.program_id(1)\n # Starting row + BLOCK_SIZE_M more rows\n\n a_rows = mid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n # Starting col + BLOCK_SIZE_N more columns\n b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_K)[None, :]\n b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_K)[:, None] * N + b_cols[None, :]\n\n c = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)\n for k in range(K // BLOCK_SIZE_K):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n c += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * N\n\n c = c.to(tl.float16)\n\n # C's block's offsets\n c_ptrs = a_rows[:, None] * N + b_cols[None, :]\n tl.store(c_ptr + c_ptrs, c)\n\n\ndef gemm(a, b):\n c = torch.empty([M, N], device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n gemm_kernel[grid](a, b, c, M, N, K)\n return c\n\n\n@triton.jit\ndef _zp_dequant_kernel(\n Q, Out,\n scales_ptr, zeros_ptr,\n stride_qk, stride_qn,\n stride_ok, stride_on,\n stride_scales_g, stride_scales_n,\n stride_zeros_g, stride_zeros_n,\n groupsize,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"\n Dequant qweight to output matrix.\n Q is of shape (K//8, N) int32\n Out is of shape (K, N) float16\n scales is of shape (G, N) float16, where G is K // groupsize\n zeros is of shape (G, N//8) int32\n \"\"\"\n pid_k = tl.program_id(axis=0)\n pid_n = tl.program_id(axis=1)\n gid = pid_k // groupsize\n\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n # pointers\n offs_q = (pid_k // 8) * stride_qk + offs_n * stride_qn\n offs_scales = gid * stride_scales_g + offs_n * stride_scales_n\n offs_zeros = gid * stride_zeros_g + (offs_n // 8) * stride_zeros_n\n\n # shifter\n shifter = (pid_k % 8) * 4\n zeros_shifter = (offs_n % 8) * 4\n\n # load\n weight = tl.load(Q + offs_q)\n scales = tl.load(scales_ptr + offs_scales)\n zeros = tl.load(zeros_ptr + offs_zeros).to(dtype=tl.int32)\n\n # unpack weight and zeros\n weight = (weight >> shifter) & 0xF\n zeros = (zeros >> zeros_shifter) & 0xF\n zeros = (zeros + 1)\n\n # dequant weight\n weight = (weight - zeros) * scales\n\n # store the result\n offs_o = pid_k * stride_ok + offs_n * stride_on\n tl.store(Out + offs_o, weight)\n\n\ndef w4a16_matmul(x, w, qweight, scales, qzeros, group_size):\n block_size_n = 128\n K = x.shape[1]\n N = qweight.shape[1]\n\n # shape constraints\n assert x.shape[-1] == (qweight.shape[0] * 8), \"Incompatible dimensions\"\n assert x.shape[-1] == w.shape[0], \"Incompatible dimensions\"\n assert w.shape[-1] == qweight.shape[-1], \"Incompatible dimensions\"\n assert K % group_size == 0, \"K must be a multiple of group size\"\n assert N % block_size_n == 0, \"N must be a multiple of block_size_n\"\n\n grid = (K, N // block_size_n)\n\n # dequant qweight to w\n _zp_dequant_kernel[grid](\n qweight, w,\n scales, qzeros,\n qweight.stride(0), qweight.stride(1),\n w.stride(0), w.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n block_size_n,\n num_warps=2, num_stages=4,\n )\n c = torch.matmul(x, w)\n\n return c\n\n\ndef triton_matmul(a, ref_weight, qweight, scales, qzeros, group_size, stream1, stream2):\n block_size_n = 128\n K = a.shape[1]\n N = qweight.shape[1]\n\n # shape constraints\n assert a.shape[-1] == (qweight.shape[0] * 8), \"Incompatible dimensions\"\n assert a.shape[-1] == ref_weight.shape[0], \"Incompatible dimensions\"\n assert ref_weight.shape[-1] == qweight.shape[-1], \"Incompatible dimensions\"\n assert K % group_size == 0, \"K must be a multiple of group size\"\n assert N % block_size_n == 0, \"N must be a multiple of block_size_n\"\n\n grid = (K, N // block_size_n)\n\n with torch.cuda.stream(stream1):\n _zp_dequant_kernel[grid](\n qweight, ref_weight,\n scales, qzeros,\n qweight.stride(0), qweight.stride(1),\n ref_weight.stride(0), ref_weight.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n block_size_n,\n num_warps=2, num_stages=4\n )\n\n with torch.cuda.stream(stream2):\n torch.matmul(a, ref_weight)\n stream2.wait_stream(stream1)\n", - "description_1": "Use triton language to implement a GEMM kernel with parameters for pointers to matrices A, B, C, dimensions M, N, K, and block sizes. Implement a dequantization kernel for quantized weights with parameters for quantized matrix Q, output matrix, scales, zeros, strides, group size, and block size. Provide functions to call these kernels and perform matrix multiplication with dequantization.", - "description_2": "Use triton language to implement a GEMM kernel and a dequantization kernel for quantized weights, and provide functions to perform matrix multiplication with these kernels.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef gemm_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n mid = tl.program_id(0)\n nid = tl.program_id(1)\n # Starting row + BLOCK_SIZE_M more rows\n\n a_rows = mid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n # Starting col + BLOCK_SIZE_N more columns\n b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_K)[None, :]\n b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_K)[:, None] * N + b_cols[None, :]\n\n c = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)\n for k in range(K//BLOCK_SIZE_K):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n c += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * N\n\n c = c.to(tl.float16)\n\n # C's block's offsets\n c_ptrs = a_rows[:, None] * N + b_cols[None, :]\n tl.store(c_ptr+ c_ptrs, c)\n\n@triton.jit\ndef _zp_dequant_kernel(\n Q, Out,\n scales_ptr, zeros_ptr,\n stride_qk, stride_qn,\n stride_ok, stride_on,\n stride_scales_g, stride_scales_n,\n stride_zeros_g, stride_zeros_n,\n groupsize,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"\n Dequant qweight to output matrix.\n Q is of shape (K//8, N) int32\n Out is of shape (K, N) float16\n scales is of shape (G, N) float16, where G is K // groupsize\n zeros is of shape (G, N//8) int32\n \"\"\"\n pid_k = tl.program_id(axis=0)\n pid_n = tl.program_id(axis=1)\n gid = pid_k // groupsize\n\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n # pointers\n offs_q = (pid_k // 8) * stride_qk + offs_n * stride_qn\n offs_scales = gid * stride_scales_g + offs_n * stride_scales_n\n offs_zeros = gid * stride_zeros_g + (offs_n // 8) * stride_zeros_n\n\n # shifter\n shifter = (pid_k % 8) * 4\n zeros_shifter = (offs_n % 8) * 4\n\n # load\n weight = tl.load(Q + offs_q)\n scales = tl.load(scales_ptr + offs_scales)\n zeros = tl.load(zeros_ptr + offs_zeros).to(dtype=tl.int32)\n\n # unpack weight and zeros\n weight = (weight >> shifter) & 0xF\n zeros = (zeros >> zeros_shifter) & 0xF\n zeros = (zeros + 1)\n\n # dequant weight\n weight = (weight - zeros) * scales\n\n # store the result\n offs_o = pid_k * stride_ok + offs_n * stride_on\n tl.store(Out + offs_o, weight)\n\ndef w4a16_matmul(x, w, qweight, scales, qzeros, group_size):\n block_size_n=128\n K = x.shape[1]\n N = qweight.shape[1]\n\n # shape constraints\n assert x.shape[-1] == (qweight.shape[0] * 8), \"Incompatible dimensions\"\n assert x.shape[-1] == w.shape[0], \"Incompatible dimensions\"\n assert w.shape[-1] == qweight.shape[-1], \"Incompatible dimensions\"\n assert K % group_size == 0, \"K must be a multiple of group size\"\n assert N % block_size_n == 0, \"N must be a multiple of block_size_n\"\n\n grid = (K, N // block_size_n)\n\n # dequant qweight to w\n\n _zp_dequant_kernel[grid](\n qweight, w,\n scales, qzeros,\n qweight.stride(0), qweight.stride(1),\n w.stride(0), w.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n block_size_n,\n num_warps=2, num_stages=4,\n )\n c = torch.matmul(x, w)\n return c\n\n@evaluate_kernel(inputs=inputs)\ndef triton_matmul(a,ref_weight, qweight, scales, qzeros, group_size, stream1, stream2):\n block_size_n = 128\n K = a.shape[1]\n N = qweight.shape[1]\n\n # shape constraints\n assert a.shape[-1] == (qweight.shape[0] * 8), \"Incompatible dimensions\"\n assert a.shape[-1] == ref_weight.shape[0], \"Incompatible dimensions\"\n assert ref_weight.shape[-1] == qweight.shape[-1], \"Incompatible dimensions\"\n assert K % group_size == 0, \"K must be a multiple of group size\"\n assert N % block_size_n == 0, \"N must be a multiple of block_size_n\"\n\n grid = (K, N // block_size_n)\n\n with torch.cuda.stream(stream1):\n _zp_dequant_kernel[grid](\n qweight, ref_weight,\n scales, qzeros,\n qweight.stride(0), qweight.stride(1),\n ref_weight.stride(0), ref_weight.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n block_size_n,\n num_warps=2, num_stages=4\n )\n\n with torch.cuda.stream(stream2):\n torch.matmul(a, ref_weight)\n stream2.wait_stream(stream1)\n torch.cuda.synchronize()\n", - "description_1": "Use triton language to implement two kernels: 'gemm_kernel' for general matrix multiplication with parameters for pointers to matrices, dimensions, and block sizes; '_zp_dequant_kernel' for dequantizing a quantized weight matrix with parameters for pointers, strides, group size, and block size. Additionally, implement functions to call these kernels and perform matrix multiplication.", - "description_2": "Use triton language to create a GEMM kernel for matrix multiplication and a dequantization kernel for processing quantized weights, with appropriate function calls for execution.", - "difficulty": 3 - }, - { - "code": "import triton\n\n# Kernel function\n@triton.jit\ndef kernel_function(arg1, arg2):\n # Code logic for the kernel\n pass\n\n# Function calling the Triton kernel\ndef call_kernel():\n # Assuming appropriate grid and stream setup\n kernel_function.run(arg1, arg2, grid=(1,), stream=None)\n", - "description_1": "Use triton language to define a kernel with two arguments, performing some operations. Then, create a function to execute the kernel with those arguments.", - "description_2": "Use triton language to implement and run a simple kernel with two parameters.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = tl.program_id(0)\n block_start = pid * 1024\n offsets = block_start + tl.arange(0, 1024)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_add_kernel(X, Y, Z, N):\n grid = lambda meta: (triton.cdiv(N, 1024),)\n add_kernel[grid](X, Y, Z, N)\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = torch.empty(1024, device='cuda')\nN = X.numel()\ncall_add_kernel(X, Y, Z, N)\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that takes four arguments: X, Y, Z, and N. The kernel performs element-wise addition of two input tensors X and Y, storing the result in tensor Z. The computation is done in blocks of 1024 elements, and the kernel is launched with a grid size determined by the number of elements N. The function 'call_add_kernel' is a wrapper that sets up the grid and calls the kernel.", - "description_2": "Use triton language to perform element-wise addition of two tensors on the GPU using a custom kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' with 4 parameters: X, Y, Z, and BLOCK_SIZE. The kernel is called using 'call_example_kernel' function which takes 4 arguments: x, y, z, and block_size.", - "description_2": "Use triton language to define a kernel and a function to call it with specified parameters.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel to promote a scalar to a tensor\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n# Kernel to check if a tensor is of floating type\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n# Kernel to accumulate product\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n# Kernel to compute product along a given axis\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n# Kernel to compute minimum of two tensors\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Kernel to compute maximum of two tensors\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Kernel to compute minimum along a given dimension\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n# Kernel to compute maximum along a given dimension\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n# Kernel to compute minimum with index\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n equal |= a_isnan and b_isnan\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n# Kernel to compute maximum with index\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n equal |= a_isnan and b_isnan\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n# Kernel to compute minimum with index along a given dimension\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n# Kernel to compute maximum with index along a given dimension\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n# Kernel for Welford reduction\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n# Kernel to combine Welford results\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n# Kernel for Welford reduction along a given dimension\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n# Kernel to assert a condition on the device\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n# Kernel to generate a random 64-bit integer\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n# Kernel to combine any operation\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n# Kernel to compute any operation along a given dimension\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n# Kernel for bucketize binary search\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n full_range = (full_range + 1) // 2\n return low\n\n# Kernel to pack value and flag\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n# Kernel to unpack value\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n# Kernel to unpack flag\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n# Kernel for exclusive scan with decoupled lookback\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``\n DTYPE_PACK: Unsigned type twice the width of block_value\n\n NOTE: This function is limited to values which are 32-bits or less because\n we need to pack (value, flag) into a single unsigned int.\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n if index > 0:\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], DTYPE_VALUE)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n # tl.atomic_load\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n# Kernel for exclusive scan with decoupled lookback for 64-bit values\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block, must be 64-bits wide\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n if index > 0:\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], block_value.dtype)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n\n return exclusive_prefix\n\n# Kernel to compute mantissa and exponent of a floating-point number\n@triton.jit\ndef frexp(x):\n # TODO: use inline_asm_elementwise here\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n", - "description_1": "Use triton language to implement various mathematical operations and reductions, including tensor promotion, floating-point checks, product accumulation, minimum and maximum calculations, Welford reduction, random integer generation, and exclusive scans with decoupled lookback.", - "description_2": "Use triton language to create kernels for mathematical operations and reductions, such as tensor promotion, floating-point checks, and exclusive scans.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = triton.program_id(0)\n block_size = 1024\n offset = pid * block_size + triton.arange(0, block_size)\n mask = offset < N\n x = triton.load(X + offset, mask=mask)\n y = triton.load(Y + offset, mask=mask)\n z = x + y\n triton.store(Z + offset, z, mask=mask)\n\n# Function to call the Triton kernel\ndef add(X, Y):\n assert X.shape == Y.shape\n Z = torch.empty_like(X)\n N = X.numel()\n grid = lambda meta: (triton.cdiv(N, meta['block_size']),)\n add_kernel[grid](X, Y, Z, N)\n return Z\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel 'add_kernel' takes four parameters: X, Y, Z, and N. X and Y are input tensors, Z is the output tensor, and N is the total number of elements. The kernel computes the sum of X and Y element-wise and stores the result in Z. The function 'add' calls this kernel, ensuring that the input tensors X and Y have the same shape, and returns the result tensor Z.", - "description_2": "Use triton language to implement a kernel for element-wise addition of two tensors and a function to call this kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input_broadcasted._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n", - "description_1": "Use triton language to implement a sampled matrix multiplication kernel with parameters for alpha, beta, block sizes, and strides. The kernel computes the product of two matrices with optional scaling and addition, storing the result in a specified output tensor.", - "description_2": "Use triton language to implement a sparse matrix multiplication kernel with customizable parameters for scaling and block sizes.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\nfrom triton.language import load, store\n\n# Triton Kernel: add_kernel\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: add_kernel_with_optional_param\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output with optional addition\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: add_kernel_autotuned\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output with autotuning\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: add_kernel_2d_autotuned\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=4, num_warps=4\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n # Each program calculates a 2D block of the output with autotuning\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n# Triton Kernel: add_kernel_with_scaling\n@triton.jit\ndef add_kernel_with_scaling(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n scaling_factor,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output with a scaling factor\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = (x + y) * scaling_factor\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: mul2_kernel\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program multiplies each element by 2\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: mul2_inplace_kernel\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program multiplies each element by 2 in-place\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n# Triton Kernel: zero_negs\n@triton.jit\ndef zero_negs(x):\n # Replace negative numbers with zero\n return tl.where(x >= 0, x, 0)\n\n# Triton Kernel: indirection_kernel\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n # Each program applies a specified activation function\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n elif ACTIVATION == \"add_kernel\":\n add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n\n# Triton Kernel: double_strided_kernel\n@triton.jit\ndef double_strided_kernel(\n in_ptr,\n out_ptr,\n in_y_stride,\n out_y_stride,\n X_BLOCK_SIZE: \"tl.constexpr\",\n Y_BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program doubles the values considering striding\n xid = tl.program_id(axis=0)\n yid = tl.program_id(axis=1)\n x_start = xid * X_BLOCK_SIZE\n y_start = yid * Y_BLOCK_SIZE\n x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)\n y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)\n src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]\n dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]\n src = tl.load(in_ptr + src_offsets)\n tl.store(out_ptr + dst_offsets, src * 2.0)\n\n# Triton Kernel: inline_asm_kernel\n@triton.jit\ndef inline_asm_kernel(X, Y, Z, n: \"tl.constexpr\", BLOCK: \"tl.constexpr\"):\n # Each program executes inline assembly\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.load(Y + tl.arange(0, BLOCK))\n s = tl.full([BLOCK], n, tl.int32)\n z = tl.inline_asm_elementwise(\n \"shf.l.wrap.b32 $0, $1, $2, $3;\",\n \"=r,r, r, r\",\n [x, y, s],\n dtype=tl.int32,\n is_pure=True,\n pack=1,\n )\n tl.store(Z + tl.arange(0, BLOCK), z)\n\n# Triton Kernel: add_kernel_with_block_ptr\n@triton.jit\ndef add_kernel_with_block_ptr(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n # Each program calculates a block of the output using block pointers\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n y = tl.load(\n tl.make_block_ptr(\n base=y_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n output = x + y\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n output,\n boundary_check=[0],\n )\n\n# Triton Kernel: kernel_with_block_ptr_2d\n@triton.jit\ndef kernel_with_block_ptr_2d(\n x_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n # Each program calculates a 2D block of the output using block pointers\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements, 1],\n strides=[1, 1],\n offsets=[block_start, 0],\n block_shape=[BLOCK_SIZE, 1],\n order=[1, 0],\n ),\n boundary_check=[0],\n )\n output = x\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements, 1],\n strides=[1, 1],\n offsets=[block_start, 0],\n block_shape=[BLOCK_SIZE, 1],\n order=[1, 0],\n ),\n output,\n boundary_check=[0],\n )\n\n# Triton Kernel: add_kernel_with_import\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output using imported functions\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: cond_op_kernel\n@triton.jit\ndef cond_op_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output conditionally\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n if tl.program_id(0) == 0:\n output = x + y\n else:\n output = x * y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: atomic_add_kernel\n@triton.jit\ndef atomic_add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output using atomic addition\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.atomic_add(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: add_4_times_kernel\n@triton.jit\ndef add_4_times_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output and stores it four times\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n for i in range(2):\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n i = 2\n while i > 0:\n i -= 1\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Triton Kernel: add_kernel_out_of_order_fn2\n@triton.jit\ndef add_kernel_out_of_order_fn2(\n in_ptr0,\n in_ptr1,\n n_elements,\n out_ptr,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Each program calculates a block of the output with parameters out of order\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to implement various element-wise operations over arrays. The kernels cover operations like addition, conditional operations, scaling, atomic addition, and more. They are highly parallelized using Triton and sometimes autotuned for optimal performance.", - "description_2": "Use triton language to create kernels that perform parallel element-wise addition and scaling with possible autotuning for optimal performance.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n # Compute the index of the element to process\n idx = triton.program_id(0)\n if idx < N:\n # Perform the addition\n Z[idx] = X[idx] + Y[idx]\n\n# Function to launch the Triton kernel\ndef add_tensors(X, Y, Z, N):\n # Launch the kernel with a grid size of N\n add_kernel[(N,)](X, Y, Z, N)\n\n# Example usage\nX = torch.tensor([1.0, 2.0, 3.0], device='cuda')\nY = torch.tensor([4.0, 5.0, 6.0], device='cuda')\nZ = torch.empty_like(X)\nN = X.numel()\n\n# Call the function to perform addition\nadd_tensors(X, Y, Z, N)\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that performs element-wise addition of two input tensors X and Y, storing the result in tensor Z. The kernel takes four parameters: X, Y, Z (all tensors), and N (an integer representing the number of elements). The function 'add_tensors' is used to launch this kernel with a grid size of N.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors and a function to launch this kernel.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _abs_max(val1, val2):\n # Calculate the absolute maximum of two values.\n val1_abs = tl.abs(val1)\n val2_abs = tl.abs(val2)\n if val1_abs >= val2_abs:\n return val1_abs\n else:\n return val2_abs\n\n@triton.autotune(configs=_get_autotune_configs(), key=[\"M\", \"N\"])\n@triton.jit\ndef _triton_dynamic_quantize_kernel(\n output_ptr, # Pointer to output tensor\n input_ptr, # Pointer to input tensor\n scale_ptr, # Pointer to scale tensor\n stride_outputm, # Stride for output in dimension m\n stride_outputn, # Stride for output in dimension n\n stride_inputm, # Stride for input in dimension m\n stride_inputn, # Stride for input in dimension n\n n_elements, # Number of elements to process\n M: tl.constexpr, # Number of rows (tokens)\n N: tl.constexpr, # Number of columns (hidden size)\n):\n # Dynamic quantization kernel\n pid = tl.program_id(axis=0)\n offsets = tl.arange(0, N)\n mask = offsets < n_elements\n input_ptrs = input_ptr + pid * stride_inputm + offsets\n input_vals = tl.load(input_ptrs, mask=mask, other=1e-6)\n abs_max_f = tl.reduce(input_vals, 0, _abs_max)\n dynamic_per_token_scale = 127.0 / abs_max_f\n precison_mask = tl.where(input_vals > 0, 0.5, -0.5)\n output_vals = (input_vals * dynamic_per_token_scale + precison_mask).to(tl.int8)\n output_ptrs = output_ptr + pid * stride_outputm + offsets\n tl.store(output_ptrs, output_vals, mask=mask)\n tl.store(scale_ptr + pid, abs_max_f / 127.0)\n\n\ndef triton_dynamic_quantize(out, input, scale):\n # Function to initiate the dynamic quantization process\n assert input.is_contiguous(), \"input must be contiguous\"\n num_tokens = input.size(0)\n hidden_size = input.size(1)\n block_size = 1024\n # Ensure hidden_size is a power-of-two for tl.reduce\n if hidden_size & (hidden_size - 1) == 0 and hidden_size > 0:\n block_size = min(hidden_size / 2, block_size)\n else:\n hidden_size = triton.next_power_of_2(int(hidden_size))\n block_size = min(hidden_size / 2, block_size)\n # num_warps = int(max(block_size / THREADS_PER_WARP, 1))\n _triton_dynamic_quantize_kernel[(num_tokens,)](\n out,\n input,\n scale,\n out.stride(0),\n out.stride(1),\n input.stride(0),\n input.stride(1),\n n_elements=input.size(1),\n M=num_tokens,\n N=hidden_size,\n )\n", - "description_1": "Use triton language to implement dynamic quantization with two kernels: (1) a helper kernel '_abs_max' that takes two values and returns the maximum absolute value between them; (2) a main kernel '_triton_dynamic_quantize_kernel' that processes input data pointers, applies a dynamic scale, and stores quantized results. It receives pointers for input, output, and scale tensors, strides, number of elements, and dimensions M and N. It computes the absolute maximum and dynamic scale per token, performs quantization, and stores the results. The main function 'triton_dynamic_quantize' prepares the input sizes and launches the kernel for processing.", - "description_2": "Use triton language to create a quantization kernel that takes input/output tensor pointers, applies a dynamic scale to quantize the input, computes the maximum absolute value, and stores scaled results in the output. Implement a helper kernel for computing absolute maximum values and ensure proper input tensor configuration before processing.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flashnn.kernel_backend import get_autotune_triton_kernels\n\n@triton.jit\ndef _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n STAGE: tl.constexpr,\n offs_m: tl.constexpr,\n offs_n: tl.constexpr,\n N_CTX: tl.constexpr,\n):\n # range of values handled by this stage\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n else:\n lo, hi = 0, N_CTX\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if STAGE != 1:\n k = tl.load(K_block_ptr, boundary_check=(0, 1))\n else:\n k = tl.load(K_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if STAGE != 1:\n n_ctx_mask = tl.where(\n (offs_m[:, None] < N_CTX) & ((start_n + offs_n[None, :]) < N_CTX),\n 0,\n float(\"-inf\"),\n )\n qk += n_ctx_mask\n qk += tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, float(\"-inf\"))\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n if STAGE != 1:\n v = tl.load(V_block_ptr, boundary_check=(0, 1))\n else:\n v = tl.load(V_block_ptr)\n acc = tl.dot(p.to(tl.float16), v, acc)\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.jit\ndef _triton_attn_fwd(\n Q,\n K,\n V,\n sm_scale,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_km,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vm,\n stride_vk,\n stride_oz,\n stride_oh,\n stride_om,\n stride_ok,\n Z,\n H,\n N_CTX,\n POWER_OF_2_N_CTX: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n STAGE: tl.constexpr,\n GROUPS: tl.constexpr,\n ORDER_12: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n off_g = tl.program_id(2)\n q_offset = (\n off_z.to(tl.int64) * stride_qz\n + (off_h * GROUPS + off_g).to(tl.int64) * stride_qh\n )\n k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh\n v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh\n o_offset = (\n off_z.to(tl.int64) * stride_oz\n + (off_h * GROUPS + off_g).to(tl.int64) * stride_oh\n )\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vm, stride_vk),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_km),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_ok),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale\n qk_scale *= 1.44269504\n q = tl.load(Q_block_ptr, boundary_check=(0, 1))\n if ORDER_12:\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n 4 - STAGE,\n offs_m,\n offs_n,\n N_CTX,\n )\n if STAGE & 2:\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n 2,\n offs_m,\n offs_n,\n N_CTX,\n )\n else:\n if STAGE & 2:\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n 2,\n offs_m,\n offs_n,\n N_CTX,\n )\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n 4 - STAGE,\n offs_m,\n offs_n,\n N_CTX,\n )\n acc = acc / l_i[:, None]\n tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1))\n\ndef triton_flash_attention_forward(q, k, v, causal, sm_scale=None, ORDER_12=False):\n q_dim, k_dim, v_dim = q.dim(), k.dim(), v.dim()\n assert q_dim == 4 and q_dim == k_dim and q_dim == v_dim\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n num_heads_q = q.shape[2]\n num_heads_k = k.shape[2]\n num_heads_v = v.shape[2]\n assert num_heads_k == num_heads_v\n assert num_heads_q % num_heads_k == 0\n groups = num_heads_q // num_heads_k\n\n o = torch.empty_like(q)\n BLOCK_M = 64\n BLOCK_N = 64\n num_stages = 2\n num_warps = 8\n stage = 3 if causal else 1\n\n batch_size = q.shape[0]\n seq_len = q.shape[1]\n head_dims = q.shape[-1]\n\n sm_scale = 1.0 / Lk**0.5 if sm_scale is None else sm_scale\n\n kwargs = [\n q,\n k,\n v,\n sm_scale,\n o,\n q.stride(0),\n q.stride(-2),\n q.stride(1),\n q.stride(-1),\n k.stride(0),\n k.stride(-2),\n k.stride(1),\n k.stride(-1),\n v.stride(0),\n v.stride(-2),\n v.stride(1),\n v.stride(-1),\n o.stride(0),\n o.stride(-2),\n o.stride(1),\n o.stride(-1),\n batch_size,\n num_heads_k,\n seq_len,\n ]\n POWER_OF_2_N_CTX = triton.next_power_of_2(seq_len)\n const_kwargs = {\n \"POWER_OF_2_N_CTX\": POWER_OF_2_N_CTX,\n \"BLOCK_DMODEL\": Lk,\n \"STAGE\": stage,\n \"GROUPS\": groups,\n \"ORDER_12\": ORDER_12,\n }\n\n if get_autotune_triton_kernels():\n def grid(META):\n return (\n triton.cdiv(seq_len, META[\"BLOCK_M\"]),\n batch_size * num_heads_k,\n groups,\n )\n\n def keep(conf):\n BLOCK_M = conf.kwargs[\"BLOCK_M\"]\n BLOCK_N = conf.kwargs[\"BLOCK_N\"]\n if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:\n return False\n return True\n\n flash_attn = triton.autotune(\n configs=list(filter(keep, _get_flash_attn_autotune_configs())),\n key=[\"POWER_OF_2_N_CTX\"],\n )(_triton_attn_fwd)\n else:\n base_config = {\n \"BLOCK_M\": BLOCK_M,\n \"BLOCK_N\": BLOCK_N,\n \"num_stages\": num_stages,\n \"num_warps\": num_warps,\n }\n grid = (\n triton.cdiv(seq_len, base_config[\"BLOCK_M\"]),\n batch_size * num_heads_k,\n groups,\n )\n const_kwargs.update(base_config)\n flash_attn = _triton_attn_fwd\n flash_attn[grid](*kwargs, **const_kwargs)\n return o\n", - "description_1": "Use triton language to implement a flash attention mechanism. This mechanism includes two Triton kernels: `_attn_fwd_inner` and `_triton_attn_fwd`. The `_attn_fwd_inner` kernel performs attention accumulation operations within specified blocks of the Q, K, and V matrices, adjusting for specific stages and computing contextually appropriate transformations. The `_triton_attn_fwd` kernel uses these operations to compute attention outputs over larger tensor structures, incorporating various parameter strides, shapes, and computation stages. The main call function `triton_flash_attention_forward` facilitates setting up these computations, addressing tensor dimensions and the autotuning of kernels for performance optimization. It has 6 parameters: `q`, `k`, `v`, `causal`, `sm_scale`, `ORDER_12` representing the input tensors, computation mode, scaling factors and processing orders.", - "description_2": "Use triton language to implement an attention mechanism using custom kernels for optimized tensor computations. Define kernels to handle the accumulation and computation of tensor products in blocks, and configure a main function to execute these kernels while adjusting for performance and input constraints.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flashnn.kernel_backend import get_autotune_triton_kernels\nfrom flashnn.triton_kernels.triton_utils import compile_and_cache_kernels\n\n@triton.jit\ndef _fused_moe_kernel_a16w4_perchannel(\n A, B, C, scale_b_ptr, zero_points_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens,\n stride_am, stride_ak, stride_be, stride_bn, stride_bk, stride_cm, stride_cn, stride_scale_be, stride_scale_bn, stride_scale_bk, stride_zero_points_e, stride_zero_points_n, stride_zero_points_k,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, add_zero_points: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * 2) // 2) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = A + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = B + off_experts * stride_be + (offs_k[None, :] * stride_bk + offs_bn[:, None] * stride_bn)\n\n if add_zero_points:\n offs_zero_points = pid_n * BLOCK_SIZE_N * 2 + tl.arange(0, 2 * BLOCK_SIZE_N)\n zero_points_ptrs = zero_points_ptr + off_experts * stride_zero_points_e + offs_zero_points\n _ZERO_POINT0 = tl.zeros([1], dtype=zero_points_ptr.dtype.element_ty)\n zero_points_vals = tl.load(zero_points_ptrs, mask=offs_zero_points < 2 * N, other=_ZERO_POINT0)\n\n _A0 = tl.zeros([1, 1], dtype=A.dtype.element_ty)\n _B0 = tl.zeros([1, 1], dtype=B.dtype.element_ty)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N * 2), dtype=tl.float32)\n l_shifter = (1 - tl.arange(0, BLOCK_SIZE_N * 2) % 2) * 4\n for k in range(tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=_A0)\n b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=_B0)\n b = (b << l_shifter[:, None]).to(tl.int8).__rshift__(4)\n if add_zero_points:\n b -= zero_points_vals[:, None]\n b = tl.trans(b)\n b = b.to(a_ptrs.dtype.element_ty)\n accumulator += tl.dot(a, b, out_dtype=tl.float32)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_scale = pid_n * BLOCK_SIZE_N * 2 + tl.arange(0, BLOCK_SIZE_N * 2)\n scale_ptrs = scale_b_ptr + off_experts * stride_scale_be + offs_scale * stride_scale_bn\n _SCALE0 = tl.zeros([1], dtype=scale_b_ptr.dtype.element_ty)\n scales = tl.load(scale_ptrs, mask=offs_scale < 2 * N, other=_SCALE0)\n accumulator *= scales[None, :]\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(A.dtype.element_ty)\n\n offs_cn = pid_n * BLOCK_SIZE_N * 2 + tl.arange(0, BLOCK_SIZE_N * 2)\n c_ptrs = C + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N * 2)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef _fused_moe_kernel_a16w4_subchannel(\n A, B, C, scale_b_ptr, zero_points_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens,\n stride_am, stride_ak, stride_be, stride_bn, stride_bk, stride_cm, stride_cn, stride_scale_be, stride_scale_bn, stride_scale_bk, stride_zero_points_e, stride_zero_points_n, stride_zero_points_k,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, add_zero_points: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * 2) // 2) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = A + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = B + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n if add_zero_points:\n offs_zp_n = (pid_n * BLOCK_SIZE_N * 2 + tl.arange(0, 2 * BLOCK_SIZE_N)) % (2 * N)\n _ZERO_POINT0 = tl.zeros([1], dtype=zero_points_ptr.dtype.element_ty)\n\n _A0 = tl.zeros([1, 1], dtype=A.dtype.element_ty)\n _B0 = tl.zeros([1, 1], dtype=B.dtype.element_ty)\n _SCALE0 = tl.zeros([1], dtype=scale_b_ptr.dtype.element_ty)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N * 2), dtype=tl.float32)\n l_shifter = (1 - tl.arange(0, BLOCK_SIZE_N * 2) % 2) * 4\n for k in range(tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=_A0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=_B0)\n b = (b << l_shifter[None, :]).to(tl.int8).__rshift__(4)\n if add_zero_points:\n zp_ptrs = zero_points_ptr + off_experts * stride_zero_points_e + offs_zp_n * stride_zero_points_n + k\n zero_points_vals = tl.load(zp_ptrs)\n b = b - zero_points_vals[None, :]\n offs_scale_n = pid_n * BLOCK_SIZE_N * 2 + tl.arange(0, 2 * BLOCK_SIZE_N)\n scale_b_ptrs = scale_b_ptr + off_experts * stride_scale_be + offs_scale_n * stride_scale_bn + k\n scales_val = tl.load(scale_b_ptrs, mask=offs_scale_n < 2 * N, other=_SCALE0)\n b = b * scales_val[None, :]\n accumulator += tl.dot(a, b, out_dtype=tl.float32)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(A.dtype.element_ty)\n\n offs_cn = pid_n * BLOCK_SIZE_N * 2 + tl.arange(0, BLOCK_SIZE_N * 2)\n c_ptrs = C + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N * 2)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef fused_moe_a16w4_forward(\n A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, scale_b: torch.Tensor, zero_points: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int, BM: int\n):\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n assert B.shape[1] % 16 == 0 and B.shape[2] % 16 == 0\n N, K, EM, num_valid_tokens = B.shape[1], B.shape[2], sorted_token_ids.shape[0], topk_ids.numel()\n\n add_zero_points = True if zero_points is not None else False\n is_perchannel = scale_b.dim() == 2\n\n kwargs = [\n A, B, C, scale_b, zero_points, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded,\n N, K, EM, num_valid_tokens,\n A.stride(0), A.stride(1), B.stride(0), B.stride(1), B.stride(2), C.stride(1), C.stride(2), scale_b.stride(0), scale_b.stride(1), scale_b.stride(-1)\n ]\n kwargs += [1, 1, 1] if not add_zero_points else [zero_points.stride(0), zero_points.stride(1), zero_points.stride(-1)]\n\n const_kwargs = {\"MUL_ROUTED_WEIGHT\": mul_routed_weight, \"top_k\": top_k}\n const_kwargs.update({\"add_zero_points\": add_zero_points})\n if not is_perchannel:\n k_per_scale = B.shape[-1] // scale_b.shape[-1]\n const_kwargs.update({\"BLOCK_SIZE_K\": k_per_scale})\n\n method_name = \"fuse_moe_a16w4_\" + \"_\".join(str(value) for value in const_kwargs.values())\n method_name += \"_\"\n method_name += \"_\".join(str(value) for value in [BM, N, K, triton.next_power_of_2(EM)])\n method_name += \"_perchannel\" if is_perchannel else \"_subchannel\"\n\n moe_kernel = _fused_moe_kernel_a16w4_perchannel if is_perchannel else _fused_moe_kernel_a16w4_subchannel\n\n if get_autotune_triton_kernels():\n def grid(META):\n return triton.cdiv(sorted_token_ids.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(B.shape[1], META[\"BLOCK_SIZE_N\"]), 1, 1\n\n moe_kernel = triton.autotune(\n configs=_get_a16w4_configs(BM, is_perchannel=is_perchannel),\n key=[\"N\", \"K\", \"EM\"],\n )(moe_kernel)\n else:\n BN, BK, GM, stages, num_warps = 32, 64, 8, 2, 4\n base_config = {\"BLOCK_SIZE_M\": BM, \"BLOCK_SIZE_N\": BN, \"GROUP_SIZE_M\": GM, \"num_stages\": stages, \"num_warps\": num_warps}\n if is_perchannel:\n base_config.update({\"BLOCK_SIZE_K\": BK})\n grid = triton.cdiv(sorted_token_ids.shape[0], base_config[\"BLOCK_SIZE_M\"]) * triton.cdiv(B.shape[1], base_config[\"BLOCK_SIZE_N\"]), 1, 1\n const_kwargs.update(base_config)\n\n compile_and_cache_kernels(\n moe_kernel,\n method_name,\n grid,\n kwargs,\n const_kwargs=const_kwargs,\n )\n", - "description_1": "Use triton language to implement a kernel function for fused Mixture of Experts (MoE) with A16W4 using token and expert matrices. This involves two kernels, _fused_moe_kernel_a16w4_perchannel and _fused_moe_kernel_a16w4_subchannel, each decorated with @triton.jit. The kernels process the input tensors A, B, and C along with additional parameters like scale_b_ptr, zero_points_ptr, and topk_weights_ptr, among others. The kernels perform operations such as loading tensor blocks, performing matrix multiplications, and applying scaling. The function fused_moe_a16w4_forward acts as the main wrapper, preparing the necessary arguments and invoking the appropriate Triton kernel.", - "description_2": "Use triton language to create kernels that handle fused computation for Mixture of Experts with A16W4. These kernels process input tensors through block-wise operations and specialized matrix multiplications.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fused_moe_a8w8_kernel(\n A,\n B,\n C,\n alpha_row_ptr,\n alpha_col_ptr,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n N,\n K,\n EM,\n num_valid_tokens,\n stride_am,\n stride_ak,\n stride_be,\n stride_bn,\n stride_bk,\n stride_cm,\n stride_cn,\n stride_scale_be,\n stride_scale_bn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = A + (\n offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak\n )\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = (\n B\n + off_experts * stride_be\n + (offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk)\n )\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)\n _A0 = tl.zeros([1, 1], dtype=a_ptrs.dtype.element_ty)\n _B0 = tl.zeros([1, 1], dtype=b_ptrs.dtype.element_ty)\n lo = 0\n hi = tl.cdiv(K, BLOCK_SIZE_K)\n for k in range(lo, hi - 1):\n a = tl.load(\n a_ptrs,\n mask=token_mask[:, None],\n other=_A0,\n )\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n for k in range(hi - 1, hi):\n a = tl.load(\n a_ptrs,\n mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=_A0,\n )\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=_B0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n alpha_row_ptrs = alpha_row_ptr + offs_token // top_k\n alpha_col_ptrs = alpha_col_ptr + off_experts * stride_scale_be + offs_cn\n _ALPHA0 = tl.zeros([1], dtype=alpha_row_ptr.dtype.element_ty)\n alpha_row = tl.load(alpha_row_ptrs, mask=token_mask, other=_ALPHA0).to(tl.float32)\n alpha_col = tl.load(alpha_col_ptrs, mask=offs_cn < N, other=_ALPHA0).to(tl.float32)\n accumulator = accumulator * alpha_row[:, None]\n accumulator = accumulator * alpha_col[None, :]\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(tl.float16)\n c_ptrs = C + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef fused_moe_a8w8_forward(\n A: torch.Tensor,\n B: torch.Tensor,\n C: torch.Tensor,\n alpha_row_ptr: torch.Tensor,\n alpha_col_ptr: torch.Tensor,\n topk_weights: torch.Tensor,\n topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool,\n top_k: int,\n BM: int,\n):\n N, K, EM, num_valid_tokens = (\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n )\n kwargs = [\n A,\n B,\n C,\n alpha_row_ptr,\n alpha_col_ptr,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n N,\n K,\n EM,\n num_valid_tokens,\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(1),\n B.stride(2),\n C.stride(1),\n C.stride(2),\n alpha_col_ptr.stride(0),\n alpha_col_ptr.stride(1),\n ]\n\n const_kwargs = {\n \"MUL_ROUTED_WEIGHT\": mul_routed_weight,\n \"top_k\": top_k,\n }\n\n method_name = \"fuse_moe_a8w8_\" + \"_\".join(\n str(value) for value in const_kwargs.values()\n )\n method_name += \"_\"\n method_name += \"_\".join(str(value) for value in [BM, N, K, triton.next_power_of_2(EM)])\n moe_kernel = _fused_moe_a8w8_kernel\n\n base_config = {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n \"num_stages\": 2,\n \"num_warps\": 4,\n }\n grid = (\n triton.cdiv(sorted_token_ids.shape[0], base_config[\"BLOCK_SIZE_M\"])\n * triton.cdiv(B.shape[1], base_config[\"BLOCK_SIZE_N\"]),\n 1,\n 1,\n )\n const_kwargs.update(base_config)\n\n compile_and_cache_kernels(\n moe_kernel,\n method_name,\n grid,\n kwargs,\n const_kwargs=const_kwargs,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MOE) kernel. The kernel takes 24 parameters: pointers to matrices A, B, C, alpha_row_ptr, alpha_col_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, and integers N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bn, stride_bk, stride_cm, stride_cn, stride_scale_be, stride_scale_bn, and constexpr values BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, MUL_ROUTED_WEIGHT, top_k. The kernel computes a block of the C matrix by iterating over blocks of A and B, applying masks, and accumulating results. It also applies scaling factors and optional routing weights before storing the result.", - "description_2": "Use triton language to implement a forward function for the fused MOE kernel. The function takes 13 parameters: tensors A, B, C, alpha_row_ptr, alpha_col_ptr, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, a boolean mul_routed_weight, an integer top_k, and an integer BM. It prepares the necessary arguments and configurations for the kernel, including calculating dimensions and strides, and then compiles and caches the kernel for execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fused_moe_kernel(\n A,\n B,\n C,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n N,\n K,\n EM,\n num_valid_tokens,\n stride_am,\n stride_ak,\n stride_be,\n stride_bn,\n stride_bk,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = A + (\n offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak\n )\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = (\n B\n + off_experts * stride_be\n + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n )\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n _A0 = tl.zeros([1, 1], dtype=a_ptrs.dtype.element_ty)\n _B0 = tl.zeros([1, 1], dtype=b_ptrs.dtype.element_ty)\n for k in range(tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(\n a_ptrs,\n mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=_A0,\n )\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=_B0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(A.dtype.element_ty)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = C + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef fused_moe_fp16_forward(\n A: torch.Tensor,\n B: torch.Tensor,\n C: torch.Tensor,\n topk_weights: torch.Tensor,\n topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool,\n top_k: int,\n BM: int,\n):\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n assert B.shape[1] % 16 == 0 and B.shape[2] % 16 == 0\n\n N, K, EM, num_valid_tokens = (\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n )\n kwargs = [\n A,\n B,\n C,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n N,\n K,\n EM,\n num_valid_tokens,\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(1),\n B.stride(2),\n C.stride(1),\n C.stride(2),\n ]\n\n const_kwargs = {\n \"MUL_ROUTED_WEIGHT\": mul_routed_weight,\n \"top_k\": top_k,\n }\n\n method_name = \"fuse_moe_a16w16_\" + \"_\".join(\n str(value) for value in const_kwargs.values()\n )\n method_name += \"_\"\n method_name += \"_\".join(str(value) for value in [BM, N, K, triton.next_power_of_2(EM)])\n moe_kernel = _fused_moe_kernel\n\n grid = (\n triton.cdiv(sorted_token_ids.shape[0], BM)\n * triton.cdiv(B.shape[1], 32),\n 1,\n 1,\n )\n const_kwargs.update({\n \"BLOCK_SIZE_M\": BM,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n \"num_stages\": 2,\n \"num_warps\": 4,\n })\n\n compile_and_cache_kernels(\n moe_kernel,\n method_name,\n grid,\n kwargs,\n const_kwargs=const_kwargs,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MOE) kernel. The kernel takes pointers to matrices A, B, C, and additional parameters for matrix dimensions and strides. It computes a block of the C matrix by iterating over the K dimension and accumulating results. The kernel supports optional multiplication by routed weights and writes back the computed block to the output matrix C.", - "description_2": "Use triton language to implement a fused MOE forward function. This function prepares the necessary parameters and configurations for the MOE kernel, including grid size and constant kernel arguments. It then compiles and caches the kernel for execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _triton_gemm_a16w4_per_channel_kernel(\n A, B, C, scale_b, bias, zero_points, M, N, K,\n rescale_m, rescale_n, rescale_k, stride_am, stride_ak,\n stride_bn, stride_bk, stride_cm, stride_cn, stride_zpk,\n stride_zpn, stride_scalek, stride_scalen, add_bias: tl.constexpr,\n add_zero_points: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rbn[:, None] * stride_bn + rk[None, :] * stride_bk)\n acc_l = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n acc_h = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n _A0 = tl.zeros((1, 1), dtype=A.dtype.element_ty)\n _B0 = tl.zeros((1, 1), dtype=B.dtype.element_ty)\n if add_zero_points:\n offs_zero_points = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n zero_points_ptrs = zero_points + offs_zero_points\n _ZERO_POINT0 = tl.zeros([1], dtype=zero_points.dtype.element_ty)\n zero_points_vals = tl.load(\n zero_points_ptrs, mask=offs_zero_points < 2 * N, other=_ZERO_POINT0\n )\n zero_points_vals = tl.reshape(zero_points_vals, (BLOCK_N, 2))\n (zp_l, zp_h) = tl.split(zero_points_vals)\n offs_scale = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n scale_ptrs = scale_b + offs_scale\n _SCALE0 = tl.zeros([1], dtype=scale_b.dtype.element_ty)\n scales = tl.load(scale_ptrs, mask=offs_scale < 2 * N, other=_SCALE0)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n b_int4_two = tl.load(B, mask=rk[None, :] < k_remaining, other=_B0)\n b_int4_l = (\n b_int4_two.__lshift__(4).to(tl.int8).__rshift__(4).to(A.dtype.element_ty)\n )\n b_int4_h = b_int4_two.__rshift__(4).to(A.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_A0)\n a = tl.trans(a)\n if add_zero_points:\n b_int4_l -= zp_l[:, None]\n b_int4_h -= zp_h[:, None]\n acc_l += tl.dot(b_int4_l, a, out_dtype=tl.float32, allow_tf32=True)\n acc_h += tl.dot(b_int4_h, a, out_dtype=tl.float32, allow_tf32=True)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc_l = tl.trans(acc_l)\n acc_h = tl.trans(acc_h)\n acc = tl.interleave(acc_l, acc_h)\n offs_scale = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n scale_ptrs = scale_b + offs_scale\n _SCALE0 = tl.zeros([1], dtype=scale_b.dtype.element_ty)\n scales = tl.load(scale_ptrs, mask=offs_scale < 2 * N, other=_SCALE0)\n acc *= scales[None, :]\n acc = acc.to(C.dtype.element_ty)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n mask = (rm < M)[:, None] & (rn < 2 * N)[None, :]\n if add_bias:\n offs_bias = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n bias_ptrs = bias + offs_bias\n _BIAS0 = tl.zeros([1], dtype=bias.dtype.element_ty)\n bias_vals = tl.load(bias_ptrs, mask=offs_bias < 2 * N, other=_BIAS0)\n if pid_z == 0:\n acc += bias_vals[None, :]\n if SPLIT_K == 1:\n tl.store(C + rm[:, None] * stride_cm + rn[None, :], acc, mask=mask)\n else:\n tl.atomic_add(C + rm[:, None] * stride_cm + rn[None, :], acc, mask=mask)\n\n@triton.jit\ndef _triton_gemm_a16w4_sub_channel_kernel(\n A, B, C, scale_b, bias, zero_points, M, N, K,\n rescale_m, rescale_n, rescale_k, stride_am, stride_ak,\n stride_bn, stride_bk, stride_cm, stride_cn, stride_zpk,\n stride_zpn, stride_scalek, stride_scalen, add_bias: tl.constexpr,\n add_zero_points: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rbn[:, None] * stride_bn + rk[None, :] * stride_bk)\n acc_l = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n acc_h = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n _A0 = tl.zeros((1, 1), dtype=A.dtype.element_ty)\n _B0 = tl.zeros((1, 1), dtype=B.dtype.element_ty)\n if add_zero_points:\n zero_points_offs = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n _ZERO_POINT0 = tl.zeros([1], dtype=zero_points.dtype.element_ty)\n scale_offs = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n _SCALE0 = tl.zeros([1], dtype=scale_b.dtype.element_ty)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n b_int4_two = tl.load(B, mask=rk[None, :] < k_remaining, other=_B0)\n b_int4_l = b_int4_two.__lshift__(4).to(tl.int8).__rshift__(4)\n b_int4_h = b_int4_two.__rshift__(4)\n if add_zero_points:\n zero_points_ptrs = (\n zero_points\n + k * SPLIT_K * stride_zpk\n + pid_z * stride_zpk\n + zero_points_offs\n )\n zero_points_vals = tl.load(\n zero_points_ptrs, mask=zero_points_offs < 2 * N, other=_ZERO_POINT0\n )\n zero_points_vals = tl.reshape(zero_points_vals, (BLOCK_N, 2))\n (zp_l, zp_h) = tl.split(zero_points_vals)\n b_int4_l -= zp_l[:, None]\n b_int4_h -= zp_h[:, None]\n scales_val = tl.load(\n scale_b + k * SPLIT_K * stride_scalek + pid_z * stride_scalek + scale_offs,\n mask=scale_offs < 2 * N,\n other=_SCALE0,\n )\n scales_val = tl.reshape(scales_val, (BLOCK_N, 2))\n (scale_l, scale_h) = tl.split(scales_val)\n b_int4_l = b_int4_l * scale_l[:, None]\n b_int4_h = b_int4_h * scale_h[:, None]\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_A0)\n a = tl.trans(a)\n acc_l += tl.dot(b_int4_l, a, out_dtype=tl.float32, allow_tf32=True)\n acc_h += tl.dot(b_int4_h, a, out_dtype=tl.float32, allow_tf32=True)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc_l = tl.trans(acc_l)\n acc_h = tl.trans(acc_h)\n acc = tl.interleave(acc_l, acc_h)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n mask = (rm < M)[:, None] & (rn < 2 * N)[None, :]\n if add_bias:\n offs_bias = pid_n * BLOCK_N * 2 + tl.arange(0, 2 * BLOCK_N)\n bias_ptrs = bias + offs_bias\n _BIAS0 = tl.zeros([1], dtype=bias.dtype.element_ty)\n bias_vals = tl.load(bias_ptrs, mask=offs_bias < 2 * N, other=_BIAS0)\n if pid_z == 0:\n acc += bias_vals[None, :]\n if SPLIT_K == 1:\n tl.store(C + rm[:, None] * stride_cm + rn[None, :], acc, mask=mask)\n else:\n tl.atomic_add(C + rm[:, None] * stride_cm + rn[None, :], acc, mask=mask)\n\ndef triton_gemm_a16w4_forward(out, act, quant_w, scale_w, bias=None, zero_points=None):\n assert quant_w.dtype == torch.int8, \"Weight must be int8 type\"\n assert act.is_contiguous(), \"Activation must be contiguous\"\n assert quant_w.is_contiguous(), \"Weight must be contiguous\"\n assert act.shape[1] == quant_w.shape[1], \"Matrix B must be transposed\"\n\n scale_w = scale_w.squeeze()\n\n M, K = act.shape\n N, K = quant_w.shape\n\n add_bias = True if bias is not None else False\n add_zero_points = True if zero_points is not None else False\n is_perchannel = scale_w.dim() == 1\n\n rescale_m = M // 16\n rescale_n = N // 512\n rescale_k = K // 512\n\n def grid(META):\n return (\n triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n META[\"SPLIT_K\"],\n )\n\n kwargs = {\n \"A\": act,\n \"B\": quant_w,\n \"C\": out,\n \"scale_b\": scale_w,\n \"bias\": bias,\n \"zero_points\": zero_points,\n \"M\": M,\n \"N\": N,\n \"K\": K,\n \"rescale_m\": rescale_m,\n \"rescale_n\": rescale_n,\n \"rescale_k\": rescale_k,\n \"stride_am\": act.stride(0),\n \"stride_ak\": act.stride(1),\n \"stride_bn\": quant_w.stride(0),\n \"stride_bk\": quant_w.stride(1),\n \"stride_cm\": out.stride(0),\n \"stride_cn\": out.stride(1),\n \"stride_zpk\": zero_points.stride(0) if add_zero_points else 0,\n \"stride_zpn\": zero_points.stride(1)\n if add_zero_points and not is_perchannel\n else 0,\n \"stride_scalek\": 0 if is_perchannel else scale_w.stride(0),\n \"stride_scalen\": 0 if is_perchannel else scale_w.stride(1),\n \"add_bias\": add_bias,\n \"add_zero_points\": add_zero_points,\n }\n if scale_w.dim() == 1:\n triton_gemm_a16w4_per_channel = triton.autotune(\n configs=_get_autotune_configs(is_perchannel),\n key=[\"M\", \"N\", \"K\"],\n )(_triton_gemm_a16w4_per_channel_kernel)\n triton_gemm_a16w4_per_channel[grid](**kwargs)\n else:\n k_per_scale = int(act.shape[1] / scale_w.shape[0])\n assert k_per_scale > 0, \"k_per_scale should greater than 0\"\n triton_gemm_a16w4_sub_channel = triton.autotune(\n configs=_get_autotune_configs(is_perchannel),\n key=[\"M\", \"N\", \"K\"],\n )(_triton_gemm_a16w4_sub_channel_kernel)\n triton_gemm_a16w4_sub_channel[grid](BLOCK_K=k_per_scale, **kwargs)\n\n return out\n", - "description_1": "Use triton language to define and invoke kernels for matrix multiplication with per-channel and sub-channel quantization, supporting bias and zero-point adjustments.", - "description_2": "Use triton language to create kernels for GEMM operations with quantized weights.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _triton_gemm_a16w8_per_channel_kernel(\n A, B, C, scale_b, bias, zero_points, M, N, K,\n stride_am, stride_ak, stride_bn, stride_bk,\n stride_cm, stride_cn, stride_zpk, stride_zpn,\n stride_scalek, stride_scalen, add_bias: tl.constexpr,\n add_zero_points: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr):\n pid = tl.program_id(0)\n # for split k\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rbn[:, None] * stride_bn + rk[None, :] * stride_bk)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n if add_zero_points:\n offs_zero_points = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n zero_points_ptrs = zero_points + offs_zero_points\n _ZERO_POINT0 = tl.zeros([1], dtype=zero_points.dtype.element_ty)\n zero_points_vals = tl.load(\n zero_points_ptrs, mask=offs_zero_points < N, other=_ZERO_POINT0\n )\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _A0 = tl.zeros((1, 1), dtype=A.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_A0)\n _B0 = tl.zeros((1, 1), dtype=B.dtype.element_ty)\n b = tl.load(B, mask=rk[None, :] < k_remaining, other=_B0)\n\n if add_zero_points:\n b = b - zero_points_vals[:, None]\n\n b_fp = b.to(A.dtype.element_ty)\n b_fp = tl.trans(b_fp)\n acc += tl.dot(a, b_fp, out_dtype=tl.float32, allow_tf32=True)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n offs_scale = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n scale_ptrs = scale_b + offs_scale\n _SCALE0 = tl.zeros([1], dtype=scale_b.dtype.element_ty)\n scales = tl.load(scale_ptrs, mask=offs_scale < N, other=_SCALE0)\n acc *= scales\n acc = acc.to(C.dtype.element_ty)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if add_bias:\n offs_bias = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n bias_ptrs = bias + offs_bias\n _BIAS0 = tl.zeros([1], dtype=bias.dtype.element_ty)\n bias_vals = tl.load(bias_ptrs, mask=offs_bias < N, other=_BIAS0)\n if pid_z == 0:\n acc += bias_vals[None, :]\n # Handles write-back with reduction-splitting.\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\n@triton.jit\ndef _triton_gemm_a16w8_sub_channel_kernel(\n A, B, C, scale_b, bias, zero_points, M, N, K,\n stride_am, stride_ak, stride_bn, stride_bk,\n stride_cm, stride_cn, stride_zpk, stride_zpn,\n stride_scalek, stride_scalen, add_bias: tl.constexpr,\n add_zero_points: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr):\n pid = tl.program_id(0)\n # for split k\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rbn[:, None] * stride_bn + rk[None, :] * stride_bk)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n scale_w_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n _SCALE0 = tl.zeros([1], dtype=scale_b.dtype.element_ty)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _A0 = tl.zeros((1, 1), dtype=A.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_A0)\n _B0 = tl.zeros((1, 1), dtype=B.dtype.element_ty)\n b = tl.load(B, mask=rk[None, :] < k_remaining, other=_B0)\n if add_zero_points:\n _ZERO_POINT0 = tl.zeros([1], dtype=zero_points.dtype.element_ty)\n zero_points_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n zero_points_ptrs = (\n zero_points + (k * SPLIT_K + pid_z) * stride_zpk + zero_points_offs\n )\n zero_points_vals = tl.load(\n zero_points_ptrs, mask=zero_points_offs < N, other=_ZERO_POINT0\n )\n b = b - zero_points_vals[:, None]\n scale_ptrs = (\n scale_b + k * SPLIT_K * stride_scalek + pid_z * stride_scalek + scale_w_offs\n )\n scales = tl.load(scale_ptrs, mask=scale_w_offs < N, other=_SCALE0)\n b_fp = b * scales[:, None]\n b_fp = tl.trans(b_fp)\n acc += tl.dot(a, b_fp, out_dtype=tl.float32, allow_tf32=True)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if add_bias:\n offs_bias = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n bias_ptrs = bias + offs_bias\n _BIAS0 = tl.zeros([1], dtype=bias.dtype.element_ty)\n bias_vals = tl.load(bias_ptrs, mask=offs_bias < N, other=_BIAS0)\n if pid_z == 0:\n acc += bias_vals[None, :]\n # Handles write-back with reduction-splitting.\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\ndef triton_gemm_a16w8_forward(out, act, quant_w, scale_w, bias=None, zero_points=None):\n assert quant_w.dtype == torch.int8, \"Weight must be int8 type\"\n assert act.is_contiguous(), \"Activation must be contiguous\"\n assert quant_w.is_contiguous(), \"Weight must be contiguous\"\n assert act.shape[1] == quant_w.shape[1], \"Matrix B must be transposed\"\n\n scale_w = scale_w.squeeze()\n\n M, K = act.shape\n N, K = quant_w.shape\n\n add_bias = True if bias is not None else False\n add_zero_points = True if zero_points is not None else False\n is_perchannel = scale_w.dim() == 1\n\n def grid(META):\n return (\n triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n META[\"SPLIT_K\"],\n )\n\n kwargs = {\n \"A\": act,\n \"B\": quant_w,\n \"C\": out,\n \"scale_b\": scale_w,\n \"bias\": bias,\n \"zero_points\": zero_points,\n \"M\": M,\n \"N\": N,\n \"K\": K,\n \"stride_am\": act.stride(0),\n \"stride_ak\": act.stride(1),\n \"stride_bn\": quant_w.stride(0),\n \"stride_bk\": quant_w.stride(1),\n \"stride_cm\": out.stride(0),\n \"stride_cn\": out.stride(1),\n \"stride_zpk\": zero_points.stride(0) if add_zero_points else 0,\n \"stride_zpn\": zero_points.stride(1)\n if add_zero_points and not is_perchannel\n else 0,\n \"stride_scalek\": 0 if is_perchannel else scale_w.stride(0),\n \"stride_scalen\": 0 if is_perchannel else scale_w.stride(1),\n \"add_bias\": add_bias,\n \"add_zero_points\": add_zero_points,\n }\n # per channel a16w8\n if scale_w.dim() == 1:\n triton_gemm_a16w8_per_channel = triton.autotune(\n configs=_get_autotune_configs(is_perchannel=True),\n key=[\"M\", \"N\", \"K\"],\n )(_triton_gemm_a16w8_per_channel_kernel)\n triton_gemm_a16w8_per_channel[grid](**kwargs)\n # sub channel a16w8\n else:\n k_per_scale = int(act.shape[1] / scale_w.shape[0])\n assert k_per_scale > 0, \"k_per_scale should greater than 0\"\n triton_gemm_a16w8_sub_channel = triton.autotune(\n configs=_get_autotune_configs(is_perchannel=False),\n key=[\"M\", \"N\", \"K\"],\n )(_triton_gemm_a16w8_sub_channel_kernel)\n triton_gemm_a16w8_sub_channel[grid](BLOCK_K=k_per_scale, **kwargs)\n\n return out\n", - "description_1": "Use triton language to create a GEMM kernel with 21 parameters for the per-channel case and 21 for the sub-channel case. Implement a function 'triton_gemm_a16w8_forward' with 6 parameters, using the kernels for matrix multiplication with or without bias and zero points.", - "description_2": "Use triton language to create two GEMM kernels and a forward function for matrix multiplication.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(configs=_get_autotune_configs(), key=[\"M\", \"N\", \"K\"])\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K']) == 0,\n})\n@triton.jit\ndef _triton_gemm_a8w8_kernel(\n A, B, C, alpha_row_ptr, alpha_col_ptr, M, N, K,\n stride_am, stride_ak, stride_bn, stride_bk, stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, EVEN_K: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul\n out <- ((int8)A[m, k] * (int8)B[n, k]) *\n ((fp16)scale_row[m, 1] * (fp16)scale_col[1, n])\n A has shape (M, K), B has shape (N, K) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n a_ptrs = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n b_ptrs = B + (rbn[None, :] * stride_bn + rk[:, None] * stride_bk)\n\n acc_type = tl.int32 if A.dtype.element_ty == tl.int8 else tl.float32\n accumulator = tl.zeros([BLOCK_M, BLOCK_N], dtype=acc_type)\n loop_k = tl.cdiv(K, BLOCK_K)\n if not EVEN_K:\n loop_k -= 1\n\n for _ in range(0, loop_k):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_K * stride_ak\n b_ptrs += BLOCK_K * stride_bk\n\n if not EVEN_K:\n k = loop_k\n offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K)\n a_ptrs = A + (ram[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = B + (rbn[None, :] * stride_bn + offs_k[:, None] * stride_bk)\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K, other=0.)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0.)\n accumulator += tl.dot(a, b)\n\n offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n alpha_row_ptrs = alpha_row_ptr + offs_cm\n alpha_col_ptrs = alpha_col_ptr + offs_cn\n alpha_row = tl.load(alpha_row_ptrs, mask=offs_cm < M, other=0.).to(tl.float32)\n alpha_col = tl.load(alpha_col_ptrs, mask=offs_cn < N, other=0.).to(tl.float32)\n accumulator = accumulator * alpha_row[:, None]\n accumulator = accumulator * alpha_col[None, :]\n c = accumulator.to(C.dtype.element_ty)\n\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c_ptrs = C + stride_cm * offs_cm[:, None] + offs_cn[None, :]\n tl.store(c_ptrs, c, mask=c_mask)\n\ndef triton_gemm_a8w8_forward(out, a, b, alpha_row, alpha_col):\n assert (\n a.dtype == torch.int8 and b.dtype == torch.int8\n ), \"Matrix A/B must be int8 type\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert (\n out.dtype == torch.float16 or out.dtype == torch.bfloat16\n ), \"Output type must be float16 or bfloat16\"\n assert (\n out.dtype == alpha_row.dtype and out.dtype == alpha_col.dtype\n ), \"Output type must match scale type\"\n assert a.shape[1] == b.shape[1], \"Matrix B must be transposed\"\n M, K = a.shape\n N, K = b.shape\n\n method_name = \"gemm_a8w8_\" + str(M) + \"_\" + str(N) + \"_\" + str(K)\n kwargs = [\n a,\n b,\n out,\n torch.squeeze(alpha_row),\n torch.squeeze(alpha_col),\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n out.stride(0),\n out.stride(1),\n ]\n\n def grid(META):\n return (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]), 1, 1)\n\n _triton_gemm_a8w8_kernel[grid](*kwargs)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (_triton_gemm_a8w8_kernel) with 19 parameters: A, B, C (pointers to matrices), alpha_row_ptr, alpha_col_ptr (pointers to scaling factors), M, N, K (matrix dimensions), stride_am, stride_ak, stride_bn, stride_bk, stride_cm, stride_cn (strides for matrices), BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M, EVEN_K (meta-parameters). The kernel computes the product of int8 matrices A and B, scales the result with fp16 scaling factors, and stores the result in matrix C. The function triton_gemm_a8w8_forward calls this kernel with 5 parameters: out, a, b, alpha_row, alpha_col, ensuring input matrices are of correct types and dimensions.", - "description_2": "Use triton language to create a matrix multiplication kernel for int8 matrices with scaling, and a function to call this kernel ensuring input constraints.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weight\n B, # pointer to the bias\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.0)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n tl.store(Y + cols, y, mask=mask)\n\n\ndef triton_layer_norm_forward(x, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n # construct mean and rstd\n mean = torch.empty(M, dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty(M, dtype=torch.float32, device=\"cuda\")\n # launch kernel\n method_name = \"layer_norm_\" + str(N)\n kwargs = [x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, eps]\n layer_norm = triton.autotune(configs=_get_autotune_configs(), key=[\"N\"])(\n _layer_norm_kernel\n )\n grid = (M, 1, 1)\n layer_norm[(M,)](*kwargs)\n return y\n", - "description_1": "Use triton language to implement a layer normalization kernel. The kernel function '_layer_norm_kernel' takes 10 parameters: X (input pointer), Y (output pointer), W (weight pointer), B (bias pointer), Mean (mean pointer), Rstd (1/std pointer), stride (row stride), N (number of columns), eps (epsilon for numerical stability), and BLOCK_SIZE (block size for computation). The function computes the mean and variance of the input, normalizes it, and applies the weight and bias. The 'triton_layer_norm_forward' function prepares the input, output, and auxiliary data, and launches the kernel with appropriate configurations.", - "description_2": "Use triton language to create a layer normalization operation with a kernel that computes mean and variance, normalizes input, and applies weight and bias. Implement a forward function to set up and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _layer_norm_dquant_kernel(\n X, # pointer to the input\n Y, # pointer to the normed output\n W, # pointer to the weight\n B, # pointer to the bias\n out, # pointer to the output\n scale, # pointer to the scale\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n out += row * stride\n\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.0)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n _max_x = 0.0\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n _norm = (x - mean) * rstd * w + b\n tl.store(out + cols, _norm, mask=mask)\n _max_x = tl.maximum(_max_x, tl.max(tl.abs(_norm), axis=0))\n scale_x = _max_x / 127.0\n tl.store(scale + row, scale_x)\n\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n _norm = tl.load(out + cols, mask=mask, other=0.0)\n _norm = _norm / scale_x + 0.5\n tl.store(Y + cols, _norm.to(tl.int8), mask=mask)\n\n\ndef triton_layer_norm_dquant_forward(x, weight, bias, eps):\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n scale = torch.empty((M,), dtype=x.dtype, device=x.device)\n out = torch.empty_like(x)\n y = torch.empty(x.shape, dtype=torch.int8, device=x.device)\n # launch kernel\n kwargs = [x_arg, y, weight, bias, out, scale, x_arg.stride(0), N, eps]\n layer_norm_dquant = triton.autotune(configs=_get_autotune_configs(), key=[\"N\"])(\n _layer_norm_dquant_kernel\n )\n grid = (M, 1, 1)\n layer_norm_dquant[(M,)](*kwargs)\n\n return out, y, scale\n", - "description_1": "Use triton language to implement a layer normalization kernel that supports dequantization, taking pointers to input, output, weights, biases, and scale. It processes data in blocks of configurable size and applies mean and variance calculations to normalize the input data, scales the normalized data, and stores the quantized output. The kernel is launched from a Python function with reshaped inputs and necessary triton configurations.", - "description_2": "Use triton language to create a configurable block-based layer normalization and dequantization kernel, managing input/output, weights, biases, and scale processing in a Python function.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Triton kernel for processing logits with penalties\n@triton.jit\ndef _triton_logits_processor_kernel(\n scores, # [num_tokens, vocab_size]\n penalty, # [num_tokens]\n input_ids_ptr, # [num_tokens]\n input_ids_length, # [num_tokens]\n num_tokens: tl.constexpr,\n vocab_size: tl.constexpr,\n max_ids_length: tl.constexpr,\n power_2_of_vocab_size: tl.constexpr,\n power_2_of_max_ids_length: tl.constexpr,\n penalty_ty: tl.constexpr,\n):\n token_id = tl.program_id(0)\n penalty_val = tl.load(penalty + token_id)\n if tl.abs(penalty_val - 1.0) > 1e-9:\n input_ids_address = tl.load(input_ids_ptr + token_id).to(\n tl.pointer_type(tl.int64)\n )\n current_input_ids_length = tl.load(input_ids_length + token_id)\n ids_offs = tl.arange(0, power_2_of_max_ids_length)\n ids = tl.load(\n input_ids_address + ids_offs,\n mask=ids_offs < current_input_ids_length,\n other=vocab_size,\n )\n ori_scores = tl.load(\n scores + token_id * vocab_size + ids[None, :],\n mask=ids[None, :] < vocab_size,\n other=0.0,\n )\n tl.debug_barrier()\n if penalty_ty == \"REPETITION\":\n new_scores = tl.where(\n ori_scores <= 0, ori_scores * penalty_val, ori_scores / penalty_val\n )\n elif penalty_ty == \"PRESENCE\":\n new_scores = ori_scores - penalty_val\n tl.store(\n scores + token_id * vocab_size + ids[None, :],\n new_scores,\n mask=ids[None, :] < vocab_size,\n )\n\n# Function to invoke the Triton kernel\ndef triton_logits_processor_forward(\n scores, penalty, input_ids_ptr, input_ids_length, max_ids_length, penalty_ty\n):\n assert penalty_ty in [\"REPETITION\", \"PRESENCE\"]\n num_tokens, vocab_size = scores.shape\n power_2_of_vocab_size = triton.next_power_of_2(vocab_size)\n power_2_of_max_ids_length = triton.next_power_of_2(max_ids_length)\n _triton_logits_processor_kernel[(num_tokens,)](\n scores,\n penalty,\n input_ids_ptr,\n input_ids_length,\n num_tokens,\n vocab_size,\n max_ids_length,\n power_2_of_vocab_size,\n power_2_of_max_ids_length,\n penalty_ty,\n num_warps=8,\n )\n", - "description_1": "Use triton language to implement a kernel that processes logits with penalties. The kernel takes 10 parameters: scores (2D tensor of shape [num_tokens, vocab_size]), penalty (1D tensor of shape [num_tokens]), input_ids_ptr (1D tensor of shape [num_tokens]), input_ids_length (1D tensor of shape [num_tokens]), num_tokens (constexpr), vocab_size (constexpr), max_ids_length (constexpr), power_2_of_vocab_size (constexpr), power_2_of_max_ids_length (constexpr), and penalty_ty (constexpr). The kernel applies a penalty to the scores based on the penalty type ('REPETITION' or 'PRESENCE'). The forward function prepares the parameters and launches the kernel.", - "description_2": "Use triton language to create a kernel for adjusting logits with penalties, and a function to set up and call this kernel. The kernel modifies scores based on penalty values and types, handling multiple tokens and vocab sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_stages=stages, num_warps=warps)\n for stages in [0, 1, 3, 4]\n for warps in [4, 8, 16]\n ],\n key=[\"QUERY_GROUP_SIZE\", \"HEAD_SIZE\", \"KV_BLOCK_SIZE\"],\n)\n@triton.jit\ndef _paged_attn_w_mma_kernel(\n m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE]\n l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE]\n out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE]\n q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE]\n k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE]\n v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE]\n context_lens_ptr, # [num_seqs]\n block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]\n attn_scale,\n stride_bt0,\n stride_bt1,\n stride_q0,\n stride_q1,\n stride_q2,\n stride_kv0,\n stride_kv1,\n stride_kv2,\n stride_kv3,\n stride_o0,\n stride_o1,\n stride_o2,\n stride_o3,\n stride_o4,\n HEAD_SIZE: tl.constexpr,\n QUERY_GROUP_SIZE: tl.constexpr,\n PADDED_QUERY_GROUP_SIZE: tl.constexpr,\n NUM_KV_HEADS: tl.constexpr,\n KV_BLOCK_SIZE: tl.constexpr,\n PARTITION_SIZE: tl.constexpr,\n):\n seq_idx = tl.program_id(0)\n kv_head_idx = tl.program_id(1)\n part_idx = tl.program_id(2)\n max_num_partitions = tl.num_programs(2)\n\n log2e: tl.constexpr = 1.4426950408889634\n\n USE_PARTITIONING = PARTITION_SIZE > 0\n context_len = tl.load(context_lens_ptr + seq_idx)\n if USE_PARTITIONING:\n context_start_idx = part_idx * PARTITION_SIZE\n if context_start_idx >= context_len:\n return\n context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len)\n num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE)\n else:\n num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE)\n\n block_offset = tl.arange(0, KV_BLOCK_SIZE)\n head_offset = tl.arange(0, HEAD_SIZE)\n padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE)\n\n kv_offset = (\n kv_head_idx * stride_kv1\n + block_offset[:, None] * stride_kv2\n + head_offset[None, :] * stride_kv3\n )\n\n q_offset = (\n seq_idx * stride_q0\n + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1\n + head_offset[None, :] * stride_q2\n )\n group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE\n q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0)\n q = (q * attn_scale).to(q_ptr.dtype.element_ty)\n\n m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32)\n acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32)\n\n num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE)\n for i in range(num_blocks):\n block_idx = num_prev_blocks + i\n block_number = tl.load(\n block_tables_ptr + seq_idx * stride_bt0 + block_idx * stride_bt1\n )\n\n kv_block_offset = block_number * stride_kv0 + kv_offset\n mask_offset = block_idx * KV_BLOCK_SIZE + block_offset\n kv_mask = mask_offset[:, None] < context_len\n\n k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0)\n\n if PADDED_QUERY_GROUP_SIZE == 1:\n qk = tl.sum(q[:, None, :] * k[None, :, :], axis=2)\n else:\n qk = tl.dot(q, k.T, out_dtype=tl.float32)\n\n qk = tl.where(mask_offset < context_len, qk, float(\"-inf\"))\n\n m_i_new = tl.maximum(m_i, tl.max(qk, axis=1))\n\n p = tl.math.exp2((qk - m_i_new[:, None]) * log2e)\n alpha = tl.math.exp2((m_i - m_i_new) * log2e)\n acc *= alpha[:, None]\n\n v = tl.load(v_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0)\n\n if PADDED_QUERY_GROUP_SIZE == 1:\n acc += tl.sum(p.T[:, :, None] * v[:, None, :], axis=0)\n else:\n p = p.to(v.dtype)\n acc += tl.dot(p, v, out_dtype=tl.float32)\n\n l_i = l_i * alpha + tl.sum(p, axis=1)\n m_i = m_i_new\n acc = acc / l_i[:, None]\n\n if USE_PARTITIONING:\n part_offset = (\n (seq_idx * NUM_KV_HEADS + kv_head_idx)\n * max_num_partitions\n * QUERY_GROUP_SIZE\n + part_idx * QUERY_GROUP_SIZE\n + padding_group_offset\n )\n mask = padding_group_offset < QUERY_GROUP_SIZE\n tl.store(m_i_ptr + part_offset, m_i, mask=mask)\n tl.store(l_i_ptr + part_offset, l_i, mask=mask)\n\n out_offset = seq_idx * stride_o0\n if USE_PARTITIONING:\n out_offset += kv_head_idx * stride_o1\n else:\n out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1\n out_offset += (\n part_idx * stride_o2\n + padding_group_offset[:, None] * stride_o3\n + head_offset[None, :] * stride_o4\n )\n\n group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE\n tl.store(out_ptr + out_offset, acc, mask=group_mask)\n\n\n@triton.autotune(\n configs=[triton.Config({\"UNROLL_FACTOR\": uf}) for uf in [1, 2, 4, 8]],\n key=[\n \"POWER_OF_2_MAX_SEQ_LEN\",\n \"QUERY_GROUP_SIZE\",\n \"USE_PARTITIONING\",\n \"BLOCK_SIZE\",\n \"HEAD_SIZE\",\n \"PARTITION_SIZE\",\n ],\n)\n@triton.jit\ndef _paged_attn_wo_mma_kernel(\n exp_sums, # [num_seqs, q_heads, max_num_partitions]\n max_logits, # [num_seqs, q_heads, max_num_partitions]\n out, # [num_seqs, q_heads, max_num_partitions, head_size]\n q, # [num_seqs, q_heads, head_size]\n k_cache, # [num_blocks, kv_heads, block_size, head_size]\n v_cache, # [num_blocks, kv_heads, block_size, head_size]\n scale,\n block_tables, # [num_seqs, max_num_blocks_per_seq]\n seq_lens, # [num_seqs]\n max_num_blocks_per_seq,\n alibi_slopes, # [q_heads]\n stride_qm,\n stride_qn,\n stride_om,\n stride_on,\n stride_ok,\n stride_km,\n stride_kn,\n stride_kk,\n stride_exp_m,\n stride_exp_n,\n BLOCK_SIZE: tl.constexpr,\n HEAD_SIZE: tl.constexpr,\n QUERY_GROUP_SIZE: tl.constexpr,\n PARTITION_SIZE: tl.constexpr,\n POWER_OF_2_MAX_SEQ_LEN: tl.constexpr,\n USE_PARTITIONING: tl.constexpr,\n UNROLL_FACTOR: tl.constexpr,\n):\n head_idx = tl.program_id(axis=0)\n kv_head_idx = head_idx // QUERY_GROUP_SIZE\n seq_idx = tl.program_id(axis=1)\n par_idx = tl.program_id(axis=2)\n seq_len = tl.load(seq_lens + seq_idx)\n\n if par_idx * PARTITION_SIZE >= seq_len:\n return\n\n num_context_blocks = tl.cdiv(seq_len, BLOCK_SIZE)\n if USE_PARTITIONING:\n num_blocks_per_par = PARTITION_SIZE // BLOCK_SIZE\n start_block_idx = par_idx * num_blocks_per_par\n end_block_idx = tl.minimum(\n start_block_idx + num_blocks_per_par, num_context_blocks\n )\n else:\n start_block_idx = 0\n end_block_idx = num_context_blocks\n\n if alibi_slopes is None:\n alibi_slope = 0.0\n else:\n alibi_slope = tl.load(alibi_slopes + head_idx)\n\n block_offs = tl.arange(0, BLOCK_SIZE)\n head_size_offs = tl.arange(0, HEAD_SIZE)\n q = tl.load(q + seq_idx * stride_qm + head_idx * stride_qn + head_size_offs)\n q = (q * scale).to(tl.float16)\n\n qkv = tl.zeros([BLOCK_SIZE, HEAD_SIZE], dtype=tl.float32)\n qk_max = float(\"-inf\")\n exp_sum = 0.0\n fp16_0 = tl.zeros([1, 1], dtype=k_cache.dtype.element_ty)\n base_offs_kv = (\n kv_head_idx * stride_kn\n + block_offs[:, None] * stride_kk\n + head_size_offs[None, :]\n )\n block_base_ptrs = block_tables + seq_idx * max_num_blocks_per_seq\n\n hi_unroll = ((end_block_idx - 1) // UNROLL_FACTOR) * UNROLL_FACTOR\n if UNROLL_FACTOR == 1:\n qkv, qk_max, exp_sum = _inner_paged_attn_unroll_0_kernel(\n q,\n k_cache,\n v_cache,\n stride_km,\n block_base_ptrs,\n base_offs_kv,\n alibi_slope,\n block_offs,\n seq_len,\n qkv,\n qk_max,\n exp_sum,\n BLOCK_SIZE,\n start_block_idx,\n hi_unroll,\n )\n elif UNROLL_FACTOR == 2:\n qkv, qk_max, exp_sum = _inner_paged_attn_unroll_2_kernel(\n q,\n k_cache,\n v_cache,\n stride_km,\n block_base_ptrs,\n base_offs_kv,\n alibi_slope,\n block_offs,\n seq_len,\n qkv,\n qk_max,\n exp_sum,\n BLOCK_SIZE,\n start_block_idx,\n hi_unroll,\n )\n elif UNROLL_FACTOR == 4:\n qkv, qk_max, exp_sum = _inner_paged_attn_unroll_4_kernel(\n q,\n k_cache,\n v_cache,\n stride_km,\n block_base_ptrs,\n base_offs_kv,\n alibi_slope,\n block_offs,\n seq_len,\n qkv,\n qk_max,\n exp_sum,\n BLOCK_SIZE,\n start_block_idx,\n hi_unroll,\n )\n elif UNROLL_FACTOR == 8:\n qkv, qk_max, exp_sum = _inner_paged_attn_unroll_8_kernel(\n q,\n k_cache,\n v_cache,\n stride_km,\n block_base_ptrs,\n base_offs_kv,\n alibi_slope,\n block_offs,\n seq_len,\n qkv,\n qk_max,\n exp_sum,\n BLOCK_SIZE,\n start_block_idx,\n hi_unroll,\n )\n tl.debug_barrier()\n for block_idx in range(hi_unroll, end_block_idx):\n physical_block_idx = tl.load(\n block_tables + seq_idx * max_num_blocks_per_seq + block_idx\n )\n mask = block_offs[:, None] < (seq_len - block_idx * BLOCK_SIZE)\n offs_kv = physical_block_idx * stride_km + base_offs_kv\n\n k = tl.load(k_cache + offs_kv, mask=mask, other=fp16_0)\n v = tl.load(v_cache + offs_kv, mask=mask, other=fp16_0)\n\n _qk = tl.sum((q[None, :] * k).to(tl.float32), axis=1)\n _qk = tl.where(\n block_offs < (seq_len - block_idx * BLOCK_SIZE), _qk, float(\"-inf\")\n )\n _qk += alibi_slope * (block_idx * BLOCK_SIZE + block_offs - seq_len + 1)\n _qk_max = tl.maximum(tl.max(_qk, axis=0), qk_max)\n\n _exp_sum = exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(\n tl.exp(_qk - _qk_max), axis=0\n )\n qkv = (\n qkv * (exp_sum * tl.exp(qk_max - _qk_max))\n + (tl.exp(_qk[:, None] - _qk_max)) * v\n )\n qkv = qkv / _exp_sum\n qk_max = _qk_max\n exp_sum = _exp_sum\n\n if USE_PARTITIONING:\n offs_exp = seq_idx * stride_exp_m + head_idx * stride_exp_n + par_idx\n tl.store(exp_sums + offs_exp, exp_sum)\n tl.store(max_logits + offs_exp, qk_max)\n\n offs_out = (\n seq_idx * stride_om\n + head_idx * stride_on\n + par_idx * stride_ok\n + head_size_offs\n )\n tl.store(out + offs_out, tl.sum(qkv, axis=0))\n", - "description_1": "Use triton language to implement two kernels, `_paged_attn_w_mma_kernel` and `_paged_attn_wo_mma_kernel`. `_paged_attn_w_mma_kernel` processes attention using a tiled approach where inputs include query, key, and value tensors along with additional parameters for dimensions, strides, and constants. It utilizes partitioning to handle large sequences and performs operations in float32 precision to avoid overflow. The second kernel, `_paged_attn_wo_mma_kernel`, is similar but without matrix multiplication accelerator (MMA) optimizations, handling sequences with an alternative approach. These kernels optimize attention computation using triton's multi-threaded parallelism.", - "description_2": "Use triton language to optimize attention computation for transformer models by implementing kernels with and without MMA, employing a tiled and partitioned approach.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _single_query_cached_kv_attention_v1(\n out, # [num_tokens, num_heads, head_size]\n q, # [num_tokens, num_heads, head_size]\n k_cache, # [num_blocks, num_heads, block_size, head_size]\n v_cache, # [num_blocks, num_heads, block_size, head_size]\n head_mapping,\n scale, # float\n block_tables, # [num_tokens, max_num_blocks_per_seq]\n seq_lens,\n max_num_blocks_per_seq,\n stride_qm,\n stride_qn,\n stride_om,\n stride_on,\n stride_km,\n stride_kn,\n stride_kk,\n SLOT_SIZE: tl.constexpr,\n HEAD_SIZE: tl.constexpr,\n):\n head_idx = tl.program_id(axis=0)\n token_idx = tl.program_id(axis=1)\n kv_head_idx = tl.load(head_mapping + head_idx)\n\n offs_q = token_idx * stride_qm + head_idx * stride_qn + tl.arange(0, HEAD_SIZE)\n q = tl.load(q + offs_q)\n q = (q * scale).to(tl.float16)\n seq_len = tl.load(seq_lens + token_idx)\n qkv = tl.zeros([SLOT_SIZE, HEAD_SIZE], dtype=tl.float32)\n m_prev = tl.zeros([1, 1], tl.float32) - float(\"inf\")\n d_prev = tl.zeros([1, 1], tl.float32)\n slot_offs = tl.arange(0, SLOT_SIZE)\n head_size_offs = tl.arange(0, HEAD_SIZE)\n block_base_ptrs = block_tables + token_idx * max_num_blocks_per_seq\n kv_base_offs = (\n kv_head_idx * stride_kn\n + slot_offs[:, None] * stride_kk\n + head_size_offs[None, :]\n )\n for i in range(0, tl.cdiv(seq_len, SLOT_SIZE)):\n block_idx = tl.load(block_base_ptrs + i)\n mask = (slot_offs[:, None] < (seq_len - i * SLOT_SIZE)) & (\n head_size_offs[None, :] < HEAD_SIZE\n )\n kv_offs = block_idx * stride_km + kv_base_offs\n k = tl.load(k_cache + kv_offs, mask=mask, other=0.0)\n v = tl.load(v_cache + kv_offs, mask=mask, other=0.0)\n x_i = tl.sum(q[None, :] * k, axis=1)[:, None]\n x_i = tl.where(\n slot_offs[:, None] < (seq_len - i * SLOT_SIZE), x_i, float(\"-inf\")\n )\n m_i = tl.maximum(m_prev, tl.max(x_i, axis=0))\n d_i = d_prev * tl.exp(m_prev - m_i) + tl.sum(tl.exp(x_i - m_i), axis=0)\n qkv = (\n qkv * (d_prev * tl.exp(m_prev - m_i) / d_i) + (tl.exp(x_i - m_i) / d_i) * v\n )\n m_prev = m_i\n d_prev = d_i\n offs_q = token_idx * stride_om + head_idx * stride_on + tl.arange(0, HEAD_SIZE)\n tl.store(out + offs_q, tl.sum(qkv, axis=0))\n\ndef triton_paged_attention_v1(\n output, # [num_tokens, num_heads, head_size]\n query, # [num_tokens, num_heads, head_size]\n key_cache, # [num_blocks, num_heads, block_size, head_size]\n value_cache, # [num_blocks, num_heads, block_size, head_size]\n head_mapping, # [num_heads]\n scale,\n block_tables, # [num_tokens, max_num_blocks_per_seq]\n context_lens, # [num_tokens]\n):\n num_heads = value_cache.shape[1]\n head_size = value_cache.shape[-1]\n block_size = value_cache.shape[-2]\n num_tokens = query.shape[0]\n\n assert (\n key_cache.is_contiguous() and value_cache.is_contiguous()\n ), \"kv cache must be contiguous\"\n grid = (num_heads, num_tokens, 1)\n _single_query_cached_kv_attention_v1[grid](\n output,\n query,\n key_cache,\n value_cache,\n head_mapping,\n scale,\n block_tables,\n context_lens,\n block_tables.shape[1],\n query.stride(0),\n query.stride(1),\n output.stride(0),\n output.stride(1),\n key_cache.stride(0),\n key_cache.stride(1),\n key_cache.stride(2),\n SLOT_SIZE=block_size,\n HEAD_SIZE=head_size,\n num_warps=triton.cdiv(head_size, 32),\n )\n", - "description_1": "Use triton language to implement a kernel function for single query cached key-value attention, which involves matrix operations and iterative block processing. It takes 16 parameters including input/output tensors, caches, scales, and constexpr sizes, and a wrapping function to manage grid settings and invoke the kernel.", - "description_2": "Use triton language to implement key-value attention kernel handling cached data and a function to configure execution grid and call the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nPARTITION_SIZE = 512\n\n@triton.jit\ndef _single_query_cached_kv_attention_v2_unroll4(\n exp_sums, max_logits, out, q, k_cache, v_cache, head_mapping, scale, block_tables, \n seq_lens, partiton_size, max_num_blocks_per_seq, alibi_slopes, stride_qm, stride_qn, \n stride_om, stride_on, stride_ok, stride_km, stride_kn, stride_kk, stride_exp_m, \n stride_exp_n, BLOCK_SIZE: tl.constexpr, HEAD_SIZE: tl.constexpr\n):\n seq_idx = tl.program_id(axis=1)\n par_idx = tl.program_id(axis=2)\n seq_len = tl.load(seq_lens + seq_idx)\n\n if par_idx * partiton_size >= seq_len:\n return\n\n num_context_blocks = tl.cdiv(seq_len, BLOCK_SIZE)\n num_blocks_per_par = partiton_size // BLOCK_SIZE\n\n start_block_idx = par_idx * num_blocks_per_par\n end_block_idx = tl.minimum(start_block_idx + num_blocks_per_par, num_context_blocks)\n\n head_idx = tl.program_id(axis=0)\n kv_head_idx = tl.load(head_mapping + head_idx)\n\n if alibi_slopes is None:\n alibi_slope = 0.0\n else:\n alibi_slope = tl.load(alibi_slopes + head_idx)\n\n block_offs = tl.arange(0, BLOCK_SIZE)\n head_size_offs = tl.arange(0, HEAD_SIZE)\n q = tl.load(q + seq_idx * stride_qm + head_idx * stride_qn + head_size_offs)\n q = (q * scale).to(tl.float16)\n\n qkv = tl.zeros([BLOCK_SIZE, HEAD_SIZE], dtype=tl.float32)\n qk_max = float(\"-inf\")\n exp_sum = 0.0\n fp16_0 = tl.zeros([1, 1], dtype=k_cache.dtype.element_ty)\n base_offs_kv = (\n kv_head_idx * stride_kn\n + block_offs[:, None] * stride_kk\n + head_size_offs[None, :]\n )\n block_base_ptrs = block_tables + seq_idx * max_num_blocks_per_seq\n\n for block_idx in range(start_block_idx, end_block_idx, 4):\n mask_0 = block_offs[:, None] < (seq_len - (block_idx + 0) * BLOCK_SIZE)\n mask_1 = block_offs[:, None] < (seq_len - (block_idx + 1) * BLOCK_SIZE)\n mask_2 = block_offs[:, None] < (seq_len - (block_idx + 2) * BLOCK_SIZE)\n mask_3 = block_offs[:, None] < (seq_len - (block_idx + 3) * BLOCK_SIZE)\n offs_kv_0 = tl.load(block_base_ptrs + block_idx + 0) * stride_km + base_offs_kv\n offs_kv_1 = tl.load(block_base_ptrs + block_idx + 1) * stride_km + base_offs_kv\n offs_kv_2 = tl.load(block_base_ptrs + block_idx + 2) * stride_km + base_offs_kv\n offs_kv_3 = tl.load(block_base_ptrs + block_idx + 3) * stride_km + base_offs_kv\n\n k_0 = tl.load(k_cache + offs_kv_0, mask=mask_0, other=fp16_0)\n k_1 = tl.load(k_cache + offs_kv_1, mask=mask_1, other=fp16_0)\n k_2 = tl.load(k_cache + offs_kv_2, mask=mask_2, other=fp16_0)\n k_3 = tl.load(k_cache + offs_kv_3, mask=mask_3, other=fp16_0)\n\n v_0 = tl.load(v_cache + offs_kv_0, mask=mask_0, other=fp16_0)\n v_1 = tl.load(v_cache + offs_kv_1, mask=mask_1, other=fp16_0)\n v_2 = tl.load(v_cache + offs_kv_2, mask=mask_2, other=fp16_0)\n v_3 = tl.load(v_cache + offs_kv_3, mask=mask_3, other=fp16_0)\n\n _qk_0 = tl.sum((q[None, :] * k_0).to(tl.float32), axis=1)\n _qk_1 = tl.sum((q[None, :] * k_1).to(tl.float32), axis=1)\n _qk_2 = tl.sum((q[None, :] * k_2).to(tl.float32), axis=1)\n _qk_3 = tl.sum((q[None, :] * k_3).to(tl.float32), axis=1)\n\n _qk_0 += alibi_slope * ((block_idx + 0) * BLOCK_SIZE + block_offs - seq_len + 1)\n _qk_1 += alibi_slope * ((block_idx + 1) * BLOCK_SIZE + block_offs - seq_len + 1)\n _qk_2 += alibi_slope * ((block_idx + 2) * BLOCK_SIZE + block_offs - seq_len + 1)\n _qk_3 += alibi_slope * ((block_idx + 3) * BLOCK_SIZE + block_offs - seq_len + 1)\n\n _qk_max = tl.maximum(tl.max(_qk_0, axis=0), qk_max)\n _qk_max = tl.maximum(tl.max(_qk_1, axis=0), _qk_max)\n _qk_max = tl.maximum(tl.max(_qk_2, axis=0), _qk_max)\n _qk_max = tl.maximum(tl.max(_qk_3, axis=0), _qk_max)\n\n qk_0 = tl.where(mask_0, _qk_0[:, None], float(\"-inf\"))\n qk_1 = tl.where(mask_1, _qk_1[:, None], float(\"-inf\"))\n qk_2 = tl.where(mask_2, _qk_2[:, None], float(\"-inf\"))\n qk_3 = tl.where(mask_3, _qk_3[:, None], float(\"-inf\"))\n\n _exp_sum = (\n exp_sum * tl.exp(qk_max - _qk_max)\n + tl.sum(tl.exp(_qk_0 - _qk_max), axis=0)\n + tl.sum(tl.exp(_qk_1 - _qk_max), axis=0)\n + tl.sum(tl.exp(_qk_2 - _qk_max), axis=0)\n + tl.sum(tl.exp(_qk_3 - _qk_max), axis=0)\n )\n qkv = (\n qkv * (exp_sum * tl.exp(qk_max - _qk_max) / _exp_sum)\n + (tl.exp(qk_0 - _qk_max) / _exp_sum) * v_0\n + (tl.exp(qk_1 - _qk_max) / _exp_sum) * v_1\n + (tl.exp(qk_2 - _qk_max) / _exp_sum) * v_2\n + (tl.exp(qk_3 - _qk_max) / _exp_sum) * v_3\n )\n qk_max = _qk_max\n exp_sum = _exp_sum\n\n offs_exp = seq_idx * stride_exp_m + head_idx * stride_exp_n + par_idx\n tl.store(exp_sums + offs_exp, exp_sum)\n tl.store(max_logits + offs_exp, qk_max)\n\n offs_out = (\n seq_idx * stride_om\n + head_idx * stride_on\n + par_idx * stride_ok\n + head_size_offs\n )\n tl.store(out + offs_out, tl.sum(qkv, axis=0))\n\n\n@triton.jit\ndef _paged_attention_v2_reduce(\n out, exp_sums, max_logits, tmp_out, context_lens, stride_exp_m,\n stride_exp_n, stride_out_m, stride_out_n, stride_tmp_m, stride_tmp_n,\n stride_tmp_k, HEAD_SIZE: tl.constexpr, NUM_PARTITIONS: tl.constexpr\n):\n seq_idx = tl.program_id(axis=1)\n head_idx = tl.program_id(axis=0)\n context_len = tl.load(context_lens + seq_idx)\n\n num_partitions = tl.cdiv(context_len, PARTITION_SIZE)\n\n exp_sum = 0.0\n max_logit = float(\"-inf\")\n offs_logit = seq_idx * stride_exp_m + head_idx * stride_exp_n\n\n head_size_offs = tl.arange(0, HEAD_SIZE)\n tmp_out_ptr = seq_idx * stride_tmp_m + head_idx * stride_tmp_n\n out_ptr = seq_idx * stride_out_m + head_idx * stride_out_n + head_size_offs\n\n acc = tl.zeros([HEAD_SIZE], dtype=tl.float32)\n global_exp_sum = tl.zeros([1], dtype=tl.float32)\n\n logits = tl.load(\n max_logits + offs_logit + tl.arange(0, NUM_PARTITIONS),\n mask=tl.arange(0, NUM_PARTITIONS) < num_partitions,\n other=float(\"-inf\"),\n )\n max_logit = tl.max(logits, axis=0)\n\n exp_sum = tl.load(\n exp_sums + offs_logit + tl.arange(0, NUM_PARTITIONS),\n mask=tl.arange(0, NUM_PARTITIONS) < num_partitions,\n other=0.0,\n )\n rescaled_exp_sum = exp_sum * tl.exp(logits - max_logit)\n global_exp_sum += tl.sum(rescaled_exp_sum, axis=0)\n\n tmp = tl.load(\n tmp_out\n + tmp_out_ptr\n + tl.arange(0, NUM_PARTITIONS)[:, None] * stride_tmp_k\n + head_size_offs\n )\n acc += tl.sum(tmp * rescaled_exp_sum[:, None], axis=0)\n\n inv_sum = 1.0 / (global_exp_sum + 1e-6)\n tl.store(out + out_ptr, acc * inv_sum)\n\n\ndef triton_paged_attention_v2(\n out, query, key_cache, value_cache, head_mapping, scale,\n block_tables, context_lens, max_context_len, alibi_slopes=None\n):\n num_heads = value_cache.shape[1]\n head_size = value_cache.shape[-1]\n block_size = value_cache.shape[-2]\n num_seqs = query.shape[0]\n\n max_num_partitions = triton.cdiv(max_context_len, PARTITION_SIZE)\n\n exp_sums = torch.empty(\n (num_seqs, num_heads, max_num_partitions), dtype=torch.float32, device=\"cuda\"\n )\n max_logits = torch.empty(\n (num_seqs, num_heads, max_num_partitions), dtype=torch.float16, device=\"cuda\"\n )\n tmp_out = torch.empty(\n (num_seqs, num_heads, max_num_partitions, head_size),\n dtype=torch.float32,\n device=\"cuda\",\n )\n\n # online softmax with unroll4\n kwargs = [\n exp_sums,\n max_logits,\n tmp_out,\n query,\n key_cache,\n value_cache,\n head_mapping,\n scale,\n block_tables,\n context_lens,\n PARTITION_SIZE,\n block_tables.shape[1],\n alibi_slopes,\n query.stride(0),\n query.stride(1),\n tmp_out.stride(0),\n tmp_out.stride(1),\n tmp_out.stride(2),\n key_cache.stride(0),\n key_cache.stride(1),\n key_cache.stride(2),\n exp_sums.stride(0),\n exp_sums.stride(1),\n ]\n grid = (num_heads, num_seqs, max_num_partitions)\n const_kwargs = {\"BLOCK_SIZE\": block_size, \"HEAD_SIZE\": head_size}\n _single_query_cached_kv_attention_v2_unroll4[grid](*kwargs, **const_kwargs)\n\n # reduction across partitions\n num_partitions = triton.next_power_of_2(max_num_partitions)\n kwargs = [\n out,\n exp_sums,\n max_logits,\n tmp_out,\n context_lens,\n exp_sums.stride(0),\n exp_sums.stride(1),\n out.stride(0),\n out.stride(1),\n tmp_out.stride(0),\n tmp_out.stride(1),\n tmp_out.stride(2),\n ]\n grid = (num_heads, num_seqs, 1)\n const_kwargs = {\n \"HEAD_SIZE\": head_size,\n \"NUM_PARTITIONS\": num_partitions,\n \"num_warps\": triton.cdiv(head_size, 32),\n }\n _paged_attention_v2_reduce[grid](*kwargs, **const_kwargs)\n", - "description_1": "Use triton language to define three functions. `_single_query_cached_kv_attention_v2_unroll4` handles softmax computation for attention with parameters for managing sequences, heads, and blocks. `_paged_attention_v2_reduce` manages the reduction across partitions and computes final attention output. `triton_paged_attention_v2` orchestrates the process by setting up grid configurations and invoking the other functions.", - "description_2": "Use triton language to implement attention mechanisms using efficient block processing with grid configurations for sequences, heads, and blocks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef _get_autotune_configs():\n configs = [\n triton.Config({\"BLOCK_SIZE\": 64}, num_warps=2),\n triton.Config({\"BLOCK_SIZE\": 128}, num_warps=2),\n triton.Config({\"BLOCK_SIZE\": 128}, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 256}, num_warps=2),\n triton.Config({\"BLOCK_SIZE\": 256}, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 256}, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 512}, num_warps=2),\n triton.Config({\"BLOCK_SIZE\": 512}, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 512}, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 512}, num_warps=16),\n triton.Config({\"BLOCK_SIZE\": 1024}, num_warps=2),\n triton.Config({\"BLOCK_SIZE\": 1024}, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 1024}, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 1024}, num_warps=16),\n ]\n return configs\n\n@triton.jit\ndef _rms_norm_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\ndef triton_rmsnorm_forward(x, weight, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n\n kwargs = [x_arg, y, weight, x_arg.stride(0), N, eps]\n rms_norm = triton.autotune(configs=_get_autotune_configs(), key=[\"N\"])(\n _rms_norm_kernel\n )\n grid = (M, 1, 1)\n rms_norm[(M,)](*kwargs)\n\n return y\n", - "description_1": "Use triton language to implement a root mean square normalization kernel. The kernel function '_rms_norm_kernel' takes 7 parameters: X (input pointer), Y (output pointer), W (weights pointer), stride (row stride), N (number of columns), eps (epsilon for numerical stability), and BLOCK_SIZE (block size for computation). The function computes the variance and inverse square root of the variance, then normalizes and scales the input data. The 'triton_rmsnorm_forward' function prepares the input data, sets up the kernel execution grid, and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a kernel for RMS normalization with configurable block size and warps. Implement a forward function to prepare data and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rms_norm_dquant_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n scale, # pointer to the output scale\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr, # block size\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n _max_x = 0.0\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n w = tl.load(W + cols, mask=mask)\n norm = x * rstd * w\n _max_x = tl.maximum(_max_x, tl.max(tl.abs(norm), axis=0))\n scale_x = _max_x / 127.0\n tl.store(scale + row, scale_x)\n\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n w = tl.load(W + cols, mask=mask)\n norm = x * rstd * w\n norm = norm / scale_x\n # rounding to nearest even\n norm = tl.where(norm > 0, norm + 0.5, norm - 0.5)\n tl.store(Y + cols, norm.to(tl.int8), mask=mask)\n\ndef triton_rmsnorm_dquant_forward(x, weight, eps):\n # allocate output\n y = torch.empty(x.shape, dtype=torch.int8, device=x.device)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n scale = torch.empty((M,), dtype=x.dtype, device=x.device)\n # enqueue kernel\n kwargs = [x_arg, y, weight, scale, x_arg.stride(0), N, eps]\n grid = (M, 1, 1)\n rmsnorm_dquant = triton.autotune(configs=_get_autotune_configs(), key=[\"N\"])(\n _rms_norm_dquant_kernel\n )\n rmsnorm_dquant[grid](*kwargs)\n\n scale = scale.reshape(x.shape[:-1])\n return y, scale\n", - "description_1": "Use triton language to implement a kernel function '_rms_norm_dquant_kernel' with 8 parameters: X (input pointer), Y (output pointer), W (weights pointer), scale (output scale pointer), stride (row stride), N (number of columns), eps (epsilon for numerical stability), and BLOCK_SIZE (block size). The kernel normalizes input data, computes a scale factor, and stores quantized results. The function 'triton_rmsnorm_dquant_forward' calls this kernel with 3 parameters: x (input tensor), weight (weights tensor), and eps (epsilon), preparing data and managing kernel execution.", - "description_2": "Use triton language to create a kernel for RMS normalization and quantization, and a function to execute this kernel with input tensors and parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rotary_embedding_kernel(\n q_rot_ptr,\n k_rot_ptr,\n q_ptr,\n k_ptr,\n cos_ptr,\n sin_ptr,\n seq_len,\n batch_size,\n num_heads,\n num_kv,\n hidden_size,\n q_strides,\n q_strideb,\n q_strideh,\n q_strided,\n k_strides,\n k_strideb,\n k_stridekv,\n k_strided,\n seq_offset,\n BLOCK_SIZE_SEQ: tl.constexpr,\n BLOCK_SIZE_BH: tl.constexpr,\n BLOCK_SIZE_D: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_bh_blocks = tl.cdiv(batch_size * num_heads, BLOCK_SIZE_BH)\n num_d_blocks = tl.cdiv(hidden_size // 2, BLOCK_SIZE_D)\n bh_id = pid % num_bh_blocks\n d_id = pid // num_bh_blocks % num_d_blocks\n seq_block_id = pid // num_bh_blocks // num_d_blocks\n\n seq_offs = seq_offset + seq_block_id * BLOCK_SIZE_SEQ + tl.arange(0, BLOCK_SIZE_SEQ)\n\n bh_offs = bh_id * BLOCK_SIZE_BH + tl.arange(0, BLOCK_SIZE_BH)\n q_common_offs = (\n seq_offs[:, None, None] * q_strides + bh_offs[None, :, None] * q_strideh\n )\n k_common_offs = (\n seq_offs[:, None, None] * k_strides\n + bh_offs[None, :, None] // (num_heads // num_kv) * k_stridekv\n )\n q_base_offs, qo_base_offs = q_ptr + q_common_offs, q_rot_ptr + q_common_offs\n k_base_offs, ko_base_offs = k_ptr + k_common_offs, k_rot_ptr + k_common_offs\n c_base_offs = cos_ptr + seq_offs[:, None] * hidden_size\n s_base_offs = sin_ptr + seq_offs[:, None] * hidden_size\n\n hidden_block_range = tl.arange(0, BLOCK_SIZE_D)\n\n hidden_offs_l = d_id * BLOCK_SIZE_D + hidden_block_range\n hidden_offs_r = hidden_size // 2 + hidden_offs_l\n mask_l, mask_r = hidden_offs_l < hidden_size // 2, hidden_offs_r < hidden_size\n mask_bh = bh_offs < batch_size * num_heads\n mask_seq = seq_offs < seq_len\n mask_bh_seq = mask_bh[None, :, None] & mask_seq[:, None, None]\n\n q_l, k_l = tl.load(\n q_base_offs + hidden_offs_l[None, None, :] * q_strided,\n mask=mask_l[None, None, :] & mask_bh_seq,\n other=0,\n ), tl.load(\n k_base_offs + hidden_offs_l[None, None, :] * k_strided,\n mask=mask_l[None, None, :] & mask_bh_seq,\n other=0,\n )\n q_r, k_r = tl.load(\n q_base_offs + hidden_offs_r[None, None, :] * q_strided,\n mask=mask_r[None, None, :] & mask_bh_seq,\n other=0,\n ), tl.load(\n k_base_offs + hidden_offs_r[None, None, :] * k_strided,\n mask=mask_r[None, None, :] & mask_bh_seq,\n other=0,\n )\n cos_l, cos_r = (\n tl.load(c_base_offs + hidden_offs_l[None, :], mask=mask_l[None, :], other=0)[\n :, None, :\n ],\n tl.load(c_base_offs + hidden_offs_r[None, :], mask=mask_r[None, :], other=0)[\n :, None, :\n ],\n )\n sin_l, sin_r = (\n tl.load(s_base_offs + hidden_offs_l[None, :], mask=mask_l[None, :], other=0)[\n :, None, :\n ],\n tl.load(s_base_offs + hidden_offs_r[None, :], mask=mask_r[None, :], other=0)[\n :, None, :\n ],\n )\n\n qo_l = q_l * cos_l - q_r * sin_l\n tl.store(\n qo_base_offs + hidden_offs_l, qo_l, mask=mask_l[None, None, :] & mask_bh_seq\n )\n qo_r = q_r * cos_r + q_l * sin_r\n tl.store(\n qo_base_offs + hidden_offs_r, qo_r, mask=mask_r[None, None, :] & mask_bh_seq\n )\n ko_l = k_l * cos_l - k_r * sin_l\n tl.store(\n ko_base_offs + hidden_offs_l, ko_l, mask=mask_l[None, None, :] & mask_bh_seq\n )\n ko_r = k_r * cos_r + k_l * sin_r\n tl.store(\n ko_base_offs + hidden_offs_r, ko_r, mask=mask_r[None, None, :] & mask_bh_seq\n )\n\ndef triton_rotary_embd_forward(\n q, k, cos_ptr, sin_ptr, offset=0, max_seq_len=None, seq_dim=0\n):\n if max_seq_len is None:\n max_seq_len = k.shape[seq_dim]\n max_seq_len += offset\n query_rot = torch.empty_like(q)\n key_rot = torch.empty_like(k)\n _, B, H, D = q.shape\n _, _, nKV, _ = k.shape\n\n kwargs = [\n query_rot,\n key_rot,\n q,\n k,\n cos_ptr,\n sin_ptr,\n max_seq_len,\n B,\n H,\n nKV,\n D,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n offset,\n ]\n\n def grid(META):\n return (\n max(\n 1,\n (\n triton.cdiv(\n max_seq_len * B * H,\n META[\"BLOCK_SIZE_SEQ\"] * META[\"BLOCK_SIZE_BH\"],\n )\n * triton.cdiv(D // 2, META[\"BLOCK_SIZE_D\"])\n ),\n ),\n 1,\n 1,\n )\n\n rotary_embedding = triton.autotune(\n configs=_get_autotune_configs(),\n key=[\"seq_len\", \"batch_size\", \"num_heads\", \"num_kv\", \"hidden_size\"],\n )(_rotary_embedding_kernel)\n\n rotary_embedding[grid](*kwargs)\n return query_rot, key_rot\n", - "description_1": "Use triton language to implement a rotary embedding kernel with 22 parameters: q_rot_ptr, k_rot_ptr, q_ptr, k_ptr, cos_ptr, sin_ptr, seq_len, batch_size, num_heads, num_kv, hidden_size, q_strides, q_strideb, q_strideh, q_strided, k_strides, k_strideb, k_stridekv, k_strided, seq_offset, BLOCK_SIZE_SEQ, BLOCK_SIZE_BH, BLOCK_SIZE_D. The kernel performs rotary embedding on input queries and keys using cosine and sine values, storing the results in q_rot_ptr and k_rot_ptr.", - "description_2": "Use triton language to create a function triton_rotary_embd_forward with 7 parameters: q, k, cos_ptr, sin_ptr, offset, max_seq_len, seq_dim. This function prepares data and calls the rotary embedding kernel to compute the rotary embeddings for the input queries and keys.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_SIZE_M': 32,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 32,\n 'NUM_SM': 128,\n }, num_warps=2, num_stages=5),\n ],\n key=['group_size'],\n)\n@triton.jit\ndef grouped_matmul_kernel(\n fused_input_ptr,\n cum_input_group_range,\n fused_b_ptr,\n fused_output_ptr,\n group_size,\n n,\n k,\n lda,\n ldb,\n ldc,\n NUM_SM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n tile_idx = tl.program_id(0)\n last_problem_end = 0\n for g in range(group_size):\n a_offset = tl.load(cum_input_group_range + g)\n gm = tl.load(cum_input_group_range + g + 1) - a_offset\n gn = n\n gk = k\n num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n num_tiles = num_m_tiles * num_n_tiles\n while (tile_idx >= last_problem_end\n and tile_idx < last_problem_end + num_tiles):\n\n k = gk\n a_ptr = fused_input_ptr + a_offset * lda\n b_ptr = fused_b_ptr + g * k * n\n c_ptr = fused_output_ptr + a_offset * ldc\n tile_idx_in_gemm = tile_idx - last_problem_end\n tile_m_idx = tile_idx_in_gemm // num_n_tiles\n tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]\n b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),\n dtype=tl.float32)\n for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n tl.multiple_of(a_ptrs, [16, 16])\n tl.multiple_of(b_ptrs, [16, 16])\n\n a = tl.load(a_ptrs,\n mask=offs_k[None, :] < k - kk * BLOCK_SIZE_K,\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < k - kk * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * ldb\n\n if ACTIVATION == \"silu\":\n accumulator = silu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]\n c_mask = (offs_cm[:, None] < gm) & (offs_cn[None, :] < gn)\n\n tl.store(c_ptrs, c, mask=c_mask)\n\n tile_idx += NUM_SM\n\n last_problem_end = last_problem_end + num_tiles\n\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\n\ndef grouped_matmul(fused_input: torch.Tensor,\n cum_group_range: torch.Tensor,\n fused_group_b: torch.Tensor,\n activation: str = \"\"):\n device = torch.device('cuda')\n assert cum_group_range.shape[0] == fused_group_b.shape[0] + 1\n group_size = cum_group_range.shape[0] - 1\n output = torch.zeros(fused_input.shape[0],\n fused_group_b.shape[2],\n device=device,\n dtype=fused_input.dtype)\n\n grid = lambda META: (META['NUM_SM'], )\n grouped_matmul_kernel[grid](\n fused_input,\n cum_group_range,\n fused_group_b,\n output,\n group_size,\n n=fused_group_b.shape[2],\n k=fused_group_b.shape[1],\n lda=fused_input.stride(0),\n ldb=fused_group_b.stride(1),\n ldc=output.stride(0),\n ACTIVATION=activation,\n )\n\n return output\n", - "description_1": "Use triton language to implement a grouped matrix multiplication kernel. The kernel has parameters for device tensors, tile sizes, and activation function. It uses a loop to process matrix tiles, performing matrix multiplications and applying optional activation (silu). The kernel writes the results back to the output tensor. Additionally, the grouped_matmul function, which calls the kernel, manages tensor dimensions, allocations, and invocation of the kernel with required parameters.", - "description_2": "Use triton language to implement a grouped matrix multiplication kernel with configurable tile sizes and activation function. The kernel is called from a wrapper function managing tensor allocations and dimensions.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb,\n stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k,\n seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Compute QK^T and apply softmax\n # Detailed implementation omitted for brevity\n pass\n\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom,\n nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n):\n # Compute O*DO^T\n # Detailed implementation omitted for brevity\n pass\n\n\n@triton.jit\ndef _bwd_store_dk_dv(\n dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n):\n # Store gradients dK and dV\n # Detailed implementation omitted for brevity\n pass\n\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm,\n stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Compute gradients for one column block\n # Detailed implementation omitted for brevity\n pass\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False},\n num_warps=8,\n num_stages=1,\n pre_hook=lambda nargs: nargs[\"DQ\"].zero_(),\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True},\n num_warps=8,\n num_stages=1,\n pre_hook=lambda nargs: nargs[\"DQ\"].zero_(),\n ),\n ],\n key=[\n \"CACHE_KEY_SEQLEN_Q\", \"CACHE_KEY_SEQLEN_K\", \"BIAS_TYPE\", \"IS_CAUSAL\",\n \"BLOCK_HEADDIM\",\n ],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh,\n stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb,\n stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb,\n stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Compute backward pass\n # Detailed implementation omitted for brevity\n pass\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Forward wrapper for _fwd_kernel\n pass\n\n\ndef _flash_attn_backward(\n do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None\n):\n # Backward wrapper for _bwd_kernel\n pass\n\n\nclass FlashAttnQKVPackedFunc(torch.autograd.Function):\n @staticmethod\n def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):\n (o, lse, ctx.softmax_scale) = _flash_attn_forward(\n qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,\n softmax_scale=softmax_scale,\n )\n ctx.save_for_backward(qkv, o, lse, bias)\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n (qkv, o, lse, bias) = ctx.saved_tensors\n dqkv = torch.empty_like(qkv)\n _flash_attn_backward(\n do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0],\n dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal,\n softmax_scale=ctx.softmax_scale,\n )\n return (dqkv, None, None, None)\n\n\nflash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply\n\n\nclass FlashAttnKVPackedFunc(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):\n (o, lse, ctx.softmax_scale) = _flash_attn_forward(\n q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal,\n softmax_scale=softmax_scale,\n )\n ctx.save_for_backward(q, kv, o, lse, bias)\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n (q, kv, o, lse, bias) = ctx.saved_tensors\n dq = torch.empty_like(q)\n dkv = torch.empty_like(kv)\n _flash_attn_backward(\n do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1],\n bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale,\n )\n return (dq, dkv, None, None, None)\n\n\nflash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):\n (o, lse, ctx.softmax_scale) = _flash_attn_forward(\n q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale\n )\n ctx.save_for_backward(q, k, v, o, lse, bias)\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n (q, k, v, o, lse, bias) = ctx.saved_tensors\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n _flash_attn_backward(\n do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal,\n softmax_scale=ctx.softmax_scale,\n )\n return (dq, dk, dv, None, None, None)\n\n\nflash_attn_func = FlashAttnFunc.apply\n", - "description_1": "Use triton language to implement efficient forward and backward kernels for FlashAttention with triton.jit, supporting causal and non-causal attention, optimized with heuristics and autotuning, ensuring compatibility with different head dimensions and sequence lengths.", - "description_2": "Use triton language to develop advanced attention mechanisms with forward and backward pass kernels, supporting flexible dimensions and computational optimizations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import foreach\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\nclass ForeachKernel(Kernel):\n def jit_line(self):\n can_use_32bit = all(k.index_dtype == \"tl.int32\" for k in self.sub_kernels)\n size_dtype = \"tl.int32\" if can_use_32bit else \"tl.int64\"\n _, _, signature = self.args.python_argdefs()\n triton_meta = {\n \"signature\": signature_to_meta(signature, size_dtype=size_dtype),\n \"device\": V.graph.scheduler.current_device.index,\n \"device_type\": V.graph.scheduler.current_device.type,\n \"constants\": {},\n }\n triton_meta[\"configs\"] = [config_of(signature)]\n inductor_meta = {\"kernel_name\": str(Placeholder.DESCRIPTIVE_NAME)}\n return (\n f\"@foreach(num_warps={self.num_warps}, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r})\\n\"\n + \"@triton.jit\"\n )\n\n def codegen_kernel(self, name=None):\n code = IndentedBuffer()\n code.splice(\n \"\"\"\n import triton\n import triton.language as tl\n from torch._inductor.triton_heuristics import foreach\n from torch._inductor.utils import instance_descriptor\n from torch._inductor import triton_helpers\n \"\"\"\n )\n argdefs, _, _ = self.args.python_argdefs()\n code.writeline(self.jit_line())\n code.writeline(\n f\"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):\"\n )\n with code.indent():\n code.splice(\"xpid = tl.program_id(0)\")\n if self.blocking_2d:\n code.splice(\"ypid = tl.program_id(1)\")\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_2d}\")\n code.splice(f\"YBLOCK: tl.constexpr = {self.block_size_2d}\")\n else:\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_1d}\")\n\n for sub_kernel in self.sub_kernels:\n assert len(sub_kernel.numels) <= 3\n numel_ind = 0 if not self.blocking_2d else 1\n self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))\n with code.indent():\n if self.blocking_2d:\n code.splice(f\"ynumel = {sub_kernel.numels[0]}\")\n code.splice(f\"xnumel = {sub_kernel.numels[1]}\")\n else:\n code.splice(f\"xnumel = {sub_kernel.numels[0]}\")\n\n sub_kernel.codegen_body()\n code.splice(sub_kernel.body)\n\n code.splice(\"else:\")\n with code.indent():\n code.splice(\"pass\")\n\n return code.getvalue()\n\n def call_kernel(self, code, name: str):\n _, call_args, _ = self.args.python_argdefs()\n for i in range(len(call_args)):\n if V.graph.is_unspec_arg(call_args[i]):\n call_args[i] = call_args[i] + \".item()\"\n if V.graph.cpp_wrapper:\n V.graph.wrapper_code.generate_kernel_call(\n name,\n call_args,\n device_index=V.graph.scheduler.current_device.index,\n grid=self.grid(),\n )\n else:\n call_args_str = \", \".join(call_args)\n stream_name = code.write_get_raw_stream(\n V.graph.scheduler.current_device.index\n )\n code.writeline(\n f\"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})\"\n )\n", - "description_1": "Use triton language to define and execute a generic kernel using the @triton.jit decorator in a ForeachKernel class. The code defines the kernel and executes it over a specified grid, incorporating meta-information, kernel parameters, and blocking strategies.", - "description_2": "Use triton language to define a kernel with @triton.jit and execute it with specific meta settings.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n # Triton kernel to add two vectors\n pid = tl.program_id(0)\n block_start = pid * 1024\n offsets = block_start + tl.arange(0, 1024)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n # Function to call the Triton kernel\n assert x.is_cuda and y.is_cuda\n assert x.shape == y.shape\n z = torch.empty_like(x)\n N = x.numel()\n grid = lambda meta: (triton.cdiv(N, 1024),)\n add_kernel[grid](x, y, z, N)\n return z\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that adds two vectors X and Y, storing the result in Z. The kernel takes four arguments: X, Y, Z (all pointers to the data) and N (the number of elements). The kernel uses a block size of 1024 and computes the sum of elements in X and Y, storing the result in Z. The function 'add' is a wrapper that calls this kernel, ensuring the input tensors are on CUDA and have the same shape, and returns the result tensor.", - "description_2": "Use triton language to define a kernel for element-wise addition of two vectors, and a wrapper function to execute this kernel on CUDA tensors.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef example_kernel(X, stride_xm, stride_xk, stride_xn, size_m, size_n, size_k, Y):\n m = tl.program_id(0)\n n = tl.program_id(1)\n k = tl.program_id(2)\n\n x = tl.load(X + m * stride_xm + k * stride_xk)\n tl.store(Y + m * stride_xm + n * stride_xn, x)\n\ndef example_call(X, Y):\n grid = lambda meta: (X.shape[0], X.shape[1], 1)\n example_kernel[grid](X, Y)\n", - "description_1": "Use triton language to implement a kernel that performs a simple memory load and store operation. The function `example_kernel` takes 8 parameters: X, stride_xm, stride_xk, stride_xn, size_m, size_n, size_k, and Y. It reads a value from a 3D matrix X at a position determined by the strides and the program ids, and stores it into a corresponding position in a 3D matrix Y. The `example_call` function is a wrapper that launches the kernel on a grid defined by the dimensions of X.", - "description_2": "Use triton language to implement a 3D grid kernel that loads data from one matrix and stores it in another, based on calculated indices from program ids and provided strides.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight):\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n return (\n new_mean,\n m2 + delta * (value - new_mean),\n new_weight,\n )\n\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n", - "description_1": "Use triton language to implement various arithmetic and reduction operations including promotion to tensor, checking if a value is floating, product accumulation, product reduction, minimum, maximum, min and max with index, Welford reduction and combination, device assertions, random integer generation, and a bucketize operation. These operations often involve comparison and selection using masks, handling NaNs, and employing triton's built-in reduction and random number generation functions.", - "description_2": "Use triton language to implement various arithmetic and reduction operations including min, max, and random number generation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n def grid(META):\n return (crow_indices.size(0) - 1,)\n\n dot_out_dtype = {torch.float16: tl.float32,\n torch.bfloat16: tl.float32,\n torch.float32: tl.float64,\n torch.float64: tl.float64}[out.dtype]\n if 'allow_tf32' not in meta:\n meta.update(allow_tf32=dot_out_dtype == tl.float32)\n\n _sampled_addmm_kernel[grid](\n alpha, beta, beta == 0.0,\n *blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n acc_dtype=dot_out_dtype,\n allow_tf32=dot_out_dtype == tl.float32\n )\n\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n", - "description_1": "Use triton language to implement a sparse sampled matrix multiplication with scaling and addition using kernel `_sampled_addmm_kernel`.", - "description_2": "Implement a function in triton that performs matrix multiplication on sparse data with an optional addition term, following the kernel `_sampled_addmm_kernel` pattern.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n# Kernel to add two arrays element-wise\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Kernel to add two arrays with an optional parameter\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel to add two 2D arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n# Kernel to multiply an array by 2\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Kernel to multiply an array by 2 in place\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n# Kernel with indirection and activation\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n", - "description_1": "Use triton language to implement various kernels: add_kernel for element-wise addition of two arrays, add_kernel_with_optional_param for addition with an optional parameter, add_kernel_autotuned for autotuned element-wise addition, add_kernel_2d_autotuned for autotuned 2D array addition, mul2_kernel for multiplying an array by 2, mul2_inplace_kernel for in-place multiplication by 2, and indirection_kernel for applying an activation function with indirection.", - "description_2": "Use triton language to create kernels for element-wise operations and autotuning, including addition, multiplication, and activation functions with optional parameters and indirection.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _kernel(A,\n B,\n C,\n stride_za,\n stride_ha,\n stride_ma,\n stride_ka,\n stride_zb,\n stride_hb,\n stride_kb,\n stride_nb,\n stride_zc,\n stride_hc,\n stride_mc,\n stride_nc,\n DS0,\n DS1,\n SDD_K,\n SDD_off_width,\n lut,\n locks,\n nlocks,\n **meta):\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n # Initialize variables and load data as per meta configuration\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8)\n offpb = 0\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8)\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8)\n offpa = 0\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8)\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n\n # Initialize pointers and load data\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n\n # Inner loop for accumulation\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n # Prefetch\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n\n # Store the result\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(\n 1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n\n\nclass _sparse_matmul(torch.autograd.Function):\n\n @staticmethod\n def _sdd_matmul(a,\n b,\n trans_a,\n trans_b,\n trans_c,\n spdims,\n block,\n luts,\n num_locks,\n widths,\n packs,\n bench,\n time):\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n AS0 = a.size(0)\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]\n if a_inner != b_inner:\n raise ValueError(\n f\"Size of tensor A along the {a_dim} dim ({a_inner}) must match size \"\n f\"of tensor B along the {b_dim} dim ({b_inner})\")\n if a_inner % 16 != 0:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n\n batch_size = a.size(0)\n a_outer = a.size(3 if trans_a else 2)\n dtype = a.dtype\n is_16_multiple = a_inner % 16 == 0\n is_32_multiple = a_inner % 32 == 0\n is_64_multiple = a_inner % 64 == 0\n if not is_16_multiple:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n device = a.device\n total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])\n c = torch.empty((batch_size,\n total_width,\n block,\n block),\n dtype=dtype,\n device=a.device)\n for lut, width, pack in zip(luts, widths, packs):\n F32TK = [8, 16]\n F16TK = [16]\n F16TK += [32] if is_32_multiple else []\n F16TK += [64] if is_64_multiple else []\n TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]\n num_lock = 1\n meta = {\n 'TM': block * pack,\n 'TN': block * pack,\n 'BLOCK': block,\n 'TK': TK[0],\n 'TZ': 1,\n 'SDD': True,\n 'DSD': False,\n 'DDS': False\n }\n locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)\n max_width = 49152\n total = 0 if bench else None\n for off_width in range(0, width, max_width):\n grid = lambda meta: [\n meta['TZ'],\n min(max_width,\n width - off_width),\n batch_size\n ]\n _kernel[grid](a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(0),\n c.stride(2),\n c.stride(3),\n a_outer,\n a_outer,\n a_inner,\n off_width,\n lut,\n locks,\n num_lock,\n num_warps=4,\n **meta)\n return c\n", - "description_1": "Use triton language to implement a sparse matrix multiplication kernel (_kernel) and a function (_sdd_matmul) to execute sparse = dense x dense matrix multiplication. The kernel handles various strides and offsets, locks for synchronization, and meta configuration for different cases. It performs matrix multiplication using dot products within a block-grid setup with optional spin-locks for accumulation. The function _sdd_matmul organizes the setup for kernel execution, handling dimensions and ensuring configuration compatibility with given meta parameters and device.", - "description_2": "Use triton language to implement a sparse matrix multiplication kernel and execute sparse = dense x dense matrix multiplication using block-grid setup with optional spin-locks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _forward(X,\n scale,\n LUT,\n RPE,\n KP_M,\n ATTN_M,\n sizemax,\n stride_zx,\n stride_zrpe,\n stride_hrpe,\n stride_srpe,\n stride_zkpm,\n stride_zattnm,\n **meta):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n columnid = tl.load(LUT + offset + rbmn * 4 + 1)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # apply scale\n if meta['APPLY_SCALE']:\n x = x * scale\n # apply RPE\n if meta['APPLY_RPE']:\n prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn\n rpe = tl.load(prpe, mask=check, other=0)\n x = x + rpe\n # apply key-padding mask\n if meta['APPLY_KP_MASK']:\n pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn\n kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))\n if meta['KP_MASK_MUL']:\n kp_m = tl.where(kp_m == 0, -float('inf'), 0.)\n x = x + kp_m\n # apply attention mask\n if meta['APPLY_ATTN_MASK']:\n pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn\n attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))\n if meta['ATTN_MASK_MUL']:\n attn_m = tl.where(attn_m == 0, -float('inf'), 0.)\n x = x + attn_m\n # computation\n x = tl.softmax(x)\n tl.store(px, x, mask=check)\n\n@triton.jit\ndef _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n # compute fused softmax backward\n x = tl.load(X, mask=check, other=0)\n dx = tl.load(DX, mask=check, other=0)\n x = x.to(tl.float32)\n dx = dx.to(tl.float32)\n y = x * (dx - tl.sum(x * dx, 0)) * scale\n tl.store(DX, y, mask=check)\n\nclass _sparse_softmax(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx,\n x,\n scale,\n rpe,\n key_padding_mask,\n attn_mask,\n kp_mask_mode,\n attn_mask_mode,\n spdims,\n block,\n lut,\n num_blocks,\n maxlut,\n bench,\n time):\n\n apply_scale = False if scale == 1.0 else True\n\n # handle None rpe\n if rpe is None:\n apply_rpe = False\n stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0\n rpe = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_rpe = True\n stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)\n\n # handle None key_padding_mask\n if key_padding_mask is None:\n apply_kp_mask = False\n stride_zkpm = 0\n key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_kp_mask = True\n stride_zkpm = key_padding_mask.stride(0)\n\n # handle None attention_mask\n if attn_mask is None:\n apply_attn_mask = False\n stride_zattnm = 0\n attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_attn_mask = True\n stride_zattnm = attn_mask.stride(0)\n\n # run kernel\n M = x.shape[0]\n meta = {\n 'BLOCK': block,\n 'APPLY_SCALE': apply_scale,\n 'APPLY_RPE': apply_rpe,\n 'APPLY_KP_MASK': apply_kp_mask,\n 'APPLY_ATTN_MASK': apply_attn_mask,\n 'KP_MASK_MUL': kp_mask_mode == 'mul',\n 'ATTN_MASK_MUL': attn_mask_mode == 'mul',\n }\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\\\n stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)\n\n # save to context\n ctx.mark_dirty(x)\n ctx.save_for_backward(x, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.apply_scale = apply_scale\n ctx.apply_rpe = apply_rpe\n ctx.apply_kp_mask = apply_kp_mask\n ctx.apply_attn_mask = apply_attn_mask\n ctx.kp_mask_mode = kp_mask_mode\n ctx.attn_mask_mode = attn_mask_mode\n return x\n\n @staticmethod\n def backward(ctx, dx):\n\n # retrieve from context\n x, lut = ctx.saved_tensors\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x,\n ctx.scale,\n dx,\n lut,\n ctx.maxlut,\n x.stride(0),\n dx.stride(0),\n BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n", - "description_1": "Use triton language to implement a block-sparse softmax operation with optional scaling, relative position embedding, key padding mask, and attention mask. The forward kernel (_forward) takes 13 parameters: X (input tensor), scale (scaling factor), LUT (look-up table), RPE (relative position embedding), KP_M (key padding mask), ATTN_M (attention mask), sizemax (maximum size), stride_zx (stride for X), stride_zrpe (stride for RPE), stride_hrpe (stride for RPE head), stride_srpe (stride for RPE sequence), stride_zkpm (stride for key padding mask), and stride_zattnm (stride for attention mask). The backward kernel (_backward) takes 7 parameters: X (input tensor), scale (scaling factor), DX (gradient tensor), LUT (look-up table), sizemax (maximum size), stride_zx (stride for X), and stride_zdx (stride for DX).", - "description_2": "Use triton language to create a block-sparse softmax function with forward and backward passes, supporting scaling and various masks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, TMP, Out, stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, \n stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, \n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n q = tl.load(q_ptrs)\n for start_n in range(0, N_CTX, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs)\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\nclass triton_flash_attn(torch.nn.Module):\n def __init__(self, ):\n super(triton_flash_attn, self).__init__()\n\n def forward(self, q, k, v, sm_scale, block_128=True):\n BLOCK = 128 if block_128 else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, tmp, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n k.shape[0], k.shape[1], k.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=1,\n )\n return o\n", - "description_1": "Use triton language to implement a fused attention kernel (_fwd_kernel) with 27 parameters for matrix multiplication and softmax scaling. The wrapper class triton_flash_attn has a forward method accepting 5 parameters to set up the computation grid and call the kernel.", - "description_2": "Use triton language to create a fused attention mechanism with a specialized kernel and invoke it within a PyTorch module.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n# Forward pass kernel\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Kernel implementation\n pass\n\n# Backward pass kernel\n@triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=lambda nargs: nargs['DQ'].zero_()), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=lambda nargs: nargs['DQ'].zero_())], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'])\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Kernel implementation\n pass\n\n# Flash attention forward function\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Function implementation\n pass\n\n# Flash attention backward function\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n # Function implementation\n pass\n", - "description_1": "Use triton language to implement forward and backward kernels for FlashAttention, supporting operations on tensors with shapes related to batch, sequence length, heads, and head dimensions, including features for causal masking and bias adjustment. The forward function handles the computation of output and log-sum-exp values, while the backward function computes gradients for the input tensors.", - "description_2": "Use triton language to implement FlashAttention's forward and backward passes, managing tensor shapes for batching and multiple heads, and including causal and bias considerations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _softmax_fwd_kernel(\n output_ptr, \n stride_output_row, \n input_ptr,\n stride_input_row,\n num_cols,\n block_size: tl.constexpr,\n):\n row_index = tl.program_id(0)\n row_start_prt = input_ptr + (row_index * stride_input_row)\n col_offsets = tl.arange(0, block_size)\n input_pointers = row_start_prt + col_offsets\n row_mask = col_offsets < num_cols\n row = tl.load(input_pointers, mask=row_mask, other=float(\"-inf\"))\n safe_row = row - tl.max(row, axis=0)\n numerator = tl.exp(safe_row)\n denominator = tl.sum(numerator, axis=0)\n softmax_out = numerator / denominator\n output_ptr_row = output_ptr + (row_index * stride_output_row)\n output_pointers = output_ptr_row + col_offsets\n tl.store(output_pointers, softmax_out, mask=row_mask)\n\ndef softmax(x: torch.Tensor) -> torch.Tensor:\n rows, cols = x.shape\n assert x.dim() == 2, f\"only accepts 2D tensors for now\"\n block_size = triton.next_power_of_2(cols)\n num_warps = 4\n if block_size > 2047:\n num_warps = 8\n if block_size > 4095:\n num_warps = 16\n\n grid = (rows,)\n softmax_out = torch.empty_like(x)\n\n _softmax_fwd_kernel[grid](\n softmax_out,\n softmax_out.stride(0),\n x,\n x.stride(0),\n cols,\n block_size=block_size,\n num_warps=num_warps,\n )\n\n return softmax_out\n\nsample = torch.tensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]], dtype=torch.float32, device='cuda')\ntriton_out = softmax(sample)\nprint(f\"{triton_out=}\")\n", - "description_1": "Use triton language to implement a forward pass of the softmax operation. The _softmax_fwd_kernel function takes 6 arguments: 1) output_ptr: a pointer to the output memory location, 2) stride_output_row: the stride between rows in the output, 3) input_ptr: a pointer to the input memory location, 4) stride_input_row: the stride between rows in the input, 5) num_cols: the number of columns in the input, and 6) block_size: a compile-time constant indicating the block size for processing. The softmax function prepares input dimensions and block size, allocates output memory, and calls _softmax_fwd_kernel to compute softmax over each row of the input tensor.", - "description_2": "Use triton language to compute the softmax operation over a 2D input tensor by implementing a kernel function with memory pointers and strides, and calling it from a softmax function.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef softmax_kernel(input_ptr, output_ptr, \n input_row_stride, output_row_stride, \n n_cols, BLOCK_SIZE: tl.constexpr):\n # Get the batch index\n batch_idx = tl.program_id(0)\n\n # Compute the start pointer for the batch\n batch_start_ptr = input_ptr + batch_idx * input_row_stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = batch_start_ptr + col_offsets\n \n # Load the row with masking\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) \n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n\n # Write back the result to DRAM\n output_row_start_ptr = output_ptr + batch_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(input: Tensor) -> Tensor:\n # Reshape input for processing\n reshaped_input = input.unsqueeze(0) if input.ndim == 1 else input\n reshaped_input = reshaped_input.flatten(0, -2)\n \n batch_dim, feat_dim = reshaped_input.shape\n BLOCK_SIZE = triton.next_power_of_2(feat_dim)\n num_warps = 8\n\n output = torch.empty_like(reshaped_input)\n\n # Launch the Triton kernel\n softmax_kernel[(batch_dim, )](reshaped_input, output, reshaped_input.stride(0), output.stride(0),\n feat_dim, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE)\n\n return output.view_as(input)\n", - "description_1": "Use triton language to implement a softmax operation on a 2D tensor. The kernel 'softmax_kernel' takes 6 parameters: input_ptr (pointer to input data), output_ptr (pointer to output data), input_row_stride (stride of input rows), output_row_stride (stride of output rows), n_cols (number of columns in the input), and BLOCK_SIZE (block size for processing). The function 'softmax' prepares the input tensor, calculates the block size, and launches the kernel with appropriate parameters.", - "description_2": "Use triton language to implement a softmax operation on a 2D tensor using a kernel that processes data in blocks and writes the result back to memory.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,\n BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):\n # starting row of the program\n row_start = tl.program_id(0)\n row_step = tl.num_programs(0)\n for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):\n row_start_ptr = input_ptr + row_idx * input_row_stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n mask = col_offsets < n_cols\n row = tl.load(input_ptrs, mask=mask, other=-float('inf'))\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=mask)\n\ndef softmax(input):\n reshaped_input = input.unsqueeze(0) if input.ndim == 1 else input\n reshaped_input = input.flatten(0, -2)\n n_rows, n_cols = reshaped_input.shape\n\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n num_warps = 4\n num_stages = 4\n output = torch.empty_like(reshaped_input)\n\n kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))\n if kernel is None:\n kernel = softmax_kernel.warmup(output, reshaped_input,\n reshaped_input.stride(0), output.stride(0),\n n_rows, n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_stages=num_stages,\n num_warps=num_warps,\n grid=(1,))\n kernel._init_handles()\n kernels[BLOCK_SIZE] = (kernel, 1)\n\n num_programs = min(1, n_rows)\n\n kernel[(num_programs, 1, 1)](\n output,\n reshaped_input,\n reshaped_input.stride(0),\n output.stride(0),\n n_rows,\n n_cols,\n )\n return output.view_as(input)\n", - "description_1": "Use triton language to implement a softmax kernel function and a corresponding softmax function. The kernel function 'softmax_kernel' has seven parameters: output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, and BLOCK_SIZE, where output_ptr and input_ptr are pointers to the output and input data respectively, input_row_stride and output_row_stride are strides for accessing rows, n_rows and n_cols are dimensions of the input data, and BLOCK_SIZE is a compile-time constant. The softmax function manages the reshaping of input data, computes necessary constants like BLOCK_SIZE and num_warps, allocates output space, and launches the kernel function with appropriate parameters.", - "description_2": "Use triton language to create a softmax operation with parallel processing capability, define the kernel computation pattern, manage inputs, and execute the operation across GPU threads.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, # ptr to 1st input vector\n y_ptr, # ptr to 2nd input vector\n output_ptr, # ptr to output vector\n n_elements, # size of the vector\n BLOCK_SIZE: tl.constexpr, # # of elements each program should process\n ):\n pid = tl.program_id(axis=0) # 1D Launch grid so axis is 0\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel 'add_kernel' takes 5 parameters: x_ptr (pointer to the first input vector), y_ptr (pointer to the second input vector), output_ptr (pointer to the output vector), n_elements (size of the vector), and BLOCK_SIZE (number of elements each program should process). The kernel computes the element-wise sum of two input vectors and stores the result in the output vector. The 'add' function is a wrapper that prepares the output tensor, sets up the grid for kernel execution, and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a kernel for element-wise vector addition and a wrapper function to execute it on CUDA tensors.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n # NOTE: `constexpr` so it can be used as a shape value.\n ):\n # There are multiple 'programs' processing different data. We identify which program\n # we are here:\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n # This program will process inputs that are offset from the initial data.\n # For instance, if you had a vector of length 256 and block_size of 64, the programs\n # would each access the elements [0:64, 64:128, 128:192, 192:256].\n # Note that offsets is a list of pointers:\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to guard memory operations against out-of-bounds accesses.\n mask = offsets < n_elements\n # Load x and y from DRAM, masking out any extra elements in case the input is not a\n # multiple of the block size.\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n # Write x + y back to DRAM.\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n # We need to preallocate the output.\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n # The SPMD launch grid denotes the number of kernel instances that run in parallel.\n # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].\n # In this case, we use a 1D grid where the size is the number of blocks:\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n # NOTE:\n # - Each torch.tensor object is implicitly converted into a pointer to its first element.\n # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.\n # - Don't forget to pass meta-parameters as keywords arguments.\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still\n # running asynchronously at this point.\n return output\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel function 'add_kernel' takes five parameters: pointers to the input vectors x and y, a pointer to the output vector, the number of elements in the vectors, and a block size as a compile-time constant. The kernel computes the element-wise sum of x and y, storing the result in the output vector. The 'add' function is a wrapper that prepares the input tensors, sets up the execution grid, and launches the kernel.", - "description_2": "Use triton language to create a kernel for element-wise vector addition and a wrapper function to execute it on CUDA tensors.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef kernel_vector_addition(a_ptr, b_ptr, out_ptr, num_elems: tl.constexpr, block_size: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * block_size\n thread_offsets = block_start + tl.arange(0, block_size)\n mask = thread_offsets < num_elems\n a_pointers = tl.load(a_ptr + thread_offsets, mask=mask)\n b_pointers = tl.load(b_ptr + thread_offsets, mask=mask)\n result = a_pointers + b_pointers\n tl.store(out_ptr + thread_offsets, result, mask=mask)\n\ndef ceil_div(x: int, y: int) -> int:\n return ((x + y - 1) // y)\n\ndef vector_addition(a: torch.tensor, b: torch.tensor) -> torch.tensor:\n output_buffer = torch.empty_like(a)\n assert a.is_cuda and b.is_cuda\n num_elems = a.numel()\n assert num_elems == b.numel()\n block_size = 1024\n grid_size = ceil_div(num_elems, block_size)\n grid = (grid_size,)\n num_warps = 8\n kernel_vector_addition[grid](a, b, output_buffer, num_elems, block_size, num_warps=num_warps)\n return output_buffer\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel 'kernel_vector_addition' takes 5 parameters: a_ptr, b_ptr, out_ptr (pointers to input and output data), num_elems (total number of elements to process), and block_size (size of each block of threads). It computes the sum of two vectors element-wise and stores the result. The function 'vector_addition' is a wrapper that prepares the data and calls the kernel. It takes two torch tensors 'a' and 'b', ensures they are on CUDA, and have the same number of elements. It then sets up the grid and block size, and calls the kernel to perform the addition.", - "description_2": "Use triton language to create a kernel for element-wise vector addition, and a wrapper function to handle data preparation and kernel invocation.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr\n):\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, \n stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr, \n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement the SWIGLU forward and backward functions. The forward function _swiglu_fwd_kernel takes 7 arguments: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, and ncols, where X and Y are input matrices, OUT is the output matrix, and strides are for accessing rows. It computes the element-wise product of X, sigmoid(X), and Y, storing the result in OUT. The backward function _swiglu_bwd_kernel takes 14 arguments: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, and BLOCK_N. It calculates gradients DX and DY, using the derivative of the SWIGLU activation function. Optional recomputation of OUT based on the argument RECOMPUTE_OUTPUT is also included.", - "description_2": "Use triton language to perform SWIGLU activation in the forward pass by computing element-wise product of input matrices with sigmoid and implement the backward pass to calculate the gradients for input matrices with optional recomputation of the output matrix.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\nconfigs_autotune = [\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n]\n\ndef config_prune(configs):\n warp_size = 32 # default warp size\n max_block_sz = 1024\n max_num_warps = max_block_sz // warp_size\n pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]\n return pruned_configs\n\npruned_configs_autotune = config_prune(configs_autotune)\n\n@triton.autotune(\n configs = pruned_configs_autotune,\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n@triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n X1,\n W1,\n B1,\n Y1,\n RESIDUAL_OUT, # pointer to the residual\n ROWSCALE,\n SEEDS, # Dropout seeds for each row\n DROPOUT_MASK,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n stride_x1_row,\n stride_y1_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n dropout_p, # Dropout probability\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_DROPOUT: tl.constexpr,\n STORE_DROPOUT_MASK: tl.constexpr,\n HAS_ROWSCALE: tl.constexpr,\n HAS_X1: tl.constexpr,\n HAS_W1: tl.constexpr,\n HAS_B1: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n if HAS_X1:\n X1 += row * stride_x1_row\n if HAS_W1:\n Y1 += row * stride_y1_row\n\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n x *= rowscale\n if HAS_DROPOUT:\n keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n if HAS_X1:\n x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n x1 *= rowscale\n if HAS_DROPOUT:\n keep_mask = (\n tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n )\n x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)\n x += x1\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n\n tl.store(Y + cols, y, mask=mask)\n if HAS_W1:\n w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n if HAS_B1:\n b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n tl.store(Y1 + cols, y1, mask=mask)\n\n\ndef _layer_norm_fwd(\n x,\n weight,\n bias,\n eps,\n residual=None,\n x1=None,\n weight1=None,\n bias1=None,\n dropout_p=0.0,\n rowscale=None,\n out_dtype=None,\n residual_dtype=None,\n is_rms_norm=False,\n return_dropout_mask=False,\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n if weight1 is not None:\n y1 = torch.empty_like(y)\n else:\n y1 = None\n if (\n residual is not None\n or (residual_dtype is not None and residual_dtype != x.dtype)\n or dropout_p > 0.0\n or rowscale is not None\n or x1 is not None\n ):\n residual_out = torch.empty(\n M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype\n )\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n if dropout_p > 0.0:\n seeds = torch.randint(\n 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n )\n else:\n seeds = None\n if return_dropout_mask and dropout_p > 0.0:\n dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)\n else:\n dropout_mask = None\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n x1,\n weight1,\n bias1,\n y1,\n residual_out,\n rowscale,\n seeds,\n dropout_mask,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n x1.stride(0) if x1 is not None else 0,\n y1.stride(0) if y1 is not None else 0,\n M,\n N,\n eps,\n dropout_p,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n dropout_p > 0.0,\n dropout_mask is not None,\n rowscale is not None,\n )\n if dropout_mask is not None and x1 is not None:\n dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)\n else:\n dropout_mask1 = None\n return (\n y,\n y1,\n mean,\n rstd,\n residual_out if residual_out is not None else x,\n seeds,\n dropout_mask,\n dropout_mask1,\n )\n", - "description_1": "Use triton language to implement a forward pass for layer normalization with optional dropout, bias, residual connections, and row scaling. This kernel processes input matrices in a row-wise fashion, normalizes each row by calculating mean and variance (or RMS), and optionally applies dropout, biases, and residuals before producing the output.", - "description_2": "Use triton language to perform layer normalization forward pass with support for dropout, bias, and residual connections.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n", - "description_1": "Use triton language to implement a forward pass of a layer normalization operation. The kernel '_layer_norm_fwd_1pass_kernel' takes 17 parameters: pointers to input (X), output (Y), weights (W), biases (B), another branch (Z), mean (Mean), and 1/std (Rstd), strides for X, Y, and Z, number of rows (M) and columns (N) in X, epsilon (eps) to avoid division by zero, and several compile-time constants (BLOCK_N, HAS_BIAS, HAS_Z, NORM_BEFORE_GATE, IS_RMS_NORM). The function '_layer_norm_fwd' is a wrapper that prepares the input data and calls the kernel with appropriate grid and block sizes.", - "description_2": "Use triton language to implement a forward pass of a layer normalization operation with support for optional bias and additional branch input, handling both RMS and standard normalization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n # Strides\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt) # scalar, not a matrix\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt # vector of size (dstate,)\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n x: (batch, dim) or (batch, nheads, dim)\n dt: (batch, dim) or (batch, nheads, dim)\n A: (dim, dstate) or (nheads, dim, dstate)\n B: (batch, dstate) or (batch, ngroups, dstate)\n C: (batch, dstate) or (batch, ngroups, dstate)\n D: (dim,) or (nheads, dim)\n z: (batch, dim) or (batch, nheads, dim)\n dt_bias: (dim,) or (nheads, dim)\n Return:\n out: (batch, dim) or (batch, nheads, dim)\n \"\"\"\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 43 parameters for matrix operations and a wrapper function 'selective_state_update' with 9 parameters to manage tensor dimensions and call the kernel.", - "description_2": "Use triton language to create a kernel for selective state update with matrix operations and a wrapper to handle tensor dimensions and kernel invocation.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\n# Triton kernel to compute the softplus function element-wise\nif TRITON3:\n @triton.jit\n def softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n @triton.jit\n def softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n", - "description_1": "Use triton language to define a kernel function called softplus which takes one parameter `dt` representing input tensor data. Depending on the Triton version, the function applies an element-wise softplus transformation where for elements less than or equal to 20, the softplus is computed using either log(exp(dt) + 1) or log1p(exp(dt)). Elements greater than 20 are left unchanged.", - "description_2": "Use triton language to define a kernel function for the softplus transformation on tensor data with version-dependent behavior.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement two kernels: _bmm_chunk_fwd_kernel and _bmm_chunk_bwd_kernel. The _bmm_chunk_fwd_kernel performs a batched matrix multiplication with optional sequence index masking and causal masking. It takes 24 parameters including pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. The _bmm_chunk_bwd_kernel computes the gradient of the batched matrix multiplication with respect to one of the input matrices. It takes 22 parameters including pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration.", - "description_2": "Use triton language to implement forward and backward kernels for batched matrix multiplication with optional sequence index and causal masking.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pass\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n", - "description_1": "Use triton language to implement a forward chunk scan kernel that processes input matrices and transformations to produce an output tensor, with optional parameters for additional transformations and states.", - "description_2": "Use triton language to process matrix chunks and compute forward scans over them, handling various meta-parameters for optimization.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = softplus(dt)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n@triton.jit\ndef _chunk_cumsum_bwd_kernel(\n ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,\n ddt_ptr, dA_ptr, ddt_bias_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,\n stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,\n stride_dA_head,\n stride_ddt_bias_head,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk\n ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)\n ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n ddt = ddA * A[:, None] + ddt_out\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt_presoftplus = dt\n dt = softplus(dt)\n clamp_mask = (dt < dt_min) | (dt > dt_max)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)\n ddt = tl.where(clamp_mask, 0.0, ddt)\n if DT_SOFTPLUS:\n ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)\n tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))\n dA = tl.sum(ddA * dt, axis=1)\n tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)\n if HAS_DT_BIAS:\n ddt_bias = tl.sum(ddt, axis=1)\n tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n\ndef _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), ddt=None):\n batch, seqlen, nheads = dt.shape\n _, _, nchunks, chunk_size = ddA.shape\n assert ddA.shape == (batch, nheads, nchunks, chunk_size)\n assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)\n else:\n ddt_bias = None\n if ddt is not None:\n assert ddt.shape == dt.shape\n else:\n ddt = torch.empty_like(dt)\n dA = torch.empty_like(A, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_bwd_kernel[grid_chunk_cs](\n ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),\n ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n ddt.stride(0), ddt.stride(1), ddt.stride(2),\n dA.stride(0),\n ddt_bias.stride(0) if ddt_bias is not None else 0,\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return ddt, dA, ddt_bias\n", - "description_1": "Use triton language to implement a kernel function for forward and backward cumulative summation with optional softplus transformation. The forward kernel processes inputs with pointers to matrices dt, A, optional dt_bias, and produces dt_out and cumulative sums dA_cumsum, with specific matrix dimensions, strides, and meta-parameters indicating the presence of softplus transformation and bias. The backward kernel receives gradients with respect to A (ddA), outputs (ddt_out), and similar parameters to compute the gradients with respect to dt, A, and optional dt_bias.", - "description_2": "Use triton language to implement cumulative summation operations with optional bias and softplus activation for both forward and backward passes in a neural network training loop.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch import Tensor\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n ],\n key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n # Pointers to matrices\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n # Matrix dimensions\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n # Strides\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n # Meta-parameters\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Kernel code implementation\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n # Python wrapper function that calls the above Triton kernel\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a kernel for backward pass computation over chunks. The kernel handles data points across different chunks and dimensions. This involves complex pointer manipulations and tensor computations using Triton language. It uses autotuning for optimal configurations and is responsible for calculating gradients for given tensors across blocks. The function _chunk_scan_chunk_state_bwd_dx acts as a wrapper to call this kernel by preparing necessary data and configurations.", - "description_2": "Use triton language to implement and optimize a backward pass kernel for chunk-wise tensor operations, handling multiple dimensions and using autotuning for efficiency.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n # Pointers to matrices\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n # Matrix dimensions\n dim, nchunks, seqlen, chunk_size,\n # Strides\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n # Meta-parameters\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n # Pointers to matrices\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n # Matrix dimensions\n dim, nchunks, seqlen, chunk_size,\n # Strides\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n # Meta-parameters\n CONVERT_STATES: tl.constexpr,\n HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n else:\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\n\ndef _state_passing_bwd(\n states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,\n dstates_dtype=None, states_dtype=None, chunk_size=None\n):\n \"\"\"\n states contains the initial_states at index 0. The final states are not included in states.\n \"\"\"\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = torch.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert dfinal_states.shape == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n if dfinal_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n if dinitstates is not None else (0, 0, 0)),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n", - "description_1": "Use triton language to implement forward and backward kernels for state passing operations. The forward kernel has 33 parameters, handling input states, output pointers, initial states, sequence indices, dimensions, strides, and meta-parameters for initialization and sequence indexing. The backward kernel has 35 parameters for processing gradient information and similarly requires input, output, dimension, stride, and meta-parameter inputs. Both are highly parameterized to ensure efficient matrix operations and flexibility in handling optional inputs.", - "description_2": "Use triton language to create forward and backward state passing kernels. Forward kernel processes inputs with sequence handling, while backward kernel calculates gradients and supports optional state conversions. Both require precise configuration of dimensions, strides, and optional parameters for full functionality.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row,\n N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a forward pass kernel for layer normalization with optional residuals and bias. The kernel computes the mean and variance of the input, normalizes it, and applies a linear transformation using weights and optional bias. The kernel is optimized with autotuning for different warp configurations.", - "description_2": "Use triton language to implement a forward pass for layer normalization with optional residuals and bias, optimized with autotuning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 35 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to call the kernel with appropriate grid and block size configurations.", - "description_2": "Use triton language to create a kernel for selective state update with optional bias and scaling, and a wrapper to configure and launch the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n # Write dx\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n # Don't need to compute dresidual_in separately in this case\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to define two kernels: _layer_norm_fwd_1pass_kernel and _layer_norm_bwd_kernel. The _layer_norm_fwd_1pass_kernel takes 18 parameters, including pointers to input/output tensors and constants for normalization, and performs a forward pass of layer normalization. The _layer_norm_bwd_kernel takes 26 parameters, including pointers to input/output tensors and gradients, and performs a backward pass, computing gradients for input, weights, and bias.", - "description_2": "Use triton language to perform efficient forward and backward passes for layer normalization using kernels with multiple configurations for automatic tuning.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 35 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to manage input/output tensors and launch the kernel.", - "description_2": "Use triton language to create a kernel for selective state update with optional bias and scaling, and a wrapper to handle tensor operations and kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X,\n Y,\n OUT,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_out_row,\n ncols,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X,\n Y,\n DOUT,\n OUT,\n DX,\n DY,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dout_row,\n stride_out_row,\n stride_dx_row,\n stride_dy_row,\n ncols,\n BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement two kernels: _swiglu_fwd_kernel and _swiglu_bwd_kernel. The _swiglu_fwd_kernel takes 7 parameters: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, and ncols. It computes the forward pass of the SwiGLU activation function using Triton. The _swiglu_bwd_kernel takes 14 parameters: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, and RECOMPUTE_OUTPUT. It computes the backward pass of the SwiGLU activation function, optionally recomputing the output if RECOMPUTE_OUTPUT is true.", - "description_2": "Use triton language to implement forward and backward kernels for the SwiGLU activation function, handling input and output strides and optional recomputation of outputs.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n", - "description_1": "Use triton language to implement a forward pass of layer normalization with optional bias and gating mechanism. The kernel '_layer_norm_fwd_1pass_kernel' takes 17 parameters: pointers to input, output, weights, biases, optional other branch, mean, and 1/std, strides for input, output, and optional other branch, number of rows and columns in input, epsilon for numerical stability, and several compile-time constants. The function '_layer_norm_fwd' prepares data and launches the kernel with appropriate grid and block sizes.", - "description_2": "Use triton language to implement a forward pass of layer normalization with optional bias and gating mechanism. The kernel '_layer_norm_fwd_1pass_kernel' takes 17 parameters: pointers to input, output, weights, biases, optional other branch, mean, and 1/std, strides for input, output, and optional other branch, number of rows and columns in input, epsilon for numerical stability, and several compile-time constants.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n # Strides\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt)\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel called '_selective_scan_update_kernel' that updates the state matrix using several input matrices and parameters. The kernel accepts 27 pointers for matrices and 15 other parameters including matrix dimensions, strides, and meta-parameters. It computes an updated state and stores results based on conditions defined by meta-parameters like HAS_DT_BIAS, HAS_D, and HAS_Z. The calling function, 'selective_state_update', adapts input dimensions and calls this kernel with appropriate grid and configuration.", - "description_2": "Use triton language to implement a kernel that updates a state matrix using inputs including matrices and parameters. It handles matrix dimensions, strides, and computes updates under specified conditions, then stores the results. A wrapper function prepares inputs and calls the kernel with configured grid.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n @triton.jit\n def softplus(dt):\n # Apply the softplus function using Triton 3.0.0 or newer\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n @triton.jit\n def softplus(dt):\n # Apply the softplus function using Triton versions older than 3.0.0\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n", - "description_1": "Use triton language to implement a softplus function kernel that takes one parameter 'dt'. The kernel applies the softplus function using different implementations based on the Triton version. For Triton 3.0.0 or newer, it uses 'tl.math.log(tl.math.exp(dt) + 1)', and for older versions, it uses 'tl.math.log1p(tl.exp(dt))'.", - "description_2": "Use triton language to implement a version-dependent softplus function kernel with one parameter.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n # Implementation omitted for brevity\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n # Implementation omitted for brevity\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement a block matrix multiplication forward and backward kernel. The forward kernel (_bmm_chunk_fwd_kernel) takes pointers to input matrices a, b, and an output matrix. It performs block matrix multiplication based on the provided strides and sequence indices, with optional causal masking. The backward kernel (_bmm_chunk_bwd_kernel) computes the gradient with respect to the input matrices using the gradient of the output and an optional residual matrix.", - "description_2": "Use triton language to perform block matrix multiplication with optional causal and sequence index masking, and compute the backward gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pass\n\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n", - "description_1": "Use triton language to implement a kernel function decorated with @triton.jit named _chunk_scan_fwd_kernel which involves multiple parameters including pointers, dimensions, strides, and meta-parameters used for processing matrix and tensor operations. Additionally, implement a Python function _chunk_scan_fwd to prepare and call the Triton kernel with correctly calculated grid dimensions and stride parameters based on input tensors for operations on matrices representing batched multi-dimensional data.", - "description_2": "Use triton language to create a kernel that processes complex multi-dimensional tensor data with optimized configurations, and create a wrapper function to facilitate its execution.", - "difficulty": 5 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}),\n triton.Config({'BLOCK_SIZE_H': 2}),\n triton.Config({'BLOCK_SIZE_H': 4}),\n triton.Config({'BLOCK_SIZE_H': 8}),\n triton.Config({'BLOCK_SIZE_H': 16}),\n triton.Config({'BLOCK_SIZE_H': 32}),\n triton.Config({'BLOCK_SIZE_H': 64}),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = softplus(dt)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_bwd_kernel(\n ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,\n ddt_ptr, dA_ptr, ddt_bias_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,\n stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,\n stride_dA_head,\n stride_ddt_bias_head,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk\n ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)\n ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n ddt = ddA * A[:, None] + ddt_out\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt_presoftplus = dt\n dt = softplus(dt)\n clamp_mask = (dt < dt_min) | (dt > dt_max)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)\n ddt = tl.where(clamp_mask, 0.0, ddt)\n if DT_SOFTPLUS:\n ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)\n tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))\n dA = tl.sum(ddA * dt, axis=1)\n tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)\n if HAS_DT_BIAS:\n ddt_bias = tl.sum(ddt, axis=1)\n tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)\n\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n\n\ndef _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), ddt=None):\n batch, seqlen, nheads = dt.shape\n _, _, nchunks, chunk_size = ddA.shape\n assert ddA.shape == (batch, nheads, nchunks, chunk_size)\n assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)\n else:\n ddt_bias = None\n if ddt is not None:\n assert ddt.shape == dt.shape\n else:\n ddt = torch.empty_like(dt)\n dA = torch.empty_like(A, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_bwd_kernel[grid_chunk_cs](\n ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),\n ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n ddt.stride(0), ddt.stride(1), ddt.stride(2),\n dA.stride(0),\n ddt_bias.stride(0) if ddt_bias is not None else 0,\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return ddt, dA, ddt_bias\n\n", - "description_1": "Use triton language to create forward and backward kernel functions for cumulative sum operations over chunks of input data. The forward kernel _chunk_cumsum_fwd_kernel calculates the cumulative sum for each chunk, with options for bias and softplus transformations. The backward kernel _chunk_cumsum_bwd_kernel computes gradients for the inputs and bias using the derivatives of the cumulative sum. Both kernels take multiple parameters including pointers to input/output tensors, matrix dimensions, and stride values for efficient memory access.", - "description_2": "Use triton language to define a chunk-wise cumulative sum forward function that computes cumulative sums with optional bias and activation, and a backward function that computes gradients for input and bias, both of which operate efficiently on matrices in a batched manner.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel function\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n if not HAS_SEQ_IDX:\n scale = tl.exp(dA_cs_last - dA_cs_m)\n else:\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n\n offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc = tl.dot(b, dstates) * scale[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc += tl.dot(b, dstates)\n b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n acc *= scale[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit\n K_MIN = pid_m * BLOCK_SIZE_M\n cb_ptrs += K_MIN * stride_cb_csize_k\n dout_ptrs += K_MIN * stride_dout_seqlen\n dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n k = tl.multiple_of(k, BLOCK_SIZE_K)\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(dout_ptr.dtype.element_ty)\n acc += tl.dot(cb, dout)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n dx = acc * dt_m[:, None]\n dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n if HAS_D:\n dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n dx += dout_res * D\n tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if HAS_D:\n dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n if D_HAS_HDIM:\n dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n dD = tl.sum(dout_res * x, axis=0)\n tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n else:\n dD = tl.sum(dout_res * x)\n tl.store(dD_ptr, dD)\n ddt = tl.sum(acc * x, axis=1)\n ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n# Function to call the Triton kernel\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a backward kernel for chunk scan with 45 parameters including pointers, dimensions, strides, and meta-parameters, and a wrapper to invoke this kernel.", - "description_2": "Use triton language to compute backward gradients for chunk scan operations using a kernel with detailed data management and execution parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _state_passing_fwd_kernel(\n # Pointers to matrices\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n # Matrix dimensions\n dim, nchunks, seqlen, chunk_size,\n # Strides\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n # Meta-parameters\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n@triton.jit\ndef _state_passing_bwd_kernel(\n # Pointers to matrices\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n # Matrix dimensions\n dim, nchunks, seqlen, chunk_size,\n # Strides\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n # Meta-parameters\n CONVERT_STATES: tl.constexpr,\n HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n else:\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\n\ndef _state_passing_bwd(\n states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,\n dstates_dtype=None, states_dtype=None, chunk_size=None\n):\n \"\"\"\n states contains the initial_states at index 0. The final states are not included in states.\n \"\"\"\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = torch.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert dfinal_states.shape == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n if dfinal_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n if dinitstates is not None else (0, 0, 0)),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n", - "description_1": "Use triton language to implement two kernels: `_state_passing_fwd_kernel` and `_state_passing_bwd_kernel`. `_state_passing_fwd_kernel` handles the forward pass with parameters for pointers to input/output matrices, matrix dimensions, strides, and meta-parameters like whether initial states or sequence indices are used. `_state_passing_bwd_kernel` handles the backward pass with similar parameters plus additional meta-parameters like whether to convert states or include final states.", - "description_2": "Use triton language to create forward and backward kernels for state passing operations, handling pointer calculations and conditional logic based on matrix properties and meta-parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef my_kernel(X, stride_xm, stride_xn, Y, stride_ym, stride_yn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Program ID\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n\n # Create offsets for memory access\n offsets_xm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offsets_xn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n offsets_ym = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offsets_yn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # Load data from X and Y\n x = tl.load(X + offsets_xm[:, None] * stride_xm + offsets_xn[None, :] * stride_xn)\n y = tl.load(Y + offsets_ym[:, None] * stride_ym + offsets_yn[None, :] * stride_yn)\n\n # Perform elementwise add operation\n result = x + y\n\n # Store result\n tl.store(X + offsets_xm[:, None] * stride_xm + offsets_xn[None, :] * stride_xn, result)\n\ndef call_my_kernel(x, y, stride_xm, stride_xn, stride_ym, stride_yn):\n # Define constants\n BLOCK_M = 128\n BLOCK_N = 128\n\n # Calculate grid dimensions\n grid_m = (x.shape[0] + BLOCK_M - 1) // BLOCK_M\n grid_n = (x.shape[1] + BLOCK_N - 1) // BLOCK_N\n\n # Launch the Triton kernel\n my_kernel[(grid_m, grid_n)](x, stride_xm, stride_xn, y, stride_ym, stride_yn, BLOCK_M, BLOCK_N)\n", - "description_1": "Use triton language to create a kernel my_kernel that performs an elementwise addition of two input matrices X and Y. The kernel uses block-wise memory operations to optimize the process. The function call_my_kernel launches the kernel on GPU using the calculated grid dimensions for blocks.", - "description_2": "Use triton language to create and call a kernel for elementwise matrix addition with block-wise memory access.", - "difficulty": 3 - }, - { - "code": "import triton\n\n@triton.jit\ndef kernel_example(arg1, arg2):\n # Your kernel code here\n pass\n\ndef call_kernel_example():\n # Prepare arguments\n arg1 = ...\n arg2 = ...\n # Call the Triton kernel\n kernel_example[(grid_x,)](arg1, arg2)\n\n# Assuming (grid_x,) is defined appropriately for the kernel launch\n", - "description_1": "Use triton language to define a kernel with two arguments, arg1 and arg2, and implement your custom functionality inside the kernel. A separate function prepares the arguments and calls this Triton kernel using an appropriate grid size, for example, (grid_x,).", - "description_2": "Use triton language to define a kernel with specific arguments and call this kernel with a grid configuration.", - "difficulty": 3 - }, - { - "code": "# mypy: allow-untyped-defs\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef my_kernel(INPUT: tl.tensor, OUTPUT: tl.tensor, N: tl.constexpr):\n pid = tl.program_id(0)\n offset = pid * N\n in_ptrs = INPUT + offset + tl.arange(0, N)\n out_ptrs = OUTPUT + offset + tl.arange(0, N)\n x = tl.load(in_ptrs)\n y = x * x\n tl.store(out_ptrs, y)\n\ndef call_my_kernel(input_tensor, output_tensor, n_elements):\n grid = lambda meta: (triton.cdiv(input_tensor.size(0), meta['N']),)\n my_kernel[grid](input_tensor, output_tensor, N=n_elements)\n\n# Example usage\ninput_tensor = torch.randn(1024, device='cuda')\noutput_tensor = torch.empty_like(input_tensor)\ncall_my_kernel(input_tensor, output_tensor, 128)\n", - "description_1": "Use triton language to define a kernel my_kernel with three parameters: INPUT (tensor), OUTPUT (tensor), and N (constexpr). The kernel computes the square of each element from INPUT and stores the result in OUTPUT. Launch this kernel using call_my_kernel function, which takes three parameters: input_tensor, output_tensor, and n_elements.", - "description_2": "Use triton language to square elements of a tensor using a kernel and store results in another tensor, then call this kernel with appropriate tensor parameters.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' with 4 parameters: X, Y, Z, and BLOCK_SIZE. The kernel is called using 'call_example_kernel' function with parameters x, y, z, and block_size.", - "description_2": "Use triton language to define a kernel and call it with specified parameters.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel function to promote input to tensor\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n# Kernel function to check if input is floating\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n# Kernel function to accumulate product\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n# Kernel function to compute product along specified axis\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n# Kernel function to compute the minimum of two inputs\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Kernel function to compute the maximum of two inputs\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Kernel function to compute the minimum along specified dimension\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n# Kernel function to compute the maximum along specified dimension\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n# Kernel function to compute the minimum value with index\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n# Kernel function to compute the maximum value with index\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n# Kernel function to compute minimum with index along specified dimension\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n# Kernel function to compute maximum with index along specified dimension\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n# Kernel function to perform Welford's algorithm for online variance\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n# Kernel function to combine Welford statistics\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n# Kernel function to perform Welford's reduction along specified dimension\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n# Kernel function to assert condition and return value\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n# Kernel function to generate random integer in range\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n# Kernel function to combine with bitwise OR\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n# Kernel function to perform reduction with logical OR along specified dimension\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n# Kernel function to perform binary search in buckets\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n full_range = (full_range + 1) // 2\n return low\n\n# Kernel function to pack value and flag\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n# Kernel function to unpack value from packed data\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n# Kernel function to unpack flag from packed data\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n# Kernel function to compute exclusive scan with decoupled lookback\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``\n DTYPE_PACK: Unsigned type twice the width of block_value\n\n NOTE: This function is limited to values which are 32-bits or less because\n we need to pack (value, flag) into a single unsigned int.\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n if index > 0:\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], DTYPE_VALUE)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n # tl.atomic_load\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n# Kernel function to compute exclusive scan with decoupled lookback for 64-bit blocks\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block, must be 64-bits wide\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n if index > 0:\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], block_value.dtype)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n return exclusive_prefix\n\n# Kernel function to compute mantissa and exponent\n@triton.jit\ndef frexp(x):\n # TODO: use inline_asm_elementwise here\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n", - "description_1": "Use triton language to implement a series of kernel functions, each designed for specific arithmetic and logical operations. The functions include operations like promotion to tensor, checking if a number is floating-point, accumulating products, computing minimum and maximum values (with and without indexes), performing Welford's reduction for variance calculation, asserting device conditions, generating random integers, reducing values using logical operations, performing binary search within buckets, packing and unpacking values with flags, and computing exclusive scan with decoupled lookback (for both generic and 64-bit data). Additionally, it includes computing mantissa and exponent for given values.", - "description_2": "Use triton language to create a set of specialized kernels for arithmetic, reduction, and random operations, along with index-based and conditional computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\n\n# Use triton.jit to define a basic kernel that demonstrates simple vector addition.\n@triton.jit\ndef vector_add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n \n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n result = x + y\n \n tl.store(output_ptr + offsets, result, mask=mask)\n\n\ndef vector_add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor):\n assert x.shape == y.shape == output.shape, \"Input and output tensors must have the same shape\"\n \n n_elements = x.numel()\n block_size = 1024\n grid = lambda opt: (triton.cdiv(n_elements, opt.block_size),)\n vector_add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size)\n\n\n# Example usage of the vector_add function\nx = torch.rand(10000, device='cuda')\ny = torch.rand(10000, device='cuda')\noutput = torch.empty_like(x)\nvector_add(x, y, output)\n", - "description_1": "Use triton language to create a kernel named vector_add_kernel that performs element-wise addition of two input vectors x_ptr and y_ptr, storing the result in output_ptr. The kernel operates on BLOCK_SIZE elements per thread block, ensuring that accesses beyond n_elements are masked out. The kernel is launched by the vector_add function, which calculates the grid size needed for execution and asserts the input and output tensor shapes are compatible. The grid is determined by triton.cdiv(n_elements, block_size), and the block_size is set as a compile-time constant.", - "description_2": "Use triton language to perform element-wise addition of two vectors using a custom kernel with block-based parallel execution.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input_broadcasted._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\ndef _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None\n):\n f_name = \"_scaled_dot_product_attention\"\n check(\n not is_causal,\n f\"{f_name}(): is_causal == True is not supported.\"\n )\n check(\n attn_mask is not None,\n f\"{f_name}(): attn_mask == None is not supported.\"\n )\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\"\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\"\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\n", - "description_1": "Use triton language to implement a sparse matrix multiplication kernel for matrices stored in a block sparse row (BSR) format, with support for broadcasting batch dimensions and different tensor layouts. The kernel supports calculating matrix products with specified block sizes and handles zero-matrix scenarios.", - "description_2": "Use triton language to implement scaled dot product attention using a BSR sparse mask with dropout support.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=4, num_warps=4\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n@triton.jit\ndef add_kernel_with_scaling(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n scaling_factor,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = (x + y) * scaling_factor\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef zero_negs(x):\n return tl.where(x >= 0, x, 0)\n\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n elif ACTIVATION == \"add_kernel\":\n add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n\n@triton.jit\ndef double_strided_kernel(\n in_ptr,\n out_ptr,\n in_y_stride,\n out_y_stride,\n X_BLOCK_SIZE: \"tl.constexpr\",\n Y_BLOCK_SIZE: \"tl.constexpr\",\n):\n xid = tl.program_id(axis=0)\n yid = tl.program_id(axis=1)\n x_start = xid * X_BLOCK_SIZE\n y_start = yid * Y_BLOCK_SIZE\n x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)\n y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)\n src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]\n dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]\n src = tl.load(in_ptr + src_offsets)\n tl.store(out_ptr + dst_offsets, src * 2.0)\n\n@triton.jit\ndef inline_asm_kernel(X, Y, Z, n: \"tl.constexpr\", BLOCK: \"tl.constexpr\"):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.load(Y + tl.arange(0, BLOCK))\n s = tl.full([BLOCK], n, tl.int32)\n z = tl.inline_asm_elementwise(\n \"shf.l.wrap.b32 $0, $1, $2, $3;\",\n \"=r,r, r, r\",\n [x, y, s],\n dtype=tl.int32,\n is_pure=True,\n pack=1,\n )\n tl.store(Z + tl.arange(0, BLOCK), z)\n\n@triton.jit\ndef add_kernel_with_block_ptr(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n y = tl.load(\n tl.make_block_ptr(\n base=y_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n output = x + y\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n output,\n boundary_check=[0],\n )\n\n@triton.jit\ndef kernel_with_block_ptr_2d(\n x_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements, 1],\n strides=[1, 1],\n offsets=[block_start, 0],\n block_shape=[BLOCK_SIZE, 1],\n order=[1, 0],\n ),\n boundary_check=[0],\n )\n output = x\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements, 1],\n strides=[1, 1],\n offsets=[block_start, 0],\n block_shape=[BLOCK_SIZE, 1],\n order=[1, 0],\n ),\n output,\n boundary_check=[0],\n )\n\nfrom triton.language import load, store\n\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef cond_op_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n if tl.program_id(0) == 0:\n output = x + y\n else:\n output = x * y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef atomic_add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.atomic_add(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef add_4_times_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n for i in range(2):\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n i = 2\n while i > 0:\n i -= 1\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef add_kernel_out_of_order_fn2(\n in_ptr0,\n in_ptr1,\n n_elements,\n out_ptr,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to create several kernel functions including add_kernel, add_kernel_with_optional_param, add_kernel_autotuned, add_kernel_2d_autotuned, and others. Each kernel takes various pointers, elements count, block sizes, and performs arithmetic operations like addition and multiplication across blocks. Functions employ triton's load/store mechanisms, conditional logic, loop structures, and various triton-specific optimizations like autotune.", - "description_2": "Use triton language to implement multiple kernels that perform vectorized arithmetic operations with optional parameters, conditional logic, and 2D tiling, optimized with autotuning.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# A simple Triton kernel that adds two numbers.\n@triton.jit\ndef add_kernel(X, Y, OUT, N):\n pid = triton.program_id(0)\n if pid < N:\n OUT[pid] = X[pid] + Y[pid]\n\n# Function to call the above kernel\ndef call_add_kernel(X, Y, OUT, N):\n grid = (N,)\n add_kernel[grid](X, Y, OUT, N)\n\n# Example of how to call the kernel\nX = torch.tensor([1.0, 2.0, 3.0])\nY = torch.tensor([4.0, 5.0, 6.0])\nOUT = torch.empty_like(X)\nN = X.size(0)\ncall_add_kernel(X, Y, OUT, N)\nprint(OUT) # Expected: tensor([5.0, 7.0, 9.0])\n", - "description_1": "Use triton language to implement a simple addition kernel that adds two vectors X and Y element-wise and stores the result in OUT. The kernel has four parameters: X, Y, OUT, and N, where N is the number of elements to process. Use the call_add_kernel function to execute the kernel by passing the input tensors and their size.", - "description_2": "Use triton language to create and execute a vector addition kernel.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < X.shape[0]\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_example_kernel(X, Y, Z):\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(X.shape[0], meta['BLOCK_SIZE']),)\n example_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)\n\n# Example usage\nX = torch.randn(10240, device='cuda')\nY = torch.randn(10240, device='cuda')\nZ = torch.empty(10240, device='cuda')\ncall_example_kernel(X, Y, Z)\n", - "description_1": "Use triton language to define a kernel that adds two vectors X and Y, storing the result in Z. The kernel is launched with a grid size determined by the size of the input vectors and a block size of 1024. The kernel uses triton's program_id to determine the block of data each thread should process, and uses triton's load and store functions to read from and write to global memory with masking to handle out-of-bounds accesses.", - "description_2": "Use triton language to define a kernel that performs element-wise addition of two vectors with masking for out-of-bounds accesses.", - "difficulty": 1 - }, - { - "code": "import triton\n\n# Kernel definition using the triton language with JIT compilation.\n@triton.jit\ndef example_kernel(X, Y, BLOCK_SIZE: tl.constexpr):\n xpid = tl.program_id(0)\n xnumel = X.shape[0]\n for i in range(xpid, xnumel, BLOCK_SIZE):\n X[i] = X[i] + Y[i]\n\n# A function that calls the Triton kernel.\ndef call_example_kernel(X, Y):\n BLOCK_SIZE = 1024\n example_kernel[(ceildiv(len(X), BLOCK_SIZE),)](X, Y, BLOCK_SIZE)\n\n", - "description_1": "Use triton language to define a kernel 'example_kernel' which takes three arguments: two tensors X, Y and a BLOCK_SIZE constant. It increments each element in X by the corresponding element in Y. The function 'call_example_kernel' sets the block size and calls the Triton kernel with specified grid dimensions.", - "description_2": "Use triton language to define a kernel that increments elements of one tensor by another and a function to configure and invoke this kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = tl.program_id(0)\n block_start = pid * 1024\n offsets = block_start + tl.arange(0, 1024)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_add_kernel(X, Y, Z, N):\n grid = lambda meta: (triton.cdiv(N, 1024),)\n add_kernel[grid](X, Y, Z, N)\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = torch.empty(1024, device='cuda')\nN = X.numel()\ncall_add_kernel(X, Y, Z, N)\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that takes four arguments: X, Y, Z, and N. The kernel performs element-wise addition of two input tensors X and Y, storing the result in tensor Z. The computation is done in parallel using Triton's program_id and block size of 1024. The kernel is launched with a grid size calculated based on the input size N.", - "description_2": "Use triton language to perform element-wise addition of two tensors on GPU using a custom kernel.", - "difficulty": 1 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' with 4 parameters: X, Y, Z, and BLOCK_SIZE. The kernel is called using 'call_example_kernel' function with parameters x, y, z, and block_size.", - "description_2": "Use triton language to define a kernel and a function to call it with specified parameters.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n full_range = (full_range + 1) // 2\n return low\n\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``\n DTYPE_PACK: Unsigned type twice the width of block_value\n NOTE: This function is limited to values which are 32-bits or less because\n we need to pack (value, flag) into a single unsigned int.\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n if index > 0:\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], DTYPE_VALUE)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n # tl.atomic_load\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block, must be 64-bits wide\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n if index > 0:\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], block_value.dtype)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n return exclusive_prefix\n\n@triton.jit\ndef frexp(x):\n # TODO: use inline_asm_elementwise here\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n", - "description_1": "Use triton language to define various kernels for tensor operations, reductions, and scanning, including functions like promote_to_tensor, is_floating, prod, minimum, maximum, welford_reduce, randint64, and others. These kernels perform tasks such as tensor promotion, floating-point checks, product accumulation, minimum/maximum comparisons, reduction with indexes, random integer generation, and more complex operations like exclusive scans.", - "description_2": "Use triton language to implement kernels for tensor reductions and scanning operations, facilitating complex arithmetic and logical computations on GPU efficiently.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n", - "description_1": "Use triton language to implement a sampled matrix multiplication kernel with parameters for alpha, beta, block sizes, and strides for input matrices and indices.", - "description_2": "Use triton language to perform sparse matrix multiplication with customizable block sizes and strides.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=4, num_warps=4\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n@triton.jit\ndef add_kernel_with_scaling(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n scaling_factor,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = (x + y) * scaling_factor\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef zero_negs(x):\n return tl.where(x >= 0, x, 0)\n\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n elif ACTIVATION == \"add_kernel\":\n add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n\n@triton.jit\ndef double_strided_kernel(\n in_ptr,\n out_ptr,\n in_y_stride,\n out_y_stride,\n X_BLOCK_SIZE: \"tl.constexpr\",\n Y_BLOCK_SIZE: \"tl.constexpr\",\n):\n xid = tl.program_id(axis=0)\n yid = tl.program_id(axis=1)\n x_start = xid * X_BLOCK_SIZE\n y_start = yid * Y_BLOCK_SIZE\n x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)\n y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)\n src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]\n dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]\n src = tl.load(in_ptr + src_offsets)\n tl.store(out_ptr + dst_offsets, src * 2.0)\n\n@triton.jit\ndef inline_asm_kernel(X, Y, Z, n: \"tl.constexpr\", BLOCK: \"tl.constexpr\"):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.load(Y + tl.arange(0, BLOCK))\n s = tl.full([BLOCK], n, tl.int32)\n z = tl.inline_asm_elementwise(\n \"shf.l.wrap.b32 $0, $1, $2, $3;\",\n \"=r,r, r, r\",\n [x, y, s],\n dtype=tl.int32,\n is_pure=True,\n pack=1,\n )\n tl.store(Z + tl.arange(0, BLOCK), z)\n\n@triton.jit\ndef add_kernel_with_block_ptr(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n y = tl.load(\n tl.make_block_ptr(\n base=y_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n output = x + y\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n output,\n boundary_check=[0],\n )\n\n@triton.jit\ndef kernel_with_block_ptr_2d(\n x_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements, 1],\n strides=[1, 1],\n offsets=[block_start, 0],\n block_shape=[BLOCK_SIZE, 1],\n order=[1, 0],\n ),\n boundary_check=[0],\n )\n output = x\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements, 1],\n strides=[1, 1],\n offsets=[block_start, 0],\n block_shape=[BLOCK_SIZE, 1],\n order=[1, 0],\n ),\n output,\n boundary_check=[0],\n )\n\nfrom triton.language import load, store\n\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef cond_op_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n if tl.program_id(0) == 0:\n output = x + y\n else:\n output = x * y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef atomic_add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.atomic_add(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef add_4_times_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n for i in range(2):\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n i = 2\n while i > 0:\n i -= 1\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef add_kernel_out_of_order_fn2(\n in_ptr0,\n in_ptr1,\n n_elements,\n out_ptr,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to define various kernels. Each kernel has its own functionality and requires different parameters: add_kernel (5 params) performs element-wise addition; add_kernel_with_optional_param (6 params) adds elements with an optional parameter; add_kernel_autotuned (5 params) is an autotuned version for optimized performance; add_kernel_2d_autotuned (7 params) works on 2D data with autotuning; add_kernel_with_scaling (6 params) adds elements with scaling; mul2_kernel (4 params) doubles input elements; mul2_inplace_kernel (3 params) modifies input by doubling; zero_negs (1 param) zeroes out negative values; indirection_kernel (5 params) applies another kernel based on activation; double_strided_kernel (6 params) for strided data doubling; inline_asm_kernel (5 params) performs inline assembly operations; add_kernel_with_block_ptr (5 params) uses block pointers for addition; kernel_with_block_ptr_2d (4 params) uses block pointers for 2D operations; add_kernel_with_import (5 params) performs addition using imported load/store functions; cond_op_kernel (5 params) performs conditional operations; atomic_add_kernel (5 params) adds elements atomically; add_4_times_kernel (5 params) adds four times using loop; add_kernel_out_of_order_fn2 (5 params) performs out-of-order addition.", - "description_2": "Use triton language to implement addition and multiplication kernels with support for autotuning, block pointers, inline assembly, conditional execution, atomic operations, and operations on strided and 2D data.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel to perform an operation (e.g., matrix multiplication)\n@triton.jit\ndef triton_kernel(A, B, C, M, N, K):\n pid = tl.program_id(axis=0)\n # Triton logic goes here\n # Example: Compute element-wise multiplication and store in C\n # Implement the required algorithm using Triton intrinsics\n\n# Function to call the Triton kernel\ndef call_triton_kernel(A, B, M, N, K):\n grid = lambda META: (M, )\n triton_kernel[grid](A, B, C, M, N, K)\n\n# Example function showing the use of the kernel\ndef example_usage():\n A = torch.randn(1024, 1024, device='cuda')\n B = torch.randn(1024, 1024, device='cuda')\n C = torch.empty((1024, 1024), device='cuda')\n M, N, K = A.size(0), B.size(1), A.size(1)\n call_triton_kernel(A, B, M, N, K)\n", - "description_1": "Use triton language to define a kernel `triton_kernel` that takes six parameters: A, B, C are tensors; M, N, K are dimensions. This kernel performs element-wise multiplication of A and B, storing the result in C. The function `call_triton_kernel` calls this kernel with specified grid dimensions, and `example_usage` demonstrates using random tensors with the kernel.", - "description_2": "Use triton language to implement and call a matrix multiplication kernel taking tensors A, B, C and dimensions M, N, K.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out,\n Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Implementation details omitted for brevity\n pass\n\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out, DO, Delta,\n stride_ob, stride_oh, stride_om,\n stride_dob, stride_doh, stride_dom,\n nheads, seqlen_q, seqlen_q_rounded, headdim,\n BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n):\n # Implementation details omitted for brevity\n pass\n\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n,\n Q, K, V, Bias,\n DO, DQ, DK, DV,\n LSE, D,\n softmax_scale,\n stride_qm, stride_kn, stride_vn, stride_bm,\n stride_dom, stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim,\n ATOMIC_ADD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Implementation details omitted for brevity\n pass\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dim() == 4\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q, k, v, bias, o,\n lse, tmp,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n bias.stride(0), bias.stride(1), bias.stride(2) if has_bias else (0, 0, 0),\n o.stride(0), o.stride(2), o.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, softmax_scale\n\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n # Ensure contiguity\n if do.stride(-1) != 1:\n do = do.contiguous()\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n\n dq_accum = torch.empty_like(q, dtype=torch.float32)\n delta = torch.empty_like(lse)\n\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _bwd_preprocess_do_o_dot[grid](\n o, do, delta,\n o.stride(0), o.stride(2), o.stride(1),\n do.stride(0), do.stride(2), do.stride(1),\n nheads, seqlen_q, seqlen_q_rounded, d,\n BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,\n )\n\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dim() == 4\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n\n grid = lambda META: (triton.cdiv(seqlen_k, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1, batch * nheads)\n _bwd_kernel[grid](\n q, k, v, bias,\n do, dq_accum, dk, dv,\n lse, delta,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n bias.stride(0), bias.stride(1), bias.stride(2) if has_bias else (0, 0, 0),\n do.stride(0), do.stride(2), do.stride(1),\n dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),\n dk.stride(0), dk.stride(2), dk.stride(1),\n dv.stride(0), dv.stride(2), dv.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n )\n dq.copy_(dq_accum)\n", - "description_1": "Use triton language to implement forward and backward kernels for FlashAttention. The forward kernel processes query, key, and value matrices (Q, K, V) with optional bias to compute attention outputs. The backward kernel computes gradients for Q, K, and V based on the output gradients. Both kernels support parameters like softmax scaling, causal masking, and multiple head dimensions up to 128.", - "description_2": "Use triton language to define kernels for FlashAttention, supporting gradient computation and optional bias handling.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n q = tl.load(q_ptrs)\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs)\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n q = tl.load(q_ptrs)\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n ds = p * dp * sm_scale\n dk += tl.dot(ds.to(q.dtype), q, trans_a=True)\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(k.dtype), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk, num_warps=num_warps,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n\n num_warps = 8\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,\n num_stages=1,\n )\n return dq.to(q.dtype), dk, dv, None\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention operator with three kernels: `_fwd_kernel` for forward pass, `_bwd_preprocess` for preprocessing before backward pass, and `_bwd_kernel` for backward pass. `_fwd_kernel` requires 25 parameters to load Q, K, V tensors and calculate attention output with specific block sizes and strides. `_bwd_preprocess` takes 6 parameters to adjust gradient values before passing to `_bwd_kernel`, which needs 35 parameters to compute gradients with respect to input tensors based on block sizes and tensor strides.", - "description_2": "Use triton language to create a fused attention operation with forward and backward computation, specifying block sizes, tensor strides, and using multiple warps.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Tanh kernel\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n# Cosh kernel\n@triton.jit\ndef cosh(x):\n exp_x = tl.exp(x)\n return (exp_x + 1.0 / exp_x) * 0.5\n\n# ReLU kernel\n@triton.jit\ndef relu(x):\n \"\"\"\n ReLU_ activation function\n \"\"\"\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n# ReLU gradient kernel\n@triton.jit\ndef relu_grad(x):\n # ReLU is different from other activations\n # in that it does not require the input to retrospectively compute its gradient\n # here the input is the downstream gradient, and we return the upstream gradient directly\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n# Squared ReLU kernel\n@triton.jit\ndef squared_relu(x):\n \"\"\"\n Squared ReLU activation, as proposed in the Primer_ paper.\n \"\"\"\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n# Squared ReLU gradient kernel\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n# Leaky ReLU kernel\n@triton.jit\ndef leaky_relu(x):\n \"\"\"\n LeakyReLU_ activation\n \"\"\"\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n# Leaky ReLU gradient kernel\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n\n return tl.where(x >= 0, max_grad, min_grad)\n\n# GELU kernel\n@triton.jit\ndef gelu(x):\n \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n\n# GELU gradient kernel\n@triton.jit\ndef gelu_grad(x):\n cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization\n return cdf + x * pdf\n\n# GELU Approx kernel\n@triton.jit\ndef gelu_approx(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit, with tanh approximation\n \"\"\"\n return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))\n\n# GELU Approx gradient kernel\n@triton.jit\ndef gelu_approx_grad(x):\n # CREDITS: Fast implementation proposed in\n # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * (\n (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)\n ) + 0.5 * (1 + tanh_out)\n", - "description_1": "Use triton language to implement various activation functions (ReLU, LeakyReLU, GELU, SquaredReLU, GELU Approx, and their gradients) and related functions such as tanh and cosh for optimized computation on GPUs. Each kernel handles element-wise computation of activations and gradients using Triton's language for high performance.", - "description_2": "Use triton language to implement ReLU, LeakyReLU, GELU, SquaredReLU, GELU Approx, their gradients, and related functions such as tanh and cosh for GPU acceleration.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flash_attn.ops.triton.k_activations import gelu, gelu_approx, squared_relu\nfrom flash_attn.ops.triton.k_activations import gelu_grad, gelu_approx_grad, squared_relu_grad\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n\ndef get_configs_io_bound():\n configs = []\n for num_stages in [2, 3, 4, 5, 6]:\n for block_m in [16, 32]:\n for block_k in [32, 64]:\n for block_n in [32, 64, 128, 256]:\n num_warps = 2 if block_n <= 64 else 4\n configs.append(\n triton.Config(\n {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n, \"BLOCK_K\": block_k, \"SPLIT_K\": 1},\n num_stages=num_stages,\n num_warps=num_warps,\n )\n )\n return configs\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n ]\n + get_configs_io_bound(),\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_fwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n bias,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n stride_cm,\n # stride_cn, # Assume that stride_cn == 1\n stride_am,\n stride_ak,\n stride_bn,\n stride_bk,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n A_ROWMAJOR: tl.constexpr,\n B_COLMAJOR: tl.constexpr,\n BIAS: tl.constexpr,\n SAVE_ACT_INPUT: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Bias has shape (N,)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n This kernel will consolidate over K\n \"\"\"\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n # now compute the block that each program will go through\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n if A_ROWMAJOR:\n A = A + (ram[:, None] * stride_am + rk[None, :])\n else:\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n if B_COLMAJOR:\n B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n else:\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n if A_ROWMAJOR:\n A += BLOCK_K\n else:\n A += BLOCK_K * stride_ak\n if B_COLMAJOR:\n B += BLOCK_K\n else:\n B += BLOCK_K * stride_bk\n\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n if SAVE_ACT_INPUT:\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n tl.store(act_in_ptrs, acc)\n\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n elif ACTIVATION == \"gelu_approx\":\n acc = gelu_approx(acc)\n elif ACTIVATION == \"squared_relu\":\n acc = squared_relu(acc)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc)\n\n\ndef triton_linear_act(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor] = None,\n activation: str = 'id',\n save_act_input: bool = False,\n) -> torch.Tensor:\n \"\"\"\n Compute e = activation(x @ weight.T + bias).\n This wrapper kicks the `kernel_fwd` Triton kernel\n :param x: input tensor\n :param weight: weight matrix\n :param bias: an optional bias tensor\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_input: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n\n assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']\n\n batch_shape, n = x.shape[:-1], x.shape[-1]\n batch_dim = batch_shape.numel()\n x_reshaped = x.reshape(batch_dim, n)\n\n if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n x_reshaped = x_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n bias = bias.contiguous() if bias is not None else None\n\n assert x.dtype == weight.dtype, f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert x.dtype == bias.dtype, f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert x_reshaped.shape[1] == weight.shape[1], f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n\n assert bias is None or bias.shape[0] == weight.shape[0], \"Incompatible dimensions in between weight and bias\"\n\n M, K = x_reshaped.shape\n N, K = weight.shape\n\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_input = torch.empty_like(output) if save_act_input else None\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_fwd[grid](\n output,\n act_input,\n x_reshaped,\n weight, # data ptrs\n bias if bias is not None else x, # auto skip bias if not present\n M, # shapes\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=output.stride(0),\n stride_am=x_reshaped.stride(0),\n stride_ak=x_reshaped.stride(1),\n stride_bk=weight.stride(1),\n stride_bn=weight.stride(0),\n BIAS=bias is not None,\n SAVE_ACT_INPUT=save_act_input,\n ACTIVATION=activation,\n A_ROWMAJOR=x_reshaped.stride(1) == 1,\n B_COLMAJOR=weight.stride(1) == 1,\n GROUP_M=8,\n )\n\n if not save_act_input:\n return output.reshape(*batch_shape, output.shape[-1])\n else:\n return (output.reshape(*batch_shape, output.shape[-1]),\n act_input.reshape(*batch_shape, act_input.shape[-1]))\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n ]\n + get_configs_io_bound(),\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_bwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n stride_cm,\n # stride_cn, # Assume that stride_cn == 1\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n This kernel will consolidate over K\n \"\"\"\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n if ACTIVATION != 'id':\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n act_input = tl.load(act_in_ptrs).to(acc.dtype)\n if ACTIVATION == \"gelu\":\n acc *= gelu_grad(act_input)\n elif ACTIVATION == \"gelu_approx\":\n acc *= gelu_approx_grad(act_input)\n elif ACTIVATION == \"squared_relu\":\n acc *= squared_relu_grad(act_input)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc, mask=mask)\n\n\ndef triton_dgrad_act(\n grad_output: torch.Tensor,\n weight: torch.Tensor,\n activation: str = 'id',\n act_input: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n \"\"\"\n Compute e = activation(grad_output @ weight + bias).\n This wrapper kicks the `kernel_bwd` Triton kernel\n :param grad_output: input tensor\n :param weight: weight matrix\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_input: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']\n\n batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]\n batch_dim = batch_shape.numel()\n grad_output_reshaped = grad_output.reshape(batch_dim, n)\n\n if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:\n grad_output_reshaped = grad_output_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n\n assert grad_output.dtype == weight.dtype, f\"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}\"\n assert grad_output_reshaped.shape[1] == weight.shape[0], f\"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}\"\n if activation != 'id':\n assert act_input is not None, f'act_input is required for activation {activation}'\n\n M, K = grad_output_reshaped.shape\n K, N = weight.shape\n\n grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_bwd[grid](\n grad_input,\n act_input,\n grad_output_reshaped,\n weight, # data ptrs\n M, # shapes\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=grad_input.stride(0),\n stride_am=grad_output_reshaped.stride(0),\n stride_ak=grad_output_reshaped.stride(1),\n stride_bk=weight.stride(0),\n stride_bn=weight.stride(1),\n ACTIVATION=activation,\n GROUP_M=8,\n )\n\n return grad_input.reshape(*batch_shape, grad_input.shape[-1])\n", - "description_1": "Use triton language to implement a fused linear layer with activation functions like gelu, gelu_approx, or squared_relu. The forward kernel 'kernel_fwd' computes the output of a matrix multiplication with optional bias addition and activation, and the backward kernel 'kernel_bwd' computes gradients using the provided activation function. The Triton kernels handle matrix dimension parameters, block sizes, and performance tuning configurations.", - "description_2": "Use triton language to create a matrix multiplication kernel with optional activation functions and bias, along with a gradient computation kernel for the same operations, leveraging Triton's optimization features.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef sigmoid(input):\n return (1 / (1 + tl.exp(-input)))\n\n@triton.jit\ndef sigmoid_grad(input):\n output_sigmoid = sigmoid(input)\n return output_sigmoid * (1 - output_sigmoid)\n\n@triton.jit\ndef tanh(input):\n return 2 * sigmoid(2 * input) - 1\n\n@triton.jit\ndef tanh_grad(input):\n output_tanh = tanh(input)\n return 1 - output_tanh * output_tanh\n\n@triton.jit\ndef relu(input):\n return tl.maximum(0, input)\n\n@triton.jit\ndef relu_grad(input):\n return tl.where(input <= 0, 0, 1)\n\n@triton.jit\ndef gelu(input):\n cdf = 0.5 * (1 + tl.math.erf(0.707106781 * input))\n return cdf * input\n\n@triton.jit\ndef gelu_grad(input):\n cdf = 0.5 * (1 + tl.math.erf(0.707106781 * input))\n cdf_grad = 0.39894228 * tl.exp(-0.5 * input * input)\n return (cdf_grad * input + cdf)\n\n@triton.jit\ndef silu(input):\n return (input * sigmoid(input))\n\n@triton.jit\ndef silu_grad(input):\n output_sigmoid = sigmoid(input)\n return (output_sigmoid * (input * (1 - output_sigmoid) + 1))\n\n@triton.jit\ndef relu6(input):\n return tl.minimum(relu(input), 6)\n\n@triton.jit\ndef relu6_grad(input):\n return tl.where((0 < input) & (input < 6), 1, 0)\n\n@triton.jit\ndef hardsigmoid(input):\n return tl.maximum(0, tl.minimum(1, input / 6 + 0.5))\n\n@triton.jit\ndef hardsigmoid_grad(input):\n return tl.where((-3 < input) & (input < 3), 1 / 6, 0)\n\n@triton.jit\ndef hardswish(input):\n return input * relu6(input + 3) / 6\n\n@triton.jit\ndef hardswish_grad(input):\n return (relu6(input + 3) + input * relu6_grad(input + 3)) / 6\n\n@triton.jit\ndef selu(input):\n scale = 1.0507009873554804934193349852946\n alpha = 1.6732632423543772848170429916717\n return scale * (tl.maximum(0, input) +\n tl.minimum(0, alpha * (tl.exp(input) - 1)))\n\n@triton.jit\ndef selu_grad(input):\n scale = 1.0507009873554804934193349852946\n alpha = 1.6732632423543772848170429916717\n return scale * tl.where(input <= 0, alpha * tl.exp(input), 1)\n\n@triton.jit\ndef mish(input):\n return input * tanh(tl.log(1 + tl.exp(input)))\n\n@triton.jit\ndef mish_grad(input):\n exp = tl.exp(input)\n delta = exp * (exp + 2) + 2\n return (exp * (exp * ((4 * input + 6) + exp * (exp + 4)) + 4 * (input + 1)) /\n (delta * delta))\n\n@triton.jit\ndef leaky_relu(input, negative_slope):\n return relu(input) + negative_slope * tl.minimum(0, input)\n\n@triton.jit\ndef leaky_relu_grad(input, negative_slope):\n return tl.where(input <= 0, negative_slope, 1)\n\n@triton.jit\ndef apply_act_func(input, drop_p, seed, offset, param,\n act_func: tl.constexpr, dropout: tl.constexpr):\n if act_func == 'sigmoid':\n input = input.to(tl.float32)\n output = sigmoid(input)\n elif act_func == 'tanh':\n input = input.to(tl.float32)\n output = tanh(input)\n elif act_func == 'relu':\n output = relu(input)\n elif act_func == 'gelu':\n input = input.to(tl.float32)\n output = gelu(input)\n elif act_func == 'silu':\n input = input.to(tl.float32)\n output = silu(input)\n elif act_func == 'relu6':\n output = relu6(input)\n elif act_func == 'hardsigmoid':\n output = hardsigmoid(input)\n elif act_func == 'hardswish':\n output = hardswish(input)\n elif act_func == 'selu':\n input = input.to(tl.float32)\n output = selu(input)\n elif act_func == 'mish':\n input = input.to(tl.float32)\n output = mish(input)\n elif act_func == 'leaky_relu':\n output = leaky_relu(input, param)\n if dropout:\n output = apply_dropout(output, drop_p, seed, offset)\n return output\n\n@triton.jit\ndef apply_act_func_grad(output_grad, input, drop_p, seed, offset, param,\n act_func: tl.constexpr, dropout: tl.constexpr):\n if act_func == 'sigmoid':\n input = input.to(tl.float32)\n output = sigmoid_grad(input)\n elif act_func == 'tanh':\n input = input.to(tl.float32)\n output = tanh_grad(input)\n elif act_func == 'relu':\n output = relu_grad(input)\n elif act_func == 'gelu':\n input = input.to(tl.float32)\n output = gelu_grad(input)\n elif act_func == 'silu':\n input = input.to(tl.float32)\n output = silu_grad(input)\n elif act_func == 'relu6':\n output = relu6_grad(input)\n elif act_func == 'hardsigmoid':\n output = hardsigmoid_grad(input)\n elif act_func == 'hardswish':\n output = hardswish_grad(input)\n elif act_func == 'selu':\n input = input.to(tl.float32)\n output = selu_grad(input)\n elif act_func == 'mish':\n input = input.to(tl.float32)\n output = mish_grad(input)\n elif act_func == 'leaky_relu':\n output = leaky_relu_grad(input, param)\n if dropout:\n output_grad = apply_dropout_grad(output_grad, drop_p, seed, offset)\n return output_grad * output\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef act_func_forward_kernel(\n input_pointer, output_pointer, size,\n drop_p, seed, param,\n act_func: tl.constexpr, dropout: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n ):\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n input = tl.load(input_pointer + offset, mask=mask)\n tl.store(output_pointer + offset,\n apply_act_func(input, drop_p, seed, offset,\n param, act_func, dropout),\n mask=mask)\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef act_func_backward_kernel(\n output_grad_pointer, input_pointer, input_grad_pointer, size,\n drop_p, seed, param,\n act_func: tl.constexpr, dropout: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n ):\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n output_grad = tl.load(output_grad_pointer + offset, mask=mask)\n input = tl.load(input_pointer + offset, mask=mask)\n tl.store(input_grad_pointer + offset,\n apply_act_func_grad(output_grad, input, drop_p, seed,\n offset, param, act_func, dropout),\n mask=mask)\n", - "description_1": "Use triton language to define a series of activation functions and their gradients (sigmoid, tanh, relu, gelu, silu, relu6, hardsigmoid, hardswish, selu, mish, leaky_relu) and implement a kernel for applying these activation functions and optionally fused dropout, as well as a kernel for computing gradients of these activations.", - "description_2": "Use triton language to create activation functions with dropout support and compute gradients.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import next_power_of_2\n\nfrom .act_kernels import apply_act_func\nfrom .utils import warps_kernel_configs\n\n\ndef BLOCK_SIZE_SPATIAL_heuristic(args):\n BLOCK_SIZE_BATCH = next_power_of_2(args['batch_dim'])\n BLOCK_SIZE_SPATIAL = next_power_of_2(args['spatial_dim'])\n return min(BLOCK_SIZE_SPATIAL, max(1, 2 ** 14 // BLOCK_SIZE_BATCH))\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'spatial_dim'],\n restore_value=['running_mean_pointer', 'running_var_pointer']\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': lambda args: next_power_of_2(args['batch_dim']),\n 'BLOCK_SIZE_SPATIAL': BLOCK_SIZE_SPATIAL_heuristic})\n@triton.jit\ndef batch_norm_forward_kernel(\n input_pointer, weight_pointer, bias_pointer,\n mean_pointer, inv_std_pointer,\n pre_act_add_pointer, pre_act_pointer, output_pointer,\n running_mean_pointer, running_var_pointer,\n batch_dim, spatial_dim,\n input_batch_stride, input_feat_stride, input_spatial_stride,\n pre_act_add_batch_stride, pre_act_add_feat_stride, pre_act_add_spatial_stride,\n pre_act_batch_stride, pre_act_feat_stride, pre_act_spatial_stride,\n output_batch_stride, output_feat_stride, output_spatial_stride,\n momentum, eps, param,\n affine: tl.constexpr, save_stats: tl.constexpr,\n track_running_stats: tl.constexpr, is_train: tl.constexpr,\n add_pre_act: tl.constexpr, act_func: tl.constexpr, save_pre_act: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_SPATIAL: tl.constexpr,\n ):\n feat_pid = tl.program_id(axis=0)\n batch_offset = tl.arange(0, BLOCK_SIZE_BATCH)\n batch_mask = batch_offset < batch_dim\n\n if is_train or not track_running_stats:\n count = 0\n mean = 0.0\n var = 0.0\n\n for block_ind in range(0, tl.cdiv(spatial_dim, BLOCK_SIZE_SPATIAL)):\n spatial_offset = (block_ind * BLOCK_SIZE_SPATIAL +\n tl.arange(0, BLOCK_SIZE_SPATIAL))\n spatial_mask = spatial_offset < spatial_dim\n\n curr_input_pointer = (input_pointer +\n input_feat_stride * feat_pid +\n input_batch_stride * batch_offset[:, None] +\n input_spatial_stride * spatial_offset[None, :])\n curr_input = tl.load(curr_input_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n\n spatial_count = min(BLOCK_SIZE_SPATIAL, spatial_dim - block_ind * BLOCK_SIZE_SPATIAL)\n curr_count = spatial_count * batch_dim\n count += curr_count\n\n prev_mean = mean\n mean += (tl.sum(curr_input) - curr_count * mean) / count\n deltas = tl.where(batch_mask[:, None] & spatial_mask[None, :],\n (curr_input - mean) * (curr_input - prev_mean), 0.)\n var += tl.sum(deltas)\n\n var /= count\n inv_std = tl.rsqrt(var + eps)\n\n if save_stats:\n tl.store(feat_pid + mean_pointer, mean)\n tl.store(feat_pid + inv_std_pointer, inv_std)\n\n if track_running_stats:\n running_mean_pointer += feat_pid\n running_var_pointer += feat_pid\n\n running_mean = tl.load(running_mean_pointer)\n running_var = tl.load(running_var_pointer)\n\n n = batch_dim * spatial_dim\n tl.store(running_mean_pointer,\n (1 - momentum) * running_mean + momentum * mean)\n tl.store(running_var_pointer,\n (1 - momentum) * running_var + momentum * var * n / (n - 1))\n\n else:\n mean = tl.load(feat_pid + running_mean_pointer)\n inv_std = tl.rsqrt(tl.load(feat_pid + running_var_pointer) + eps)\n\n if affine:\n weight = tl.load(feat_pid + weight_pointer)\n bias = tl.load(feat_pid + bias_pointer)\n\n else:\n weight = 1.\n bias = 0.\n\n for block_ind in range(0, tl.cdiv(spatial_dim, BLOCK_SIZE_SPATIAL)):\n spatial_offset = (block_ind * BLOCK_SIZE_SPATIAL +\n tl.arange(0, BLOCK_SIZE_SPATIAL))\n spatial_mask = spatial_offset < spatial_dim\n\n curr_input_pointer = (input_pointer +\n input_feat_stride * feat_pid +\n input_batch_stride * batch_offset[:, None] +\n input_spatial_stride * spatial_offset[None, :])\n curr_output_pointer = (output_pointer +\n output_feat_stride * feat_pid +\n output_batch_stride * batch_offset[:, None] +\n output_spatial_stride * spatial_offset[None, :])\n\n curr_input = tl.load(curr_input_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n output = weight * (curr_input - mean) * inv_std + bias\n\n if add_pre_act:\n curr_pre_act_add_pointer = (pre_act_add_pointer +\n pre_act_add_feat_stride * feat_pid +\n pre_act_add_batch_stride * batch_offset[:, None] +\n pre_act_add_spatial_stride * spatial_offset[None, :])\n curr_pre_act_add = tl.load(curr_pre_act_add_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n output += curr_pre_act_add\n\n if act_func is not None:\n if save_pre_act:\n curr_pre_act_pointer = (pre_act_pointer +\n pre_act_feat_stride * feat_pid +\n pre_act_batch_stride * batch_offset[:, None] +\n pre_act_spatial_stride * spatial_offset[None, :])\n tl.store(curr_pre_act_pointer, output,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n\n output = apply_act_func(output, None, None, None, param, act_func, False)\n\n tl.store(curr_output_pointer, output,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'spatial_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': lambda args: next_power_of_2(args['batch_dim']),\n 'BLOCK_SIZE_SPATIAL': BLOCK_SIZE_SPATIAL_heuristic})\n@triton.jit\ndef batch_norm_backward_kernel(\n output_grad_pointer, input_pointer, mean_pointer, inv_std_pointer, weight_pointer,\n input_grad_pointer, weight_grad_pointer, bias_grad_pointer,\n batch_dim, spatial_dim,\n output_grad_batch_stride, output_grad_feat_stride, output_grad_spatial_stride,\n input_batch_stride, input_feat_stride, input_spatial_stride,\n input_grad_batch_stride, input_grad_feat_stride, input_grad_spatial_stride,\n affine: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_SPATIAL: tl.constexpr,\n ):\n feat_pid = tl.program_id(axis=0)\n batch_offset = tl.arange(0, BLOCK_SIZE_BATCH)\n batch_mask = batch_offset < batch_dim\n\n mean = tl.load(feat_pid + mean_pointer)\n inv_std = tl.load(feat_pid + inv_std_pointer)\n\n term1 = 0.0\n term2 = 0.0\n\n for block_ind in range(0, tl.cdiv(spatial_dim, BLOCK_SIZE_SPATIAL)):\n spatial_offset = (block_ind * BLOCK_SIZE_SPATIAL +\n tl.arange(0, BLOCK_SIZE_SPATIAL))\n spatial_mask = spatial_offset < spatial_dim\n\n curr_output_grad_pointer = (output_grad_pointer +\n output_grad_feat_stride * feat_pid +\n output_grad_batch_stride * batch_offset[:, None] +\n output_grad_spatial_stride * spatial_offset[None, :])\n curr_input_pointer = (input_pointer +\n input_feat_stride * feat_pid +\n input_batch_stride * batch_offset[:, None] +\n input_spatial_stride * spatial_offset[None, :])\n\n curr_input = tl.load(curr_input_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n curr_pre_lin = (curr_input - mean) * inv_std\n curr_output_grad = tl.load(curr_output_grad_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n\n term1 += tl.sum(curr_pre_lin * curr_output_grad)\n term2 += tl.sum(curr_output_grad)\n\n if affine:\n weight = tl.load(feat_pid + weight_pointer)\n weight_grad = 0.0\n bias_grad = 0.0\n\n else:\n weight = 1.\n\n count = batch_dim * spatial_dim\n term1 *= weight / count\n term2 *= weight / count\n\n for block_ind in range(0, tl.cdiv(spatial_dim, BLOCK_SIZE_SPATIAL)):\n spatial_offset = (block_ind * BLOCK_SIZE_SPATIAL +\n tl.arange(0, BLOCK_SIZE_SPATIAL))\n spatial_mask = spatial_offset < spatial_dim\n\n curr_output_grad_pointer = (output_grad_pointer +\n output_grad_feat_stride * feat_pid +\n output_grad_batch_stride * batch_offset[:, None] +\n output_grad_spatial_stride * spatial_offset[None, :])\n curr_input_pointer = (input_pointer +\n input_feat_stride * feat_pid +\n input_batch_stride * batch_offset[:, None] +\n input_spatial_stride * spatial_offset[None, :])\n curr_input_grad_pointer = (input_grad_pointer +\n input_grad_feat_stride * feat_pid +\n input_grad_batch_stride * batch_offset[:, None] +\n input_grad_spatial_stride * spatial_offset[None, :])\n\n curr_input = tl.load(curr_input_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n curr_pre_lin = (curr_input - mean) * inv_std\n curr_output_grad = tl.load(curr_output_grad_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n curr_input_grad = inv_std * (weight * curr_output_grad - (term1 * curr_pre_lin + term2))\n tl.store(curr_input_grad_pointer, curr_input_grad,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n\n if affine:\n weight_grad += tl.sum(curr_pre_lin * curr_output_grad)\n bias_grad += tl.sum(curr_output_grad)\n\n if affine:\n tl.store(feat_pid + weight_grad_pointer, weight_grad)\n tl.store(feat_pid + bias_grad_pointer, bias_grad)\n", - "description_1": "Use triton language to implement a batch normalization with an optional residual addition and fused activation function. The forward kernel normalizes the input tensor and optionally applies weights, bias, and an activation function. The backward kernel computes the gradients for the input, weights, and bias based on the gradients from the subsequent layer. The forward function accepts 36 parameters and the backward function accepts 26 parameters, mostly pointers and dimensions for input, weights, biases, and gradients.", - "description_2": "Use triton language to implement batch normalization forward and backward kernels. The forward kernel takes 36 parameters for input normalization and optional activation, while the backward kernel takes 26 parameters to compute input gradients and gradients for weights and biases.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef conv2d_forward_kernel(\n input_pointer, weight_pointer, output_pointer,\n batch_dim, in_feat_dim, in_height, in_width,\n out_feat_dim, out_height, out_width,\n input_batch_stride, input_in_feat_stride, input_height_stride, input_width_stride,\n weight_out_feat_stride, weight_in_feat_stride, weight_height_stride, weight_width_stride,\n output_batch_stride, output_out_feat_stride, output_height_stride, output_width_stride,\n kernel_height: tl.constexpr, kernel_width: tl.constexpr,\n stride_height: tl.constexpr, stride_width: tl.constexpr,\n padding_height: tl.constexpr, padding_width: tl.constexpr,\n groups: tl.constexpr, fp16: tl.constexpr, tf32: tl.constexpr,\n BLOCK_SIZE_BATCH_HEIGHT_WIDTH: tl.constexpr, BLOCK_SIZE_IN_FEAT: tl.constexpr,\n BLOCK_SIZE_OUT_FEAT: tl.constexpr,\n ):\n \"\"\"\n 2D-convolves over the input using weights.\n\n Args:\n input_pointer: Pointer to the input to convolve over.\n The input must be of shape [batch_dim, in_feat_dim, in_height, in_width].\n weight_pointer: Pointer to the weights input is convolved over by.\n The weights must be of shape [out_feat_dim, in_feat_dim, kernel_height, kernel_width].\n output_pointer: Pointer to a container the result is written to.\n The container must be of shape [batch_dim, out_feat_dim, out_height, out_width].\n batch_dim: Batch dimension of the input and output.\n in_feat_dim: Dimensionality of the input features.\n in_height: Input height.\n in_width: Input width.\n out_feat_dim: Dimensionality of the output features.\n out_height: Output height.\n out_width: Output width.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_in_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n input_height_stride: Stride necessary to jump one element along the\n input's height dimension.\n input_width_stride: Stride necessary to jump one element along the\n input's width dimension.\n weight_out_feat_stride: Stride necessary to jump one element along the\n weights' output feature dimension.\n weight_in_feat_stride: Stride necessary to jump one element along the\n weights' input feature dimension.\n weight_height_stride: Stride necessary to jump one element along the\n weights' height dimension.\n weight_width_stride: Stride necessary to jump one element along the\n weights' width dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output's batch dimension.\n output_out_feat_stride: Stride necessary to jump one element along the\n output's feature dimension.\n output_height_stride: Stride necessary to jump one element along the\n output's height dimension.\n output_width_stride: Stride necessary to jump one element along the\n output's width dimension.\n kernel_height: Kernel height.\n kernel_width: Kernel width.\n stride_height: Stride of kernel across the height dimension.\n stride_width: Stride of kernel across the width dimension.\n padding_height: Padding applied to the input across the height dimension.\n padding_width: Padding applied to the input across the width dimension.\n groups: Number of groups for the convolution.\n fp16: Flag for loading the input and weights in FP16.\n tf32: Flag for performing matrix products in TF32.\n BLOCK_SIZE_BATCH_HEIGHT_WIDTH: Block size across the batch, height, and\n width dimensions.\n BLOCK_SIZE_IN_FEAT: Block size across the input feature dimension.\n BLOCK_SIZE_OUT_FEAT: Block size across the output feature dimension.\n \"\"\"\n batch_height_width_pid = tl.program_id(0)\n out_feat_pid = tl.program_id(1)\n group_pid = tl.program_id(2)\n\n in_group_dim = in_feat_dim // groups\n out_group_dim = out_feat_dim // groups\n\n batch_height_width_offset = (batch_height_width_pid * BLOCK_SIZE_BATCH_HEIGHT_WIDTH +\n tl.arange(0, BLOCK_SIZE_BATCH_HEIGHT_WIDTH))\n batch_height_offset = batch_height_width_offset // out_width\n batch_offset = batch_height_offset // out_height\n\n output_feat_offset = (out_feat_pid * BLOCK_SIZE_OUT_FEAT +\n tl.arange(0, BLOCK_SIZE_OUT_FEAT))\n output_height_offset = batch_height_offset % out_height\n output_width_offset = batch_height_width_offset % out_width\n\n input_pointer += (input_batch_stride * batch_offset +\n input_in_feat_stride * group_pid * in_group_dim)[:, None]\n weight_pointer += (weight_out_feat_stride * output_feat_offset +\n weight_out_feat_stride * group_pid * out_group_dim)[None, :]\n\n accum = tl.zeros((BLOCK_SIZE_BATCH_HEIGHT_WIDTH, BLOCK_SIZE_OUT_FEAT),\n dtype=tl.float32)\n\n for h in range(kernel_height):\n for w in range(kernel_width):\n for c in range(0, in_group_dim, BLOCK_SIZE_IN_FEAT):\n input_feat_offset = c + tl.arange(0, BLOCK_SIZE_IN_FEAT)\n input_height_offset = (h - padding_height +\n stride_height * output_height_offset)\n input_width_offset = (w - padding_width +\n stride_width * output_width_offset)\n\n curr_input_pointer = (input_pointer +\n (input_in_feat_stride * input_feat_offset)[None, :] +\n (input_height_stride * input_height_offset)[:, None] +\n (input_width_stride * input_width_offset)[:, None])\n curr_weight_pointer = (weight_pointer +\n (weight_in_feat_stride * input_feat_offset)[:, None] +\n (weight_height_stride * h) +\n (weight_width_stride * w))\n\n input_mask = ((batch_offset < batch_dim)[:, None] &\n (input_feat_offset < in_group_dim)[None, :] &\n (0 <= input_height_offset)[:, None] &\n (input_height_offset < in_height)[:, None] &\n (0 <= input_width_offset)[:, None] &\n (input_width_offset < in_width)[:, None])\n weight_mask = ((input_feat_offset < in_group_dim)[:, None] &\n (output_feat_offset < out_group_dim)[None, :])\n\n input_block = tl.load(curr_input_pointer, mask=input_mask)\n weight_block = tl.load(curr_weight_pointer, mask=weight_mask)\n\n if fp16:\n input_block = input_block.to(tl.float16)\n weight_block = weight_block.to(tl.float16)\n\n accum += tl.dot(input_block, weight_block, allow_tf32=tf32)\n\n output_pointer += ((output_batch_stride * batch_offset)[:, None] +\n (output_out_feat_stride * (group_pid * out_group_dim + output_feat_offset))[None, :] +\n (output_height_stride * output_height_offset)[:, None] +\n (output_width_stride * output_width_offset)[:, None])\n output_mask = ((batch_offset < batch_dim)[:, None] &\n (output_feat_offset < out_group_dim)[None, :] &\n (output_height_offset < out_height)[:, None] &\n (output_width_offset < out_width)[:, None])\n\n tl.store(output_pointer, accum, mask=output_mask)\n", - "description_1": "Use triton language to implement a 2D convolution kernel with parameters for input, weight, and output pointers, dimensions, strides, kernel size, stride, padding, groups, precision flags, and block sizes.", - "description_2": "Use triton language to perform 2D convolution with configurable parameters and precision options.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import next_power_of_2\n\nfrom .softmax_kernels import BLOCK_SIZE_BATCH_heuristic\nfrom .utils import warps_kernel_configs\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef cross_entropy_loss_forward_kernel(\n input_pointer, target_pointer, weight_pointer,\n sum_weights_pointer, output_pointer,\n batch_dim, feat_dim,\n input_batch_stride, input_feat_stride,\n weighted: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Measures the mean cross entropy loss between the input and target,\n with optional reweighing of each class.\n\n Args:\n input_pointer: Pointer to the input.\n The input must be of shape [batch_dim, feat_dim].\n target_pointer: Pointer to the target.\n The target must be of shape [batch_dim].\n weight_pointer: Pointer to an optional class weight vector.\n The class weight vector, if provided, must be of shape [feat_dim].\n sum_weights_pointer: Pointer to a container the sum of the class weights is written to.\n The container must be of shape [batch_dim/BLOCK_SIZE_BATCH].\n output_pointer: Pointer to a container the loss is written to.\n The container must be of shape [batch_dim/BLOCK_SIZE_BATCH].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n weighted: Flag for weighing each class.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n target = tl.load(target_pointer + batch_offset, mask=batch_mask)\n\n pred_pointer = (input_pointer +\n input_feat_stride * target +\n input_batch_stride * batch_offset)\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n\n input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :],\n other=-float('inf')).to(tl.float32)\n pred = tl.load(pred_pointer, mask=batch_mask).to(tl.float32)\n mx = tl.max(input, axis=1)\n input -= mx[:, None]\n loss = tl.log(tl.sum(tl.exp(input), axis=1)) - pred + mx\n\n if weighted:\n weight = tl.load(weight_pointer + target, mask=batch_mask).to(tl.float32)\n loss *= weight\n tl.store(sum_weights_pointer + batch_pid, tl.sum(weight))\n\n else:\n loss /= batch_dim\n\n tl.store(output_pointer + batch_pid, tl.sum(loss))\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef cross_entropy_loss_backward_kernel(\n output_grad_pointer, target_pointer, input_pointer, weight_pointer,\n sum_weights_pointer, input_grad_pointer,\n batch_dim, feat_dim,\n input_batch_stride, input_feat_stride,\n input_grad_batch_stride, input_grad_feat_stride,\n weighted: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of cross entropy loss.\n\n Args:\n output_grad_pointer: Pointer to the loss's output gradients.\n The output gradient must be a scalar.\n target_pointer: Pointer to the target.\n The target must be of shape [batch_dim].\n input_pointer: Pointer to the input.\n The input must be of shape [batch_dim, feat_dim].\n weight_pointer: Pointer to an optional class weight vector.\n The class weight vector, if provided, must be of shape [feat_dim].\n sum_weights_pointer: Pointer to the sum of the class weights if the classes were weighed.\n The sum of weights must be a scalar.\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [batch_dim, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n input_grad_batch_stride: Stride necessary to jump one element along the\n input gradient container's batch dimension.\n input_grad_feat_stride: Stride necessary to jump one element along the\n input gradient container's feature dimension.\n weighted: Flag for weighing each class.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +\n input_grad_feat_stride * feat_offset[None, :])\n\n input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :],\n other=-float('inf')).to(tl.float32)\n input -= tl.max(input, axis=1)[:, None]\n numerator = tl.exp(input)\n softmax = numerator / tl.sum(numerator, axis=1)[:, None]\n\n output_grad = tl.load(output_grad_pointer).to(tl.float32)\n target = tl.load(target_pointer + batch_offset, mask=batch_mask)\n broadcasted_feat_offset = tl.broadcast_to(feat_offset[None, :],\n (BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT))\n broadcasted_target = tl.broadcast_to(target[:, None],\n (BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT))\n input_grad = output_grad * (softmax - (broadcasted_feat_offset == broadcasted_target))\n\n if weighted:\n weight = tl.load(weight_pointer + target, mask=batch_mask).to(tl.float32)\n sum_weights = tl.load(sum_weights_pointer)\n input_grad *= weight[:, None] / sum_weights\n\n else:\n input_grad /= batch_dim\n\n tl.store(input_grad_pointer, input_grad,\n mask=batch_mask[:, None] & feat_mask[None, :])\n", - "description_1": "Use triton language to implement two kernels: one for forward pass and one for backward pass of cross entropy loss. The forward kernel calculates the mean cross entropy loss between input and target, optionally reweighing each class. It takes 13 parameters: input_pointer, target_pointer, weight_pointer, sum_weights_pointer, output_pointer, batch_dim, feat_dim, input_batch_stride, input_feat_stride, weighted, BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT. The backward kernel calculates the input gradient of cross entropy loss. It takes 15 parameters: output_grad_pointer, target_pointer, input_pointer, weight_pointer, sum_weights_pointer, input_grad_pointer, batch_dim, feat_dim, input_batch_stride, input_feat_stride, input_grad_batch_stride, input_grad_feat_stride, weighted, BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT.", - "description_2": "Use triton language to create kernels for forward and backward computation of cross entropy loss with optional class weighting. The forward kernel computes the loss, and the backward kernel computes the gradient with respect to the input.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef apply_dropout(input, drop_p, seed, offset):\n \"\"\"\n Randomly zeroes elements in the input.\n\n Args:\n input: Input. The input must be loaded and cannot be a pointer.\n drop_p: Probability of dropping an element.\n seed: Seed for generating the dropout mask.\n offset: Offset to generate the mask for.\n\n Returns:\n Input with elements randomly zeroed out.\n \"\"\"\n random = tl.rand(seed, offset)\n return tl.where(random < drop_p, 0, input / (1 - drop_p))\n\n@triton.jit\ndef apply_dropout_grad(output_grad, drop_p, seed, offset):\n \"\"\"\n Calculates the input gradient of dropout.\n\n Args:\n output_grad: Output gradients. The output gradients must be\n loaded and cannot be a pointer.\n drop_p: Probability of dropping an element.\n seed: Seed for generating the dropout mask.\n offset: Offset to generate the mask for.\n\n Returns:\n Gradient of dropout.\n \"\"\"\n random = tl.rand(seed, offset)\n return tl.where(random < drop_p, 0, output_grad / (1 - drop_p))\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef dropout_forward_kernel(\n input_pointer, output_pointer, size,\n drop_p, seed,\n BLOCK_SIZE: tl.constexpr,\n ):\n \"\"\"\n Randomly zeroes elements in the input.\n\n Args:\n input_pointer: Pointer to the input to perform dropout on.\n The input must be of shape [size].\n output_pointer: Pointer to a container the result is written to.\n The container must be of shape [size].\n size: Number of elements in the input.\n drop_p: Probability of dropping an element.\n seed: Seed for generating the dropout mask.\n BLOCK_SIZE: Block size.\n \"\"\"\n # This program processes BLOCK_SIZE rows.\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n\n input = tl.load(input_pointer + offset, mask=mask)\n output = apply_dropout(input, drop_p, seed, offset)\n tl.store(output_pointer + offset, output, mask=mask)\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef dropout_backward_kernel(\n output_grad_pointer, input_grad_pointer, size,\n drop_p, seed,\n BLOCK_SIZE: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of dropout.\n\n Args:\n output_grad_pointer: Pointer to dropout's output gradients.\n The output gradients must be of shape [size].\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [size].\n size: Number of elements in the input.\n drop_p: Probability of dropping an element used in dropout.\n seed: Seed for generating the dropout mask.\n BLOCK_SIZE: Block size.\n \"\"\"\n # This program processes BLOCK_SIZE rows.\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n\n output_grad = tl.load(output_grad_pointer + offset, mask=mask)\n input_grad = apply_dropout_grad(output_grad, drop_p, seed, offset)\n tl.store(input_grad_pointer + offset, input_grad, mask=mask)\n", - "description_1": "Use triton language to implement dropout operations. The `apply_dropout` kernel takes 4 parameters: input (the data to apply dropout on), drop_p (probability of dropping an element), seed (for random number generation), and offset (to generate the mask). It returns the input with elements randomly zeroed out. The `apply_dropout_grad` kernel also takes 4 parameters: output_grad (gradients of the output), drop_p, seed, and offset, and returns the gradient of dropout. The `dropout_forward_kernel` and `dropout_backward_kernel` are triton kernels that perform the forward and backward passes of dropout, respectively. They take pointers to input/output data, size of the data, drop probability, seed, and block size as parameters.", - "description_2": "Use triton language to create kernels for dropout forward and backward operations, handling input/output pointers, dropout probability, and random seed.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .act_kernels import apply_act_func, apply_act_func_grad\nfrom .utils import element_wise_kernel_configs\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef glu_forward_kernel(\n input1_pointer, input2_pointer, output_pointer, size, param,\n act_func: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n ):\n \"\"\"\n Applies the gated linear unit with an arbitrary activation function\n to the input.\n\n Args:\n input1_pointer: Pointer to the first half of the input to gate.\n The first half must be contiguous and contain size elements.\n input2_pointer: Pointer to the second half of the input to gate.\n The second half must be contiguous and contain size elements.\n output_pointer: Pointer to a container the result is written to.\n The container must be contiguous and contain size elements.\n size: Number of elements in each half of the input.\n param: Parameter in the case of parameterized activation functions.\n act_func: Name of activation function to apply.\n Options are 'sigmoid', 'tanh', 'relu', 'gelu', 'silu',\n 'relu6', 'hardsigmoid', 'hardswish', 'selu', 'mish', and 'leaky_relu'.\n BLOCK_SIZE: Block size.\n \"\"\"\n # This program processes BLOCK_SIZE elements.\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n\n input1 = tl.load(input1_pointer + offset, mask=mask)\n input2 = tl.load(input2_pointer + offset, mask=mask)\n\n output = input1 * apply_act_func(input2, None, None, None, param,\n act_func, False)\n tl.store(output_pointer + offset, output, mask=mask)\n\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef glu_backward_kernel(\n output_grad_pointer, input1_pointer, input2_pointer,\n input1_grad_pointer, input2_grad_pointer, size, param,\n act_func: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of the gated linear unit.\n\n Args:\n output_grad_pointer: Pointer to the unit's output gradients.\n The output gradients must be contiguous and contain size elements.\n input1_pointer: Pointer to the first half of the input that was gated.\n The first half must be contiguous and contain size elements.\n input2_pointer: Pointer to the second half of the input that was gated.\n The second half must be contiguous and contain size elements.\n input1_grad_pointer: Pointer to a container the first half's gradients are written to.\n The container must be contiguous and contain size elements.\n input2_grad_pointer: Pointer to a container the second half's gradients are written to.\n The container must be contiguous and contain size elements.\n size: Number of elements in each half of the input.\n param: Parameter in the case of parameterized activation functions.\n act_func: Name of activation function to apply.\n Options are 'sigmoid', 'tanh', 'relu', 'gelu', 'silu',\n 'relu6', 'hardsigmoid', 'hardswish', 'selu', 'mish', and 'leaky_relu'.\n BLOCK_SIZE: Block size.\n \"\"\"\n # This program processes BLOCK_SIZE elements.\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n\n output_grad = tl.load(output_grad_pointer + offset, mask=mask)\n input1 = tl.load(input1_pointer + offset, mask=mask)\n input2 = tl.load(input2_pointer + offset, mask=mask)\n\n input1_grad = output_grad * apply_act_func(input2, None, None, None, param,\n act_func, False)\n input2_grad = output_grad * input1 * apply_act_func_grad(1, input2,\n None, None, None,\n param, act_func,\n False)\n\n tl.store(input1_grad_pointer + offset, input1_grad, mask=mask)\n tl.store(input2_grad_pointer + offset, input2_grad, mask=mask)\n", - "description_1": "Use triton language to implement two kernels: glu_forward_kernel and glu_backward_kernel. The glu_forward_kernel takes 7 parameters: input1_pointer, input2_pointer, output_pointer, size, param, act_func, and BLOCK_SIZE. It applies a gated linear unit with an arbitrary activation function to the input. The glu_backward_kernel takes 9 parameters: output_grad_pointer, input1_pointer, input2_pointer, input1_grad_pointer, input2_grad_pointer, size, param, act_func, and BLOCK_SIZE. It calculates the input gradient of the gated linear unit.", - "description_2": "Use triton language to create kernels for forward and backward passes of a gated linear unit with customizable activation functions.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import next_power_of_2\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef layer_norm_forward_kernel(\n input_pointer, weight_pointer, bias_pointer,\n mean_pointer, inv_std_pointer, output_pointer,\n batch_dim, feat_dim,\n input_batch_stride, input_feat_stride,\n output_batch_stride, output_feat_stride,\n eps,\n scale_by_weight: tl.constexpr, add_bias: tl.constexpr, save_stats: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Layer-normalizes the input.\n\n Args:\n input_pointer: Pointer to the input to layer-normalize.\n The input must be of shape [batch_dim, feat_dim].\n weight_pointer: Pointer to optional weights for affine transform.\n The weights, if provided, must be of shape [feat_dim].\n bias_pointer: Pointer to an optional bias vector for affine transform.\n The bias vector, if provided, must be of shape [feat_dim].\n mean_pointer: Pointer to an optional container the input's mean\n is written to if save_stats is True.\n The container, if provided, must be of shape [batch_dim].\n inv_std_pointer: Pointer to an optional container the input's inverse\n standard deviation is written to if save_stats is True.\n The container, if provided, must be of shape [batch_dim].\n output_pointer: Pointer to a container the result is written to.\n The container must be of shape [batch_dim, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output container's batch dimension.\n output_feat_stride: Stride necessary to jump one element along the\n output container's feature dimension.\n eps: Epsilon added in the square root in the denominator\n to avoid division by zero.\n scale_by_weight: Flag for scaling the normalized output by weights.\n add_bias: Flag for adding a bias vector to the normalized output\n if scale_by_weight is True.\n save_stats: Flag for saving the mean and standard deviation.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n output_pointer += (output_batch_stride * batch_offset[:, None] +\n output_feat_stride * feat_offset[None, :])\n\n input = tl.load(input_pointer,\n mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n mean = tl.sum(input, axis=1) / feat_dim\n diff = tl.where(feat_mask[None, :], input - mean[:, None], 0)\n inv_std = tl.rsqrt(tl.sum(diff * diff, axis=1) / feat_dim + eps)\n\n if save_stats:\n tl.store(mean_pointer + batch_offset, mean, mask=batch_mask)\n tl.store(inv_std_pointer + batch_offset, inv_std, mask=batch_mask)\n\n output = diff * inv_std[:, None]\n if scale_by_weight:\n weight = tl.load(weight_pointer + feat_offset, mask=feat_mask)\n output *= weight\n if add_bias:\n bias = tl.load(bias_pointer + feat_offset, mask=feat_mask)\n output += bias\n\n tl.store(output_pointer, output,\n mask=batch_mask[:, None] & feat_mask[None, :])\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef layer_norm_backward_kernel(\n output_grad_pointer, input_pointer, mean_pointer, inv_std_pointer, weight_pointer,\n input_grad_pointer, weight_grad_pointer, bias_grad_pointer,\n batch_dim, feat_dim,\n output_grad_batch_stride, output_grad_feat_stride,\n input_batch_stride, input_feat_stride,\n input_grad_batch_stride, input_grad_feat_stride,\n weight_grad_batch_stride, weight_grad_feat_stride,\n bias_grad_batch_stride, bias_grad_feat_stride,\n scale_by_weight: tl.constexpr, add_bias: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of layer normalization.\n\n Args:\n output_grad_pointer: Pointer to layer normalization's output gradients.\n The output gradients must be of shape [batch_dim, feat_dim].\n input_pointer: Pointer to the input.\n The input must be of shape [batch_dim, feat_dim].\n mean_pointer: Pointer to the input's mean.\n The mean should be of shape [batch_dim].\n inv_std_pointer: Pointer to the input's inverse standard deviation.\n The inverse standard deviation should be of shape [batch_dim].\n weight_pointer: Pointer to optional weights if affine transform occurred.\n The weights, if provided, must be of shape [feat_dim].\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [batch_dim, feat_dim].\n weight_grad_pointer: Pointer to an optional container the weights' row-wise gradients\n are written to if scale_by_weight is True, which should later be summed.\n The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim].\n bias_grad_pointer: Pointer to an optional container the bias vector's row-wise gradients\n are written to if scale_by_weight and add_bias are True, which should later be summed.\n The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n output_grad_batch_stride: Stride necessary to jump one element along the\n output gradients' batch dimension.\n output_grad_feat_stride: Stride necessary to jump one element along the\n output gradients' feature dimension.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n input_grad_batch_stride: Stride necessary to jump one element along the\n input gradient container's batch dimension.\n input_grad_feat_stride: Stride necessary to jump one element along the\n input gradient container's feature dimension.\n weight_grad_batch_stride: Stride necessary to jump one element along the\n weight gradient container's batch dimension.\n weight_grad_feat_stride: Stride necessary to jump one element along the\n weight gradient container's feature dimension.\n bias_grad_batch_stride: Stride necessary to jump one element along the\n weight gradient container's batch dimension.\n bias_grad_feat_stride: Stride necessary to jump one element along the\n weight gradient container's feature dimension.\n scale_by_weight: Flag for scaling the normalized output by weights.\n add_bias: Flag for adding a bias vector to the normalized output\n if scale_by_weight is True.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes a single row and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] +\n output_grad_feat_stride * feat_offset[None, :])\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +\n input_grad_feat_stride * feat_offset[None, :])\n\n output_grad = tl.load(output_grad_pointer,\n mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n mean = tl.load(mean_pointer + batch_offset, mask=batch_mask)\n inv_std = tl.load(inv_std_pointer + batch_offset, mask=batch_mask)\n pre_lin = (input - mean[:, None]) * inv_std[:, None]\n\n if scale_by_weight:\n weight = tl.load(weight_pointer + feat_offset, mask=feat_mask)\n weight_output_grad_prod = weight * output_grad\n\n else:\n weight_output_grad_prod = output_grad\n\n term1 = tl.sum(pre_lin * weight_output_grad_prod, axis=1) / feat_dim\n term1 = pre_lin * term1[:, None]\n term2 = tl.sum(weight_output_grad_prod, axis=1) / feat_dim\n input_grad = (inv_std[:, None] *\n (weight_output_grad_prod - (term1 + term2[:, None])))\n\n tl.store(input_grad_pointer, input_grad,\n mask=batch_mask[:, None] & feat_mask[None, :])\n\n if scale_by_weight:\n weight_grad_pointer += (weight_grad_batch_stride * batch_pid +\n weight_grad_feat_stride * feat_offset)\n tl.store(weight_grad_pointer,\n tl.sum(output_grad * pre_lin, axis=0),\n mask=feat_mask)\n\n if add_bias:\n bias_grad_pointer += (bias_grad_batch_stride * batch_pid +\n bias_grad_feat_stride * feat_offset)\n tl.store(bias_grad_pointer,\n tl.sum(output_grad, axis=0),\n mask=feat_mask)\n", - "description_1": "Use triton language to implement a forward and backward kernel for layer normalization. The forward kernel normalizes input data with optional weights and bias, storing results and optionally storing means and inverse standard deviations. It involves handling batch and feature dimensions, with specific strides for memory access, and uses epsilon to prevent division by zero. The backward kernel calculates gradients with respect to input data, weights, and bias, using pre-computed means and inverse standard deviations, and accommodates optional affine transformations.", - "description_2": "Use triton language to implement layer normalization kernels for forward and backward passes, supporting optional affine transformations and statistics saving.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .act_kernels import apply_act_func\n\ndef linear_forward_config(\n BLOCK_SIZE_BATCH: int,\n BLOCK_SIZE_IN_FEAT: int,\n BLOCK_SIZE_OUT_FEAT: int,\n GROUP_SIZE_BATCH: int = 8,\n n_warps: int = 4,\n n_stages: int = 2,\n ) -> triton.Config:\n return triton.Config({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH,\n 'BLOCK_SIZE_IN_FEAT': BLOCK_SIZE_IN_FEAT,\n 'BLOCK_SIZE_OUT_FEAT': BLOCK_SIZE_OUT_FEAT,\n 'GROUP_SIZE_BATCH': GROUP_SIZE_BATCH},\n num_warps=n_warps, num_stages=n_stages)\n\n@triton.autotune(\n configs=[\n linear_forward_config(32, 32, 32, n_warps=2, n_stages=2),\n linear_forward_config(64, 32, 32, n_warps=2, n_stages=5),\n linear_forward_config(64, 32, 128, n_warps=4, n_stages=4),\n linear_forward_config(64, 32, 256, n_warps=4, n_stages=4),\n linear_forward_config(128, 32, 32, n_warps=4, n_stages=4),\n linear_forward_config(128, 32, 64, n_warps=4, n_stages=4),\n linear_forward_config(128, 32, 128, n_warps=4, n_stages=4),\n linear_forward_config(128, 64, 256, n_warps=8, n_stages=3),\n ],\n key=['batch_dim', 'in_feat_dim', 'out_feat_dim', 'fp16'],\n)\n@triton.heuristics({'tf32': lambda _: True})\n@triton.jit\ndef linear_forward_kernel(\n input_pointer, weight_pointer, bias_pointer, pre_act_pointer, output_pointer,\n batch_dim, in_feat_dim, out_feat_dim,\n input_batch_stride, input_in_feat_stride,\n weight_in_feat_stride, weight_out_feat_stride,\n pre_act_batch_stride, pre_act_out_feat_stride,\n output_batch_stride, output_out_feat_stride, param,\n add_bias: tl.constexpr, act_func: tl.constexpr, save_pre_act: tl.constexpr,\n fp16: tl.constexpr, tf32: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_IN_FEAT: tl.constexpr,\n BLOCK_SIZE_OUT_FEAT: tl.constexpr, GROUP_SIZE_BATCH: tl.constexpr,\n ):\n pid = tl.program_id(axis=0)\n n_batch_pids = tl.cdiv(batch_dim, BLOCK_SIZE_BATCH)\n n_out_feat_pids = tl.cdiv(out_feat_dim, BLOCK_SIZE_OUT_FEAT)\n pids_per_group = GROUP_SIZE_BATCH * n_out_feat_pids\n group_id = pid // pids_per_group\n first_batch_pid = group_id * GROUP_SIZE_BATCH\n GROUP_SIZE_BATCH = min(n_batch_pids - first_batch_pid, GROUP_SIZE_BATCH)\n batch_pid = first_batch_pid + (pid % GROUP_SIZE_BATCH)\n out_feat_pid = (pid % pids_per_group) // GROUP_SIZE_BATCH\n\n batch_offset = (batch_pid * BLOCK_SIZE_BATCH +\n tl.arange(0, BLOCK_SIZE_BATCH))\n out_feat_offset = (out_feat_pid * BLOCK_SIZE_OUT_FEAT +\n tl.arange(0, BLOCK_SIZE_OUT_FEAT))\n\n batch_mask = batch_offset < batch_dim\n out_feat_mask = out_feat_offset < out_feat_dim\n\n input_pointer += input_batch_stride * batch_offset[:, None]\n weight_pointer += weight_out_feat_stride * out_feat_offset[None, :]\n\n accum = tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE_OUT_FEAT),\n dtype=tl.float32)\n\n for block_ind in range(0, tl.cdiv(in_feat_dim, BLOCK_SIZE_IN_FEAT)):\n in_feat_offset = (block_ind * BLOCK_SIZE_IN_FEAT +\n tl.arange(0, BLOCK_SIZE_IN_FEAT))\n in_feat_mask = in_feat_offset < in_feat_dim\n\n curr_input_pointer = (input_pointer +\n input_in_feat_stride * in_feat_offset[None, :])\n curr_weight_pointer = (weight_pointer +\n weight_in_feat_stride * in_feat_offset[:, None])\n\n input_block = tl.load(curr_input_pointer,\n mask=batch_mask[:, None] & in_feat_mask[None, :])\n weight_block = tl.load(curr_weight_pointer,\n mask=out_feat_mask[None, :] & in_feat_mask[:, None])\n\n if fp16:\n input_block = input_block.to(tl.float16)\n weight_block = weight_block.to(tl.float16)\n\n accum += tl.dot(input_block, weight_block, allow_tf32=tf32)\n\n if add_bias:\n bias = tl.load(bias_pointer + out_feat_offset,\n mask=out_feat_mask)\n\n if fp16:\n bias = bias.to(tl.float16)\n\n accum += bias[None, :]\n\n if act_func is not None:\n if save_pre_act:\n pre_act_pointer += (pre_act_batch_stride * batch_offset[:, None] +\n pre_act_out_feat_stride * out_feat_offset[None, :])\n tl.store(pre_act_pointer, accum,\n mask=batch_mask[:, None] & out_feat_mask[None, :])\n\n accum = apply_act_func(accum, None, None, None, param, act_func, False)\n\n output_pointer += (output_batch_stride * batch_offset[:, None] +\n output_out_feat_stride * out_feat_offset[None, :])\n tl.store(output_pointer, accum,\n mask=batch_mask[:, None] & out_feat_mask[None, :])\n", - "description_1": "Use triton language to implement a kernel that performs linear transformation on input data with optional bias addition and activation function. The kernel takes several configuration parameters like BLOCK_SIZE_BATCH, BLOCK_SIZE_IN_FEAT, BLOCK_SIZE_OUT_FEAT, and GROUP_SIZE_BATCH. It supports both FP16 and TF32 data types and allows pre-activation output saving.", - "description_2": "Use triton language to perform batched matrix multiplication with optional bias and activation, configurable via BLOCK_SIZE and support for FP16/TF32.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .act_kernels import apply_act_func\n\n@triton.jit\ndef accum_linear(accum, input1, input2, fp16: tl.constexpr, tf32: tl.constexpr):\n \"\"\"\n Accumulates matrix multiplications of input tensors for linear functions.\n\n Args:\n accum: Accumulator holding aggregation of matrix multiplications.\n The accumulator must be of shape [BLOCK_SIZE1, BLOCK_SIZE3].\n input1: First operand of matrix multiplication.\n The operand must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].\n input2: Second operand of matrix multiplication.\n The operand must be of shape [BLOCK_SIZE2, BLOCK_SIZE3].\n fp16: Flag for converting operands to FP16.\n tf32: Flag for performing matrix multiplication in TF32.\n\n Returns:\n Accumulator with the result of the new matrix multiplication added to it.\n \"\"\"\n if fp16:\n input1 = input1.to(tl.float16)\n input2 = input2.to(tl.float16)\n\n return accum + tl.dot(input1, input2, allow_tf32=tf32)\n\n@triton.jit\ndef glu(input1, input2, param, act_func: tl.constexpr):\n \"\"\"\n Applies the gated linear unit with an arbitrary activation function\n to the input.\n\n Args:\n input1: First half of input to gate.\n The first half must be of the same shape as the second half.\n input2: Second half of input to gate.\n The second half must be of the same shape as the first half.\n param: Parameter in the case of parameterized activation functions.\n act_func: Name of activation function to apply.\n Options are 'sigmoid', 'tanh', 'relu', 'gelu', 'silu',\n 'relu6', 'hardsigmoid', 'hardswish', 'selu', 'mish', and 'leaky_relu'.\n\n Returns:\n Input transformed by the gated linear unit\n with an arbitrary activation function.\n \"\"\"\n return input1 * apply_act_func(input2, None, None, None, param, act_func, False)\n\n@triton.jit\ndef softmax(input, log: tl.constexpr):\n \"\"\"\n Normalizes the input using softmax along the last dimension.\n\n Args:\n input: Input to normalize.\n The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].\n log: Flag for indicating if the log of softmax should be taken.\n\n Returns:\n Input normalized by softmax.\n \"\"\"\n input = input.to(tl.float32)\n\n input = input - tl.max(input, axis=1)[:, None]\n numerator = tl.exp(input)\n denominator = tl.sum(numerator, axis=1)[:, None]\n\n if log:\n output = input - tl.log(denominator)\n\n else:\n output = numerator / denominator\n\n return output\n\n@triton.jit\ndef calc_mean_and_inv_std(input, last_dim, eps, last_dim_mask: tl.constexpr):\n \"\"\"\n Calculates the mean and inverse standard deviation of the input\n along the last dimension.\n\n Args:\n input: Input whose mean and inverse standard deviation are calculated.\n The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].\n last_dim: Size of the last dimension of input.\n eps: Epsilon added in the square root in the denominator\n to avoid division by zero.\n last_dim_mask: Mask for the last dimension indicating\n which elements should be included in the calculations.\n The mask must be of shape [BLOCK_SIZE2].\n\n Returns:\n Mean and inverse standard deviation of the input.\n \"\"\"\n input = input.to(tl.float32)\n\n mean = tl.sum(input, axis=1) / last_dim\n diff = tl.where(last_dim_mask[None, :], input - mean[:, None], 0)\n inv_std = tl.rsqrt(tl.sum(diff * diff, axis=1) / last_dim + eps)\n\n return mean, inv_std\n\n@triton.jit\ndef update_welford(input, prev_count, prev_mean, prev_var, curr_count, mask: tl.constexpr):\n \"\"\"\n Updates count, mean, and variance (M2) statistics for Welford's algorithm.\n\n Args:\n input: Input used to update statistics.\n The input must be of the same shape as the mask.\n prev_count: Previous count statistic to update.\n prev_mean: Previous mean statistic to update.\n prev_var: Previous variance (M2) statistic to update.\n curr_count: Count of elements in current input.\n mask: Mask indicating which elements should be included in the calculations.\n The mask must be of the same shape as the input.\n\n Returns:\n Updated count, mean, and variance (M2) statistics\n \"\"\"\n input = input.to(tl.float32)\n\n count = prev_count + curr_count\n mean = (tl.sum(input) - curr_count * prev_mean) / count\n deltas = tl.where(mask, (input - mean) * (input - prev_mean), 0.)\n var = prev_var + tl.sum(deltas)\n\n return count, mean, var\n\n@triton.jit\ndef update_ema(prev_ema, new_val, momentum):\n \"\"\"\n Updates exponential moving average.\n\n Args:\n prev_ema: Previous exponential moving average.\n new_val: Value used to update the exponential moving average.\n momentum: Momentum.\n\n Returns:\n Updated running statistic.\n \"\"\"\n return (1 - momentum) * prev_ema + momentum * new_val\n\n@triton.jit\ndef standardize(input, mean, inv_std, weight, bias):\n \"\"\"\n Standardizes the input given its mean and inverse standard deviation,\n multiplies the result by weights, and adds a bias vector.\n\n Args:\n input: Input to standardize.\n mean: Mean of input.\n inv_std: Inverse standard deviation of input.\n weight: Weight multiplied by the standardized input.\n bias: Bias added to the result of the weight multiplication.\n\n Returns:\n Standardized input.\n \"\"\"\n return weight * inv_std * (input - mean) + bias\n\n@triton.jit\ndef calc_p_loss(input, target, size, p_loss: tl.constexpr, reduction: tl.constexpr):\n \"\"\"\n Measures the L1 or squared L2 norm of the difference between the input\n and target (i.e., mean absolute error or mean squared error).\n\n Args:\n input: Input.\n The input must be of shape [BLOCK_SIZE].\n target: Target.\n The target must be of shape [BLOCK_SIZE].\n size: Number of elements in the input and target.\n This value is used only if reduction is 'mean'.\n p_loss: p-norm used to compute the error.\n Options are 1 for MAE and 2 for MSE.\n reduction: Reduction strategy for the output.\n Options are 'none' for no reduction, 'mean' for averaging the error\n across all entries, and 'sum' for summing the error across all entries.\n\n Returns:\n Error.\n \"\"\"\n input = input.to(tl.float32)\n target = target.to(tl.float32)\n\n diff = input - target\n\n if p_loss == 1:\n error = tl.abs(diff)\n\n elif p_loss == 2:\n error = diff * diff\n\n if reduction == 'none':\n output = error\n\n elif reduction == 'mean':\n output = tl.sum(error) / size\n\n elif reduction == 'sum':\n output = tl.sum(error)\n\n return output\n\n@triton.jit\ndef nll_loss(input, size, reduction: tl.constexpr):\n \"\"\"\n Measures the negative log likelihood loss given log-probabilities of target class.\n\n Args:\n input: Input containing predicted log-probabilities corresponding to target class.\n The input can have arbitrary shape.\n size: Number of elements in the input.\n This value is used only if reduction is 'mean'.\n reduction: Reduction strategy for the output.\n Options are 'none' for no reduction, 'mean' for averaging the loss\n across all entries, and 'sum' for summing the loss across all entries.\n\n Returns:\n Loss.\n \"\"\"\n input = input.to(tl.float32)\n\n if reduction == 'none':\n output = -input\n\n elif reduction == 'mean':\n output = -tl.sum(input) / size\n\n elif reduction == 'sum':\n output = -tl.sum(input)\n\n return output\n\n@triton.jit\ndef cross_entropy_loss(input, pred):\n \"\"\"\n Measures the per-row cross entropy loss given\n input and predicted logits corresponding to target class.\n\n Args:\n input: Input.\n The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].\n pred: Predicted logits corresponding to target class.\n The predictions must be of shape [BLOCK_SIZE1].\n\n Returns:\n Loss.\n \"\"\"\n input = input.to(tl.float32)\n pred = pred.to(tl.float32)\n\n mx = tl.max(input, axis=1)\n input -= mx[:, None]\n loss = tl.log(tl.sum(tl.exp(input), axis=1)) - pred + mx\n\n return loss\n", - "description_1": "Use triton language to implement various mathematical operations on tensors, including matrix multiplication accumulation, gated linear unit application, softmax normalization, mean and inverse standard deviation calculation, Welford's algorithm for statistics update, exponential moving average update, input standardization, L1/L2 norm loss calculation, negative log likelihood loss, and cross entropy loss.", - "description_2": "Use triton language to perform tensor operations such as matrix multiplication, softmax, and loss calculations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.amp import custom_bwd, custom_fwd\n\n\ndef is_hip():\n return triton.runtime.driver.active.get_current_target().backend == \"hip\"\n\n\n@triton.jit\ndef _fwd_kernel(Q, K, V, sm_scale, #\n L, #\n Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vn, stride_vk, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, N_CTX, #\n Z_H_N_CTX, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n IS_CAUSAL: tl.constexpr #\n ):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n vk_offset = qvk_offset // stride_qm\n\n K_block_ptr = tl.make_block_ptr(\n base=K,\n shape=(BLOCK_DMODEL, Z_H_N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, vk_offset),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_vn, stride_vk),\n offsets=(vk_offset, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # credits to: Adam P. Goucher (https://github.com/apgoucher):\n # scale sm_scale by 1/log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n\n offs_k = tl.arange(0, BLOCK_DMODEL)\n Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n q = tl.load(Q_ptrs)\n\n q = (q * qk_scale).to(K.dtype.element_ty)\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc *= alpha[:, None]\n acc += tl.dot(p.to(V.dtype.element_ty), v)\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n # write back l and m\n acc = acc / l_i[:, None]\n l_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, m_i + tl.math.log2(l_i))\n # write back O\n O_block_ptr = tl.make_block_ptr(\n base=Out,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(vk_offset + start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n tl.store(O_block_ptr, acc.to(K.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out,\n DO,\n Delta,\n BLOCK_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n # compute\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #\n Out, DO, #\n DQ, DK, DV, #\n L, #\n D, #\n Q_block_ptr, K_block_ptr, V_block_ptr, #\n DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #\n stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vn, stride_vk, #\n Z, H, N_CTX, #\n off_h, off_z, off_hz, start_n, num_block, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n SEQUENCE_PARALLEL: tl.constexpr, #\n CAUSAL: tl.constexpr, #\n MMA_V3: tl.constexpr #\n ):\n if CAUSAL:\n lo = start_n * BLOCK_M\n else:\n lo = 0\n\n Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm\n DQ_offset = off_z * stride_qz + off_h * stride_qh\n K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn\n V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn\n if SEQUENCE_PARALLEL:\n DQ_offset += stride_dqa * start_n\n DQ_offset = DQ_offset // stride_qm\n\n Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))\n K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))\n V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))\n DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))\n DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))\n DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))\n DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))\n\n # initialize row/col offsets\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n l_ptrs = L + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(Q_block_ptr)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n if CAUSAL:\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float(\"-inf\"))\n else:\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n qk *= qk_scale\n l_i = tl.load(l_ptrs + offs_m_curr)\n p = tl.math.exp2(qk - l_i[:, None])\n # compute dv\n do = tl.load(DO_block_ptr)\n dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp = tl.dot(do, tl.trans(v))\n # compute ds = p * (dp - delta[:, None])\n ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty)\n # compute dk = dot(ds.T, q)\n dk += tl.dot(tl.trans(ds), q)\n # compute dq\n if not SEQUENCE_PARALLEL:\n dq = tl.load(DQ_block_ptr)\n dq += tl.dot(ds, k)\n tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))\n elif SEQUENCE_PARALLEL:\n if MMA_V3:\n dq = tl.dot(ds, k)\n else:\n # not work with mma v3, because M % 64 != 0\n dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))\n tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))\n\n # increment pointers\n DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))\n Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))\n DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))\n # write-back\n tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))\n tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_kernel(Q, K, V, sm_scale, #\n Out, DO, #\n DQ, DK, DV, #\n L, #\n D, #\n stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vn, stride_vk, #\n Z, H, N_CTX, #\n Z_H_N_CTX, #\n SQ_Z_H_N_CTX, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n SEQUENCE_PARALLEL: tl.constexpr, #\n CAUSAL: tl.constexpr, #\n MMA_V3: tl.constexpr #\n ):\n qk_scale = sm_scale * 1.44269504\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n\n Q_block_ptr = tl.make_block_ptr(\n base=Q,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_vn, stride_vk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n DO_block_ptr = tl.make_block_ptr(\n base=DO,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n if SEQUENCE_PARALLEL:\n DQ_block_ptr = tl.make_block_ptr(\n base=DQ,\n shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n else:\n DQ_block_ptr = tl.make_block_ptr(\n base=DQ,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n\n DK_block_ptr = tl.make_block_ptr(\n base=DK,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n DV_block_ptr = tl.make_block_ptr(\n base=DV,\n shape=(Z_H_N_CTX, BLOCK_DMODEL),\n strides=(stride_vn, stride_vk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n\n num_block_n = tl.cdiv(N_CTX, BLOCK_N)\n if not SEQUENCE_PARALLEL:\n for start_n in range(0, num_block_n):\n _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #\n DQ, DK, DV, #\n L, #\n D, #\n Q_block_ptr, K_block_ptr, V_block_ptr, #\n DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #\n stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vn, stride_vk, #\n Z, H, N_CTX, #\n off_h, off_z, off_hz, start_n, num_block_n, #\n BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #\n BLOCK_N=BLOCK_N, #\n SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #\n CAUSAL=CAUSAL, #\n MMA_V3=MMA_V3 #\n )\n else:\n start_n = tl.program_id(1)\n _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #\n DQ, DK, DV, #\n L, #\n D, #\n Q_block_ptr, K_block_ptr, V_block_ptr, #\n DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #\n stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vn, stride_vk, #\n Z, H, N_CTX, #\n off_h, off_z, off_hz, start_n, num_block_n, #\n BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #\n BLOCK_N=BLOCK_N, #\n SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #\n CAUSAL=CAUSAL, #\n MMA_V3=MMA_V3 #\n )\n\n\nclass _attention(torch.autograd.Function):\n @staticmethod\n @custom_fwd(device_type='cuda')\n def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):\n # only support for Ampere now\n capability = torch.cuda.get_device_capability()\n if capability[0] < 8:\n raise RuntimeError(\"Flash attention currently only supported for compute capability >= 80\")\n BLOCK_M = 128\n BLOCK_N = 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale, #\n L, #\n o, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], q.shape[2], #\n q.shape[0] * q.shape[1] * q.shape[2], #\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #\n IS_CAUSAL=causal, #\n num_warps=num_warps, #\n num_stages=4 #\n )\n\n ctx.save_for_backward(q, k, v, o, L)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n ctx.causal = causal\n ctx.sequence_parallel = sequence_parallel\n return o\n\n @staticmethod\n @custom_bwd(device_type='cuda')\n def backward(ctx, do):\n capability = torch.cuda.get_device_capability()\n MMA_V3 = capability[0] >= 9\n BLOCK = 128\n\n if is_hip():\n # Bwd pass runs out of shared memory on HIP with larger block size.\n BLOCK = 64\n\n q, k, v, o, L = ctx.saved_tensors\n sequence_parallel = ctx.sequence_parallel\n seq_len_kv = k.shape[2]\n do = do.contiguous()\n if sequence_parallel:\n replicas = triton.cdiv(seq_len_kv, BLOCK)\n new_dq_shape = (replicas, ) + q.shape\n dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)\n else:\n dq = torch.zeros_like(q, dtype=q.dtype)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n delta = torch.empty_like(L)\n _bwd_preprocess[(triton.cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](\n o,\n do,\n delta,\n BLOCK_M=BLOCK,\n D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1], triton.cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](\n q, k, v, ctx.sm_scale, #\n o, do, #\n dq, dk, dv, #\n L, #\n delta, #\n o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n q.shape[0], q.shape[1], q.shape[2], #\n q.shape[0] * q.shape[1] * q.shape[2], #\n triton.cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, #\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, #\n SEQUENCE_PARALLEL=sequence_parallel, #\n CAUSAL=ctx.causal, #\n MMA_V3=MMA_V3, #\n num_warps=8, #\n num_stages=1 #\n )\n\n if len(dq.shape) == 5:\n dq = dq.sum(dim=0)\n return dq, dk, dv, None, None, None\n", - "description_1": "Use triton language to implement multi-headed attention kernels, including forward and backward kernels. The forward kernel '_fwd_kernel' takes 30 parameters and performs block-wise operations on input tensors Q, K, V, and computes the attention output. The backward preprocessing kernel '_bwd_preprocess' takes 4 parameters and calculates delta values for gradient computations. The backward kernel '_bwd_kernel' and its auxiliary function '_bwd_kernel_one_col_block' perform backpropagation to compute gradients for input tensors, taking 36 and 38 parameters respectively. The function '_attention', a subclass of torch.autograd.Function, manages the forward and backward passes by invoking these Triton kernels.", - "description_2": "Use triton language to perform flash attention, with kernels for forward and backward passes, optimized for GPUs with compute capability >= 80.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import next_power_of_2\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'spatial_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_SPATIAL': lambda args: next_power_of_2(args['spatial_dim'])})\n@triton.jit\ndef nll_loss_forward_kernel(\n input_pointer, target_pointer, weight_pointer,\n sum_weights_pointer, output_pointer,\n batch_dim, spatial_dim,\n input_batch_stride, input_feat_stride, input_spatial_stride,\n target_batch_stride, target_spatial_stride,\n output_batch_stride, output_spatial_stride,\n reduction: tl.constexpr, weighted: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_SPATIAL: tl.constexpr,\n ):\n \"\"\"\n Measures the negative log likelihood loss between the input and target,\n with optional reweighing of each class.\n\n Args:\n input_pointer: Pointer to the input.\n The input must be of shape [batch_dim, feat_dim, spatial_dim].\n target_pointer: Pointer to the target.\n The target must be of shape [batch_dim, spatial_dim].\n weight_pointer: Pointer to an optional class weight vector.\n The class weight vector, if provided, must be of shape [feat_dim].\n sum_weights_pointer: Pointer to a container the sum of the class weights is written to.\n The container must be of shape [batch_dim/BLOCK_SIZE_BATCH].\n output_pointer: Pointer to a container the loss is written to.\n The container must be of shape [batch_dim, spatial_dim] if reduction is 'none',\n and otherwise of shape [batch_dim/BLOCK_SIZE].\n batch_dim: Batch dimension.\n spatial_dim: Spatial dimension.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n input_spatial_stride: Stride necessary to jump one element along the\n input's spatial dimension.\n target_batch_stride: Stride necessary to jump one element along the\n target's batch dimension.\n target_spatial_stride: Stride necessary to jump one element along the\n target's spatial dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output container's batch dimension.\n output_spatial_stride: Stride necessary to jump one element along the\n output container's spatial dimension.\n reduction: Reduction strategy for the output.\n Options are 'none' for no reduction, 'mean' for averaging the loss\n across all entries, and 'sum' for summing the loss across all entries.\n If a reduction method is specified, the reduced result of each\n program is written to a separate index in the summed weights and\n output container, which should later be summed.\n weighted: Flag for weighing each class.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_SPATIAL: Block size across the spatial dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and\n # BLOCK_SIZE_SPATIAL spatial elements.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n spatial_offset = tl.arange(0, BLOCK_SIZE_SPATIAL)\n\n batch_mask = batch_offset < batch_dim\n spatial_mask = spatial_offset < spatial_dim\n\n target_pointer += (target_batch_stride * batch_offset[:, None] +\n target_spatial_stride * spatial_offset[None, :])\n target = tl.load(target_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n\n input_pointer += (input_feat_stride * target +\n input_batch_stride * batch_offset[:, None] +\n input_spatial_stride * spatial_offset[None, :])\n input = tl.load(input_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n\n output = -input\n if weighted:\n weight = tl.load(weight_pointer + target,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n output *= weight\n\n if reduction == 'none':\n output_pointer += (output_batch_stride * batch_offset[:, None] +\n output_spatial_stride * spatial_offset[None, :])\n tl.store(output_pointer, output,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n\n elif reduction == 'mean':\n if weighted:\n tl.store(sum_weights_pointer + batch_pid, tl.sum(weight))\n tl.store(output_pointer + batch_pid, tl.sum(output))\n\n else:\n tl.store(output_pointer + batch_pid,\n tl.sum(output) / (batch_dim * spatial_dim))\n\n elif reduction == 'sum':\n tl.store(output_pointer + batch_pid, tl.sum(output))\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'spatial_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_SPATIAL': lambda args: next_power_of_2(args['spatial_dim'])})\n@triton.jit\ndef nll_loss_backward_kernel(\n output_grad_pointer, target_pointer, weight_pointer,\n sum_weights_pointer, input_grad_pointer,\n batch_dim, spatial_dim,\n output_grad_batch_stride, output_grad_feat_stride,\n target_batch_stride, target_spatial_stride,\n input_grad_batch_stride, input_grad_feat_stride, input_grad_spatial_stride,\n reduction: tl.constexpr, weighted: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_SPATIAL: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of negative log likelihood loss.\n\n Args:\n output_grad_pointer: Pointer to the loss's output gradients.\n The output gradients must be of shape [batch_dim, spatial_dim]\n if reduction is 'none', and otherwise [batch_dim/BLOCK_SIZE_BATCH].\n target_pointer: Pointer to the target.\n The target must be of shape [batch_dim, spatial_dim].\n weight_pointer: Pointer to an optional class weight vector.\n The class weight vector, if provided, must be of shape [feat_dim].\n sum_weights_pointer: Pointer to the sum of the class weights if the classes were weighed.\n The sum of weights must be a scalar.\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [batch_dim, feat_dim, spatial_dim] and zeroed.\n batch_dim: Batch dimension.\n spatial_dim: Spatial dimension.\n output_grad_batch_stride: Stride necessary to jump one element along the\n output gradients' batch dimension.\n output_grad_feat_stride: Stride necessary to jump one element along the\n output gradients' feature dimension.\n input_spatial_stride: Stride necessary to jump one element along the\n input's spatial dimension.\n target_batch_stride: Stride necessary to jump one element along the\n target's batch dimension.\n target_spatial_stride: Stride necessary to jump one element along the\n target's spatial dimension.\n input_grad_batch_stride: Stride necessary to jump one element along the\n input gradient container's batch dimension.\n input_grad_feat_stride: Stride necessary to jump one element along the\n input gradient container's feature dimension.\n input_grad_spatial_stride: Stride necessary to jump one element along the\n input gradient container's spatial dimension.\n reduction: Reduction strategy for the output whose gradient is calculated.\n Options are 'none' for no reduction, 'mean' for averaging the loss\n across all entries, and 'sum' for summing the loss across all entries.\n weighted: Flag for weighing each class.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_SPATIAL: Block size across the spatial dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and\n # BLOCK_SIZE_SPATIAL spatial elements.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n spatial_offset = tl.arange(0, BLOCK_SIZE_SPATIAL)\n\n batch_mask = batch_offset < batch_dim\n spatial_mask = spatial_offset < spatial_dim\n\n output_grad_mask = None\n if reduction == 'none':\n output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] +\n output_grad_feat_stride * spatial_offset[None, :])\n output_grad_mask = batch_mask[:, None] & spatial_mask[None, :]\n\n output_grad = tl.load(output_grad_pointer, mask=output_grad_mask).to(tl.float32)\n input_grad = -output_grad\n\n target_pointer += (target_batch_stride * batch_offset[:, None] +\n target_spatial_stride * spatial_offset[None, :])\n target = tl.load(target_pointer,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n\n if weighted:\n weight = tl.load(weight_pointer + target,\n mask=batch_mask[:, None] & spatial_mask[None, :]).to(tl.float32)\n input_grad *= weight\n\n if reduction == 'mean':\n input_grad /= tl.load(sum_weights_pointer)\n\n elif reduction == 'mean':\n input_grad /= batch_dim * spatial_dim\n\n input_grad_pointer += (input_grad_feat_stride * target +\n input_grad_batch_stride * batch_offset[:, None] +\n input_grad_spatial_stride * spatial_offset[None, :])\n tl.store(input_grad_pointer, input_grad,\n mask=batch_mask[:, None] & spatial_mask[None, :])\n", - "description_1": "Use triton language to implement two kernels: nll_loss_forward_kernel and nll_loss_backward_kernel. The forward kernel computes the negative log likelihood loss between input and target with optional class weighting and reduction strategies. It takes 18 parameters including pointers to input, target, weight, sum_weights, and output, dimensions, strides, reduction strategy, weighting flag, and block sizes. The backward kernel calculates the gradient of the input for the negative log likelihood loss, taking 19 parameters including pointers to output gradients, target, weight, sum_weights, input gradients, dimensions, strides, reduction strategy, weighting flag, and block sizes.", - "description_2": "Use triton language to create kernels for computing negative log likelihood loss and its gradient, supporting class weighting and reduction strategies.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .utils import element_wise_kernel_configs\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef p_loss_forward_kernel(\n input_pointer, target_pointer, output_pointer,\n size, p_loss: tl.constexpr, reduction: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n ):\n \"\"\"\n Measures the L1 or squared L2 norm of the difference between the input\n and target (i.e., mean absolute error or mean squared error).\n\n Args:\n input_pointer: Pointer to the input.\n The input must be of shape [size].\n target_pointer: Pointer to the target.\n The target must be of shape [size].\n output_pointer: Pointer to a container the error is written to.\n The container must be of shape [size] if reduction is 'none',\n and otherwise of shape [size/BLOCK_SIZE].\n size: Number of elements in the input and target.\n p_loss: p-norm used to compute the error.\n Options are 1 for MAE and 2 for MSE.\n reduction: Reduction strategy for the output.\n Options are 'none' for no reduction, 'mean' for averaging the error\n across all entries, and 'sum' for summing the error across all entries.\n If a reduction method is specified, the reduced result of each\n program is written to a separate index in the output container,\n which should later be summed.\n BLOCK_SIZE: Block size.\n \"\"\"\n # This program processes BLOCK_SIZE rows.\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n\n input = tl.load(input_pointer + offset, mask=mask).to(tl.float32)\n target = tl.load(target_pointer + offset, mask=mask).to(tl.float32)\n diff = input - target\n\n if p_loss == 1:\n error = tl.abs(diff)\n\n elif p_loss == 2:\n error = diff * diff\n\n if reduction == 'none':\n tl.store(output_pointer + offset, error, mask=mask)\n\n elif reduction == 'mean':\n tl.store(output_pointer + pid, tl.sum(error) / size)\n\n elif reduction == 'sum':\n tl.store(output_pointer + pid, tl.sum(error))\n\n\n@triton.autotune(\n configs=element_wise_kernel_configs(),\n key=['size'],\n)\n@triton.jit\ndef p_loss_backward_kernel(\n output_grad_pointer, input_pointer, target_pointer,\n input_grad_pointer, target_grad_pointer, size,\n p_loss: tl.constexpr, reduction: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of the mean absolute error or\n mean squared error.\n\n Args:\n output_grad_pointer: Pointer to the error's output gradients.\n The output gradients must be a scalar or of shape [size].\n input_pointer: Pointer to the input.\n The input must be of shape [size].\n target_pointer: Pointer to the target.\n The target must be of shape [size].\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [size].\n target_grad_pointer: Pointer to a container the target's gradients are written to.\n The container must be of shape [size].\n size: Number of elements in the input and target.\n p_loss: p-norm used to compute the error whose gradient is calculated.\n Options are 1 for MAE and 2 for MSE.\n reduction: Reduction strategy for the output whose gradient is calculated.\n Options are 'none' for no reduction, 'mean' for averaging the error\n across all entries, and 'sum' for summing the error across all entries.\n BLOCK_SIZE: Block size.\n \"\"\"\n # This program processes BLOCK_SIZE rows.\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n\n output_grad_mask = None\n if reduction == 'none':\n output_grad_pointer += offset\n output_grad_mask = mask\n\n input = tl.load(input_pointer + offset, mask=mask).to(tl.float32)\n target = tl.load(target_pointer + offset, mask=mask).to(tl.float32)\n output_grad = tl.load(output_grad_pointer, mask=output_grad_mask).to(tl.float32)\n\n if p_loss == 1:\n input_grad = tl.where(target <= input, 1, -1)\n\n elif p_loss == 2:\n input_grad = 2 * (input - target)\n\n if reduction == 'mean':\n input_grad /= size\n\n input_grad *= output_grad\n tl.store(input_grad_pointer + offset, input_grad, mask=mask)\n tl.store(target_grad_pointer + offset, -input_grad, mask=mask)\n", - "description_1": "Use triton language to implement two kernels: p_loss_forward_kernel and p_loss_backward_kernel. The p_loss_forward_kernel computes the L1 or squared L2 norm of the difference between input and target, with options for reduction ('none', 'mean', 'sum'). It takes 7 parameters: input_pointer, target_pointer, output_pointer, size, p_loss, reduction, and BLOCK_SIZE. The p_loss_backward_kernel calculates the gradient of the input for the mean absolute error or mean squared error, with similar reduction options. It takes 9 parameters: output_grad_pointer, input_pointer, target_pointer, input_grad_pointer, target_grad_pointer, size, p_loss, reduction, and BLOCK_SIZE.", - "description_2": "Use triton language to create kernels for computing p-norm-induced losses and their gradients, supporting L1 and L2 norms with various reduction strategies.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import next_power_of_2\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef rms_norm_forward_kernel(\n input_pointer, weight_pointer,\n inv_rms_pointer, output_pointer,\n batch_dim, feat_dim,\n input_batch_stride, input_feat_stride,\n output_batch_stride, output_feat_stride,\n eps,\n scale_by_weight: tl.constexpr, save_stats: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Root-mean-square-normalizes the input.\n\n Args:\n input_pointer: Pointer to the input to root-mean-square-normalize.\n The input must be of shape [batch_dim, feat_dim].\n weight_pointer: Pointer to optional weights for linear transform.\n The weights, if provided, must be of shape [feat_dim].\n inv_rms_pointer: Pointer to an optional container the input's inverse\n root mean square is written to if save_stats is True.\n The container, if provided, must be of shape [batch_dim].\n output_pointer: Pointer to a container the result is written to.\n The container must be of shape [batch_dim, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output container's batch dimension.\n output_feat_stride: Stride necessary to jump one element along the\n output container's feature dimension.\n eps: Epsilon added in the square root in the denominator\n to avoid division by zero.\n scale_by_weight: Flag for scaling the normalized output by weights.\n save_stats: Flag for saving the root mean square.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n output_pointer += (output_batch_stride * batch_offset[:, None] +\n output_feat_stride * feat_offset[None, :])\n\n input = tl.load(input_pointer,\n mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n inv_rms = tl.rsqrt(tl.sum(input * input, axis=1) / feat_dim + eps)\n output = input * inv_rms[:, None]\n\n if save_stats:\n tl.store(inv_rms_pointer + batch_offset, inv_rms, mask=batch_mask)\n\n if scale_by_weight:\n weight = tl.load(weight_pointer + feat_offset, mask=feat_mask)\n output *= weight\n\n tl.store(output_pointer, output,\n mask=batch_mask[:, None] & feat_mask[None, :])\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef rms_norm_backward_kernel(\n output_grad_pointer, input_pointer, inv_rms_pointer, weight_pointer,\n input_grad_pointer, weight_grad_pointer,\n batch_dim, feat_dim,\n output_grad_batch_stride, output_grad_feat_stride,\n input_batch_stride, input_feat_stride,\n input_grad_batch_stride, input_grad_feat_stride,\n weight_grad_batch_stride, weight_grad_feat_stride,\n scale_by_weight: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of root mean square normalization.\n\n Args:\n output_grad_pointer: Pointer to root mean square normalization's output gradients.\n The output gradients must be of shape [batch_dim, feat_dim].\n input_pointer: Pointer to the input.\n The input must be of shape [batch_dim, feat_dim].\n inv_rms_pointer: Pointer to the input's inverse root mean square.\n The inverse root mean square should be of shape [batch_dim].\n weight_pointer: Pointer to optional weights if affine transform occurred.\n The weights, if provided, must be of shape [feat_dim].\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [batch_dim, feat_dim].\n weight_grad_pointer: Pointer to an optional container the weights' row-wise gradients\n are written to if scale_by_weight is True, which should later be summed.\n The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim].\n bias_grad_pointer: Pointer to an optional container the bias vector's row-wise gradients\n are written to if scale_by_weight and add_bias are True, which should later be summed.\n The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n output_grad_batch_stride: Stride necessary to jump one element along the\n output gradients' batch dimension.\n output_grad_feat_stride: Stride necessary to jump one element along the\n output gradients' feature dimension.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n input_grad_batch_stride: Stride necessary to jump one element along the\n input gradient container's batch dimension.\n input_grad_feat_stride: Stride necessary to jump one element along the\n input gradient container's feature dimension.\n weight_grad_batch_stride: Stride necessary to jump one element along the\n weight gradient container's batch dimension.\n weight_grad_feat_stride: Stride necessary to jump one element along the\n weight gradient container's feature dimension.\n scale_by_weight: Flag for scaling the normalized output by weights.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes a single row and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] +\n output_grad_feat_stride * feat_offset[None, :])\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +\n input_grad_feat_stride * feat_offset[None, :])\n\n output_grad = tl.load(output_grad_pointer,\n mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n inv_rms = tl.load(inv_rms_pointer + batch_offset, mask=batch_mask)\n pre_lin = input * inv_rms[:, None]\n\n if scale_by_weight:\n weight = tl.load(weight_pointer + feat_offset, mask=feat_mask)\n weight_output_grad_prod = weight * output_grad\n\n else:\n weight_output_grad_prod = output_grad\n\n term1 = input * tl.sum(input * weight_output_grad_prod, axis=1)\n term2 = inv_rms[:, None] * inv_rms[:, None]\n input_grad = (inv_rms[:, None] *\n (weight_output_grad_prod - term1 * term2 / feat_dim))\n\n tl.store(input_grad_pointer, input_grad,\n mask=batch_mask[:, None] & feat_mask[None, :])\n\n if scale_by_weight:\n weight_grad_pointer += (weight_grad_batch_stride * batch_pid +\n weight_grad_feat_stride * feat_offset)\n tl.store(weight_grad_pointer,\n tl.sum(output_grad * pre_lin, axis=0),\n mask=feat_mask)\n", - "description_1": "Use triton language to define and execute forward and backward root mean square normalization kernels. The forward kernel takes pointers to input data, weights, inverse RMS, and output containers, along with dimensions, strides, epsilon for numerical stability, and flags for scaling and saving stats. It normalizes the input data across specified dimensions. The backward kernel similarly takes pointers, dimensions, strides, and a scaling flag, and computes gradients for input data and optionally for weights. Both kernels optimize execution with block size configurations and handle masks for out-of-bound elements.", - "description_2": "Use triton language to create RMS normalization kernels for both forward and backward passes, with support for optional weight scaling and statistical saving, using block size optimizations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import next_power_of_2\nfrom .utils import warps_kernel_configs\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': lambda args: (min(max(1, next_power_of_2(args['batch_dim'] // 2 ** 10)), 128) if args['feat_dim'] < 64 else 1),\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef softmax_forward_kernel(\n input_pointer, output_pointer,\n batch_dim, feat_dim,\n input_batch_stride, input_feat_stride,\n output_batch_stride, output_feat_stride,\n log: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Normalizes the input using softmax.\n\n Args:\n input_pointer: Pointer to the input to normalize.\n The input must be of shape [batch_dim, feat_dim].\n output_pointer: Pointer to a container the result is written to.\n The container must be of shape [batch_dim, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output container's batch dimension.\n output_feat_stride: Stride necessary to jump one element along the\n output container's feature dimension.\n log: Flag for indicating if the log of softmax should be taken.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n input_pointer += (input_batch_stride * batch_offset[:, None] +\n input_feat_stride * feat_offset[None, :])\n output_pointer += (output_batch_stride * batch_offset[:, None] +\n output_feat_stride * feat_offset[None, :])\n\n input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :],\n other=-float('inf')).to(tl.float32)\n input -= tl.max(input, axis=1)[:, None]\n numerator = tl.exp(input)\n denominator = tl.sum(numerator, axis=1)[:, None]\n\n if log:\n output = input - tl.log(denominator)\n\n else:\n output = numerator / denominator\n\n tl.store(output_pointer, output, mask=batch_mask[:, None] & feat_mask[None, :])\n\n\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=['batch_dim', 'feat_dim'],\n)\n@triton.heuristics({'BLOCK_SIZE_BATCH': lambda args: (min(max(1, next_power_of_2(args['batch_dim'] // 2 ** 10)), 128) if args['feat_dim'] < 64 else 1),\n 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})\n@triton.jit\ndef softmax_backward_kernel(\n output_grad_pointer, output_pointer, input_grad_pointer,\n batch_dim, feat_dim,\n output_grad_batch_stride, output_grad_feat_stride,\n output_batch_stride, output_feat_stride,\n input_grad_batch_stride, input_grad_feat_stride,\n log: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,\n ):\n \"\"\"\n Calculates the input gradient of softmax.\n\n Args:\n output_grad_pointer: Pointer to softmax's output gradients.\n The output gradients must be of shape [batch_dim, feat_dim].\n output_pointer: Pointer to softmax's output.\n The output must be of shape [batch_dim, feat_dim].\n input_grad_pointer: Pointer to a container the input's gradients are written to.\n The container must be of shape [batch_dim, feat_dim].\n batch_dim: Batch dimension.\n feat_dim: Dimensionality of the features.\n output_grad_batch_stride: Stride necessary to jump one element along the\n output gradients' batch dimension.\n output_grad_feat_stride: Stride necessary to jump one element along the\n output gradients' feature dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output's batch dimension.\n output_feat_stride: Stride necessary to jump one element along the\n output's feature dimension.\n input_grad_batch_stride: Stride necessary to jump one element along the\n input gradient container's batch dimension.\n input_grad_feat_stride: Stride necessary to jump one element along the\n input gradient container's feature dimension.\n log: Flag indicating if log of softmax was taken.\n BLOCK_SIZE_BATCH: Block size across the batch dimension.\n BLOCK_SIZE_FEAT: Block size across the feature dimension.\n \"\"\"\n # This program processes a single row and BLOCK_SIZE_FEAT columns.\n batch_pid = tl.program_id(axis=0)\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)\n\n batch_mask = batch_offset < batch_dim\n feat_mask = feat_offset < feat_dim\n\n output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] +\n output_grad_feat_stride * feat_offset[None, :])\n output_pointer += (output_batch_stride * batch_offset[:, None] +\n output_feat_stride * feat_offset[None, :])\n input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +\n input_grad_feat_stride * feat_offset[None, :])\n\n output_grad = tl.load(output_grad_pointer,\n mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n output = tl.load(output_pointer,\n mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)\n\n if log:\n input_grad = (output_grad -\n tl.exp(output) * tl.sum(output_grad, axis=1)[:, None])\n\n else:\n input_grad = output * (output_grad -\n tl.sum(output_grad * output, axis=1)[:, None])\n\n tl.store(input_grad_pointer, input_grad,\n mask=batch_mask[:, None] & feat_mask[None, :])\n", - "description_1": "Use triton language to implement two kernels: softmax_forward_kernel and softmax_backward_kernel. The softmax_forward_kernel takes 10 parameters: input_pointer, output_pointer, batch_dim, feat_dim, input_batch_stride, input_feat_stride, output_batch_stride, output_feat_stride, log, BLOCK_SIZE_BATCH, and BLOCK_SIZE_FEAT. It normalizes the input using softmax and writes the result to the output_pointer. The softmax_backward_kernel takes 13 parameters: output_grad_pointer, output_pointer, input_grad_pointer, batch_dim, feat_dim, output_grad_batch_stride, output_grad_feat_stride, output_batch_stride, output_feat_stride, input_grad_batch_stride, input_grad_feat_stride, log, BLOCK_SIZE_BATCH, and BLOCK_SIZE_FEAT. It calculates the input gradient of softmax and writes it to the input_grad_pointer.", - "description_2": "Use triton language to create a softmax forward kernel that normalizes input data and a backward kernel that computes gradients for backpropagation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, stride_x_row, stride_y_row,\n stride_res_row, stride_res_out_row, N, eps, IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, W, B, Y, DY, DX, DW, DB, DRESIDUAL, DRESIDUAL_IN, Mean, Rstd, stride_x_row, \n stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row, stride_dres_in_row,\n M, N, eps, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, \n HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, \n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy, x, weight, bias, eps, mean, rstd, dresidual=None, has_residual=False, is_rms_norm=False, \n x_dtype=None, recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement a kernel for layer normalization with support for residual connections and RMS norm. The kernel '_layer_norm_fwd_1pass_kernel' accepts 21 parameters: input data pointer X, output pointer Y, weights pointer W, biases pointer B, residual pointer RESIDUAL, output residual pointer RESIDUAL_OUT, mean pointer Mean, rstd pointer Rstd, strides for input/output/residual as stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, number of columns N, epsilon eps for numerical stability, boolean constexpr for RMS norm IS_RMS_NORM, block size BLOCK_N, boolean flags HAS_RESIDUAL, STORE_RESIDUAL_OUT, HAS_BIAS. It normalizes input data using the mean and variance computed along the last axis and applies a linear transformation using weights and biases.", - "description_2": "Use triton language to implement a backward kernel for layer normalization. The kernel '_layer_norm_bwd_kernel' accepts 28 parameters: input pointer X, weights pointer W, biases pointer B, output pointer Y, gradient of output DY, gradient of input DX, partial weight gradient DW, partial bias gradient DB, gradient of residual DRESIDUAL, input gradient for residual DRESIDUAL_IN, mean pointer Mean, rstd pointer Rstd, strides for input/output gradients as stride_x_row, stride_y_row, stride_dy_row, stride_dx_row, strides for residual gradients as stride_dres_row, stride_dres_in_row, number of rows M, columns N, epsilon eps, rows per program rows_per_program, and boolean constexpr flags for RMS norm IS_RMS_NORM, block size BLOCK_N, flags HAS_DRESIDUAL, STORE_DRESIDUAL, HAS_BIAS, RECOMPUTE_OUTPUT. It computes the gradient of the input, weights, and biases using the chain rule, accommodating possible recomputation of the output.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 35 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to call the kernel with appropriate grid and block size configurations.", - "description_2": "Use triton language to implement a state update kernel with optional bias and scaling, and a wrapper function to configure and call the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _update_step(\n # Pointers to matrices\n kv_state_ptr, v_ptr, k_ptr, q_ptr, out_ptr,\n # Matrix dimensions\n dim, dstate,\n # Strides\n stride_kv_state_batch, stride_kv_state_dim, stride_kv_state_dstate,\n stride_v_batch, stride_v_dim,\n stride_k_batch, stride_k_dstate,\n stride_q_batch, stride_q_dstate,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n kv_state_ptr += pid_b * stride_kv_state_batch\n v_ptr += pid_b * stride_v_batch\n k_ptr += pid_b * stride_k_batch\n q_ptr += pid_b * stride_q_batch\n\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n kv_state_ptrs = kv_state_ptr + (offs_m[:, None] * stride_kv_state_dim + offs_n[None, :] * stride_kv_state_dstate)\n v_ptrs = v_ptr + offs_m * stride_v_dim\n k_ptrs = k_ptr + offs_n * stride_k_dstate\n q_ptrs = q_ptr + offs_n * stride_q_dstate\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n kv_state = tl.load(kv_state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n V = tl.load(v_ptrs, mask=offs_m < dim, other=0.0)\n K = tl.load(k_ptrs, mask=offs_n < dstate, other=0.0)\n Q = tl.load(q_ptrs, mask=offs_n < dstate, other=0.0)\n\n kv_state = kv_state + K[None, :] * V[:, None]\n tl.store(kv_state_ptrs, kv_state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n num = tl.sum(kv_state * Q[None, :], axis=1)\n tl.store(out_ptrs, num, mask=offs_m < dim)\n\n\ndef lin_attn_step(\n kv_state, \n v, k, q\n):\n \"\"\"\n Argument:\n kv state: (batch, dim, dstate)\n v: (batch, dim)\n k: (batch, dstate)\n q: (batch, dstate)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = kv_state.shape\n assert v.shape == (batch, dim)\n assert k.shape == (batch, dstate)\n assert q.shape == k.shape\n\n out = torch.empty_like(v)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n BLOCK_SIZE_M, num_warps = (4, 8)\n\n with torch.cuda.device(v.device.index):\n _update_step[grid](\n kv_state, v, k, q, out,\n dim, dstate,\n kv_state.stride(0), kv_state.stride(1), kv_state.stride(2),\n v.stride(0), v.stride(1),\n k.stride(0), k.stride(1),\n q.stride(0), q.stride(1),\n out.stride(0), out.stride(1),\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel (_update_step) that performs matrix updates and accumulations. The kernel has 5 pointer parameters (kv_state_ptr, v_ptr, k_ptr, q_ptr, out_ptr) for matrices, 2 integer parameters (dim, dstate) representing dimensions, and multiple stride parameters for memory access. The kernel utilizes meta-parameters BLOCK_SIZE_M and BLOCK_SIZE_DSTATE for block processing. A function (lin_attn_step) wraps this kernel to operate on torch tensors, setting the grid for execution and preparing inputs.", - "description_2": "Use triton language to perform batched matrix updates and attentions, with kernel configurations for block sizes and memory strides.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Triton kernel for forward cross-entropy computation with logits normalization.\n Arguments:\n - logits_ptr: Pointer to logits data (device pointer).\n - logits_row_stride: Stride for each row in logits (int).\n - loss_ptr: Pointer to the loss output (device pointer).\n - logsumexp_ptr: Pointer to logsumexp output (device pointer).\n - labels_ptr: Pointer to label data (device pointer).\n - VOCAB_SIZE: Size of vocabulary (constexpr).\n - BLOCK_SIZE: Size of each block for Triton execution (constexpr).\n \"\"\"\n row_idx = tl.program_id(0)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx\n labels_ptr += row_idx\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = logsumexp - x\n else:\n loss = 0.0\n tl.store(logsumexp_ptr, logsumexp)\n tl.store(loss_ptr, loss)\npass\n\n@triton.jit\ndef _chunked_cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n N_CHUNKS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Triton kernel for forward cross-entropy computation on chunked logits.\n Arguments:\n - logits_ptr: Pointer to logits data (device pointer).\n - logits_row_stride: Stride for each row in logits (int).\n - loss_ptr: Pointer to the loss output (device pointer).\n - logsumexp_ptr: Pointer to logsumexp output (device pointer).\n - labels_ptr: Pointer to label data (device pointer).\n - VOCAB_SIZE: Size of vocabulary (constexpr).\n - N_CHUNKS: Number of chunks to divide the vocabulary (constexpr).\n - BLOCK_SIZE: Size of each block for Triton execution (constexpr).\n \"\"\"\n row_idx = tl.program_id(0)\n chunk_idx = tl.program_id(1)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n labels_ptr += row_idx\n\n col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if chunk_idx == 0:\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = -1.0 * x\n else:\n loss = 0.0\n tl.store(loss_ptr, loss)\n pass\n tl.store(logsumexp_ptr, logsumexp)\npass\n\n@triton.jit\ndef _cross_entropy_backward(\n logits_ptr, logits_row_stride,\n dloss_ptr, dloss_row_stride,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Triton kernel for backward pass of cross-entropy computation.\n Arguments:\n - logits_ptr: Pointer to logits data (device pointer).\n - logits_row_stride: Stride for each row in logits (int).\n - dloss_ptr: Pointer to gradient loss data (device pointer).\n - dloss_row_stride: Stride for each row in dloss (int).\n - logsumexp_ptr: Pointer to logsumexp output (device pointer).\n - labels_ptr: Pointer to label data (device pointer).\n - VOCAB_SIZE: Size of vocabulary (constexpr).\n - BLOCK_SIZE: Size of each block for Triton execution (constexpr).\n \"\"\"\n row_idx = tl.program_id(0)\n block_idx = tl.program_id(1)\n\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n dloss_ptr += row_idx * dloss_row_stride\n col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n if label_idx != -100:\n dloss = tl.load(dloss_ptr)\n else:\n dloss = 0.0\n\n x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n logsumexp = tl.load(logsumexp_ptr + row_idx)\n y = tl.exp(x - logsumexp)\n y = tl.where(\n col_offsets == label_idx,\n y - 1.0,\n y,\n )\n tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)\npass\n\nclass Fast_CrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(ctx, logits, labels):\n \"\"\"\n Perform forward pass for fast cross-entropy computation using Triton kernels.\n Arguments:\n - logits: Input logits tensor (torch.Tensor).\n - labels: Target labels tensor (torch.Tensor).\n Returns:\n - losses: Computed loss tensor (torch.Tensor).\n \"\"\"\n n_rows, vocab_size = logits.shape\n\n div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n n_chunks = div + (mod != 0)\n losses = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n if n_chunks == 1:\n BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n logsumexp = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n _cross_entropy_forward[(n_rows,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n else:\n logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda\")\n\n _chunked_cross_entropy_forward[(n_rows, n_chunks,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n N_CHUNKS = n_chunks,\n BLOCK_SIZE = MAX_FUSED_SIZE,\n num_warps = 32,\n )\n logsumexp = torch.logsumexp(logsumexp, dim = 1)\n losses += logsumexp\n losses.masked_fill_(labels == -100, 0)\n pass\n\n ctx.save_for_backward(logits, logsumexp, labels)\n return losses\n pass\n\n @staticmethod\n def backward(ctx, dlosses):\n \"\"\"\n Perform backward pass for fast cross-entropy computation using Triton kernels.\n Arguments:\n - dlosses: Gradient of the loss tensor (torch.Tensor).\n Returns:\n - gradients with respect to input logits.\n \"\"\"\n logits, logsumexp, labels = ctx.saved_tensors\n n_rows, vocab_size = logits.shape\n\n BLOCK_SIZE = 4096\n div, mod = divmod(vocab_size, BLOCK_SIZE)\n n_blocks = div + (mod != 0)\n\n _cross_entropy_backward[(n_rows, n_blocks,)](\n logits, logits.stride(0),\n dlosses, dlosses.stride(0),\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = 8,\n )\n return logits, None, None,\n pass\npass\n\ndef fast_cross_entropy_loss(logits, labels):\n \"\"\"\n Wrapper function for computing fast cross-entropy loss using the Fast_CrossEntropyLoss class.\n Arguments:\n - logits: Input logits tensor with shape (batch, seq_len, vocab_size).\n - labels: Target labels tensor with shape (batch, seq_len).\n Returns:\n - Normalized cross-entropy loss value (float).\n \"\"\"\n batch, seq_len, d = logits.shape\n assert(labels.shape == (batch, seq_len))\n\n loss = Fast_CrossEntropyLoss.apply(\n logits.view(batch*seq_len, d),\n labels.view(-1),\n )\n n_items = torch.count_nonzero(labels != -100)\n return loss.sum() / n_items\npass\n", - "description_1": "Use triton language to implement cross-entropy forward and backward kernels for computing normalized cross-entropy loss. The forward kernel, `_cross_entropy_forward`, takes pointers to logits, loss, logsumexp, and labels, and constexpr values for VOCAB_SIZE and BLOCK_SIZE. It computes the logsumexp for normalization and stores the computed loss. The `_chunked_cross_entropy_forward` kernel handles chunked computation for larger vocabularies. The backward kernel, `_cross_entropy_backward`, computes the gradient of the loss with respect to the logits and stores the results. A PyTorch custom autograd function class, `Fast_CrossEntropyLoss`, wraps these kernels, providing efficient forward and backward operations for fast cross-entropy loss computation. The `fast_cross_entropy_loss` function provides a convenient interface for users to compute loss, supporting both small and large vocabularies.", - "description_2": "Use triton language to implement optimized kernels for cross-entropy loss computation on GPU with chunked processing for large vocabularies. The kernels perform efficient forward normalization and backward gradient computation integrated into a PyTorch autograd function.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))\n # h = f * up\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)\n f_row = f_row.to(g_row.dtype) # Exact copy from HF\n h_row = f_row * g_row\n\n # Store h\n tl.store(h + offsets, h_row, mask=mask)\n\ndef geglu_exact_forward_kernel(gate, up):\n batch, seq_len, hd = gate.shape\n n_elements = gate.numel()\n out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE=1024)\n return out\n\n@triton.jit\ndef _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)\n f_row = f_partial_row * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n\n t = 0.3989422804014327 # 1/sqrt(2*pi)\n df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)\n\n de_row = dg_row.to(tl.float32) * df_de\n de_row = de_row.to(DW_row.dtype)\n\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\n\ndef geglu_exact_backward_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n\n@triton.jit\ndef _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n s = 0.7978845608028654 # math.sqrt(2 / math.pi)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_row = 0.5 * e_row * (\n tl.math.tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0\n )\n f_row = f_row.to(g_row.dtype) # Exact copy from HF\n h_row = f_row * g_row\n\n tl.store(h + offsets, h_row, mask=mask)\n\ndef geglu_approx_forward_kernel(gate, up):\n batch, seq_len, hd = gate.shape\n n_elements = gate.numel()\n out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE=1024)\n return out\n\n@triton.jit\ndef _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n s = 0.7978845608028654 # math.sqrt(2 / math.pi)\n a = s * e_row\n b = a * 0.044715 * e_row * e_row\n T = 1.0 + tl.math.tanh(a + b)\n T2 = 0.5 * T\n Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)\n df_de = T2 + Q2\n\n f_row = T2 * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n\n de_row = dg_row.to(tl.float32) * df_de\n de_row = de_row.to(DW_row.dtype)\n\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\n\ndef geglu_approx_backward_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n", - "description_1": "Use triton language to implement four kernels: _exact_forward_kernel, _exact_backward_kernel, _approx_forward_kernel, and _approx_backward_kernel. Each kernel takes five parameters: e, g, h (or DW), n_elements, and BLOCK_SIZE. The kernels perform element-wise operations on input tensors using Triton's parallel programming model. The forward kernels compute a transformation of the input tensor e, using either an exact or approximate method, and store the result in h. The backward kernels compute gradients for the input tensors e and g, using either an exact or approximate method, and store the results in DW, e, and g.", - "description_2": "Use triton language to implement exact and approximate forward and backward kernels for element-wise tensor operations with parallel execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom .utils import calculate_settings\n\n@triton.jit\ndef _rms_layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n n_cols, eps,\n BLOCK_SIZE : tl.constexpr\n):\n \"\"\"\n Fast RMS Layernorm kernel\n Inspiration from a Triton tutorial:\n https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0)\n\n row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n inv_var = tl.math.rsqrt(row_var + eps)\n tl.store(r, inv_var)\n normed = X_row * inv_var\n normed = normed.to(W_row.dtype)\n output = normed * W_row\n tl.store(Y + col_offsets, output, mask = mask)\n\n@triton.jit\ndef _gemma_rms_layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n n_cols, eps,\n BLOCK_SIZE : tl.constexpr,\n):\n # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31\n # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33\n # exactly. Essentially all in float32!\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n inv_var = 1.0 / tl.sqrt(row_var + eps)\n tl.store(r, inv_var)\n normed = X_row * inv_var\n output = normed * (W_row + 1.0)\n\n tl.store(Y + col_offsets, output, mask = mask)\n\n@triton.heuristics({\"GEMMA\": lambda args: args[\"GEMMA\"],})\n@triton.jit\ndef _rms_layernorm_backward(\n dY, dY_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n dW, dW_row_stride,\n n_cols, eps,\n GEMMA : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Fast RMS Layernorm kernel for the backward pass\n Inspiration from a Triton tutorial:\n https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dY += row_idx * dY_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n inv_var = tl.load(r).to(tl.float32)\n normed = X_row * inv_var\n\n if GEMMA: dY_W = dY_row * (W_row + 1.0)\n else: dY_W = dY_row * W_row\n\n rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)\n output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)\n tl.store(dY + col_offsets, output, mask = mask)\n\nclass Fast_RMS_Layernorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, X, W, eps, gemma = False):\n shape = X.shape\n dim = shape[-1]\n X = X.view(-1, dim)\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda\")\n r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward\n fx[(n_rows,)](\n Y, Y.stride(0),\n X, X.stride(0),\n W, W.stride(0),\n r, r.stride(0),\n n_cols, eps,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n ctx.eps = eps\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.GEMMA = gemma\n ctx.save_for_backward(X, W, r)\n return Y.view(*shape)\n\n @staticmethod\n def backward(ctx, dY):\n shape = dY.shape\n dim = shape[-1]\n dY = dY.view(-1, dim)\n X, W, r = ctx.saved_tensors\n n_rows, n_cols = dY.shape\n dW = X\n\n _rms_layernorm_backward[(n_rows,)](\n dY, dY.stride(0),\n X, X .stride(0),\n W, W .stride(0),\n r, r .stride(0),\n dW, dW.stride(0),\n n_cols, ctx.eps,\n GEMMA = ctx.GEMMA,\n BLOCK_SIZE = ctx.BLOCK_SIZE,\n num_warps = ctx.num_warps,\n )\n dX = dY.view(*shape)\n return dX, None, None, None\n\ndef fast_rms_layernorm(layernorm, X, gemma = False):\n W = layernorm.weight\n eps = layernorm.variance_epsilon\n out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)\n return out\n", - "description_1": "Use triton language to implement a fast RMS Layernorm kernel and its backward pass. The forward kernel (_rms_layernorm_forward) takes 10 parameters: output tensor Y, its row stride, input tensor X, its row stride, weight tensor W, its row stride, variance tensor r, its row stride, number of columns n_cols, and epsilon eps. It computes the layer normalization using block size BLOCK_SIZE. The backward kernel (_rms_layernorm_backward) takes 12 parameters: gradient tensor dY, its row stride, input tensor X, its row stride, weight tensor W, its row stride, variance tensor r, its row stride, gradient weight tensor dW, its row stride, number of columns n_cols, epsilon eps, GEMMA flag, and block size BLOCK_SIZE. It computes the gradient of the layer normalization. The Fast_RMS_Layernorm class provides a forward and backward method to apply these kernels, and the fast_rms_layernorm function is a utility to apply the Fast_RMS_Layernorm class.", - "description_2": "Use triton language to implement a fast RMS Layernorm kernel with forward and backward passes, and provide a utility function to apply it.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n # f = e * sigmoid(e)\n f_row = e_row * tl.sigmoid(e_row)\n f_row = f_row.to(g_row.dtype)\n # h = f * g\n h_row = f_row * g_row\n\n # Store h\n tl.store(h + offsets, h_row, mask=mask)\n\ndef swiglu_fg_kernel(e, g):\n batch, seq_len, hd = e.shape\n n_elements = e.numel()\n h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE=1024)\n return h\n\n@triton.jit\ndef _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n # e = e.float()\n # se = 1.0 / (1.0 + torch.exp(-e))\n se_row = tl.sigmoid(e_row)\n # f = (se * e).to(dtype)\n f_row = se_row * e_row\n f_row = f_row.to(DW_row.dtype)\n # h = f * g\n h_row = f_row * g_row\n # df = DW * f\n df_row = DW_row * f_row\n # dg = DW * g\n dg_row = DW_row * g_row\n # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)\n de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))\n de_row = de_row.to(DW_row.dtype)\n\n # Store derivatives in buffers\n tl.store(DW + offsets, h_row, mask=mask) # h = f * g\n tl.store(e + offsets, df_row, mask=mask) # df = DW * f\n tl.store(g + offsets, de_row, mask=mask) # de\n\ndef swiglu_DWf_DW_dfg_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n", - "description_1": "Use triton language to define two kernels and corresponding invocation functions. The first kernel '_fg_kernel' computes a transformation involving element-wise operations on inputs 'e' and 'g', storing the result in 'h'. The function 'swiglu_fg_kernel' calls this kernel, computing the grid based on input shape. The second kernel '_DWf_DW_dfg_kernel' performs derivative computations for inputs 'DW', 'e', 'g', and updates them in place. The function 'swiglu_DWf_DW_dfg_kernel' invokes this kernel similarly.", - "description_2": "Use triton language to perform element-wise transformations and derivative computations on CUDA tensors using defined kernels and their invocations.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef leaky_relu(x):\n x = x + 1\n return tl.where(x >= 0, x, 0.01 * x)\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation\n )\n return c\n\ntorch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\ntriton_output = matmul(a, b)\ntorch_output = torch.matmul(a, b)\nprint(f\"triton_output={triton_output}\")\nprint(f\"torch_output={torch_output}\")\nif torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):\n print(\"✅ Triton and Torch match\")\nelse:\n print(\"❌ Triton and Torch differ\")\n", - "description_1": "Use triton language to create a matrix multiplication kernel `matmul_kernel` that computes C = A x B, with optional leaky_relu activation. The kernel accepts pointers to matrices A, B, and C, matrix dimensions M, N, K, stride values for each matrix, and several block size meta-parameters. A function `matmul` is provided to call this kernel with appropriate grid dimensions, checking input matrix compatibility.", - "description_2": "Use triton language to implement a matrix multiplication with leaky_relu support, leveraging optimized block and grid strategies.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr, d_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_dm, stride_dn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n \n offs_dm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_dn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n d_ptrs = d_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]\n d_mask = (offs_dm[:, None] < M) & (offs_dn[None, :] < N)\n d = tl.load(d_ptrs, mask=d_mask, other=0.0)\n\n c = accumulator\n c += d\n\n if ACTIVATION == \"relu\":\n c = relu(c)\n\n c = c.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef relu(x):\n return tl.where(x >= 0, x, 0.0)\n\ndef triton_addmm(a, b, d, activation=\"None\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert d.is_contiguous(), \"Matrix D must be contiguous\"\n\n M, K = a.shape\n K, N = b.shape\n\n if len(d.shape) == 1:\n d_stride_0 = 0\n d_stride_1 = d.stride(0)\n elif len(d.shape) == 2:\n d_stride_0 = d.stride(0)\n d_stride_1 = d.stride(1)\n\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c, d,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n d_stride_0, d_stride_1,\n ACTIVATION=activation\n )\n return c\n", - "description_1": "Use triton language to define a matrix multiplication kernel (matmul_kernel) and a ReLU function (relu). The matmul_kernel function takes 21 arguments: a_ptr, b_ptr, c_ptr, d_ptr (pointers to matrices A, B, C, D), M, N, K (dimensions of matrices), stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_dm, stride_dn (stride variables for memory access), and BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, ACTIVATION (compile-time constants). The relu function takes 1 argument: x (a tensor) and performs ReLU operation. Use triton_addmm function to compute matrix multiplication of a (MxK) and b (KxN), adding d (MxN) with optional activation; it takes 4 arguments: a, b, d (torch tensors), and activation (string).", - "description_2": "Use triton language to implement matrix multiplication with support for optional ReLU activation, optimized for performance using Triton's autotuning and block-level parallelism.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef addmm_sigmoid_kernel(\n a_ptr, b_ptr, c_ptr, sigmoid_ptr, \n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn, stride_sigmoidm, stride_sigmoidn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, \n GROUP_SIZE_M: tl.constexpr,\n SIGMOID_TYPE: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c = tl.load(c_ptrs, mask=c_mask, other=0.0)\n\n add = accumulator + c\n sigmoid = tl.sigmoid(add)\n\n if SIGMOID_TYPE == \"float16\":\n sigmoid = sigmoid.to(tl.float16)\n offs_sigmoidm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_sigmoidn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n sigmoid_ptrs = sigmoid_ptr + stride_sigmoidm * offs_sigmoidm[:, None] + stride_sigmoidn * offs_sigmoidn[None, :]\n sigmoid_mask = (offs_sigmoidm[:, None] < M) & (offs_sigmoidn[None, :] < N)\n tl.store(sigmoid_ptrs, sigmoid, mask=sigmoid_mask)\n\n\ndef triton_addmm_sigmoid(a, b, c):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n\n c_stride_0, c_stride_1 = c.stride(0), c.stride(1) if len(c.shape) == 2 else (0, c.stride(0))\n sigmoid = torch.empty((M, N), device=a.device, dtype=a.dtype)\n \n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n addmm_sigmoid_kernel[grid](\n a, b, c, sigmoid,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c_stride_0, c_stride_1,\n sigmoid.stride(0), sigmoid.stride(1),\n SIGMOID_TYPE=str(a.dtype).split('.')[-1]\n )\n return sigmoid\n", - "description_1": "Use triton language to implement a kernel function 'addmm_sigmoid_kernel' that performs matrix multiplication on matrices A and B, adds matrix C, and then applies a sigmoid operation. The kernel accepts pointers to matrices and their dimensions, strides, and meta-parameters for block sizes and group size. It calculates each block of matrix C by mapping program IDs to specific computation blocks, iteratively loading blocks of A and B, performing dot product accumulation, and storing the result with applied sigmoid operation. The wrapper function 'triton_addmm_sigmoid' manages memory allocations and kernel execution with appropriate grid configuration.", - "description_2": "Use triton language to implement a matrix multiplication followed by addition and sigmoid using 'addmm_sigmoid_kernel'. Manage kernel execution and memory with 'triton_addmm_sigmoid'.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef mm_dp_relu_bp_kernel(\n a_ptr, b_ptr, c_ptr, d_ptr, mul_2_ptr, \n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_dm, stride_dn, stride_mul_2m, stride_mul_2n,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_2_TYPE: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c = tl.load(c_ptrs, mask=c_mask, other=0.0)\n\n offs_dm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_dn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n d_ptrs = d_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]\n d_mask = (offs_dm[:, None] < M) & (offs_dn[None, :] < N)\n d = tl.load(d_ptrs, mask=d_mask, other=0.0)\n\n mul = accumulator * c\n mul_1 = mul * 1.0\n ne = d != 0\n mul_2 = mul_1 * ne\n\n if MUL_2_TYPE == \"float16\":\n mul_2 = mul_2.to(tl.float16)\n offs_mul_2m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_mul_2n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n mul_2_ptrs = mul_2_ptr + stride_mul_2m * offs_mul_2m[:, None] + stride_mul_2n * offs_mul_2n[None, :]\n mul_2_mask = (offs_mul_2m[:, None] < M) & (offs_mul_2n[None, :] < N)\n tl.store(mul_2_ptrs, mul_2, mask=mul_2_mask)\n\n\ndef triton_mm_dp_relu_bp(a, b, c, d):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n\n c_stride_0, c_stride_1 = (c.stride(0), c.stride(1)) if len(c.shape) == 2 else (0, c.stride(0)) if len(c.shape) == 1 else (0, 0)\n d_stride_0, d_stride_1 = (d.stride(0), d.stride(1)) if len(d.shape) == 2 else (0, d.stride(0)) if len(d.shape) == 1 else (0, 0)\n mul_2 = torch.empty((M, N), device=a.device, dtype=a.dtype)\n \n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n mm_dp_relu_bp_kernel[grid](\n a, b, c, d, mul_2,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c_stride_0, c_stride_1,\n d_stride_0, d_stride_1,\n mul_2.stride(0), mul_2.stride(1),\n MUL_2_TYPE=str(a.dtype).split('.')[-1]\n )\n return mul_2\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with additional ReLU-like backpropagation features. The kernel has 18 parameters: 5 pointers to input and output matrices, 3 integers representing matrix dimensions (M, N, K), 10 integers for stride information of each matrix, and 4 meta-parameters defining block and group sizes and type configuration. The wrapper function 'triton_mm_dp_relu_bp' prepares and calls this kernel using PyTorch tensors as inputs.", - "description_2": "Use triton language to implement a matrix multiplication with ReLU backpropagation in a CUDA environment, optimizing memory access patterns and computational efficiency.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef xmlcnn_loss_kernel(\n # Pointers to matrices\n a_ptr, b_ptr, c_ptr, d_ptr, e_ptr, mul_2_ptr, \n # Matrix dimensions\n M, N, K,\n # Stride variables\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_dm, stride_dn,\n stride_em, stride_en,\n stride_mul_2m, stride_mul_2n,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, \n GROUP_SIZE_M: tl.constexpr,\n MUL_2_TYPE: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n c = tl.load(c_ptrs, mask=c_mask, other=0.0)\n\n offs_dm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_dn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n d_ptrs = d_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]\n d_mask = (offs_dm[:, None] < M) & (offs_dn[None, :] < N)\n d = tl.load(d_ptrs, mask=d_mask, other=0.0)\n\n offs_em = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_en = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n e_ptrs = e_ptr + stride_em * offs_em[:, None] + stride_en * offs_en[None, :]\n e_mask = (offs_em[:, None] < M) & (offs_en[None, :] < N)\n e = tl.load(e_ptrs, mask=e_mask, other=0.0)\n\n add = accumulator + c\n sigmoid = tl.sigmoid(add)\n mul = d * -1\n add_1 = sigmoid + mul\n mul_1 = e * 0.0078125\n mul_2 = add_1 * mul_1\n\n if MUL_2_TYPE == \"float16\":\n mul_2 = mul_2.to(tl.float16)\n offs_mul_2m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_mul_2n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n mul_2_ptrs = mul_2_ptr + stride_mul_2m * offs_mul_2m[:, None] + stride_mul_2n * offs_mul_2n[None, :]\n mul_2_mask = (offs_mul_2m[:, None] < M) & (offs_mul_2n[None, :] < N)\n tl.store(mul_2_ptrs, mul_2, mask=mul_2_mask)\n\n\ndef triton_xmlcnn_loss(a, b, c, d, e):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n\n c_stride_0, c_stride_1 = (c.stride(0), c.stride(1)) if len(c.shape) == 2 else (0, c.stride(0)) if len(c.shape) == 1 else (0, 0)\n d_stride_0, d_stride_1 = (d.stride(0), d.stride(1)) if len(d.shape) == 2 else (0, d.stride(0)) if len(d.shape) == 1 else (0, 0)\n e_stride_0, e_stride_1 = (e.stride(0), e.stride(1)) if len(e.shape) == 2 else (0, e.stride(0)) if len(e.shape) == 1 else (0, 0)\n mul_2 = torch.empty((M, N), device=a.device, dtype=a.dtype)\n \n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n xmlcnn_loss_kernel[grid](\n a, b, c, d, e, mul_2,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c_stride_0, c_stride_1,\n d_stride_0, d_stride_1,\n e_stride_0, e_stride_1,\n mul_2.stride(0), mul_2.stride(1),\n MUL_2_TYPE=str(a.dtype).split('.')[-1]\n )\n return mul_2\n", - "description_1": "Use triton language to create a kernel (xmlcnn_loss_kernel) and its calling function (triton_xmlcnn_loss) for performing a matrix multiplication operation followed by additional element-wise operations on result matrices. The kernel function has 17 parameters: pointers to input and output matrices, matrix dimensions, stride values for each matrix, and meta-parameters for block and group sizes in the grid. The calling function (triton_xmlcnn_loss) initializes grid dimensions, prepares arguments, and invokes the kernel for execution.", - "description_2": "Use triton language to perform matrix multiplication and element-wise operations using a kernel with 17 parameters for matrix pointers, dimensions, strides, and meta-values, and a calling function to execute the kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n LAST_K_BLOCK: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n BLOCK_N: tl.constexpr,\n D_HEAD: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +\n k_block_col_idx * layout_col_stride_m).to(tl.int32)\n start_n = k_block_id * BLOCK_N\n if LAST_K_BLOCK:\n if EVEN_D:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=offs_n[None, :] + start_n < k_seqlen,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=(offs_n[None, :] + start_n < k_seqlen) &\n (offs_d[:, None] < D_HEAD),\n )\n else:\n if EVEN_D:\n k = tl.load(k_ptrs + start_n * stride_kt)\n else:\n k = tl.load(k_ptrs + start_n * stride_kt,\n mask=offs_d[:, None] < D_HEAD)\n\n qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n if LAST_K_BLOCK | M_LT_N:\n qk += tl.where(\n offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),\n 0,\n float(\"-inf\"),\n )\n\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n p = tl.math.exp2(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n m_i = m_ij\n l_i = l_i * alpha + l_ij\n\n p = p.to(Q.dtype.element_ty)\n if LAST_K_BLOCK:\n if EVEN_D:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=offs_n[:, None] + start_n < k_seqlen,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=(offs_n[:, None] + start_n < k_seqlen) &\n (offs_d[None, :] < D_HEAD),\n )\n else:\n if EVEN_D:\n v = tl.load(v_ptrs + start_n * stride_vt)\n else:\n v = tl.load(v_ptrs + start_n * stride_vt,\n mask=offs_d[None, :] < D_HEAD)\n\n acc += tl.dot(p, v)\n\n return acc, l_i, m_i\n\n\n@triton.heuristics({\n \"M_LT_N\":\n lambda kwargs: kwargs[\"BLOCK_M\"] < kwargs[\"BLOCK_N\"],\n})\n@triton.jit\ndef _fwd_kernel_batch_inference(\n Q,\n K,\n V,\n Out,\n sm_scale,\n q_batch_starts,\n q_batch_ends,\n k_batch_starts,\n k_batch_ends,\n q_batch_ids,\n q_start_sids,\n stride_qb,\n stride_qt,\n stride_qh,\n stride_qd,\n stride_kb,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_vb,\n stride_vt,\n stride_vh,\n stride_vd,\n stride_ob,\n stride_ot,\n stride_oh,\n stride_od,\n layout_crow_ptr,\n layout_col_ptr,\n layout_crow_stride_h,\n layout_crow_stride_m,\n layout_col_stride_h,\n layout_col_stride_m,\n q_k_ratio,\n HAS_BATCH_DIM: tl.constexpr,\n D_HEAD: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n off_zm = tl.program_id(0)\n off_h = tl.program_id(1)\n\n off_h_for_kv = off_h // q_k_ratio\n\n if HAS_BATCH_DIM:\n off_z = tl.program_id(2)\n Q += off_z * stride_qb\n K += off_z * stride_kb\n V += off_z * stride_vb\n Out += off_z * stride_ob\n start_m = off_zm\n q_start_sid = start_m * BLOCK_M\n else:\n off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)\n q_start_sid = tl.load(q_start_sids + off_zm)\n start_m = q_start_sid // BLOCK_M\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_D)\n\n q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)\n q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start\n k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)\n k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start\n past_len = k_seqlen - q_seqlen\n\n Q += q_cu_start * stride_qt + off_h * stride_qh\n K += k_cu_start * stride_kt + off_h_for_kv * stride_kh\n V += k_cu_start * stride_vt + off_h_for_kv * stride_vh\n Out += q_cu_start * stride_ot + off_h * stride_oh\n\n q_pbid = (past_len + q_start_sid) // BLOCK_M\n\n if EVEN_D:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n other=0,\n )\n\n sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +\n q_pbid * layout_crow_stride_m)\n\n k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)\n k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)\n\n m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)\n\n k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd\n v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd\n\n sm_scale *= (\n 1.44269504 \n )\n\n for k_block_col_idx in range(k_block_start, k_block_end - 1):\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n False,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_end - 1,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n True,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n\n if EVEN_D:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n )\n", - "description_1": "Use triton language to implement two kernels: '_fwd_kernel_inner' with 21 parameters for performing the core attention computation in a single block and '_fwd_kernel_batch_inference' with 54 parameters for processing a batch of queries and keys/values, taking care of loading necessary data and handling block-level operations.", - "description_2": "Use triton language to implement an attention mechanism with two kernels to handle core attention block computation and batch processing for queries and keys/values.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom vllm.platforms import current_platform\n\nif triton.__version__ >= \"2.1.0\":\n\n # Triton kernel for forward attention with no padding\n @triton.jit\n def _fwd_kernel(\n Q, K, V, K_cache, V_cache, B_Loc, sm_scale, k_scale, v_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, Out,\n stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs,\n stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d,\n stride_k_cache_bl, stride_k_cache_x, stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl,\n num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr):\n # Kernel implementation details are omitted for brevity...\n\n # Triton kernel for forward attention with alibi bias\n @triton.jit\n def _fwd_kernel_alibi(\n Q, K, V, K_cache, V_cache, B_Loc, sm_scale, k_scale, v_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, Alibi_slopes,\n block_size, x, Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh,\n stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_k_cache_bs, stride_k_cache_h,\n stride_k_cache_d, stride_k_cache_bl, stride_k_cache_x, stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d,\n stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr, BLOCK_N: tl.constexpr):\n # Kernel implementation details are omitted for brevity...\n\n # Function to execute context attention forward pass\n @torch.inference_mode()\n def context_attention_fwd(\n q, k, v, o, kv_cache_dtype: str, k_cache, v_cache, b_loc, b_start_loc, b_seq_len, b_ctx_len, max_input_len,\n k_scale: float = 1.0, v_scale: float = 1.0, alibi_slopes=None, sliding_window=None):\n # Determine device capability and configure block size and number of warps\n cap = current_platform.get_device_capability()\n BLOCK = 128 if cap[0] >= 8 else 64\n NUM_WARPS = 8\n if q.dtype is torch.float32:\n BLOCK = BLOCK // 2\n\n # Handling of FP8 tensor conversion\n if \"fp8\" in kv_cache_dtype:\n target_dtype = torch.float8_e4m3fn if kv_cache_dtype in (\"fp8\", \"fp8_e4m3\") else torch.float8_e5m2\n k_cache = k_cache.view(target_dtype)\n v_cache = v_cache.view(target_dtype)\n\n # Define scales and prepare for grid launch\n sm_scale = 1.0 / (q.shape[-1]**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n # Launch appropriate Triton kernel\n if alibi_slopes is not None:\n _fwd_kernel_alibi[grid](\n q, k, v, k_cache, v_cache, b_loc, sm_scale, k_scale, v_scale, b_start_loc, b_seq_len, b_ctx_len,\n alibi_slopes, v_cache.shape[3], k_cache.shape[4], o, b_loc.stride(0), b_loc.stride(1), q.stride(0),\n q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2), k_cache.stride(0), k_cache.stride(1), k_cache.stride(2),\n k_cache.stride(3), k_cache.stride(4), v_cache.stride(0), v_cache.stride(1), v_cache.stride(2),\n v_cache.stride(3), num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=k.shape[-1],\n BLOCK_DMODEL_PADDED=triton.next_power_of_2(k.shape[-1]), BLOCK_N=BLOCK, num_warps=NUM_WARPS, num_stages=1)\n else:\n _fwd_kernel[grid](\n q, k, v, k_cache, v_cache, b_loc, sm_scale, k_scale, v_scale, b_start_loc, b_seq_len, b_ctx_len,\n v_cache.shape[3], k_cache.shape[4], o, b_loc.stride(0), b_loc.stride(1), q.stride(0), q.stride(1),\n q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0),\n o.stride(1), o.stride(2), k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), k_cache.stride(3),\n k_cache.stride(4), v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), v_cache.stride(3),\n num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=k.shape[-1],\n BLOCK_DMODEL_PADDED=triton.next_power_of_2(k.shape[-1]), BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window,\n num_warps=NUM_WARPS, num_stages=1)\n return\n", - "description_1": "Use triton language to implement forward attention kernels with optional alibi bias. The kernels process input queries, keys, and values to compute output based on scaled dot-product attention. They support sliding window masking and can handle FP8 tensor formats. The kernel launch configuration is determined based on device capabilities.", - "description_2": "Use triton language to create a forward attention mechanism with optional alibi bias, supporting sliding window and FP8 data format.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ntorch_dtype: tl.constexpr = torch.float16\n\n@triton.jit\ndef cdiv_fn(x, y):\n return (x + y - 1) // y\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n if first and second:\n tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)\n elif first:\n tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)\n elif second:\n tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)\n else:\n tensor = tl.load(block_ptr)\n return tensor\n\n@triton.jit\ndef _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n actual_seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n OFFS_M: tl.constexpr,\n OFFS_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n MASK_STEPS: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n):\n for start_n in range(block_min, block_max, BLOCK_N):\n k = load_fn(\n K_block_ptr,\n PADDED_HEAD,\n MASK_STEPS and (n_extra_tokens != 0),\n \"zero\",\n )\n if PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MASK_STEPS:\n if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n boundary_m = tl.full([BLOCK_M],\n actual_seqlen_k,\n dtype=tl.int32)\n size_n = start_n + OFFS_N[None, :]\n mask = size_n < boundary_m[:, None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if IS_CAUSAL:\n causal_boundary = start_n + offs_n_causal\n causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n qk = tl.where(causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n if bias_ptr is not None:\n bias = load_fn(bias_ptr, False, MASK_STEPS\n and (n_extra_tokens != 0), \"zero\")\n qk += bias * 1.44269504089\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = (batch_philox_offset +\n start_m * BLOCK_M * actual_seqlen_k + start_n -\n BLOCK_N)\n keep = dropout_mask(\n philox_seed,\n philox_offset,\n dropout_p,\n BLOCK_M,\n BLOCK_N,\n actual_seqlen_k,\n )\n if RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n tl.where(keep, p,\n -p).to(encoded_softmax_block_ptr.type.element_ty),\n )\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n p.to(encoded_softmax_block_ptr.type.element_ty),\n )\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": True,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 64,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 32,\n \"BLOCK_N\": 32,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 16,\n \"BLOCK_N\": 16,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n ],\n key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],\n)\n@triton.jit\ndef attn_fwd(\n Q,\n K,\n V,\n bias,\n sm_scale,\n L,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n stride_bz,\n stride_bh,\n stride_bm,\n stride_bn,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n HQ: tl.constexpr,\n HK: tl.constexpr,\n ACTUAL_BLOCK_DMODEL: tl.constexpr,\n MAX_SEQLENS_Q: tl.constexpr,\n MAX_SEQLENS_K: tl.constexpr,\n VARLEN: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_h_q = tl.program_id(1)\n off_z = tl.program_id(2)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n if VARLEN:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M > seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n else:\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = MAX_SEQLENS_Q\n seqlen_k = MAX_SEQLENS_K\n\n n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n if IS_CAUSAL:\n n_blocks_seqlen = cdiv_fn(\n (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)\n n_blocks = min(n_blocks, n_blocks_seqlen)\n if n_blocks <= 0:\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n return\n\n GROUP_SIZE: tl.constexpr = HQ // HK\n off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q\n\n n_extra_tokens = 0\n if seqlen_k < BLOCK_N:\n n_extra_tokens = BLOCK_N - seqlen_k\n elif seqlen_k % BLOCK_N:\n n_extra_tokens = seqlen_k % BLOCK_N\n padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL\n\n q_offset = (off_z * stride_qz + off_h_q * stride_qh +\n cu_seqlens_q_start * stride_qm)\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_offset = (off_z * stride_kz + off_h_k * stride_kh +\n cu_seqlens_k_start * stride_kn)\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n v_offset = (off_z * stride_vz + off_h_k * stride_vh +\n cu_seqlens_k_start * stride_vk)\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n if BIAS_TYPE != 0:\n bias_ptr = tl.make_block_ptr(\n base=bias + off_h_q * stride_bh,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_bm, stride_bn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n bias_ptr = None\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base \\\n + (off_z * HQ + off_h_q) \\\n * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n encoded_softmax_block_ptr = 0\n m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504089\n q = load_fn(Q_block_ptr, True, padded_head, \"zero\")\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n padded_block_k = n_extra_tokens != 0\n is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n if IS_CAUSAL:\n masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n else:\n masked_blocks = padded_block_k\n masked_blocks = min(masked_blocks, n_blocks)\n n_full_blocks = n_blocks - masked_blocks\n block_min = 0\n block_max = n_blocks * BLOCK_N\n if n_full_blocks > 0:\n block_max = (n_blocks - masked_blocks) * BLOCK_N\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n 0,\n 0,\n 0,\n bias_ptr,\n False,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n False,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n block_min = block_max\n block_max = n_blocks * BLOCK_N\n\n tl.debug_barrier()\n if masked_blocks > 0:\n offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0\n K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, n_full_blocks))\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n True,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n end_m_idx = (start_m + 1) * BLOCK_M\n start_m_idx = start_m * BLOCK_M\n causal_start_idx = seqlen_q - seqlen_k\n acc = acc.to(Out.type.element_ty)\n if IS_CAUSAL:\n if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n out_mask_boundary = tl.full((BLOCK_DMODEL, ),\n causal_start_idx,\n dtype=tl.int32)\n mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n out_ptrs_mask = (mask_m_offsets[:, None] >=\n out_mask_boundary[None, :])\n z = 0.0\n acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n tl.store(O_block_ptr, acc, boundary_check=(0, 1))\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n q,\n k,\n v,\n o,\n cu_seqlens_q,\n cu_seqlens_k,\n max_seqlens_q,\n max_seqlens_k,\n causal=False,\n sm_scale=1.0,\n bias=None,\n ):\n if o is None:\n o = torch.empty_like(q, dtype=v.dtype)\n\n check_args(\n q,\n k,\n v,\n o,\n varlen=True,\n cu_seqlens_q=cu_seqlens_q,\n cu_seqlens_k=cu_seqlens_k,\n )\n if True:\n total_q, nheads_q, head_size = q.shape\n total_k, nheads_k, _ = k.shape\n batch = len(cu_seqlens_q) - 1\n q_strides = (0, q.stride(1), q.stride(0), q.stride(2))\n k_strides = (0, k.stride(1), k.stride(0), k.stride(2))\n v_strides = (0, v.stride(1), v.stride(0), v.stride(2))\n o_strides = (0, o.stride(1), o.stride(0), o.stride(2))\n else:\n batch, seqlen_q, nheads_q, head_size = q.shape\n _, seqlen_k, nheads_k, _ = k.shape\n q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))\n k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))\n v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))\n o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))\n\n unpadded_head_dims = {32, 64, 128, 256}\n if head_size not in unpadded_head_dims:\n padded_d_model = None\n for i in unpadded_head_dims:\n if i > head_size:\n padded_d_model = i\n break\n assert padded_d_model is not None\n else:\n padded_d_model = head_size\n\n grid = lambda META: (\n triton.cdiv(max_seqlens_q, META[\"BLOCK_M\"]),\n nheads_q,\n batch,\n )\n\n encoded_softmax = None\n\n philox_seed = 0x1BF52\n philox_offset = 0x1D4B42\n\n if bias is not None:\n bias_strides = (\n bias.stride(0),\n bias.stride(1),\n bias.stride(2),\n bias.stride(3),\n )\n else:\n bias_strides = (0, 0, 0, 0)\n\n attn_fwd[grid](\n q,\n k,\n v,\n bias,\n sm_scale,\n None,\n o,\n *q_strides,\n *k_strides,\n *v_strides,\n *o_strides,\n *bias_strides,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p=0.0,\n philox_seed=philox_seed,\n philox_offset_base=philox_offset,\n encoded_softmax=encoded_softmax,\n HQ=nheads_q,\n HK=nheads_k,\n ACTUAL_BLOCK_DMODEL=head_size,\n MAX_SEQLENS_Q=max_seqlens_q,\n MAX_SEQLENS_K=max_seqlens_k,\n IS_CAUSAL=causal,\n VARLEN=True,\n BLOCK_DMODEL=padded_d_model,\n BIAS_TYPE=0 if bias is None else 1,\n ENABLE_DROPOUT=False,\n RETURN_ENCODED_SOFTMAX=False,\n )\n\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = head_size\n ctx.causal = causal\n ctx.dropout_p = 0.0\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n ctx.encoded_softmax = encoded_softmax\n ctx.return_encoded_softmax = False\n return o, encoded_softmax\n\ntriton_attention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention kernel for handling multiple sequence heads. It supports variable sequence lengths, dropout, and bias for attention calculations. This involves using block-wise operations to optimize memory access and compute efficiency.", - "description_2": "Use triton language to implement attention with dropout and bias support in a fused kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test that we can convert uniform to exponential without div by 0.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel that converts uniform random numbers to exponential random numbers. The kernel '_uniform_to_exponential_kernel' takes three parameters: 'input' (a pointer to the input tensor), 'output' (a pointer to the output tensor), and 'n' (a constant expression representing the number of elements to process). The kernel uses Triton's parallel programming model to load data from the input tensor, apply the '_uniform_to_exponential' transformation, and store the results in the output tensor. The function 'test_uniform_to_exponential' tests this kernel by creating a tensor of uniform random numbers, applying the kernel, and verifying that the output is finite and greater than zero.", - "description_2": "Use triton language to create a kernel that transforms uniform random numbers to exponential random numbers and test its correctness.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _bgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n scaling,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's\n performance\n \"\"\"\n pid_sk = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n\n offset_n = tl.arange(0, BLOCK_N)\n offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K\n a_ptr = input_ptr + cur_batch * xm_stride\n b_ptr = lora_ptr + l0_stride * lora_index\n accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)\n for k in range(0, K, BLOCK_K * SPLIT_K):\n current_k = k + offset_k\n current_k_c = tl.max_contiguous(current_k, BLOCK_K)\n tiled_a = tl.load(\n a_ptr + current_k_c,\n mask=current_k < K,\n other=0.0,\n ) # [BLOCK_K]\n b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)\n\n tiled_b = tl.load(\n b_ptr + offset_n[:, None] * lora_k_stride +\n current_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n ) # [BLOCK_N,BLOCK_K]\n\n accumulator += tl.sum(tiled_a * tiled_b, 1)\n accumulator *= scaling\n offset_cn = tl.arange(0, BLOCK_N)\n c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride\n c_mask = offset_cn < N\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _bgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n scaling: float = 1.0,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n batches = lora_indices_tensor.size(0)\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_N = triton.next_power_of_2(N)\n # First try to load optimal config from the file\n config = get_lora_op_configs(\"bgmv_shrink\", batches, K)\n\n grid = lambda META: (\n META[\"SPLIT_K\"],\n batches,\n )\n _bgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_N=BLOCK_N,\n **config,\n )\n return\n", - "description_1": "Use triton language to implement a kernel (_bgmv_shrink_kernel) with 15 parameters for efficiently handling matrix-vector operations with LoRA. It uses parallel processing, vectorization, and custom memory access patterns to optimize performance. The function _bgmv_shrink wraps this kernel, managing tensor inputs, asserting conditions, calculating grid dimensions, and executing the kernel with required arguments. It has 5 parameters for the input tensors, LoRA weights, output tensor, index tensor, and scaling factor.", - "description_2": "Use triton language to write a kernel that performs matrix operations with parallel processing. Utilize a wrapper function to handle inputs and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n The sgmv's expand triton kernel is based on GroupGEMM.\n \"\"\"\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n add_inputs: bool = False,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4, 10].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n add_inputs (bool, optional): Defaults to False. adds the final lora \n results to the output.\n \"\"\"\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n # TODO tuning this config\n\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n\ntry:\n sgmv_expand = torch.library.custom_op(\"lora::sgmv_expand\",\n _sgmv_expand,\n mutates_args=[\"output_tensor\"])\nexcept AttributeError:\n sgmv_expand = _sgmv_expand\n", - "description_1": "Use triton language to implement a sparse General Matrix-Vector (SGMV) expansion kernel function, '_sgmv_expand_kernel', that operates on input pointers and LoRA weights. The kernel requires 21 parameters, including pointers to input, LoRA weights, and output, along with various strides and constants like BLOCK sizes for optimal memory access. Additionally, implement a Python wrapper function '_sgmv_expand' to handle Torch tensors and manage GPU grid configurations, with parameters like input tensors, batch sizes, and flags for additional inputs.", - "description_2": "Use triton language to implement a sparse General Matrix-Vector multiplication kernel and a Python wrapper for Torch tensor operations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n scaling,\n xm_stride, # hidden_size\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.\n The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,\n introducing SPLIT-K can improve performance\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sk = tl.program_id(axis=1)\n cur_batch = tl.program_id(axis=2)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)\n\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride)\n b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +\n offset_k[:, None] * lora_n_stride)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < k_remaining,\n other=0.0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < k_remaining,\n other=0.0)\n accumulator += tl.dot(tiled_a, tiled_b)\n\n a_ptr += BLOCK_K * SPLIT_K * xk_stride\n b_ptr += BLOCK_K * SPLIT_K * lora_n_stride\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n accumulator *= scaling\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n scaling: float,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_M = 32\n BLOCK_N = 16\n BLOCK_K = 32\n SPLIT_K = 8\n EVEN_K = K % (BLOCK_K * SPLIT_K) == 0\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n SPLIT_K,\n batches,\n )\n\n _sgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n SPLIT_K,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_sgmv_shrink_kernel' with 22 parameters for matrix operations with LoRA weights, and a wrapper function '_sgmv_shrink' with 9 parameters to prepare and invoke the kernel.", - "description_2": "Use triton language to implement a kernel for matrix operations with LoRA weights and a wrapper to invoke it.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr,\n sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk,\n stride_bn, stride_cm, stride_cn, stride_bse, stride_bsn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,\n compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr,\n use_int8_w8a16: tl.constexpr):\n \"\"\"\n Implements the fused computation for a Mixture of Experts (MOE) using\n token and expert matrices.\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n if use_int8_w8a16:\n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[\n None, :] * stride_bsn\n b_scale = tl.load(b_scale_ptrs)\n\n if use_fp8_w8a8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n if use_int8_w8a16:\n accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)\n elif use_fp8_w8a8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n if use_int8_w8a16:\n accumulator = (accumulator * b_scale).to(compute_type)\n elif use_fp8_w8a8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any], compute_type: tl.dtype,\n use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if use_fp8_w8a8:\n A, A_scale = ops.scaled_fp8_quant(A, A_scale)\n assert B_scale is not None\n elif use_int8_w8a16:\n assert B_scale is not None\n else:\n assert A_scale is None\n assert B_scale is None\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,\n B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8_w8a8=use_fp8_w8a8,\n use_int8_w8a16=use_int8_w8a16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel function 'fused_moe_kernel' takes 28 parameters including pointers to input matrices, matrix dimensions, stride variables, and meta-parameters. It performs block matrix multiplication using token and expert matrices, with optional scaling and routing weights. The 'invoke_fused_moe_kernel' function calls this kernel with 15 parameters, setting up the grid and handling optional quantization.", - "description_2": "Use triton language to implement a fused MoE kernel for block matrix multiplication with optional scaling and routing weights.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n\n @triton.jit\n def softplus(dt):\n # Apply softplus to the input\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n\n @triton.jit\n def softplus(dt):\n # Apply softplus to the input\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr,\n x_ptr,\n dt_ptr,\n dt_bias_ptr,\n A_ptr,\n B_ptr,\n C_ptr,\n D_ptr,\n z_ptr,\n out_ptr,\n batch,\n nheads,\n dim,\n dstate,\n nheads_ngroups_ratio,\n stride_state_batch,\n stride_state_head,\n stride_state_dim,\n stride_state_dstate,\n stride_x_batch,\n stride_x_head,\n stride_x_dim,\n stride_dt_batch,\n stride_dt_head,\n stride_dt_dim,\n stride_dt_bias_head,\n stride_dt_bias_dim,\n stride_A_head,\n stride_A_dim,\n stride_A_dstate,\n stride_B_batch,\n stride_B_group,\n stride_B_dstate,\n stride_C_batch,\n stride_C_group,\n stride_C_dstate,\n stride_D_head,\n stride_D_dim,\n stride_z_batch,\n stride_z_head,\n stride_z_dim,\n stride_out_batch,\n stride_out_head,\n stride_out_dim,\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n # Kernel for selective state update\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h //\n nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h //\n nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +\n offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +\n offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,\n other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt) # scalar, not a matrix\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs,\n state,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state,\n x,\n dt,\n A,\n B,\n C,\n D=None,\n z=None,\n dt_bias=None,\n dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n x: (batch, dim) or (batch, nheads, dim)\n dt: (batch, dim) or (batch, nheads, dim)\n A: (dim, dstate) or (nheads, dim, dstate)\n B: (batch, dstate) or (batch, ngroups, dstate)\n C: (batch, dstate) or (batch, ngroups, dstate)\n D: (dim,) or (nheads, dim)\n z: (batch, dim) or (batch, nheads, dim)\n dt_bias: (dim,) or (nheads, dim)\n Return:\n out: (batch, dim) or (batch, nheads, dim)\n \"\"\"\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else\n (0, 0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else\n ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(\n -1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state,\n x,\n dt,\n dt_bias,\n A,\n B,\n C,\n D,\n z,\n out,\n batch,\n nheads,\n dim,\n dstate,\n nheads // ngroups,\n state.stride(0),\n state.stride(1),\n state.stride(2),\n state.stride(3),\n x.stride(0),\n x.stride(1),\n x.stride(2),\n dt.stride(0),\n dt.stride(1),\n dt.stride(2),\n *(dt_bias.stride(0),\n dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0),\n A.stride(1),\n A.stride(2),\n B.stride(0),\n B.stride(1),\n B.stride(2),\n C.stride(0),\n C.stride(1),\n C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0],\n z_strides[1],\n z_strides[2],\n out.stride(0),\n out.stride(1),\n out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to define two kernels. The first kernel, `softplus(dt)`, applies the softplus function to a given input tensor `dt`, with different implementations based on Triton version. It takes one input argument: `dt`, representing the input tensor. The second kernel, `_selective_scan_update_kernel`, performs selective state updates based on multiple parameters and conditions. It requires 46 parameters in total, covering pointers to matrices (state_ptr, x_ptr, etc.), matrix dimensions (batch, nheads, dim, etc.), and meta-parameters (DT_SOFTPLUS, TIE_HDIM, etc.). This kernel utilizes the `softplus` function depending on a condition to compute updates for a given state, using various inputs like x, dt, A, B, C, D, and z.", - "description_2": "Use triton language to implement a selective state update kernel and a softplus function kernel, processing input data with conditions and matrix updates.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n \"\"\"Similar to torch.rand, but allows for seeds to be set per row.\"\"\"\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output\n tensor. The random numbers in a row generated using the seed for that row.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a seeded uniform random number generator. The function `seeded_uniform` takes parameters: size (dimensions of the output tensor), seeds (1D tensor for per-row seeds), out (optional output tensor), dtype (optional data type), device (optional device), and pin_memory (optional boolean for pinned memory). It calculates the necessary strides and block sizes, then calls the Triton kernel `_seeded_uniform_triton`. The kernel generates random float32 numbers in [0, 1) for each element in the output tensor using the provided seeds. It takes parameters: out_ptr (output tensor), seed_ptr (seed tensor), out_row_stride (stride between rows), out_3d_stride (stride between 3D slices), seed_row_stride (stride between seed rows), n_rows (number of rows), n_3d (size of second dimension if 3D), n_cols (number of columns), n_slices (number of philox outputs), and block_size (size of each block).", - "description_2": "Use triton language to create a random number generator that produces float32 numbers in [0, 1) for each element in a tensor, with seeds set per row.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS: tl.constexpr = 1e-6\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n col_offsets = tl.arange(0, block_size)\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n if modify_greedy_probs:\n if not uses_random_sampling:\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to implement a kernel that samples tokens from a probability distribution. The kernel takes 18 parameters: sample_indices_ptr (tensor of sample indices), output_ptr (tensor to store sampled tokens), output_logprobs_ptr (tensor to store log probabilities of sampled tokens), output_modified_probs_ptr (tensor to store modified probabilities), probs_ptr (tensor of probabilities), logprobs_ptr (tensor of log probabilities), seeds_ptr (tensor of seeds for sampling), uniform_noise_ptr (tensor of uniform noise), output_row_stride (stride for output tensor), probs_row_stride (stride for probability tensor), uniform_noise_row_stride (stride for uniform noise tensor), uniform_noise_best_stride (stride for best uniform noise), n_samples (number of samples), n_cols (number of columns in probability tensor), n_best (number of best samples), block_size (block size for loading data), modify_greedy_probs (flag to modify greedy probabilities), save_logprobs (flag to save log probabilities), and save_modified_probs (flag to save modified probabilities). The kernel loads probability data, applies noise if needed, finds the maximum probability, and stores the result. It can also modify probabilities for greedy sampling and save log probabilities and modified probabilities.", - "description_2": "Use triton language to implement a kernel that converts uniform noise to exponential noise. The kernel takes 1 parameter: uniform_noise (tensor of uniform noise). It clamps the noise to avoid division by zero, applies the inversion method to convert uniform samples to exponential samples, and returns the result.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nAWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\n\n@triton.jit\ndef awq_dequantize_kernel(\n qweight_ptr, # quantized matrix\n scales_ptr, # scales, per group\n zeros_ptr, # zeros, per group\n group_size, # Should always be one of the supported group sizes\n result_ptr, # Output matrix\n num_cols, # input num cols in qweight\n num_rows, # input num rows in qweight\n BLOCK_SIZE_X: tl.constexpr,\n BLOCK_SIZE_Y: tl.constexpr):\n pid_x = tl.program_id(axis=0)\n pid_y = tl.program_id(axis=1)\n offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)\n offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)\n offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]\n masks_y = offsets_y < num_rows\n masks_x = offsets_x < num_cols\n masks = masks_y[:, None] & masks_x[None, :]\n result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)\n result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(\n 0, BLOCK_SIZE_X * 8)\n result_offsets = (8 * num_cols * result_offsets_y[:, None] +\n result_offsets_x[None, :])\n result_masks_y = result_offsets_y < num_rows\n result_masks_x = result_offsets_x < num_cols * 8\n result_masks = result_masks_y[:, None] & result_masks_x[None, :]\n iweights = tl.load(qweight_ptr + offsets, masks)\n iweights = tl.interleave(iweights, iweights)\n iweights = tl.interleave(iweights, iweights)\n iweights = tl.interleave(iweights, iweights)\n reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +\n tl.arange(0, 4)[:, None]).reshape(8)\n shifts = reverse_awq_order_tensor * 4\n shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))\n shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))\n iweights = (iweights >> shifts) & 0xF\n zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)\n zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)\n zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]\n zero_masks_y = zero_offsets_y < num_rows // group_size\n zero_masks_x = zero_offsets_x < num_cols\n zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]\n zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)\n zeros = tl.interleave(zeros, zeros)\n zeros = tl.interleave(zeros, zeros)\n zeros = tl.interleave(zeros, zeros)\n zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))\n zeros = (zeros >> shifts) & 0xF\n scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)\n scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +\n tl.arange(0, BLOCK_SIZE_X * 8))\n scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +\n scale_offsets_x[None, :])\n scale_masks_y = scale_offsets_y < num_rows // group_size\n scale_masks_x = scale_offsets_x < num_cols * 8\n scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]\n scales = tl.load(scales_ptr + scale_offsets, scale_masks)\n scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))\n iweights = (iweights - zeros) * scales\n iweights = iweights.to(result_ptr.type.element_ty)\n tl.store(result_ptr + result_offsets, iweights, result_masks)\n\n@triton.jit\ndef awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,\n group_size, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n SPLIT_K: tl.constexpr):\n pid = tl.program_id(axis=0)\n pid_z = tl.program_id(1)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n pid_m = pid // num_pid_n\n pid_n = pid % num_pid_n\n accumulator_dtype = c_ptr.type.element_ty\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),\n dtype=accumulator_dtype)\n reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +\n tl.arange(0, 4)[:, None]).reshape(8)\n shifts = reverse_awq_order_tensor * 4\n shifts = tl.broadcast_to(shifts[None, :],\n (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))\n shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))\n offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n masks_am = offsets_am < M\n offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)\n masks_bn = offsets_bn < N // 8\n offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)\n masks_zn = offsets_zn < N // 8\n offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n masks_sn = offsets_sn < N\n offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offsets_a = K * offsets_am[:, None] + offsets_k[None, :]\n offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]\n a_ptrs = a_ptr + offsets_a\n b_ptrs = b_ptr + offsets_b\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n masks_k = offsets_k < K\n masks_a = masks_am[:, None] & masks_k[None, :]\n a = tl.load(a_ptrs, mask=masks_a)\n masks_b = masks_k[:, None] & masks_bn[None, :]\n b = tl.load(b_ptrs, mask=masks_b)\n b = tl.interleave(b, b)\n b = tl.interleave(b, b)\n b = tl.interleave(b, b)\n offsets_szk = (\n (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +\n tl.arange(0, 1))\n offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]\n masks_zk = offsets_szk < K // group_size\n masks_z = masks_zk[:, None] & masks_zn[None, :]\n zeros_ptrs = zeros_ptr + offsets_z\n zeros = tl.load(zeros_ptrs, mask=masks_z)\n zeros = tl.interleave(zeros, zeros)\n zeros = tl.interleave(zeros, zeros)\n zeros = tl.interleave(zeros, zeros)\n zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))\n offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]\n masks_sk = offsets_szk < K // group_size\n masks_s = masks_sk[:, None] & masks_sn[None, :]\n scales_ptrs = scales_ptr + offsets_s\n scales = tl.load(scales_ptrs, mask=masks_s)\n scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))\n b = (b >> shifts) & 0xF\n zeros = (zeros >> shifts) & 0xF\n b = (b - zeros) * scales\n b = b.to(c_ptr.type.element_ty)\n accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)\n offsets_k += BLOCK_SIZE_K * SPLIT_K\n a_ptrs += BLOCK_SIZE_K * SPLIT_K\n b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)\n c = accumulator.to(c_ptr.type.element_ty)\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\ndef awq_dequantize_triton(qweight: torch.Tensor,\n scales: torch.Tensor,\n zeros: torch.Tensor,\n block_size_x: int = 32,\n block_size_y: int = 32) -> torch.Tensor:\n K = qweight.shape[0]\n M = scales.shape[1]\n group_size = qweight.shape[0] // scales.shape[0]\n assert K > 0 and M > 0\n assert scales.shape[0] == K // group_size and scales.shape[1] == M\n assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8\n assert group_size <= K\n assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K\n result = torch.empty(qweight.shape[0],\n qweight.shape[1] * 8,\n device=qweight.device,\n dtype=scales.dtype)\n Y = qweight.shape[0]\n X = qweight.shape[1]\n grid = lambda META: (\n triton.cdiv(X, META['BLOCK_SIZE_X']),\n triton.cdiv(Y, META['BLOCK_SIZE_Y']),\n )\n awq_dequantize_kernel[grid](qweight,\n scales,\n zeros,\n group_size,\n result,\n X,\n Y,\n BLOCK_SIZE_X=block_size_x,\n BLOCK_SIZE_Y=block_size_y)\n return result\n\ndef awq_gemm_triton(input: torch.Tensor,\n qweight: torch.Tensor,\n scales: torch.Tensor,\n qzeros: torch.Tensor,\n split_k_iters: int,\n block_size_m: int = 32,\n block_size_n: int = 32,\n block_size_k: int = 32) -> torch.Tensor:\n M, K = input.shape\n N = qweight.shape[1] * 8\n group_size = qweight.shape[0] // qzeros.shape[0]\n assert N > 0 and K > 0 and M > 0\n assert qweight.shape[0] == K and qweight.shape[1] == N // 8\n assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8\n assert scales.shape[0] == K // group_size and scales.shape[1] == N\n assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0\n assert split_k_iters <= 32\n assert group_size <= K\n assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(\n N, META['BLOCK_SIZE_N']),\n split_k_iters,\n )\n result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)\n awq_gemm_kernel[grid](input,\n qweight,\n result,\n qzeros,\n scales,\n M,\n N,\n K,\n group_size,\n BLOCK_SIZE_M=block_size_m,\n BLOCK_SIZE_N=block_size_n,\n BLOCK_SIZE_K=block_size_k,\n SPLIT_K=split_k_iters)\n return result\n", - "description_1": "Use triton language to implement two kernels: awq_dequantize_kernel and awq_gemm_kernel. The awq_dequantize_kernel takes 8 arguments including pointers to quantized weights, scales, and zeros, with group size, result pointer, number of columns and rows, and block sizes for computation, and it dequantizes a quantized matrix using these parameters. The awq_gemm_kernel takes 12 arguments including pointers to input matrices, zero and scale pointers, dimensions M, N, K, group size, and block sizes for matrix multiplication, it performs a quantized GEMM operation with dequantization during the computation.", - "description_2": "Use triton language to create two functions: awq_dequantize_triton and awq_gemm_triton. The awq_dequantize_triton function calls awq_dequantize_kernel to dequantize a matrix given quantized weights, scales, and zeros, along with block sizes. The awq_gemm_triton function performs matrix multiplication using awq_gemm_kernel, requiring quantized input matrices, scales, zeros, split_k_iters, and block sizes to output a dequantized matrix result.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef group_norm_kernel(\n input_ptr,\n output_ptr,\n gamma_ptr,\n beta_ptr,\n img_size,\n c,\n c_per_group,\n eps,\n BLOCK_SIZE: tl.constexpr,\n HW_SIZE: tl.constexpr,\n ACTIVATION_SWISH: tl.constexpr,\n):\n row_x = tl.program_id(0)\n row_y = tl.program_id(1)\n stride = img_size * c\n input_ptr += row_x * stride + row_y * c_per_group\n output_ptr += row_x * stride + row_y * c_per_group\n gamma_ptr += row_y * c_per_group\n beta_ptr += row_y * c_per_group\n\n cols = tl.arange(0, BLOCK_SIZE)\n hw = tl.arange(0, HW_SIZE)\n offsets = hw[:, None] * c + cols[None, :]\n mask = (cols < c_per_group)[None, :]\n\n # Calculate mean and variance\n _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)\n _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)\n for i in range(tl.cdiv(img_size, HW_SIZE)):\n x_ptr = input_ptr + i * HW_SIZE * c\n a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)\n _sum += a\n _square_sum += a * a\n\n # Set axis=None (or leave it unspecified) to reduce all axes.\n group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group)\n group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean\n\n rstd = 1 / tl.sqrt(group_var + eps)\n\n # Normalize and apply linear transformation\n gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)\n beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)\n for i in range(tl.cdiv(img_size, HW_SIZE)):\n x_ptr = input_ptr + i * HW_SIZE * c\n y_ptr = output_ptr + i * HW_SIZE * c\n x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)\n x_hat = (x - group_mean) * rstd\n y = x_hat * gamma + beta\n if ACTIVATION_SWISH:\n y *= tl.sigmoid(y)\n tl.store(y_ptr + offsets, y, mask=mask)\n\n\ndef get_function_table():\n func_table = []\n from itertools import product\n\n with_swish = [True, False]\n dtypes = [\"fp32\", \"fp16\"]\n blocks = [16, 32, 64, 128]\n hw_sizes = [8, 16, 32, 64, 128, 256]\n warps = [1, 2, 4, 8, 16]\n name_pattern = \"GroupNormTriton_{}_{}_b{}_hw{}_w{}\"\n sig_pattern = \"*{},*{},*fp32,*fp32,i32,i32,i32,fp32\"\n group_pattern = \"GroupNormTriton_{}_{}\"\n\n for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks):\n swish_suffix = \"Swish\" if swish else \"Pass\"\n name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp)\n group = group_pattern.format(swish_suffix, dtype)\n sig = sig_pattern.format(dtype, dtype)\n kwargs = {\n \"num_warps\": warp,\n \"constants\": {\"BLOCK_SIZE\": b, \"HW_SIZE\": hw_size, \"ACTIVATION_SWISH\": int(swish)},\n }\n func_desc = {\"name\": name, \"group\": group, \"func\": group_norm_kernel, \"sig\": sig, \"kwargs\": kwargs}\n func_table.append(func_desc)\n return func_table\n\n\nif __name__ == \"__main__\":\n func_table = get_function_table()\n for func_desc in func_table:\n print(func_desc)\n", - "description_1": "Use triton language to implement a group normalization kernel `group_norm_kernel` that takes 11 parameters: input and output pointers, gamma and beta pointers, image size, channels, channels per group, epsilon for numerical stability, BLOCK_SIZE, HW_SIZE, and ACTIVATION_SWISH. This kernel computes group normalization for a batch of inputs with the option to apply swish activation. The calling function `get_function_table` generates multiple configurations of this kernel with varying parameters like block size, hardware sizes, and activation choices.", - "description_2": "Use triton language to create a group normalization kernel with optional swish activation and generate multiple kernel configurations for different hardware and execution parameters.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float(\"inf\"))\n row_f32 = row.to(tl.float32)\n # Subtract maximum for numerical stability\n row_minus_max = row_f32 - tl.max(row_f32, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output.to(row.dtype), mask=col_offsets < n_cols)\n\ndtypes = [\"fp32\", \"fp16\"]\nblocks = [1024, 2048, 4096, 8192, 16384]\nname_pattern = \"softmax_{}_{}\"\nsig_pattern = \"*{},*{},i32,i32,i32\"\ngroup_pattern = \"softmax_{}\"\n\ndef get_function_table():\n func_table = []\n\n def get_num_warps(block_size):\n num_warps = 4\n if block_size >= 2048:\n num_warps = 8\n if block_size >= 4096:\n num_warps = 16\n return num_warps\n\n for dtype in dtypes:\n for b in blocks:\n name = name_pattern.format(dtype, b)\n group = group_pattern.format(dtype)\n sig = sig_pattern.format(dtype, dtype)\n num_warps = get_num_warps(b)\n kwargs = {\"num_warps\": num_warps, \"constants\": {\"BLOCK_SIZE\": b}}\n func_desc = {\"name\": name, \"group\": group, \"func\": softmax_kernel, \"sig\": sig, \"kwargs\": kwargs}\n func_table.append(func_desc)\n\n return func_table\n", - "description_1": "Use triton language to implement a softmax kernel that operates over rows of a matrix stored in DRAM. The kernel takes 6 parameters: output_ptr (pointer to output matrix in DRAM), input_ptr (pointer to input matrix in DRAM), input_row_stride (stride to move between rows of the input matrix), output_row_stride (stride to move between rows of the output matrix), n_cols (number of columns in the matrix), and BLOCK_SIZE (block size for parallelization). It performs the softmax computation for each row independently, leveraging parallelism across rows.", - "description_2": "Use triton language to define a function table creation that organizes softmax kernel functions for different data types and block sizes, by generating appropriate function descriptions with varying numbers of warps and constants.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"XBLOCK\": 128}, num_warps=4),\n triton.Config({\"XBLOCK\": 256}, num_warps=8),\n ],\n key=[\"xnumel\"],\n)\n@triton.jit\ndef elementwise_kernel(x_input, x_output, xnumel, XBLOCK: tl.constexpr):\n xnumel = xnumel\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)\n xmask = xindex < xnumel\n\n x0 = xindex // 8\n x1 = xindex % 8\n\n x_input_val = tl.load(x_input + x0 * 8 + x1, xmask, other=0.0)\n x_output_val = tl.exp(x_input_val)\n tl.store(x_output + x0 * 8 + x1, x_output_val, xmask)\n\n\ndef launch_elementwise_kernel(x_input, x_output, xnumel):\n grid = lambda meta: (triton.cdiv(xnumel, meta[\"XBLOCK\"]),)\n elementwise_kernel[grid](x_input, x_output, xnumel, XBLOCK=128)\n", - "description_1": "Use triton language to implement an elementwise operation on input tensor, applying exponential function to each element. The kernel has parameters for input tensor, output tensor, and number of elements to process. Use grid strategy based on input element size.", - "description_2": "Implement an elementwise exponential function using Triton for input tensors, utilizing grid-based parallel execution for performance.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n # Add autotune configurations here\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef kernel_mm(\n A, B, OUT, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr\n):\n # Triton kernel for matrix multiplication\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(K, 0, -BLOCK_K):\n a = tl.load(A)\n b = tl.load(B)\n acc += tl.dot(a, b, allow_tf32=True)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n idx_m = rm[:, None]\n idx_n = rn[None, :]\n mask = (idx_m < M) & (idx_n < N)\n OUT = OUT + (idx_m * N + idx_n)\n tl.store(OUT, acc, mask=mask)\n\ndef mm_func(a, b, out):\n # Function to call the Triton kernel\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n kernel_mm[grid](a, b, out, M, N, K)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with parameters for input matrices A and B, output matrix OUT, dimensions M, N, K, and block sizes BLOCK_M, BLOCK_N, BLOCK_K. The kernel reorders program IDs for better performance and uses a loop to accumulate results in acc, which is then stored in OUT.", - "description_2": "Use triton language to create a matrix multiplication kernel and a function to execute it, handling input matrices and output storage.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for slice log softmax\n@triton.jit\ndef _triton_slice_log_softmax(log_prob, logit, d: tl.constexpr, c: tl.constexpr, RBLOCK: tl.constexpr):\n xoffset = tl.program_id(0)\n logit_xoffset = (xoffset // d * (d + 1) + xoffset % d) * c\n rbase = tl.arange(0, RBLOCK)\n logit_max_row = tl.zeros([RBLOCK], tl.float32) + float(\"-inf\")\n for roffset in range(0, c, RBLOCK):\n rindex = rbase + roffset\n rmask = rindex < c\n logit_row = tl.load(logit + logit_xoffset + rindex, mask=rmask, other=0.0).to(tl.float32)\n logit_max_row = tl.where(rmask & (logit_max_row < logit_row), logit_row, logit_max_row)\n logit_max_reduced = tl.max(logit_max_row, axis=0)\n exp_sum_row = tl.zeros([RBLOCK], tl.float32)\n for roffset in range(0, c, RBLOCK):\n rindex = rbase + roffset\n rmask = rindex < c\n logit_row = tl.load(logit + logit_xoffset + rindex, mask=rmask, other=0.0).to(tl.float32)\n exp_sum_row = tl.where(rmask, exp_sum_row + tl.exp(logit_row - logit_max_reduced), exp_sum_row)\n reduced_log_sum = tl.log(tl.sum(exp_sum_row, axis=0)) + logit_max_reduced\n for roffset in range(0, c, RBLOCK):\n rindex = rbase + roffset\n rmask = rindex < c\n logit_row = tl.load(logit + logit_xoffset + rindex, mask=rmask, other=0.0).to(tl.float32)\n output_row = logit_row - reduced_log_sum\n tl.store(log_prob + xoffset * c + rindex, output_row, mask=rmask)\n\n# Triton kernel for slice softmax cross-entropy loss\n@triton.jit\ndef _triton_slice_scel(\n loss,\n factor,\n log_prob,\n label,\n ignore_index,\n d: tl.constexpr,\n c: tl.constexpr,\n n_cols: tl.constexpr,\n RBLOCK: tl.constexpr,\n):\n rbase = tl.arange(0, RBLOCK)\n neg_sum_row = tl.zeros([RBLOCK], tl.float32)\n factor_row = tl.zeros([RBLOCK], tl.float32)\n for roffset in range(0, n_cols, RBLOCK):\n rindex = rbase + roffset\n rmask = rindex < n_cols\n label_row = tl.load(label + (rindex // d) * (d + 1) + rindex % d + 1, mask=rmask, other=0.0).to(tl.int32)\n mask = rmask & (label_row != ignore_index)\n log_prob_row = tl.load(log_prob + rindex * c + label_row, mask=mask, other=0.0)\n neg_sum_row = tl.where(mask, neg_sum_row - log_prob_row, neg_sum_row)\n factor_row = tl.where(mask, factor_row + 1.0, factor_row)\n reduced_neg_sum = tl.sum(neg_sum_row, axis=0)\n reduced_factor = tl.sum(factor_row, axis=0)\n loss_value = reduced_neg_sum / reduced_factor\n tl.store(loss, loss_value)\n tl.store(factor, reduced_factor)\n\n# Function to compute slice softmax cross-entropy loss\ndef slice_scel(logit, label, ignore_index):\n ignore_index_value = ignore_index.item()\n c = logit.shape[-1]\n logit_d = logit.shape[-2]\n d = logit_d - 1\n n = logit.numel() // (logit_d * c)\n log_prob_shape = list(logit.shape)[:-2] + [d, c]\n log_prob = torch.empty(log_prob_shape, dtype=torch.float, device=logit.device)\n rblock = 4096 if c > 4096 else triton.next_power_of_2(c)\n num_warps = 16 if rblock >= 4096 else (8 if rblock >= 2048 else 4)\n _triton_slice_log_softmax[(n * d,)](log_prob, logit, d, c, num_warps=num_warps, RBLOCK=rblock)\n loss = torch.empty([], dtype=logit.dtype, device=logit.device)\n factor = torch.empty([], dtype=torch.float, device=logit.device)\n n_cols = n * d\n rblock = 1024 if n_cols > 1024 else triton.next_power_of_2(n_cols)\n _triton_slice_scel[(1,)](loss, factor, log_prob, label, ignore_index_value, d, c, n_cols, RBLOCK=rblock)\n return loss, log_prob, factor\n\n# Triton kernel for slice softmax cross-entropy loss backward\n@triton.jit\ndef _triton_slice_scel_backward(\n dlogit,\n dloss,\n log_prob,\n label,\n factor,\n d: tl.constexpr,\n c: tl.constexpr,\n n_elements: tl.constexpr,\n XBLOCK: tl.constexpr,\n):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)\n xmask = xindex < n_elements\n nd_index = xindex // c\n dlogit_nd_index = (nd_index // d) * (d + 1) + nd_index % d\n label_nd_index = dlogit_nd_index + 1\n c_index = xindex % c\n dloss_value = tl.load(dloss).to(tl.float32)\n log_prob_row = tl.load(log_prob + xindex, mask=xmask, other=0.0)\n label_row = tl.load(label + label_nd_index, mask=xmask, other=0.0).to(tl.int32)\n factor_value = tl.load(factor)\n dlogit_row = dloss_value * (tl.exp(log_prob_row) - tl.where(c_index == label_row, 1.0, 0.0)) / factor_value\n tl.store(dlogit + dlogit_nd_index * c + c_index, dlogit_row, mask=xmask)\n\n# Triton kernel for slice softmax cross-entropy loss backward with bias\n@triton.jit\ndef _triton_slice_scel_bias_backward(\n dlogit,\n dloss,\n log_prob,\n label,\n factor,\n bias,\n dlogit_d: tl.constexpr,\n c: tl.constexpr,\n n_elements: tl.constexpr,\n XBLOCK: tl.constexpr,\n):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)\n xmask = xindex < n_elements\n dlogit_nd_index = xindex // c\n dlogit_n_index = dlogit_nd_index // dlogit_d\n dlogit_d_index = dlogit_nd_index % dlogit_d\n nd_index = dlogit_n_index * (dlogit_d - 1) + dlogit_d_index\n nd_mask = xmask & (dlogit_d_index != dlogit_d - 1)\n c_index = xindex % c\n dloss_value = tl.load(dloss).to(tl.float32)\n log_prob_row = tl.load(log_prob + nd_index * c + c_index, mask=nd_mask, other=0.0)\n label_row = tl.load(label + dlogit_nd_index + 1, mask=nd_mask, other=0.0).to(tl.int32)\n factor_value = tl.load(factor)\n bias_row = tl.load(bias + xindex, mask=xmask, other=0.0).to(tl.float32)\n dlogit_row = dloss_value * (tl.exp(log_prob_row) - tl.where(c_index == label_row, 1.0, 0.0)) / factor_value\n dlogit_row = tl.where(nd_mask, dlogit_row, 0.0) + bias_row\n tl.store(dlogit + xindex, dlogit_row, mask=xmask)\n\n# Function to compute slice softmax cross-entropy loss backward\ndef slice_scel_backward(dloss, log_prob, label, factor, bias):\n c = log_prob.shape[-1]\n d = log_prob.shape[-2]\n dlogit_d = d + 1\n dlogit_shape = list(log_prob.shape)[:-2] + [dlogit_d, c]\n dlogit = (\n torch.empty(dlogit_shape, dtype=dloss.dtype, device=dloss.device)\n if bias is not None\n else torch.zeros(dlogit_shape, dtype=dloss.dtype, device=dloss.device)\n )\n n_elements = dlogit.numel() if bias is not None else log_prob.numel()\n xblock = 1024 if n_elements > 1024 else triton.next_power_of_2(n_elements)\n\n def grid(meta):\n return (triton.cdiv(n_elements, meta[\"XBLOCK\"]),)\n\n if bias is not None:\n _triton_slice_scel_bias_backward[grid](\n dlogit, dloss, log_prob, label, factor, bias, dlogit_d, c, n_elements, XBLOCK=xblock\n )\n else:\n _triton_slice_scel_backward[grid](dlogit, dloss, log_prob, label, factor, d, c, n_elements, XBLOCK=xblock)\n return dlogit\n", - "description_1": "Use triton language to implement slice log softmax and slice softmax cross-entropy loss with backward pass. The kernels handle operations on tensors with specific dimensions and compute gradients efficiently.", - "description_2": "Use triton language to implement slice log softmax and slice softmax cross-entropy loss with backward pass.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n\ndef init_to_zero(name):\n return lambda nargs: nargs[name].zero_()\n\ndef get_configs_io_bound():\n configs = []\n for num_stages in [2, 3, 4, 5, 6]:\n for block_m in [16, 32]:\n for block_k in [32, 64]:\n for block_n in [32, 64, 128, 256]:\n num_warps = 2 if block_n <= 64 else 4\n configs.append(\n triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},\n num_stages=num_stages, num_warps=num_warps))\n for split_k in [2, 4, 8, 16]:\n configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},\n num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))\n return configs\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n ] + get_configs_io_bound(),\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias: tl.constexpr,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr\n ):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // group_size\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n w_factor = tl.load(state_w_ptr)\n x_factor = tl.load(state_x_ptr + ram)[:, None]\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = (w_factor * (x_factor * (acc * divfactor)))\n acc = acc.to(C.dtype.element_ty)\n if has_bias:\n bias = tl.load(bias + rn).to(C.dtype.element_ty)\n acc = acc + bias[None, :]\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\ndef int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):\n device = a.device\n divfactor = 1. / (127. * 127.)\n has_bias = 0 if bias is None else 1\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n c = torch.empty((M, N), device=device, dtype=torch.float16)\n ACC_TYPE = tl.float32\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n GROUP_M=8, ACC_TYPE=ACC_TYPE)\n return c\n", - "description_1": "Use triton language to implement an int8 matrix multiplication and dequantization kernel, supporting row-wise quantized input and global quantized weight with optional bias. The kernel takes two matrices 'a' and 'b', and additional parameters including quantization states and bias. It handles varying block sizes, splits for K dimension, and optimizes for performance using autotuning and heuristics.", - "description_2": "Use triton language to create a matrix multiplication and dequantization kernel for int8 inputs with optimizations for different block sizes and splits.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr\n ):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n w_factor = tl.load(state_w_ptr + rbn)[None, :]\n x_factor = tl.load(state_x_ptr + ram)[:, None]\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n \n acc = (w_factor * (x_factor * (acc * divfactor)))\n acc = acc.to(C.dtype.element_ty)\n\n if has_bias:\n bias = tl.load(bias + rn).to(C.dtype.element_ty)\n acc = acc + bias[None, :]\n\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\ndef int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):\n divfactor = 1. / (127. * 127.)\n\n has_bias = 0 if bias is None else 1\n\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n c = torch.empty((M, N), device=device, dtype=torch.float16)\n ACC_TYPE = tl.float32\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n GROUP_M=8, ACC_TYPE=ACC_TYPE)\n return c\n", - "description_1": "Use triton language to implement a kernel for int8 matrix multiplication with rowwise dequantization. The kernel function '_int8_matmul_rowwise_dequantize' takes 22 arguments: input matrices A and B, output matrix C, bias, scaling factors 'state_x_ptr' and 'state_w_ptr', matrix dimensions M, N, K, a dequantization factor 'divfactor', a flag 'has_bias' to indicate if bias should be added, and various stride and block size parameters for efficient computation. The function performs a matrix multiplication on quantized inputs, scales the results using pre-computed scaling factors, adds bias if applicable, and writes the results back to the output matrix C.", - "description_2": "Use triton language to implement a fused int8 matrix multiplication with dequantization, supporting bias addition and optimized using autotuning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n# This kernel does fused columnwise quantization and transpose.\n@triton.jit\ndef _quantize_columnwise_and_transpose(\n x_ptr,\n output_ptr,\n output_maxs,\n n_elements,\n M: tl.constexpr, N: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid\n p2_arange = tl.arange(0, P2)\n p2_arange_mask = p2_arange < M\n arange = p2_arange * N\n offsets = block_start + arange\n x = tl.load(x_ptr + offsets, mask=p2_arange_mask)\n abs_x = tl.abs(x)\n max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)\n output = tl.libdevice.llrint(127. * (x / max_val))\n\n new_start = pid * M\n new_offsets = new_start + p2_arange\n tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)\n tl.store(output_maxs + pid, max_val)\n\ndef quantize_columnwise_and_transpose(x: torch.Tensor):\n M, N = x.shape\n output = torch.empty(N, M, device=x.device, dtype=torch.int8)\n output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(M))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)\n return output, output_maxs\n", - "description_1": "Use triton language to create a kernel function _quantize_columnwise_and_transpose that performs a fused columnwise quantization and transpose operation on an input tensor x. This function takes eight parameters: x_ptr (pointer to input tensor), output_ptr (pointer to output tensor for quantized values), output_maxs (pointer to store max values per column), n_elements (number of elements to process), M (constant, number of rows in input), N (constant, number of columns in input), BLOCK_SIZE (constant, size of blocks to process), and P2 (constant, power of two greater than or equal to M). It loads values from x, computes absolute values and maximum per column, scales and quantizes them, stores results in output_ptr and maximum values in output_maxs. The function quantize_columnwise_and_transpose is a wrapper that prepares parameters and launches the kernel on the CUDA device.", - "description_2": "Use triton language to perform columnwise quantization and transpose of a CUDA tensor, storing quantized values and column maximums.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel for global quantization\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),\n triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),\n ],\n key=['n_elements']\n)\n@triton.jit\ndef _quantize_global(\n x_ptr,\n absmax_inv_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n absmax_inv = tl.load(absmax_inv_ptr)\n output = tl.libdevice.llrint(127. * (x * absmax_inv))\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef quantize_global(x: torch.Tensor):\n absmax = x.abs().max().unsqueeze(0)\n absmax_inv = 1./ absmax\n output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _quantize_global[grid](x, absmax_inv, output, n_elements)\n return output, absmax\n\n# Triton kernel for global quantization and transpose\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),\n ],\n key=['M', 'N']\n)\n@triton.jit\ndef _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, \n BLOCK_M : tl.constexpr, \n BLOCK_N : tl.constexpr, \n GROUP_M : tl.constexpr):\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n \n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // group_size\n \n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n a = tl.load(A, mask=mask)\n absmax_inv = tl.load(absmax_inv_ptr)\n \n # rematerialize to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n output = tl.libdevice.llrint(127. * (a * absmax_inv))\n\n tl.store(B, output, mask=mask)\n\ndef quantize_global_transpose(input):\n absmax = input.abs().max().unsqueeze(0)\n absmax_inv = 1./ absmax\n M, N = input.shape\n out = torch.empty(N, M, device='cuda', dtype=torch.int8)\n \n assert out.size(0) == N and out.size(1) == M\n assert input.stride(0) == 1 or input.stride(1) == 1\n assert out.stride(0) == 1 or out.stride(1) == 1\n \n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)\n _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)\n return out, absmax\n", - "description_1": "Use triton language to implement two kernels: one for quantizing a tensor globally and another for quantizing and transposing a tensor. The first kernel (_quantize_global) takes a tensor pointer, inverse of max absolute value, output pointer, and number of elements; applies scaling and stores quantized results. The second kernel (_quantize_global_transpose) handles quantization and transposition by reading an input matrix and writing transposed, quantized data to the output. Auxiliary Python functions wrap these kernels and manage memory and CUDA configurations.", - "description_2": "Use triton language to create two operators: one for tensor global quantization and another for quantization with transposition. Both utilize memory pointers and scaling for efficient processing on GPU.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _quantize_rowwise(\n x_ptr,\n output_ptr,\n output_maxs,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n arange = tl.arange(0, P2)\n offsets = block_start + arange\n row_mask = arange < BLOCK_SIZE\n x = tl.load(x_ptr + offsets, mask=row_mask)\n \n abs_x = tl.abs(x)\n max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)\n output = tl.libdevice.llrint(127. * (x / max_val))\n tl.store(output_ptr + offsets, output, mask=row_mask)\n tl.store(output_maxs + pid, max_val)\n\ndef quantize_rowwise(x: torch.Tensor):\n output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)\n output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (x.shape[0],)\n _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)\n return output, output_maxs\n", - "description_1": "Use triton language to create a function '_quantize_rowwise' that quantizes a row of a given tensor. It takes six arguments: (1) 'x_ptr': pointer to input tensor data, (2) 'output_ptr': pointer to store quantized output, (3) 'output_maxs': pointer to store the maximum value per row, (4) 'n_elements': number of elements to process, (5) 'BLOCK_SIZE': size of block to process, and (6) 'P2': nearest power of 2 greater than or equal to the row size. Use 'quantize_rowwise' function to handle Torch tensors and launch the Triton kernel with calculated configurations.", - "description_2": "Use triton language to implement a kernel for row-wise tensor quantization and provide a PyTorch function to handle the data and launch the kernel.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, \n 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, \n 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, \n stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, \n stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, \n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, \n BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, \n EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n\n # Pointer arithmetic for inputs and outputs\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == 'vector':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == 'matrix':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n \n # Initialization\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n\n # Load Q\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n elif EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)\n\n # End range\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n\n # Main loop over K, V\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n elif EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float('-inf'))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float('-inf'))\n if BIAS_TYPE != 'none':\n if BIAS_TYPE == 'vector':\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == 'matrix':\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n elif EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n\n # Scale acc_o\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n\n # Store LSE\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n\n # Store result\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n elif EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n (batch, seqlen_q, nheads, d) = q.shape\n (_, seqlen_k, _, _) = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'\n assert q.is_cuda and k.is_cuda and v.is_cuda\n\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)\n\n _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), \n k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), \n *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, \n seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, \n BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1)\n return (o, lse, softmax_scale)\n", - "description_1": "Use triton language to implement a FlashAttention forward pass kernel. This kernel processes Q, K, V matrices with optional bias and causal masking. The kernel supports varying head dimensions and computes the attention output alongside a log-sum-exp tensor. The forward function (_flash_attn_forward) initializes tensors, calculates strides, and invokes the triton kernel with appropriately defined grid and block configurations.", - "description_2": "Use triton language to create a FlashAttention forward pass that processes Q, K, V matrices and optional bias, handling varying head dimensions and causal settings, computing attention outputs.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _smelu_kernel_forward(\n input_pointer,\n beta: float,\n output_pointer,\n n_elements: int,\n BLOCK_SIZE: tl.constexpr,\n):\n \"\"\" Triton kernel SmeLU forward \"\"\"\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(input_pointer + offsets, mask=mask)\n output = tl.where(x >= beta, x, 0.)\n output = tl.where(tl.abs(x) <= beta, ((x + beta) * (x + beta)) / (4. * beta), output)\n # Write-back output\n tl.store(output_pointer + offsets, output, mask=mask)\n\ndef _smelu_triton_forward(\n input: torch.Tensor,\n beta: float = 2.\n) -> torch.Tensor:\n \"\"\"\n Wrapper function for SmeLU forward triton kernel\n :param input (torch.Tensor): Input tensor of any shape\n :param beta (float): Beta value of SmeLU\n :return (torch.Tensor): Activation of SmeLU\n \"\"\"\n # Init output tensor\n output: torch.Tensor = torch.empty_like(input)\n # Make input contiguous if needed\n if not input.is_contiguous():\n input = input.contiguous()\n # Get number of elements in input\n number_of_elements: int = input.numel()\n # Call triton kernel\n grid = lambda meta: (triton.cdiv(number_of_elements, meta['BLOCK_SIZE']),)\n _smelu_kernel_forward[grid](input, beta, output, number_of_elements, BLOCK_SIZE=1024)\n return output\n\n@triton.jit\ndef _smelu_kernel_backward(\n input_pointer,\n beta: float,\n output_pointer,\n n_elements: int,\n BLOCK_SIZE: tl.constexpr,\n):\n \"\"\" Triton kernel SmeLU backward \"\"\"\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(input_pointer + offsets, mask=mask)\n gradient = tl.where(x >= beta, 1., 0.)\n gradient = tl.where(tl.abs(x) <= beta, 0.5 * (x + beta) / beta, gradient)\n # Write-back output\n tl.store(output_pointer + offsets, gradient, mask=mask)\n\ndef _smelu_triton_backward(\n input: torch.Tensor,\n beta: float = 2.\n) -> torch.Tensor:\n \"\"\"\n Wrapper function for SmeLU backward triton kernel\n :param input (torch.Tensor): Input tensor of any shape\n :param beta (float): Beta value of SmeLU\n :return (torch.Tensor): Gradient of SmeLU\n \"\"\"\n # Init output tensor\n output: torch.Tensor = torch.empty_like(input)\n # Make input contiguous if needed\n if not input.is_contiguous():\n input = input.contiguous()\n # Get number of elements in input\n number_of_elements: int = input.numel()\n # Call triton kernel\n grid = lambda meta: (triton.cdiv(number_of_elements, meta['BLOCK_SIZE']),)\n _smelu_kernel_backward[grid](input, beta, output, number_of_elements, BLOCK_SIZE=1024)\n return output\n", - "description_1": "Use triton language to implement SmeLU activation function with two kernels: one for forward pass and one for backward pass. The forward kernel takes 5 parameters: input_pointer (pointer to input data), beta (float value for SmeLU), output_pointer (pointer to output data), n_elements (number of elements to process), and BLOCK_SIZE (block size for parallel execution). It computes the SmeLU activation and stores the result. The backward kernel also takes 5 parameters: input_pointer, beta, output_pointer, n_elements, and BLOCK_SIZE. It computes the gradient of the SmeLU activation and stores the result. Wrapper functions _smelu_triton_forward and _smelu_triton_backward are provided to call these kernels with PyTorch tensors.", - "description_2": "Use triton language to create a forward and backward kernel for the SmeLU activation function, handling input and output pointers, beta parameter, and element count, with a specified block size for execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef blocksparse_flash_attn_varlen_fwd(\n q,\n k,\n v, # (#tokens, n_heads, head_size)\n cu_seqlens_k,\n cu_seqlens_q,\n sm_scale,\n sparse_layout,\n *,\n block_size=64,\n q_block_size=None,\n max_seqlen=None):\n # split q to blocks\n assert isinstance(sparse_layout, (list, tuple))\n _, n_heads, head_size = q.shape\n batch_size = cu_seqlens_k.size(0) - 1\n q_block_size = q_block_size or block_size\n\n assert q.dim() == k.dim() == v.dim() == 3\n assert q.size(1) % k.size(1) == 0\n assert q.size(2) == k.size(2)\n assert k.shape == v.shape\n assert cu_seqlens_k.dim() == 1\n\n q_k_ratio = q.size(1) // k.size(1)\n\n if cu_seqlens_q is None:\n if q.size(0) == batch_size: # decoding only\n cu_seqlens_q = torch.arange(\n 0,\n batch_size + 1,\n dtype=cu_seqlens_k.dtype,\n device=cu_seqlens_k.device,\n )\n elif q.size(0) == k.size(0):\n cu_seqlens_q = cu_seqlens_k\n else:\n raise ValueError(\"cu_seqlens_q must be specified\\\n if it mix of prefilling and decoding.\")\n else:\n assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)\n\n q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()\n k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()\n\n assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (\n \"length of q should either be 1 (decoding) or same as k (prefilling).\")\n\n if max_seqlen:\n assert k_lens.max() <= max_seqlen\n\n n_blocks = (q_lens + q_block_size - 1) // q_block_size\n\n q_batch_ids = torch.tensor(\n [i for i, n in enumerate(n_blocks) for _ in range(n)],\n dtype=cu_seqlens_q.dtype,\n device=cu_seqlens_q.device,\n )\n q_start_sids = torch.tensor(\n [i * q_block_size for n in n_blocks for i in range(n)],\n dtype=cu_seqlens_q.dtype,\n device=cu_seqlens_q.device,\n )\n\n out = q.new_empty(q.shape)\n cu_seqlens_q = cu_seqlens_q.contiguous()\n cu_seqlens_k = cu_seqlens_k.contiguous()\n\n layout_crow_indices, layout_col_indices = sparse_layout\n block_d = triton.next_power_of_2(head_size)\n\n decoding_only = (q_lens == 1).all().item()\n grid = (len(q_start_sids), n_heads, 1)\n\n _fwd_kernel_batch_inference[grid](\n q,\n k,\n v,\n out,\n sm_scale,\n cu_seqlens_q[:-1],\n cu_seqlens_q[1:],\n cu_seqlens_k[:-1],\n cu_seqlens_k[1:],\n q_batch_ids,\n q_start_sids,\n 0,\n *q.stride(),\n 0,\n *k.stride(),\n 0,\n *v.stride(),\n 0,\n *out.stride(),\n layout_crow_indices,\n layout_col_indices,\n *layout_crow_indices.stride(),\n *layout_col_indices.stride(),\n q_k_ratio,\n HAS_BATCH_DIM=False,\n D_HEAD=head_size,\n BLOCK_M=q_block_size,\n BLOCK_N=block_size,\n BLOCK_D=block_d,\n BLOCK_M_LOADING=(16 if decoding_only else\n q_block_size), # smaller for decoding\n EVEN_D=block_d == head_size,\n num_warps=1 if decoding_only else 4,\n num_stages=3)\n\n return out\n\n@triton.jit\ndef _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n LAST_K_BLOCK: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n BLOCK_N: tl.constexpr,\n D_HEAD: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +\n k_block_col_idx * layout_col_stride_m).to(tl.int32)\n start_n = k_block_id * BLOCK_N\n if LAST_K_BLOCK:\n if EVEN_D:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=offs_n[None, :] + start_n < k_seqlen,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=(offs_n[None, :] + start_n < k_seqlen) &\n (offs_d[:, None] < D_HEAD),\n )\n else:\n if EVEN_D:\n k = tl.load(k_ptrs + start_n * stride_kt)\n else:\n k = tl.load(k_ptrs + start_n * stride_kt,\n mask=offs_d[:, None] < D_HEAD)\n\n qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n if LAST_K_BLOCK | M_LT_N:\n qk += tl.where(\n offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),\n 0,\n float(\"-inf\"),\n )\n\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n p = tl.math.exp2(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n m_i = m_ij\n l_i = l_i * alpha + l_ij\n\n p = p.to(Q.dtype.element_ty)\n if LAST_K_BLOCK:\n if EVEN_D:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=offs_n[:, None] + start_n < k_seqlen,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=(offs_n[:, None] + start_n < k_seqlen) &\n (offs_d[None, :] < D_HEAD),\n )\n else:\n if EVEN_D:\n v = tl.load(v_ptrs + start_n * stride_vt)\n else:\n v = tl.load(v_ptrs + start_n * stride_vt,\n mask=offs_d[None, :] < D_HEAD)\n\n acc += tl.dot(p, v)\n\n return acc, l_i, m_i\n\n@triton.heuristics({\n \"M_LT_N\":\n lambda kwargs: kwargs[\"BLOCK_M\"] < kwargs[\"BLOCK_N\"],\n})\n@triton.jit\ndef _fwd_kernel_batch_inference(\n Q,\n K,\n V,\n Out,\n sm_scale,\n q_batch_starts,\n q_batch_ends,\n k_batch_starts,\n k_batch_ends,\n q_batch_ids,\n q_start_sids,\n stride_qb,\n stride_qt,\n stride_qh,\n stride_qd,\n stride_kb,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_vb,\n stride_vt,\n stride_vh,\n stride_vd,\n stride_ob,\n stride_ot,\n stride_oh,\n stride_od,\n layout_crow_ptr,\n layout_col_ptr,\n layout_crow_stride_h,\n layout_crow_stride_m,\n layout_col_stride_h,\n layout_col_stride_m,\n q_k_ratio,\n HAS_BATCH_DIM: tl.constexpr,\n D_HEAD: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n off_zm = tl.program_id(0)\n off_h = tl.program_id(1)\n\n off_h_for_kv = off_h // q_k_ratio\n\n if HAS_BATCH_DIM:\n off_z = tl.program_id(2)\n Q += off_z * stride_qb\n K += off_z * stride_kb\n V += off_z * stride_vb\n Out += off_z * stride_ob\n start_m = off_zm\n q_start_sid = start_m * BLOCK_M # always 0 for decoding\n else:\n off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)\n q_start_sid = tl.load(q_start_sids + off_zm)\n start_m = q_start_sid // BLOCK_M\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_D)\n\n q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)\n q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start\n k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)\n k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start\n past_len = k_seqlen - q_seqlen\n\n Q += q_cu_start * stride_qt + off_h * stride_qh\n K += k_cu_start * stride_kt + off_h_for_kv * stride_kh\n V += k_cu_start * stride_vt + off_h_for_kv * stride_vh\n Out += q_cu_start * stride_ot + off_h * stride_oh\n\n q_pbid = (past_len + q_start_sid) // BLOCK_M\n\n if EVEN_D:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n other=0,\n )\n\n sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +\n q_pbid * layout_crow_stride_m)\n\n k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)\n k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)\n\n m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)\n\n k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd\n v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd\n\n sm_scale *= (\n 1.44269504 # 1/log2 as we use base2 for exponential and logarithm\n )\n\n for k_block_col_idx in range(k_block_start, k_block_end - 1):\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n False,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_end - 1,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n True,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n\n if EVEN_D:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n )\n", - "description_1": "Use triton language to implement a blocksparse Flash Attention forward pass kernel and its supporting functions for batch inference. The kernel has multiple parameters such as query, key, value tensors (Q, K, V), scaling factors, and metadata for sparse layout. The kernel handles irregular sequence lengths and optimizes for specific batch dimensions and data layout.", - "description_2": "Use triton language to create a forward kernel for sparse batch Flash Attention, handling variable sequence lengths and leveraging hardware optimizations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom vllm.platforms import current_platform\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n k_scale,\n v_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SLIDING_WINDOW: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @triton.jit\n def _fwd_kernel_flash_attn_v2(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n k_scale,\n v_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n kv_cache_dtype: str,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n k_scale: float = 1.0,\n v_scale: float = 1.0,\n alibi_slopes=None,\n sliding_window=None):\n # Function implementation\n pass\n", - "description_1": "Use triton language to implement three kernels: _fwd_kernel, _fwd_kernel_flash_attn_v2, and _fwd_kernel_alibi, each with specific parameters for handling query, key, value tensors, cache, and other configurations. The context_attention_fwd function orchestrates these kernels based on input parameters, including optional alibi slopes and sliding window configurations.", - "description_2": "Use triton language to implement kernels for forward attention computation with optional alibi and sliding window support, and a function to manage these kernels based on input configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef cdiv_fn(x, y):\n return (x + y - 1) // y\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n if first and second:\n tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)\n elif first:\n tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)\n elif second:\n tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)\n else:\n tensor = tl.load(block_ptr)\n return tensor\n\n@triton.jit\ndef _attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, dropout_p,\n philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max,\n offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr,\n IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, \n PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, \n RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr\n):\n for start_n in range(block_min, block_max, BLOCK_N):\n k = load_fn(\n K_block_ptr,\n PADDED_HEAD,\n MASK_STEPS and (n_extra_tokens != 0),\n \"zero\",\n )\n if PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MASK_STEPS:\n if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)\n size_n = start_n + OFFS_N[None, :]\n mask = size_n < boundary_m[:, None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if IS_CAUSAL:\n causal_boundary = start_n + offs_n_causal\n causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n qk = tl.where(causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n if bias_ptr is not None:\n bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), \"zero\")\n qk += bias * 1.44269504089\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = (batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N)\n keep = tl.rand(philox_seed, philox_offset) > dropout_p\n if RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty))\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty))\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"waves_per_eu\": 2, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=8),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"waves_per_eu\": 2, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=4),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"waves_per_eu\": 2, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=8),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"waves_per_eu\": 1, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"waves_per_eu\": 3, \"PRE_LOAD_V\": True}, num_stages=1, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"waves_per_eu\": 3, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64, \"waves_per_eu\": 4, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=8),\n triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 32, \"waves_per_eu\": 4, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=8),\n triton.Config({\"BLOCK_M\": 16, \"BLOCK_N\": 16, \"waves_per_eu\": 1, \"PRE_LOAD_V\": False}, num_stages=1, num_warps=4),\n ],\n key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],\n)\n@triton.jit\ndef attn_fwd(\n Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn,\n cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax,\n HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,\n MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr\n):\n start_m = tl.program_id(0)\n off_h_q = tl.program_id(1)\n off_z = tl.program_id(2)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n if VARLEN:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M > seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n else:\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = MAX_SEQLENS_Q\n seqlen_k = MAX_SEQLENS_K\n\n n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n if IS_CAUSAL:\n n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)\n n_blocks = min(n_blocks, n_blocks_seqlen)\n if n_blocks <= 0:\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n return\n\n GROUP_SIZE: tl.constexpr = HQ // HK\n off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q\n\n n_extra_tokens = 0\n if seqlen_k < BLOCK_N:\n n_extra_tokens = BLOCK_N - seqlen_k\n elif seqlen_k % BLOCK_N:\n n_extra_tokens = seqlen_k % BLOCK_N\n padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL\n\n q_offset = (off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm)\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_offset = (off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn)\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n v_offset = (off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk)\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n if BIAS_TYPE != 0:\n bias_ptr = tl.make_block_ptr(\n base=bias + off_h_q * stride_bh,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_bm, stride_bn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n bias_ptr = None\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n encoded_softmax_block_ptr = 0\n m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504089\n q = load_fn(Q_block_ptr, True, padded_head, \"zero\")\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n padded_block_k = n_extra_tokens != 0\n is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n if IS_CAUSAL:\n masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n else:\n masked_blocks = padded_block_k\n masked_blocks = min(masked_blocks, n_blocks)\n n_full_blocks = n_blocks - masked_blocks\n block_min = 0\n block_max = n_blocks * BLOCK_N\n if n_full_blocks > 0:\n block_max = (n_blocks - masked_blocks) * BLOCK_N\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n 0,\n 0,\n 0,\n bias_ptr,\n False,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n False,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n block_min = block_max\n block_max = n_blocks * BLOCK_N\n\n tl.debug_barrier()\n if masked_blocks > 0:\n offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0\n K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks))\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n True,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n end_m_idx = (start_m + 1) * BLOCK_M\n start_m_idx = start_m * BLOCK_M\n causal_start_idx = seqlen_q - seqlen_k\n acc = acc.to(Out.type.element_ty)\n if IS_CAUSAL:\n if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32)\n mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n out_ptrs_mask = (mask_m_offsets[:, None] >= out_mask_boundary[None, :])\n z = 0.0\n acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n tl.store(O_block_ptr, acc, boundary_check=(0, 1))\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,\n causal=False, sm_scale=1.0, bias=None,\n ):\n if o is None:\n o = torch.empty_like(q, dtype=v.dtype)\n\n total_q, nheads_q, head_size = q.shape\n total_k, nheads_k, _ = k.shape\n batch = len(cu_seqlens_q) - 1\n q_strides = (0, q.stride(1), q.stride(0), q.stride(2))\n k_strides = (0, k.stride(1), k.stride(0), k.stride(2))\n v_strides = (0, v.stride(1), v.stride(0), v.stride(2))\n o_strides = (0, o.stride(1), o.stride(0), o.stride(2))\n\n unpadded_head_dims = {32, 64, 128, 256}\n if head_size not in unpadded_head_dims:\n padded_d_model = None\n for i in unpadded_head_dims:\n if i > head_size:\n padded_d_model = i\n break\n assert padded_d_model is not None\n else:\n padded_d_model = head_size\n\n grid = lambda META: (\n triton.cdiv(max_seqlens_q, META[\"BLOCK_M\"]),\n nheads_q,\n batch,\n )\n\n encoded_softmax = None\n\n philox_seed = 0x1BF52\n philox_offset = 0x1D4B42\n\n if bias is not None:\n bias_strides = (\n bias.stride(0),\n bias.stride(1),\n bias.stride(2),\n bias.stride(3),\n )\n else:\n bias_strides = (0, 0, 0, 0)\n\n attn_fwd[grid](\n q,\n k,\n v,\n bias,\n sm_scale,\n None,\n o,\n *q_strides,\n *k_strides,\n *v_strides,\n *o_strides,\n *bias_strides,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p=0.0,\n philox_seed=philox_seed,\n philox_offset_base=philox_offset,\n encoded_softmax=encoded_softmax,\n HQ=nheads_q,\n HK=nheads_k,\n ACTUAL_BLOCK_DMODEL=head_size,\n MAX_SEQLENS_Q=max_seqlens_q,\n MAX_SEQLENS_K=max_seqlens_k,\n IS_CAUSAL=causal,\n VARLEN=True,\n BLOCK_DMODEL=padded_d_model,\n BIAS_TYPE=0 if bias is None else 1,\n ENABLE_DROPOUT=False,\n RETURN_ENCODED_SOFTMAX=False,\n )\n\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = head_size\n ctx.causal = causal\n ctx.dropout_p = 0.0\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n ctx.encoded_softmax = encoded_softmax\n ctx.return_encoded_softmax = False\n return o, encoded_softmax\n\n\ntriton_attention = _attention.apply\n", - "description_1": "Use triton language to implement an attention forward function, where `attn_fwd` is a kernel handling matrix multiplications and operations for scaled dot-product attention with optional bias and dropout, and `_attention` is an autograd function for the forward pass in PyTorch. The kernel `attn_fwd` takes 43 parameters for operations like loading data, applying masks, scaling, and storing results, while `_attention` manages PyTorch tensors and sets up kernel configurations.", - "description_2": "Use triton language to define an attention mechanism kernel (`attn_fwd`) and integrate it into PyTorch's autograd system via a wrapper function (`_attention`).", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test that we can convert uniform to exponential without div by 0.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel that converts uniform random numbers to exponential random numbers. The kernel '_uniform_to_exponential_kernel' takes three parameters: 'input' (a pointer to the input tensor), 'output' (a pointer to the output tensor), and 'n' (a compile-time constant representing the number of elements to process). The kernel uses Triton's parallel programming model to load elements from the input tensor, apply the '_uniform_to_exponential' transformation, and store the results in the output tensor. The function 'test_uniform_to_exponential' is a test function that verifies the kernel's correctness by checking that the output values are finite and greater than zero.", - "description_2": "Use triton language to create a kernel for transforming uniform to exponential distribution and verify its correctness with a test.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _bgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n scaling,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's\n performance\n \"\"\"\n pid_sk = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n\n offset_n = tl.arange(0, BLOCK_N)\n offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K\n a_ptr = input_ptr + cur_batch * xm_stride\n b_ptr = lora_ptr + l0_stride * lora_index\n accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)\n for k in range(0, K, BLOCK_K * SPLIT_K):\n current_k = k + offset_k\n current_k_c = tl.max_contiguous(current_k, BLOCK_K)\n tiled_a = tl.load(\n a_ptr + current_k_c,\n mask=current_k < K,\n other=0.0,\n ) # [BLOCK_K]\n b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)\n\n tiled_b = tl.load(\n b_ptr + offset_n[:, None] * lora_k_stride +\n current_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n ) # [BLOCK_N,BLOCK_K]\n\n accumulator += tl.sum(tiled_a * tiled_b, 1)\n accumulator *= scaling\n offset_cn = tl.arange(0, BLOCK_N)\n c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride\n c_mask = offset_cn < N\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _bgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n scaling: float = 1.0,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n batches = lora_indices_tensor.size(0)\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_N = triton.next_power_of_2(N)\n # First try to load optimal config from the file\n config = get_lora_op_configs(\"bgmv_shrink\", batches, K)\n\n grid = lambda META: (\n META[\"SPLIT_K\"],\n batches,\n )\n _bgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_N=BLOCK_N,\n **config,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_bgmv_shrink_kernel' with 16 parameters for performing a batched generalized matrix-vector multiplication (GroupGEMV) with optional LoRA (Low-Rank Adaptation) weights. The kernel uses a split-K strategy to improve performance for large hidden sizes. The function '_bgmv_shrink' is a wrapper that prepares the input tensors and launches the Triton kernel with the appropriate grid configuration.", - "description_2": "Use triton language to implement a GroupGEMV kernel with split-K optimization and a wrapper function to handle input preparation and kernel launch.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride, \n l0_stride, \n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n The sgmv's expand triton kernel is based on GroupGEMM.\n \"\"\"\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n add_inputs: bool = False,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4, 10].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n add_inputs (bool, optional): Defaults to False. adds the final lora \n results to the output.\n \"\"\"\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n", - "description_1": "Use triton language to define a kernel '_sgmv_expand_kernel' with 22 parameters for sequence processing in GroupGEMM and a wrapper function '_sgmv_expand' with 9 parameters to manage input tensors and perform operations using this kernel.", - "description_2": "Use triton language to create a kernel for sequence-based matrix operations and use it through a wrapper function to handle tensor inputs efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n Similar to the 'sgmv_expand' operator, but with an added parameter \n 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator \n might be that in the future, we could implement a fusion operator to \n achieve the current functionality instead of having to call it multiple \n times.\n \"\"\"\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <\n (slice_offset + N))\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n) -> None:\n \"\"\"_summary_\n\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4, 10].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n slice_offst (int): output_tensor's offst\n slice_size (int): current output_tensor's size\n add_inputs (bool, optional): Defaults to False. adds the final lora \n results to the output..\n \"\"\"\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert slice_size == lora_b_weights.size(-2)\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n # TODO tuning this config\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_sgmv_expand_slice_kernel' with 23 parameters for matrix operations with LoRA weights, and a wrapper function '_sgmv_expand_slice' with 11 parameters to prepare and launch the kernel.", - "description_2": "Use triton language to create a kernel for matrix operations with LoRA weights and a wrapper to manage inputs and launch the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n scaling,\n xm_stride, # hidden_size\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.\n The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,\n introducing SPLIT-K can improve performance\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sk = tl.program_id(axis=1)\n cur_batch = tl.program_id(axis=2)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)\n\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride)\n b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +\n offset_k[:, None] * lora_n_stride)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < k_remaining,\n other=0.0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < k_remaining,\n other=0.0)\n accumulator += tl.dot(tiled_a, tiled_b)\n\n a_ptr += BLOCK_K * SPLIT_K * xk_stride\n b_ptr += BLOCK_K * SPLIT_K * lora_n_stride\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n accumulator *= scaling\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n scaling: float,\n) -> None:\n \"\"\"\n\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_M = 32\n BLOCK_N = 16\n BLOCK_K = 32\n SPLIT_K = 8\n EVEN_K = K % (BLOCK_K * SPLIT_K) == 0\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n SPLIT_K,\n batches,\n )\n\n _sgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n SPLIT_K,\n )\n return\n\n", - "description_1": "Use triton language to implement a kernel '_sgmv_shrink_kernel' with 24 parameters, performing a group matrix multiplication with split-K optimization for multiple LoRA weights. The function '_sgmv_shrink' with 9 parameters prepares and invokes this kernel for tensor operations using Triton's grid strategy.", - "description_2": "Use triton language to implement a matrix multiplication kernel for handling multiple LoRA weights with split-K optimization and provide a calling function to execute it efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional, Dict, Any, Tuple\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr,\n sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk,\n stride_bn, stride_cm, stride_cn, stride_bse, stride_bsn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,\n compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr,\n use_int8_w8a16: tl.constexpr):\n \"\"\"\n Implements the fused computation for a Mixture of Experts (MOE) using\n token and expert matrices.\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n if use_int8_w8a16:\n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[\n None, :] * stride_bsn\n b_scale = tl.load(b_scale_ptrs)\n\n if use_fp8_w8a8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n if use_int8_w8a16:\n accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)\n elif use_fp8_w8a8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n if use_int8_w8a16:\n accumulator = (accumulator * b_scale).to(compute_type)\n elif use_fp8_w8a8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any], compute_type: tl.dtype,\n use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if use_fp8_w8a8:\n A, A_scale = ops.scaled_fp8_quant(A, A_scale)\n assert B_scale is not None\n elif use_int8_w8a16:\n assert B_scale is not None\n else:\n assert A_scale is None\n assert B_scale is None\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,\n B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8_w8a8=use_fp8_w8a8,\n use_int8_w8a16=use_int8_w8a16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel takes pointers to input matrices, scales, and other parameters to perform block matrix multiplication. It computes the product of a token matrix and an expert matrix, using parameters like block sizes and compute types. The kernel is invoked with a function that sets up the grid and passes the necessary parameters.", - "description_2": "Use triton language to create a kernel for block matrix multiplication in a Mixture of Experts model, and implement a function to invoke this kernel with appropriate parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n \"\"\"Similar to torch.rand, but allows for seeds to be set per row.\"\"\"\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output\n tensor. The random numbers in a row generated using the seed for that row.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a seeded uniform random number generator. The function `seeded_uniform` takes parameters: size (dimensions of the output tensor), seeds (1D tensor for per-row seeds), out (optional output tensor), dtype (optional data type), device (optional device), and pin_memory (optional boolean for pinned memory). It calculates the necessary strides and block sizes, then calls the Triton kernel `_seeded_uniform_triton`. The kernel `_seeded_uniform_triton` takes parameters: out_ptr (output tensor), seed_ptr (seed tensor), out_row_stride (stride between rows), out_3d_stride (stride between 3D slices), seed_row_stride (stride between seed rows), n_rows (number of rows), n_3d (size of second dimension if 3D), n_cols (number of columns), n_slices (number of philox outputs), and block_size (block size for random number generation). It generates random numbers using the philox PRNG and stores them in the output tensor.", - "description_2": "Use triton language to create a random number generator that generates random numbers for each element in a tensor using per-row seeds. The generator should handle up to 3D tensors and use the philox PRNG for efficient random number generation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS: tl.constexpr = 1e-6\n\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n # The rows are independent, so we parallelize across those\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n\n # Load the row index from DRAM\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n\n # The stride represents how much we need to increase the\n # pointer to advance 1 row\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n\n # The block size is the next power of two greater than n_cols,\n # so we can fit each row in a single block\n col_offsets = tl.arange(0, block_size)\n\n # Load the row into SRAM, using a mask since block_size may be > than n_cols\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n # clamp sampled token to n_cols - 1\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n # Write back output to DRAM\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n\n if modify_greedy_probs:\n if not uses_random_sampling:\n # Set the probability of the sampled token to 1, all other\n # tokens to zero.\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to define two kernels: _uniform_to_exponential, which converts uniform noise to exponential noise, and _sample_triton, which samples tokens from a probability matrix considering uniform noise and optional modifications to greedy probabilities, storing results in output tensors.", - "description_2": "Use triton language to create kernels for noise conversion and token sampling from a probability distribution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Define a custom tanh function using Triton\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n# Stage 1 of forward kernel\n@triton.jit\ndef _fwd_kernel_stage1(\n Q,\n K_Buffer,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n Att_Out,\n stride_req_to_tokens_b,\n stride_qbs,\n stride_qh,\n stride_buf_kbs,\n stride_buf_kh,\n att_stride_h,\n kv_group_num: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n logit_cap: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n cur_batch_start_index = 0\n cur_batch_end_index = cur_batch_seq_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n offs_buf_k = (\n k_loc[:, None] * stride_buf_kbs\n + cur_kv_head * stride_buf_kh\n + offs_d[None, :]\n )\n k = tl.load(\n K_Buffer + offs_buf_k,\n mask=offs_n_new[:, None] < cur_batch_end_index,\n other=0.0,\n ).to(REDUCE_TRITON_TYPE)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n\n if logit_cap > 0:\n att_value = logit_cap * tanh(att_value / logit_cap)\n\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n\n# Stage 2 of forward kernel\n@triton.jit\ndef _fwd_kernel_stage2(\n Logics,\n V_Buffer,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n stride_logic_h,\n stride_buf_vbs,\n stride_buf_vh,\n stride_obs,\n stride_oh,\n stride_req_to_token_b,\n kv_group_num: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]\n v_ptrs = V_Buffer + offs_buf_v\n\n e_max = float(\"-inf\")\n e_sum = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n v_index = tl.load(\n Req_to_tokens\n + cur_batch_req_idx * stride_req_to_token_b\n + (start_n + offs_n),\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0,\n )\n\n qk = tl.load(\n Logics\n + cur_head * stride_logic_h\n + (cur_batch_start_loc + start_n + offs_n),\n mask=start_n + offs_n < cur_batch_seq_len,\n other=float(\"-inf\"),\n )\n\n n_e_max = tl.maximum(tl.max(qk, 0), e_max)\n old_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max)\n e_sum = e_sum * old_scale + tl.sum(p, 0)\n v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)\n acc = acc * old_scale + tl.sum(p[:, None] * v, 0)\n e_max = n_e_max\n\n acc = acc / e_sum\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n# Function to perform forward attention\ndef decode_attention_fwd(\n q,\n k_buffer,\n v_buffer,\n o,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n max_len_in_batch,\n total_num_tokens,\n sm_scale,\n logit_cap=-1,\n att_m=None,\n):\n if att_m is None:\n att_m = torch.empty(\n (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device=\"cuda\"\n )\n\n kv_group_num = q.shape[1] // v_buffer.shape[1]\n\n if kv_group_num == 1:\n # MHA\n _decode_att_m_fwd(\n q,\n k_buffer,\n att_m,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n max_len_in_batch,\n sm_scale,\n logit_cap,\n )\n _decode_softmax_reducev_fwd(\n att_m,\n v_buffer,\n o,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n )\n else:\n # GQA/MQA/MLA\n _decode_grouped_att_m_fwd(\n q,\n k_buffer,\n att_m,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n max_len_in_batch,\n sm_scale,\n logit_cap,\n )\n _decode_grouped_softmax_reducev_fwd(\n att_m,\n v_buffer,\n o,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n )\n", - "description_1": "Use triton language to implement an efficient memory attention mechanism for decoding. It includes three kernel functions: tanh function, forward kernel stage 1, and forward kernel stage 2. Tanh function has one parameter x (the input to the function). The forward kernel stage 1 has 15 parameters: Q (query), K_Buffer (key buffer), sm_scale (scaling factor), Req_to_tokens (request to token mapping), B_req_idx (request index), B_Start_Loc (start location), B_Seqlen (sequence length), Att_Out (attention output), stride_req_to_tokens_b (stride for request to token), stride_qbs (stride for query), stride_qh (stride for head), stride_buf_kbs (stride for buffer key), stride_buf_kh (stride for buffer head), att_stride_h (stride for attention), kv_group_num (group number, constant), BLOCK_DMODEL (block model size, constant), BLOCK_N (block size, constant), logit_cap (logit capacity, constant). The forward kernel stage 2 has 13 parameters: Logics (logical operation result), V_Buffer (value buffer), Out (output), Req_to_tokens (request to token mapping), B_req_idx (request index), B_Start_Loc (start location), B_Seqlen (sequence length), stride_logic_h (stride for logic), stride_buf_vbs (stride for value buffer), stride_buf_vh (stride for value head), stride_obs (stride for output), stride_oh (stride for output head), kv_group_num (group number, constant), BLOCK_DMODEL (block model size, constant), BLOCK_N (block size, constant). A decode_attention_fwd function is also implemented to drive these kernels with necessary parameters.", - "description_2": "Use triton language to create efficient memory attention kernels for decoding tasks, with forward processing handled in two stages.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef _fwd_kernel(\n Q_Extend,\n K_Extend,\n V_Extend,\n O_Extend,\n K_Buffer,\n V_Buffer,\n Req_to_tokens,\n B_req_idx,\n B_Seq_Len,\n B_Start_Loc_Extend,\n B_Seq_Len_Extend,\n sm_scale,\n kv_group_num,\n stride_qbs,\n stride_qh,\n stride_kbs,\n stride_kh,\n stride_vbs,\n stride_vh,\n stride_obs,\n stride_oh,\n stride_buf_kbs,\n stride_buf_kh,\n stride_buf_vbs,\n stride_buf_vh,\n stride_req_to_tokens_b,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DPE: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n logit_cap: tl.constexpr,\n):\n cur_seq = tl.program_id(0)\n cur_head = tl.program_id(1)\n cur_block_m = tl.program_id(2)\n cur_kv_head = cur_head // kv_group_num\n\n cur_seq_len = tl.load(B_Seq_Len + cur_seq)\n cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)\n cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend\n\n cur_seq_prefix_start_in_loc = 0\n cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)\n cur_batch_req_idx = tl.load(B_req_idx + cur_seq)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_dv = tl.arange(0, BLOCK_DV)\n offs_m = tl.arange(0, BLOCK_M)\n mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend\n\n offs_q = (\n (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :]\n )\n q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)\n\n if BLOCK_DPE > 0:\n offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\n offs_qpe = (\n (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n * stride_qbs\n + cur_head * stride_qh\n + offs_dpe[None, :]\n )\n qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)\n\n offs_n = tl.arange(0, BLOCK_N)\n\n acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)\n deno = tl.zeros([BLOCK_M], dtype=tl.float32)\n e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n\n for start_n in range(0, cur_seq_len_prefix, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n mask_n = (start_n + offs_n) < cur_seq_len_prefix\n offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (\n cur_seq_prefix_start_in_loc + start_n + offs_n\n )\n offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)\n\n offs_buf_k = (\n offs_kv_loc[None, :] * stride_buf_kbs\n + cur_kv_head * stride_buf_kh\n + offs_d[:, None]\n )\n k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n if BLOCK_DPE > 0:\n offs_kpe = (\n offs_kv_loc[None, :] * stride_buf_kbs\n + cur_kv_head * stride_buf_kh\n + offs_dpe[:, None]\n )\n kpe = tl.load(\n K_Buffer + offs_kpe,\n mask=mask_n[None, :],\n other=0.0,\n )\n qk += tl.dot(qpe, kpe)\n qk *= sm_scale\n\n if logit_cap > 0:\n qk = logit_cap * tanh(qk / logit_cap)\n\n qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float(\"-inf\"))\n\n n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n re_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max[:, None])\n deno = deno * re_scale + tl.sum(p, 1)\n\n offs_buf_v = (\n offs_kv_loc[:, None] * stride_buf_vbs\n + cur_kv_head * stride_buf_vh\n + offs_dv[None, :]\n )\n v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)\n p = p.to(v.dtype)\n acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n e_max = n_e_max\n\n cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)\n for start_n in range(0, cur_block_m_end, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n mask_n = (start_n + offs_n) < cur_block_m_end\n\n offs_k = (\n (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs\n + cur_kv_head * stride_kh\n + offs_d[:, None]\n )\n k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n\n if BLOCK_DPE > 0:\n offs_kpe = (\n (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])\n * stride_kbs\n + cur_kv_head * stride_kh\n + offs_dpe[:, None]\n )\n kpe = tl.load(\n K_Extend + offs_kpe,\n mask=mask_n[None, :],\n other=0.0,\n )\n qk += tl.dot(qpe, kpe)\n\n qk *= sm_scale\n\n if logit_cap > 0:\n qk = logit_cap * tanh(qk / logit_cap)\n\n mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (\n start_n + offs_n[None, :]\n )\n mask_causual &= mask_m[:, None] & mask_n[None, :]\n qk = tl.where(mask_causual, qk, float(\"-inf\"))\n\n n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n re_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max[:, None])\n deno = deno * re_scale + tl.sum(p, 1)\n\n offs_v = (\n (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs\n + cur_kv_head * stride_vh\n + offs_dv[None, :]\n )\n v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)\n p = p.to(v.dtype)\n acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n e_max = n_e_max\n\n offs_o = (\n (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n * stride_obs\n + cur_head * stride_oh\n + offs_dv[None, :]\n )\n tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])\n\n\ndef extend_attention_fwd(\n q_extend,\n k_extend,\n v_extend,\n o_extend,\n k_buffer,\n v_buffer,\n req_to_tokens,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n b_seq_len_prefix,\n b_start_loc_extend,\n b_seq_len_extend,\n max_len_in_batch,\n max_len_extend,\n sm_scale=None,\n logit_cap=-1,\n):\n \"\"\"\n q_extend, k_extend, v_extend, o_extend: contiguous tensors\n\n k_buffer, v_buffer: (prefix + extend) tensors in mem_manager\n \"\"\"\n Lq, Lk, Lv, Lo = (\n q_extend.shape[-1],\n k_extend.shape[-1],\n v_extend.shape[-1],\n o_extend.shape[-1],\n )\n\n assert Lq == Lk and Lv == Lo\n assert Lq in {16, 32, 64, 128, 256, 576}\n assert Lv in {16, 32, 64, 128, 256, 512}\n\n if Lq == 576:\n BLOCK_DMODEL = 512\n BLOCK_DPE = 64\n else:\n BLOCK_DMODEL = Lq\n BLOCK_DPE = 0\n BLOCK_DV = Lv\n\n if CUDA_CAPABILITY[0] >= 9:\n BLOCK_M, BLOCK_N = (128, 64)\n elif CUDA_CAPABILITY[0] >= 8:\n BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)\n else:\n BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)\n\n sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale\n batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]\n kv_group_num = q_extend.shape[1] // k_extend.shape[1]\n\n grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))\n num_warps = 4 if Lk <= 64 else 8\n num_stages = 1\n\n _fwd_kernel[grid](\n q_extend,\n k_extend,\n v_extend,\n o_extend,\n k_buffer,\n v_buffer,\n req_to_tokens,\n b_req_idx,\n b_seq_len,\n b_start_loc_extend,\n b_seq_len_extend,\n sm_scale,\n kv_group_num,\n q_extend.stride(0),\n q_extend.stride(1),\n k_extend.stride(0),\n k_extend.stride(1),\n v_extend.stride(0),\n v_extend.stride(1),\n o_extend.stride(0),\n o_extend.stride(1),\n k_buffer.stride(0),\n k_buffer.stride(1),\n v_buffer.stride(0),\n v_buffer.stride(1),\n req_to_tokens.stride(0),\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_DPE=BLOCK_DPE,\n BLOCK_DV=BLOCK_DV,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=num_stages,\n logit_cap=logit_cap,\n )\n\n", - "description_1": "Use triton language to implement an efficient attention mechanism for forward pass with parameterized query (Q), key (K), and value (V) tensors, where BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV are the dimensions of the model, positional encoding, and value vectors respectively. It also involves computation of scores with prefix and triangle parts and uses customized tensor slicing and accumulation logic to optimize memory and computation.", - "description_2": "Use triton language to create a custom forward attention kernel with tunable dimensions and perform tensor operations for efficient memory management during the calculation of attention scores and results.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Any, Dict, Optional, Tuple\n\n@triton.jit\ndef fused_moe_kernel(\n # Pointers to matrices\n a_ptr,\n b_ptr,\n c_ptr,\n a_scale_ptr,\n b_scale_ptr,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n # Matrix dimensions\n N,\n K,\n EM,\n num_valid_tokens,\n # The stride variables\n stride_am,\n stride_ak,\n stride_be,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n compute_type: tl.constexpr,\n use_fp8: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak\n )\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = (\n b_ptr\n + off_experts * stride_be\n + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n )\n\n if use_fp8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(\n a_ptrs,\n mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0,\n )\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n if use_fp8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n if use_fp8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(\n A: torch.Tensor,\n B: torch.Tensor,\n C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor,\n topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool,\n top_k: int,\n config: Dict[str, Any],\n compute_type: tl.dtype,\n use_fp8: bool,\n) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if not use_fp8:\n assert A_scale is None\n assert B_scale is None\n else:\n A, A_scale = ops.scaled_fp8_quant(A, A_scale)\n assert B_scale is not None\n\n grid = lambda META: (\n triton.cdiv(sorted_token_ids.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(B.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8=use_fp8,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel, where the kernel, decorated with @triton.jit, takes pointers to input tensors A, B, C, and associated parameters like strides and dimensions, to perform block matrix multiplication based on expert assignment. This implementation efficiently handles expert-specific multiplication using token IDs, scaling factors, and padding to ensure block alignment, with additional support for fp8 arithmetic if specified. The 'invoke_fused_moe_kernel' function sets up the execution grid and manages the preparation and invocation of this kernel using TensorFlow (tl) constants and meta-parameters.", - "description_2": "Use triton language to create and execute a fused Mixture of Experts kernel that processes tokens and experts efficiently, supporting specialized arithmetic and block alignment for matrix operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n Out,\n stride_qbs,\n stride_qh,\n stride_kbs,\n stride_kh,\n stride_vbs,\n stride_vh,\n stride_obs,\n stride_oh,\n kv_group_num: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :]\n )\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(\n k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,\n other=0.0,\n )\n # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(\n v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,\n other=0.0,\n )\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :]\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n\ndef context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n if CUDA_CAPABILITY[0] >= 8:\n BLOCK = 128\n else:\n BLOCK = 64\n\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128, 256}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n q.stride(0),\n q.stride(1),\n k.stride(0),\n k.stride(1),\n v.stride(0),\n v.stride(1),\n o.stride(0),\n o.stride(1),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n", - "description_1": "Use triton language to implement a forward kernel for memory-efficient attention. The kernel takes 18 parameters: Q, K, V (query, key, value tensors), sm_scale (scale for softmax), B_Start_Loc, B_Seqlen (batch start location and sequence length), Out (output tensor), stride_qbs, stride_qh, stride_kbs, stride_kh, stride_vbs, stride_vh, stride_obs, stride_oh (strides for accessing tensors), kv_group_num (number of key-value groups), BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes for matrix operations). The kernel computes the attention scores and updates the output tensor.", - "description_2": "Use triton language to implement a function 'context_attention_fwd' that sets up and launches the forward kernel for attention. It takes 7 parameters: q, k, v, o (query, key, value, output tensors), b_start_loc, b_seq_len (batch start location and sequence length), max_input_len (maximum input length). The function determines block size based on CUDA capability, calculates softmax scale, sets up grid and launches the kernel with appropriate parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef attention_fwd_kernel(\n q, k, v, h, o, s_qh, s_qt, s_qd, s_hh, s_ht, T, scale,\n BT: tl.constexpr, BD: tl.constexpr, NT: tl.constexpr, \n STORE: tl.constexpr, IFCOND: tl.constexpr\n):\n i_bh = tl.program_id(0)\n # [BD, BD]\n b_h = tl.zeros([BD, BD], dtype=tl.float32)\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_hh, (NT * BD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))\n\n if STORE:\n tl.store(p_h, b_h.to(p_h.dtype.element_ty))\n # [BT, BD]\n b_q = tl.load(p_q)\n b_q = (b_q * scale).to(b_q.dtype)\n # [BD, BT]\n b_k = tl.load(p_k)\n # [BT, BD]\n b_v = tl.load(p_v)\n\n # [BT, BT]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n # [BT, BD]\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if IFCOND:\n if i == 0:\n b_h = tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty))\n\n\nclass AttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, store=False, ifcond=False):\n batch_size, n_heads, seq_len, d_head = q.shape\n scale = d_head ** -0.5\n BD = q.shape[-1]\n BT = 32\n NT = triton.cdiv(seq_len, BT)\n num_stages = 3 if d_head <= 64 else 2\n num_warps = 4\n\n h = q.new_empty(batch_size, n_heads, NT * BD, BD)\n o = torch.empty_like(q)\n grid = (batch_size * n_heads,)\n attention_fwd_kernel[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),\n seq_len, scale,\n BT=BT, BD=BD, NT=NT, STORE=store, IFCOND=ifcond,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return o\n\n\nif __name__ == '__main__':\n B, H, T, D = 2, 8, 1024, 128\n dtype = torch.float\n torch.manual_seed(42)\n # [batch_size, n_heads, seq_len, d_head]\n q = torch.randn((B, H, T, D), dtype=dtype, device='cuda')\n k = torch.randn((B, H, T, D), dtype=dtype, device='cuda')\n v = torch.randn((B, H, T, D), dtype=dtype, device='cuda')\n\n ref = AttentionFunction.apply(q, k, v)\n print(\"DTYPE\\t\\tSTORE\\tIFCOND\\tDIFF\")\n for dtype in (torch.float, torch.bfloat16):\n q, k, v = q.clone().to(dtype), k.clone().to(dtype), v.clone().to(dtype)\n for store in [False, True]:\n for ifcond in [False, True]:\n tri = AttentionFunction.apply(q, k, v, store, ifcond)\n print(f\"{q.dtype}\\t{store}\\t{ifcond}\\t{(ref - tri).abs().max()}\")\n", - "description_1": "Use triton language to implement an attention forward kernel. The kernel takes 15 parameters: q, k, v, h, o (tensors for queries, keys, values, intermediate results, and output), s_qh, s_qt, s_qd, s_hh, s_ht (stride values), T (sequence length), scale (scaling factor), and four compile-time constants BT, BD, NT, STORE, IFCOND. It computes scaled dot-product attention and stores results to the output tensor.", - "description_2": "Use triton language to create an optimized attention forward function with parameters for input tensors, strides, sequence length, scaling factor, and conditional computation toggles.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n O, # pointer to the gate\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n O += row * stride_x_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n\n # Swish output gate\n o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)\n y = y * o * tl.sigmoid(o)\n \n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n o,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a forward pass kernel for layer normalization with optional residuals and bias. The kernel takes 20 parameters: pointers to input, gate, output, weights, biases, residuals, mean, and 1/std, strides for input, output, and residuals, number of columns, epsilon for numerical stability, and several compile-time constants for configuration.", - "description_2": "Use triton language to implement a forward pass function for layer normalization that prepares data and calls the kernel. The function takes 9 parameters: input, gate, weights, biases, epsilon, optional residuals, output data type, residual data type, and a flag for RMS normalization.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k,\n v,\n z,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n NORMK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n b_zp = tl.zeros([BK if NORMK else BV], dtype=tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if NORMK:\n p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n if i_t == 0:\n b_zp = b_zc\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[:, None]\n b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype)\n else:\n p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n if i_t == 0:\n b_zp = b_zc\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[None, :]\n b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[torch.Tensor] = None,\n output_final_state: Optional[bool] = False\n) -> torch.Tensor:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if output_final_state:\n output_final_state = False\n warnings.warn(\"output_final_state is not supported in ABC, setting it to `False`.\")\n ov, _ = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)\n return ov\n", - "description_1": "Use triton language to define multiple kernel functions for forward and backward operations in an advanced chunk-based computation. Each kernel accepts specific tensor inputs, strides, constants, and performs mathematical operations such as matrix multiplication and element-wise operations. These operations facilitate efficient parallel computing using triton.", - "description_2": "Use triton language to define optimized kernel functions for parallel matrix computations on GPUs, handling input tensors with specific memory strides and performing complex operations like matrix multiplication in a block-wise manner.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport warnings\nfrom typing import Optional\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr,\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1))\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n NORMK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if NORMK:\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n # [BK,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BK, BV]\n b_h *= tl.exp(b_gn)[:, None]\n # [BK, BT]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n else:\n p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n # [BV,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BK, BV]\n b_h *= tl.exp(b_gn)[None, :]\n # [BT, BV]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_v = (b_v * tl.exp(b_gn[None, :] - b_g)).to(b_v.dtype)\n # [BK, BV]\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_K(\n q,\n k,\n h,\n g,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # [BT, BV]\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n # [BT, BT]\n b_A += tl.dot(b_q, b_k, allow_tf32=False)\n p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n # [BT, BV]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_o = b_o * tl.exp(b_g)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n # [BT, BT]\n b_A = tl.where(m_s, b_A, 0.)\n if i_v == 0:\n tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkABCFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, s, g, initial_state, output_final_state):\n B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n BM = min(64, triton.next_power_of_2(M))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n assert not output_final_state\n assert M % 64 == 0, \"For efficiency, M must be a multiple of 64.\"\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_abc_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n NORMK=normk,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n gc = torch.empty_like(g, dtype=torch.float)\n grid = (NM, NT, B * H)\n # keep cumulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_abc_fwd_kernel_cum[grid](\n g, gc,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=M, BT=BT, BS=BM,\n num_warps=num_warps,\n num_stages=num_stages\n )\n g = gc\n\n scale = K ** -0.5\n hk = fwd_inner(\n q=q, k=k, v=s, g=g,\n B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,\n normk=False,\n h0=initial_state\n )\n ok1 = torch.empty_like(s)\n Ak = q.new_empty(B, H, T, BT)\n grid = (NM, NT, B * H)\n chunk_abc_fwd_kernel_K[grid](\n q, k, hk, g, ok1, Ak,\n k.stride(1), k.stride(2), k.stride(3),\n s.stride(1), s.stride(2), s.stride(3),\n hk.stride(1), hk.stride(2), hk.stride(3),\n scale,\n T=T, K=K, V=M, BT=BT, BK=BK, BV=BM,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return ok1, None\n\ndef gated_chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[torch.Tensor] = None,\n output_final_state: Optional[bool] = False\n) -> torch.Tensor:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if output_final_state:\n output_final_state = False\n warnings.warn(\"output_final_state is not supported in ABC, setting it to `False`.\")\n z = s.float().logcumsumexp(2)\n g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z\n s = torch.exp(s - z).to(k.dtype)\n ov, _ = ChunkABCFunction.apply(q, k, v, s, g, initial_state, output_final_state)\n return ov\n", - "description_1": "Use triton language to implement kernel functions for forward pass of an ABC attention mechanism. Each function has specific roles such as cumulative sum computation (chunk_abc_fwd_kernel_cum), hidden state updates (chunk_abc_fwd_kernel_h), and main kernel operations (chunk_abc_fwd_kernel_K). The forward function in ChunkABCFunction uses these kernels to compute output tensors based on input q, k, v, s, g, initial_state, and output_final_state.", - "description_2": "Use triton language to compute ABC attention with kernels for cumulative sum, hidden state update, and main computation using input tensors.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel for forward pass of logcumsumexp\n@triton.jit\ndef logcumsumexp_fwd_kernel(\n s, z, s_s_h, s_s_t, s_s_d,\n T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, NT: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_mp = tl.full([BS,], float('-inf'), dtype=tl.float32)\n b_zp = tl.zeros([BS,], dtype=tl.float32)\n for i_t in range(NT):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_mc = tl.max(b_s, 0)\n if i_t > 0:\n b_mc = tl.maximum(b_mp, b_mc)\n b_zp = b_zp * tl.exp(b_mp - b_mc)\n b_s = tl.exp(b_s - b_mc)\n b_z = tl.dot(m_s, b_s) + b_zp\n b_zc = tl.max(b_z, 0)\n b_mp = b_mc\n b_zp = b_zc\n b_z = tl.log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n# Kernel for forward pass of softmax\n@triton.jit\ndef softmax_fwd_kernel(\n s, p, s_s_h, s_s_t, s_s_d,\n T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_z = tl.zeros([BT,], dtype=tl.float32)\n b_m = tl.zeros([BT,], dtype=tl.float32)\n for i in range(tl.cdiv(S, BS)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_mc = tl.max(b_s, 1)\n b_mc = tl.maximum(b_m, b_mc)\n if i > 0:\n b_z = b_z * tl.exp(b_m - b_mc)\n b_z += tl.sum(tl.exp(b_s - b_mc[:, None]), 1)\n b_m = b_mc\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_s = tl.exp(b_s - b_m[:, None])\n b_p = tl.where(b_s != 0, b_s / b_z[:, None], 0.)\n tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1))\n\n# Kernel for backward pass of softmax\n@triton.jit\ndef softmax_bwd_kernel(\n p, dp, ds, s_s_h, s_s_t, s_s_d,\n T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_pp = tl.zeros([BT,], dtype=tl.float32)\n for i in range(tl.cdiv(S, BS)):\n p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i * BS), (BT, BS), (1, 0))\n p_dp = tl.make_block_ptr(dp + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i * BS), (BT, BS), (1, 0))\n b_p = tl.load(p_p, boundary_check=(0, 1)).to(tl.float32)\n b_dp = tl.load(p_dp, boundary_check=(0, 1)).to(tl.float32)\n b_pp += tl.sum(b_p * b_dp, 1)\n p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_dp = tl.make_block_ptr(dp + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_p = tl.load(p_p, boundary_check=(0, 1)).to(tl.float32)\n b_dp = tl.load(p_dp, boundary_check=(0, 1)).to(tl.float32)\n b_ds = b_p * b_dp - b_p * b_pp[:, None]\n tl.store(p_ds, b_ds.to(p_ds.dtype.element_ty), boundary_check=(0, 1))\n", - "description_1": "Use triton language to implement three kernels: logcumsumexp_fwd_kernel, softmax_fwd_kernel, and softmax_bwd_kernel. The logcumsumexp_fwd_kernel takes 10 parameters: s, z, s_s_h, s_s_t, s_s_d, T, S, BT, BS, NT, and performs a forward pass of the logcumsumexp operation. The softmax_fwd_kernel takes 9 parameters: s, p, s_s_h, s_s_t, s_s_d, T, S, BT, BS, and performs a forward pass of the softmax operation. The softmax_bwd_kernel takes 9 parameters: p, dp, ds, s_s_h, s_s_t, s_s_d, T, S, BT, BS, and performs a backward pass of the softmax operation.", - "description_2": "Use triton language to implement kernels for forward and backward passes of logcumsumexp and softmax operations, handling block pointers and boundary checks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_based_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n o, # output [B, H, L, D_head_V]\n z, # normalizer [B, H, L, 1]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n):\n # kernel code...\n pass\n\n@triton.jit\ndef fused_chunk_based_bwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n do, # gradient of output [B, H, L, D_head_V]\n dz, # gradient of normalizer [B, H, L]\n dq, # gradient of query [NV, B, H, L, D_head_K]\n dk, # gradient of key [NV, B, H, L, D_head_K]\n dv, # gradient of value [NK, B, H, L, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch_size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n):\n # kernel code...\n pass\n\nclass FusedChunkBasedFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale=1):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n )\n o = o.sum(0)\n z = z.sum(0)\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.to(q.dtype), z.to(z.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None\n\n\ntriton_fused_chunk_based = FusedChunkBasedFunction.apply\n\n\ndef fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):\n assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_fused_chunk_based(q, k, v, scale)\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a fused forward and backward kernel for chunk-based operations, taking queries, keys, and values as inputs and using them to compute outputs and normalizers, and to compute gradients for backward propagation. The forward kernel `fused_chunk_based_fwd_kernel` has 19 parameters (including tensors for queries, keys, values, outputs, normalizers, and stride sizes, as well as constants for batch size, number of heads, sequence length, scale, block sizes, and dimensional heads). The backward kernel `fused_chunk_based_bwd_kernel` has 24 parameters (including tensors for queries, keys, values, gradients of output and normalizer, and computed gradients, along with stride sizes, batch size, number of heads, sequence length, scale, block sizes, and dimensional heads).", - "description_2": "Use triton language to create a function for applying fused chunk-based forward and backward kernels for deep learning models with efficiency in memory and computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef parallel_based_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_q.dtype)\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)\n b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n\n b_dq *= scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(\n p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(\n [BTL, BV], dtype=tl.float32)\n\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale\n if i_v == 0:\n b_ds += b_dz[None, :] * scale\n b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n m_s = o_k[:, None] <= o_q[None, :]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n o_q += BTS\n\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,\n (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef parallel_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n i_h = i_bh % H\n _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\ndef parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel-based forward and backward pass for a sequence mixer, utilizing kernels with various block sizes for different dimensions, and using scaled dot-product attention. Parameters: q, k, v tensors for query, key, value, output and normalization tensors, stride sizes, batch size B, number of heads H, sequence length T, scaling factor, block sizes along sequence, K and V dimensions, and dimension sizes D_head_K and D_head_V.", - "description_2": "Use triton language to create a function that applies sequence mixing with attention, providing both forward and backward computations, and allowing optional normalization and scaling.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_fwd_kernel(\n q, k, v, beta, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n):\n # Kernel implementation is here\n pass # Placeholder for the kernel body\n\n@triton.jit\ndef fused_recurrent_bwd_kernel(\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n):\n # Kernel implementation is here\n pass # Placeholder for the kernel body\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = d_head_qk, min(d_head_v, 8)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n assert NK == 1, \"NK > 1 is not supported yet\"\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_fwd_kernel[grid](\n q, k, v, beta, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, beta, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, beta, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n dbeta = torch.empty_like(beta)\n\n fused_recurrent_bwd_kernel[grid](\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None\n\ndef fused_recurrent_linear_attn_delta_rule(q, k, v, beta, initial_state=None, output_final_state=False, normalize=False):\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentFunction.apply(\n q, k, v, beta, initial_state, output_final_state)\n if output_final_state:\n return o, final_state\n else:\n return o\n", - "description_1": "Use triton language to implement a fused recurrent forward and backward kernel for attention mechanisms with batch processing. The forward kernel takes in parameters: q, k, v, beta, o, initial_state, final_state, stride sizes, batch size, number of heads, sequence length, scaling factor, block sizes, and constants for using initial state and storing final state. The backward kernel takes in similar parameters but also includes gradients: do, dq, dk, dv, and dbeta. These kernels are called within an autograd function's forward and backward methods, which manage the memory and execute the kernels with the required grid size.", - "description_2": "Use triton language to create fused kernels for forward and backward operations in a recurrent attention network, handling input tensors, gradients, batch size, sequence length, and other parameters, integrated within a custom autograd function.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\n@triton.jit\ndef chunk_gla_fwd_kernel(\n k, v, g, h, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_hh, s_ht,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n d_b = tl.load(p_db).to(tl.float32)\n p_h = tl.make_block_ptr(h + i_bh * s_hh, ((i+1)*DK, DV), (s_ht, 1), (i*DK+i_k*BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_h *= tl.math.exp2(d_b)[:, None]\n b_h += tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_db += BT * DK\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_bwd_kernel(\n q, g, do, dh,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_hh, s_ht,\n B, H, T, TDK, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)[:, None] < DK) & (i_v * BV + tl.arange(0, BV)[None, :] < DV)\n p_dh = dh + i_bh * s_hh + (TDK - DK + i_k * BK + tl.arange(0, BK)[:, None]) * DV + i_v * BV + tl.arange(0, BV)[None, :]\n for i in range((tl.cdiv(T, BT) - 1) * BT, -BT, -BT):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_db = g + i_bh * s_qk_h + (i + BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n d_b = tl.math.exp2(tl.load(p_db).to(tl.float32))\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), mask=mask)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh = d_b[:, None] * b_dh + tl.dot(b_q, b_do, allow_tf32=False)\n p_dh -= DK * DV\n\nclass ChunkGLAFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):\n ctx.g_dtype = g.dtype\n g = torch.empty_like(g, dtype=torch.float32)\n ctx.scale = scale\n B, H, T, DK, DV = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(DK, 64), min(DV, 64)\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(DK, BK), triton.cdiv(DV, BV)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, NT, B * H)\n fwd_decay_cumsum[grid](g, q_g, q.stride(1), q.stride(2), q.stride(3), B, H, T, scale, BT=BT, BK=BK, DK=DK, num_warps=1)\n prepare_qg_kg[grid](q, k, g, q_g, k_g, q.stride(1), q.stride(2), q.stride(3), B, H, T, scale, BT=BT, BK=BK, DK=DK, num_warps=1)\n if output_final_state:\n final_state = q.new_empty(B, H, DK, DV, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n grid = (NV, NK, B * H)\n h = q.new_empty(B, H, NT * DK, DV)\n chunk_gla_fwd_kernel[grid](\n k_g, v, g, h, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n B, H, T, scale,\n BT=BT, DK=DK, DV=DV, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=4, num_stages=3\n )\n o = rearrange(q_g, 'b h (n c) d -> b h n c d', c=BT) @ rearrange(h, 'b h (n c) d -> b h n c d', c=DK)\n o = rearrange(o, 'b h n c d -> b h (n c) d')\n v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=NT)\n BK = min(DK, 128)\n NK = triton.cdiv(DK, BK)\n A = q.new_zeros(NK, B, H, NT, BT, BT)\n BC = 16\n NC = BT // BC\n grid = (NK, NT * NC * NC, B * H)\n fwd_inner_chunk[grid](\n q, k, g, A,\n q.stride(1), q.stride(2), q.stride(3),\n A.stride(2), A.stride(3), A.stride(4),\n scale,\n T=T, DK=DK, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=4, num_stages=3\n )\n A = A.sum(0)\n o2 = A @ v2\n o2 = rearrange(o2, 'b h n c d -> b h (n c) d')\n o2 += o\n ctx.save_for_backward(q, k, v, g, A, initial_state, h)\n return o2.to(v), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, g, A, initial_state, h = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n BT = 64\n g = torch.empty_like(g, dtype=torch.float32)\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n fwd_decay_cumsum[grid](g, q_g, q.stride(1), q.stride(2), q.stride(3), batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_warps=1)\n prepare_qg_kg[grid](q, k, g, q_g, k_g, q.stride(1), q.stride(2), q.stride(3), batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_warps=1)\n dq = rearrange_back(rearrange_chunk(do, BT) @ rearrange_chunk(h, d_head_qk).transpose(-1, -2)) * scale\n grid = (NV, NK, batch_size * n_heads)\n dh = torch.empty_like(h)\n chunk_gla_bwd_kernel[grid](\n q_g, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n batch_size, n_heads, seq_len, dh.shape[-2], scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=4, num_stages=1\n )\n dh = rearrange_chunk(dh, d_head_qk)\n dk = rearrange_back(torch.einsum('b h n k v, b h n c v -> b h n c k', dh, rearrange_chunk(v, BT)))\n dv = rearrange_back(torch.einsum('b h n k v, b h n c k -> b h n c v', dh, rearrange_chunk(k_g, BT)))\n num_chunk = seq_len // BT\n v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)\n do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)\n dA2 = (do2 @ v2.transpose(-2, -1)) * scale\n dv2 = A.transpose(-1, -2) @ do2\n dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)\n BK = min(d_head_qk, 64)\n NK = triton.cdiv(d_head_qk, BK)\n dk2 = torch.empty_like(k)\n dq2 = torch.empty_like(q)\n BC = 16\n grid = (BT // BC, triton.cdiv(seq_len, BT), batch_size * n_heads)\n bwd_inner_chunk[grid](\n q, k, g, dA2,\n dq2, dk2,\n q.stride(1), q.stride(2), q.stride(3),\n A.stride(1), A.stride(2), A.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, BC=BC, DK=d_head_qk,\n num_stages=4, num_warps=4\n )\n dg = torch.empty_like(g, dtype=torch.float32)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n bwd_decay_global_cumsum[grid](dq2, dq, dk2, dk, q, k, g, dg, q.stride(1), q.stride(2), q.stride(3), batch_size, n_heads, seq_len, scale, BT=BT, DK=d_head_qk, BK=BK, num_warps=1, num_stages=1)\n dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)\n def rev_cumsum_exclusive(x):\n cumsum_x = x.cumsum(-2)\n rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x\n return rev_cumsum_x\n rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])\n dg.add_(rev_cumsum_dg.unsqueeze(-2))\n dv.add_(dv2)\n dg = rearrange(dg, 'b h n c d -> b h (n c) d')\n return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None\n\ndef chunk_gla(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, scale: int = -1, initial_state: torch.Tensor = None, output_final_state: bool = False):\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = v.shape[-2]\n d_head_v = v.shape[-1]\n q, k, v, g = map(lambda x: pad(x), [q, k, v, g])\n o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)\n o = o[..., :seq_len, :d_head_v]\n if output_final_state:\n return o, final_state\n return o\n", - "description_1": "Use triton language to implement kernels for forward and backward passes of a gated linear attention mechanism, with parameters: k, v, g, h, q, scales, strides, dimensions and block sizes. The forward kernel computes the cumulative decay and performs matrix multiplications, storing results in h and optionally final_state. The backward kernel computes gradients by iterating through blocks in reverse order.", - "description_2": "Use triton language to create a forward kernel for a gated linear attention that performs sequential decay and accumulates matrix products, and a backward kernel that computes gradients for q, k, v, and g with efficient memory handling.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o,\n initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n\n d_b = tl.load(p_db).to(tl.float32)\n if CHECK and i == 0:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n else:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_db += BT * DK\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g,\n do, dq, dk, dv,\n initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n d_b = tl.load(p_db).to(tl.float32)\n\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_db = tl.load(p_db).to(tl.float32)\n\n if CHECK and i == 1:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n else:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\ndef fused_chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n):\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = v.shape[-2]\n d_head_v = v.shape[-1]\n q, k, v, g = map(lambda x: pad(x), [q, k, v, g])\n o, final_state = FusedChunkGLAFunction.apply(\n q, k, v, g, scale, initial_state, output_final_state)\n o = o[..., :seq_len, :d_head_v]\n if output_final_state:\n return o, final_state\n return o\n", - "description_1": "Use triton language to implement two kernels: fused_chunk_gla_fwd_kernel and fused_chunk_gla_bwd_kernel. The forward kernel computes a fused forward pass for a Gated Linear Attention mechanism across multiple batches, heads, and sequence lengths. It takes input queries, keys, values, cumulative sums, and initial states, and outputs an attention-modulated output and final states. The backward kernel computes the gradient of the forward pass, taking gradients of outputs and returning gradients for queries, keys, values, and cumulative sums. Both kernels use triton's advanced block pointer and boundary-checking operations to efficiently handle large matrix computations. Each kernel function has 26 parameters: the main tensor inputs/outputs, strides for accessing tensors, batch, head, and sequence dimensions, scaling factor, block sizes (chunks along sequence, key, and value dimensions), dimensional sizes for key and value heads, and boolean flags indicating whether to use initial state, store final state, and perform boundary checks.", - "description_2": "Use triton language to implement fused forward and backward kernels for Gated Linear Attention mechanism with tensor inputs and boundary checks.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n# Kernel to compute forward decay cumulative sum\n@triton.jit\ndef fwd_decay_cumsum(\n g,\n g_o, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n# Kernel to prepare qg and kg\n@triton.jit\ndef prepare_qg_kg(\n q,\n k,\n g,\n qg,\n kg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n# Kernel to compute backward decay global cumulative sum\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner,\n dq_inter,\n dk_inner,\n dk_inter,\n q, k, g, dg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n", - "description_1": "Use triton language to implement three kernels: fwd_decay_cumsum, prepare_qg_kg, and bwd_decay_global_cumsum. The fwd_decay_cumsum kernel computes a forward decay cumulative sum with 12 parameters, including input tensors, strides, batch size, head size, time steps, scale, and block sizes. The prepare_qg_kg kernel prepares qg and kg tensors with 13 parameters, including input tensors, output tensors, strides, batch size, head size, time steps, scale, and block sizes. The bwd_decay_global_cumsum kernel computes a backward decay global cumulative sum with 16 parameters, including input tensors, output tensors, strides, batch size, head size, time steps, scale, and block sizes.", - "description_2": "Use triton language to implement kernels for forward decay cumulative sum, preparation of qg and kg tensors, and backward decay global cumulative sum with specified parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, gk, gv, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[None, :]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[:, None]\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -DK if REVERSE else DK\n p_k += -DK if REVERSE else DK\n p_o += -DV if REVERSE else DV\n p_v += -DV if REVERSE else DV\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n USE_GK: tl.constexpr, USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[None, :]\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -DK if REVERSE else DK\n p_v += -DV if REVERSE else DV\n p_q += -DK if REVERSE else DK\n p_do += -DV if REVERSE else DV\n p_dq += -DK if REVERSE else DK\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + \\\n tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + \\\n tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n d_h *= _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n d_h *= _gv[None, :]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do += DV if REVERSE else -DV\n p_q += DK if REVERSE else -DK\n p_k += DK if REVERSE else -DK\n p_v += DV if REVERSE else -DV\n p_dk += DK if REVERSE else -DK\n p_dv += DV if REVERSE else -DV\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n if scale is None:\n scale = d_head_qk ** -0.5\n if gk is not None:\n gk = gk.float().exp()\n if gv is not None:\n gv = gv.float().exp()\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_gla_fwd_kernel[grid](\n q, k, v, gk, gv, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, gk, gv, initial_state, o = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=torch.float32)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=torch.float32)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n if gk is not None:\n _dgk = dq * q.float() - dk * k.float()\n if ctx.reverse:\n dgk = _dgk.cumsum(-2)\n else:\n _dgk_cumsum = _dgk.cumsum(-2)\n dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum\n else:\n dgk = None\n\n if gv is not None:\n _dgv = do.float() * o.float() - dv * v.float()\n if ctx.reverse:\n dgv = _dgv.cumsum(-2)\n else:\n _dgv_cumsum = _dgv.cumsum(-2)\n dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum\n else:\n dgv = None\n\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None\n", - "description_1": "Use triton language to implement a fused recurrent attention mechanism with both forward and backward kernels. The forward kernel takes 20 arguments: query (q), key (k), value (v), log gates (gk, gv), output (o), initial and final states, stride sizes (s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d), batch size (B), number of heads (H), sequence length (T), scale factor (scale), and various constexpr parameters (BK, BV, DK, DV, USE_INITIAL_STATE, STORE_FINAL_STATE, REVERSE, USE_GK, USE_GV). The backward kernel takes similar arguments but computes the gradients of q, k, v, etc. The FusedRecurrentGLAFunction in PyTorch, with custom forward and backward functions, orchestrates the kernel launches and data handling.", - "description_2": "Use triton language to create a fused recurrent attention mechanism by implementing forward and backward computation kernels, integrating them into a PyTorch autograd function to support efficient training of neural networks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_linear_attn_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_linear_attn_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)) \n \n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n \n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale \n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n\n tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = False\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\n\ndef fused_chunk_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = True\n):\n if initial_state is not None:\n initial_state = initial_state.detach()\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunked linear attention kernel with backward pass. The forward kernel takes 22 parameters: query, key, value, output, initial state, final state, stride sizes, batch size, number of heads, sequence length, scale factor, block sizes for sequence, K, and V dimensions, head dimensions, use of initial state, store final state, and a check flag. The backward kernel also takes 22 parameters: query, key, value, gradient of output, gradients of query, key, and value, initial state, stride sizes, batch size, number of heads, sequence length, scale factor, block sizes for sequence, K, and V dimensions, head dimensions, use of initial state, and a check flag. The main function orchestrates the execution of these kernels using PyTorch's autograd functionality.", - "description_2": "Use triton language to implement a chunked linear attention mechanism with automatic differentiation support. The kernels handle tensor operations, block pointers, and conditional logic to perform efficient forward and backward passes for deep learning models.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement parallelized forward and backward pass kernels for linear attention. The forward kernel `parallel_rebased_fwd_kernel` takes 19 parameters where `q`, `k`, `v`, `o`, `z` are tensors representing query, key, value, output, and normalizer respectively, followed by stride sizes for qk and vo dimensions, batch size `B`, number of heads `H`, sequence length `T`, and a scale factor. It uses constant expression parameters for block sizes along different dimensions and performs matrix multiplication and reduction operations in parallel blocks. The backward kernel `parallel_rebased_bwd_kernel` mirrors this functionality for backpropagation, with additional inputs for gradients and helper tensors `dq`, `dk`, `dv` representing gradients of query, key, and value respectively.", - "description_2": "Use triton language to create triton kernels for efficient parallel computation of forward and backward passes in linear attention with support for tensor operations and gradient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_retention_fwd_kernel_h(\n k, v, h, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, H: tl.constexpr, T: tl.constexpr, \n K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, \n BV: tl.constexpr, NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, \n STORE_FINAL_STATE: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef chunk_retention_fwd_kernel_o(\n q, k, v, h, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, \n s_vo_d, s_h_h, s_h_t, scale, H: tl.constexpr, T: tl.constexpr, \n K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, \n BV: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dh(\n q, do, dh, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, \n s_h_h, s_h_t, scale, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, \n V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n NT: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dqkv(\n q, k, v, h, do, dh, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale, H: tl.constexpr, T: tl.constexpr, \n K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, \n BV: tl.constexpr, NT: tl.constexpr\n):\n # Kernel implementation...\n\nclass ChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_fwd_kernel_h[grid](\n k, v, h, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_retention_fwd_kernel_o[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h)\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_ht=None):\n q, k, v, h = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_bwd_kernel_dh[grid](\n q, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_retention_bwd_kernel_dqkv[grid](\n q, k, v, h, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\n\ndef chunk_retention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, initial_state: torch.Tensor = None, output_final_state: bool = False):\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n if output_final_state:\n return o, final_state\n else:\n return o\n", - "description_1": "Use triton language to define forward and backward kernels for a chunk retention operation. The forward kernel computes intermediate results using input tensors q, k, v, and optionally an initial state, and stores results in tensors h and o. The backward kernels compute gradients for q, k, v using the forward computation results and the gradient tensor do. All functions manage tensor shapes and strides, leveraging triton's block-level parallelism.", - "description_2": "Use triton language to implement a series of kernels for attention-like operations with customizable parameters and grid dimensions to perform high-performance tensor computations on GPU.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)\n d_b = tl.math.exp2(BT * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n b_dq = tl.dot(b_ds, b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n d_s = tl.trans(d_s)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = False\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\ndef fused_chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n):\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n if output_final_state:\n return o, final_state\n else:\n return o\n", - "description_1": "Use triton language to implement a fused chunk retention mechanism with a forward kernel (`fused_chunk_retention_fwd_kernel`) and a backward kernel (`fused_chunk_retention_bwd_kernel`). The forward kernel takes 21 parameters: q, k, v, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, and block sizes BT, BK, BV, DK, DV. The backward kernel uses the same block sizes and similar parameters as the forward kernel but includes additional gradient parameters dq, dk, dv, and do. These kernels are used in a custom autograd function `FusedChunkRetentionFunction` with `forward` and `backward` methods, and encapsulated in a Python function `fused_chunk_retention` that applies the Triton kernels.", - "description_2": "Use triton language to develop a custom forward and backward function for fused chunk retention, employing Triton's just-in-time compilation to optimize matrix computations and gradient calculations.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef parallel_retention_fwd_kernel(\n q, k, v, o,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr\n):\n # Kernel implementation\n\n@triton.jit\ndef _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # Kernel implementation\n\n@triton.jit\ndef _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # Kernel implementation\n\n@triton.jit\ndef parallel_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # Kernel implementation\n\nclass ParallelRetentionFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_retention_fwd_kernel[grid](\n q, k, v, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n return o.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)\n\nparallel_retention = ParallelRetentionFunction.apply\n", - "description_1": "Use triton language to implement a forward and backward pass of a parallel retention mechanism for neural networks. The forward pass (`parallel_retention_fwd_kernel`) processes query (q), key (k), and value (v) tensors with various block sizes and stores the results in the output tensor (o). The backward pass (`parallel_retention_bwd_kernel`) computes gradients with respect to q, k, and v using helper functions `_parallel_retention_bwd_dq` and `_parallel_retention_bwd_dkv`.", - "description_2": "Use triton language to implement a neural network parallel retention operation with both forward and backward passes, calculating output and gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_retention_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n o, # output [B, H, L, D_head_V]\n initial_state,\n final_state, # final hidden state [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n STORE_FINAL_STATE: tl.constexpr, # whether to store final state\n):\n # kernel implementation...\n\n@triton.jit\ndef fused_recurrent_retention_bwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n do, # gradient of output [B, H, L, D_head_V]\n dq, # gradient of query [NV, B, H, L, D_head_K]\n dk, # gradient of key [NV, B, H, L, D_head_K]\n dv, # gradient of value [NK, B, H, L, D_head_V]\n initial_state, # initial hidden state initialization [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch_size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n):\n # kernel implementation...\n\nclass FusedRecurrentRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state=None, output_final_state=False):\n # Prepare arguments and launch the forward kernel\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq, dk, dv, None, None\n\ndef fused_recurrent_retention(q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False):\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentRetentionFunction.apply(\n q, k, v, initial_state, output_final_state)\n if output_final_state:\n return o, final_state\n else:\n return o\n", - "description_1": "Use triton language to define two kernels, fused_recurrent_retention_fwd_kernel and fused_recurrent_retention_bwd_kernel. These kernels perform the forward and backward operations for a custom recurrent retention mechanism. The forward kernel computes the retention using queries, keys, and values, while maintaining a final hidden state. The backward kernel calculates gradients for the queries, keys, and values. These kernels handle various dimensions such as batch size, number of heads, sequence length, and dimensions of the key and value embeddings. The functions also support optional initial and final state handling.", - "description_2": "Use triton language to define forward and backward kernels for a recurrent retention mechanism handling queries, keys, and values, while managing states.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, # Pointers to matrices\n X,\n COS,\n SIN,\n CU_SEQLENS,\n SEQLEN_OFFSETS, # this could be int or a pointer\n # Matrix dimensions\n seqlen,\n nheads,\n rotary_dim,\n seqlen_ro,\n CACHE_KEY_SEQLEN,\n # strides\n stride_out_batch,\n stride_out_seqlen,\n stride_out_nheads,\n stride_out_headdim,\n stride_x_batch,\n stride_x_seqlen,\n stride_x_nheads,\n stride_x_headdim,\n # Meta-parameters\n BLOCK_K: tl.constexpr,\n IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr,\n INTERLEAVED: tl.constexpr,\n CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT\n X = X + (rm[:, None] * stride_x_seqlen +\n rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(\n COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0\n ).to(tl.float32)\n sin = tl.load(\n SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x0 = tl.load(\n X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x1 = tl.load(\n X + rotary_dim_half * stride_x_headdim,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n # write back result\n OUT = OUT + (rm[:, None] * stride_out_seqlen +\n rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen)\n & (rk_half[None, :] < rotary_dim_half))\n tl.store(\n OUT + rotary_dim_half * stride_out_headdim,\n o1,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n )\n else:\n # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.\n # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].\n # Loading x0 will be fast but x1 will be slow.\n # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].\n # Then we do the calculation and use tl.where to pick put the right outputs for the even\n # and for the odd indices.\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen +\n rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen +\n rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(\n COS,\n mask=(rm_cs[:, None] < seqlen_ro) & (\n rk_repeat[None, :] < rotary_dim_half),\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN,\n mask=(rm_cs[:, None] < seqlen_ro) & (\n rk_repeat[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(\n tl.float32\n )\n x1 = tl.load(\n X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen +\n rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen)\n & (rk[None, :] < rotary_dim))\n\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved=False,\n inplace=False,\n conjugate=False,\n) -> torch.Tensor:\n \"\"\"\n Arguments:\n x: (batch, seqlen, nheads, headdim) if cu_seqlens is None\n else (total_seqlen, nheads, headdim).\n cos: (seqlen_ro, rotary_dim / 2)\n sin: (seqlen_ro, rotary_dim / 2)\n seqlen_offsets: integer or integer tensor of size (batch,)\n cu_seqlens: (batch + 1,) or None\n max_seqlen: int\n Returns:\n y: (batch, seqlen, nheads, headdim)\n \"\"\"\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch_p_1 = cu_seqlens.shape[0]\n batch = batch_p_1 - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n rotary_dim *= 2\n assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n assert (\n cos.dtype == sin.dtype\n ), f\"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}\"\n assert (\n x.dtype == cos.dtype\n ), f\"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}\"\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = (\n 32\n if rotary_dim <= 32\n else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n )\n def grid(META): return (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads) # noqa\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n # Need this, otherwise Triton tries to launch from cuda:0 and we get\n # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, # data ptrs\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen, # shapes\n nheads,\n rotary_dim,\n seqlen_ro,\n # key for triton cache (limit number of compilations)\n seqlen // 128,\n # batch_strides if not varlen else 0\n output.stride(0) if not is_varlen else 0,\n output.stride(-3), # seqlen_stride or total_seqlen_stride\n output.stride(-2), # nheads_stride\n output.stride(-1), # headdim_stride\n # batch_strides if not varlen else 0\n x.stride(0) if not is_varlen else 0,\n x.stride(-3), # seqlen stride or total_seqlen_stride\n x.stride(-2), # nheads stride\n x.stride(-1), # headdim stride\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n", - "description_1": "Use triton language to implement a kernel 'rotary_kernel' with 31 parameters. The kernel performs rotary positional encoding on input matrices with custom strides and meta-parameters. The associated function 'apply_rotary' takes 9 arguments and calls the kernel to execute this computation on a given set of tensors and parameters, returning the modified tensor.", - "description_2": "Use triton language to implement rotary positional encoding using a kernel with customizable dimensions and parameters, and call it using an associated Python function.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_h(\n k, v, h, g, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n b_h *= tl.math.exp2(b_g_last)\n b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))\n b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q, k, v, h, g, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_o = b_o * tl.math.exp2(b_g)[:, None]\n b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dh(\n q, g, do, dh, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkv(\n q, k, v, h, g, do, dh, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n mask = tl.math.exp2(b_g[None, :] - b_g[:, None])\n mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)\n b_s = b_s * mask\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n b_dq = b_dq * tl.math.exp2(b_g)[:, None]\n b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]\n b_ds = b_ds * tl.trans(mask)\n b_ds = b_ds.to(b_k.dtype)\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass SimpleGLAFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n BT = 64\n g = g.reshape(B, H, -1, BT)\n g = g.cumsum(-1) * 1.44269504\n g = g.reshape(B, H, -1)\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_fwd_kernel_h[grid](\n k, v, h, g, initial_state, final_state, q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), h.stride(1), h.stride(2), H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None, STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps, num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o, q.stride(1), q.stride(2), q.stride(3), v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), scale, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, num_warps=num_warps, num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v, h, g)\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_ht=None):\n q, k, v, h, g = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(\n 32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_bwd_kernel_dh[grid](\n q, g, do, dh, q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), dh.stride(1), dh.stride(2),\n scale, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, num_warps=num_warps, num_stages=num_stages\n )\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_simple_gla_bwd_kernel_dqkv[grid](\n q, k, v, h, g, do, dh, dq, dk, dv, q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), dh.stride(1), dh.stride(2),\n scale, B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, num_warps=num_warps, num_stages=num_stages\n )\n dv = dv.sum(0)\n dg = (dq * q - dk * k).sum(-1)\n\n def rev_cumsum(x):\n cumsum_x = x.cumsum(-1)\n rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x\n return rev_cumsum_x + x\n dg = rev_cumsum(dg)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None\n\n\ndef chunk_simple_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor, # log decay\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n):\n if initial_state is not None:\n initial_state = initial_state.detach()\n g = g.float()\n B, H, T = g.shape\n\n o, final_state = SimpleGLAFunction.apply(\n q, k, v, g, initial_state, output_final_state\n )\n\n if output_final_state:\n return o, final_state\n else:\n return o\n", - "description_1": "Use triton language to define kernels for forward and backward passes of a chunk-based simple GLA (Generalized Linear Algebra) operation. Forward pass kernels compute transformed outputs and optional final states, while backward pass kernels calculate gradients with respect to the inputs.", - "description_2": "Use triton language to implement kernels that perform tensor transformations and compute gradients for a specific GLA operation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel to perform a multi-layer perceptron computation.\n@triton.jit\ndef mlp_kernel(X, W1, B1, W2, B2, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n\n # Load input\n x = tl.load(X + block_start * K + tl.arange(0, BLOCK_SIZE)[:, None] * K + tl.arange(0, K)[None, :])\n\n # First layer: X @ W1 + B1\n w1 = tl.load(W1 + tl.arange(0, K)[:, None] * M + tl.arange(0, M)[None, :])\n b1 = tl.load(B1 + tl.arange(0, M))\n y1 = tl.dot(x, w1) + b1\n\n # Activation function (ReLU)\n y1 = tl.where(y1 > 0, y1, 0)\n\n # Second layer: y1 @ W2 + B2\n w2 = tl.load(W2 + tl.arange(0, M)[:, None] * N + tl.arange(0, N)[None, :])\n b2 = tl.load(B2 + tl.arange(0, N))\n y2 = tl.dot(y1, w2) + b2\n\n # Store output\n tl.store(Y + block_start * N + tl.arange(0, BLOCK_SIZE)[:, None] * N + tl.arange(0, N)[None, :], y2)\n\n# Function to launch the kernel\ndef mlp(X, W1, B1, W2, B2):\n Y = torch.empty((BLOCK_SIZE, N), dtype=torch.float32)\n\n # Launch kernel\n grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE), )\n\n mlp_kernel[grid](X, W1, B1, W2, B2, Y, M, N, K, BLOCK_SIZE)\n", - "description_1": "Use triton language to implement a multi-layer perceptron (MLP) kernel, which consists of two matrix multiplications with ReLU activation. The mlp_kernel function has 10 parameters: X, W1, B1, W2, B2, Y are pointers for input/output tensors, and M, N, K, BLOCK_SIZE are constants defining tensor dimensions and execution parameters. The mlp function launches this kernel with these parameters.", - "description_2": "Use triton language to create a kernel that performs two matrix multiplications with an activation function. Utilize triton's parallel execution capabilities with specified block sizes to optimize performance.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef softmax_kernel(x_ptr, out_ptr, N: tl.constexpr, block_size: tl.constexpr):\n # Get the index of the current thread\n pid = tl.program_id(0)\n block_start = pid * block_size\n offsets = block_start + tl.arange(0, block_size)\n mask = offsets < N\n\n # Load elements from global memory\n x = tl.load(x_ptr + offsets, mask=mask)\n\n # Compute linear layer\n exp_values = tl.exp(x - tl.max(x))\n probabilities = exp_values / tl.sum(exp_values)\n result = probabilities\n\n # Write result to global memory\n if pid == 0:\n tl.store(out_ptr + offsets, result, mask=mask)\n\ndef softmax(x):\n # Prepare output tensor\n out = torch.empty_like(x, dtype=torch.float32, device=x.device)\n N = out.numel()\n\n BLOCK_SIZE = 1024\n num_blocks = (N + BLOCK_SIZE - 1) // BLOCK_SIZE # Calculate the number of blocks needed\n \n # Launch Triton kernel\n grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE), )\n\n softmax_kernel[grid](x, out, N, BLOCK_SIZE)\n \n return out\n", - "description_1": "Use triton language to implement a softmax operation. The kernel function 'softmax_kernel' takes four parameters: x_ptr (pointer to input tensor), out_ptr (pointer to output tensor), N (total number of elements), and block_size (size of each block). It computes the softmax of the input tensor in a block-wise manner. The 'softmax' function prepares the output tensor, calculates the number of blocks, and launches the Triton kernel.", - "description_2": "Use triton language to implement a block-wise softmax operation on a 1D tensor using a kernel function and a wrapper function to manage tensor preparation and kernel launch.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef dot_product_kernel(x_ptr, y_ptr, out_ptr, N: tl.constexpr, block_size: tl.constexpr):\n # Get the index of the current thread\n pid = tl.program_id(0)\n block_start = pid * block_size\n offsets = block_start + tl.arange(0, block_size)\n mask = offsets < N\n\n # Load elements from global memory\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n\n # Compute dot product\n result = tl.sum(x * y, axis=0)\n\n # Write result to global memory\n if pid == 0:\n tl.store(out_ptr, result)\n\ndef dot_product(x, y):\n # Ensure x and y are 1D tensors\n if x.dim() != 1 or y.dim() != 1:\n raise ValueError(\"Both input tensors must be 1-dimensional\")\n \n if x.size(0) != y.size(0):\n raise ValueError(\"Input tensors must be of the same size\")\n\n N = next_power_of_2(x.size(0))\n block_size = 1024\n\n # Prepare output tensor\n out = torch.empty((), dtype=torch.float32, device=x.device)\n \n # Launch Triton kernel\n grid = (1,)\n dot_product_kernel[grid](x, y, out, N, block_size)\n \n return out.item()\n", - "description_1": "Use triton language to implement a dot product kernel. The kernel function 'dot_product_kernel' takes five parameters: x_ptr (pointer to the first input tensor), y_ptr (pointer to the second input tensor), out_ptr (pointer to the output tensor), N (size of the input tensors, as a compile-time constant), and block_size (size of the block, as a compile-time constant). It computes the dot product of two 1D tensors and stores the result in the output tensor. The 'dot_product' function is a wrapper that checks input validity, prepares the output tensor, and launches the Triton kernel.", - "description_2": "Use triton language to create a kernel for computing the dot product of two 1D tensors. Implement a wrapper function to validate inputs and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.runtime import driver\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,\n num_stages: tl.constexpr, num_warps: tl.constexpr):\n # starting row of the program\n row_start = tl.program_id(0)\n row_step = tl.num_programs(0)\n for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n mask = col_offsets < n_cols\n row = tl.load(input_ptrs, mask=mask, other=-float('inf'))\n # Subtract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=mask)\n\ndevice = torch.cuda.current_device()\nproperties = driver.active.utils.get_device_properties(device)\nNUM_SM = properties[\"multiprocessor_count\"]\nNUM_REGS = properties[\"max_num_regs\"]\nSIZE_SMEM = properties[\"max_shared_mem\"]\nWARP_SIZE = properties[\"warpSize\"]\ntarget = triton.runtime.driver.active.get_current_target()\nkernels = {}\n\ndef softmax(x):\n n_rows, n_cols = x.shape\n\n # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 8\n\n # Number of software piepling stages.\n num_stages = 4 if SIZE_SMEM > 200000 else 2\n\n # Allocate output\n y = torch.empty_like(x)\n\n # pre-compile kernel to get register usage and compute thread occupancy.\n kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))\n if kernel is None:\n grid=(1, )\n kernel = softmax_kernel[grid](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages, num_warps)\n kernel._init_handles()\n n_regs = kernel.n_regs\n size_smem = kernel.metadata.shared\n if is_hip():\n # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.\n # However, this is not always the case. In most cases all registers can be used as regular purpose registers.\n # ISA SECTION (3.6.4 for CDNA3)\n # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used\n # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total\n # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is\n # not required to be equal numbers of both types.\n if is_cdna():\n NUM_GPRS = NUM_REGS * 2\n\n # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.\n # When we divide this number with WARP_SIZE we get maximum number of waves that can\n # execute on a CU (multi-processor) in parallel.\n MAX_NUM_THREADS = properties[\"max_threads_per_sm\"]\n max_num_waves = MAX_NUM_THREADS // WARP_SIZE\n occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps\n else:\n occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)\n occupancy = min(occupancy, SIZE_SMEM // size_smem)\n num_programs = NUM_SM * occupancy\n kernels[BLOCK_SIZE] = (kernel, num_programs)\n\n num_programs = min(num_programs, n_rows)\n\n # Create a number of persistent programs.\n kernel[(num_programs, 1, 1)](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_rows,\n n_cols,\n )\n return y\n", - "description_1": "Use triton language to implement a softmax operation on a 2D tensor. The kernel function 'softmax_kernel' takes 8 parameters: output_ptr (output tensor pointer), input_ptr (input tensor pointer), input_row_stride (stride of input rows), output_row_stride (stride of output rows), n_rows (number of rows), n_cols (number of columns), BLOCK_SIZE (block size for processing), num_stages (number of software pipeline stages), and num_warps (number of warps for parallel processing). The function normalizes each row of the input tensor and writes the result to the output tensor. The 'softmax' function prepares the input tensor, sets up kernel parameters, and launches the kernel.", - "description_2": "Use triton language to create a softmax kernel that normalizes rows of a 2D tensor using parallel processing. Implement a function to launch this kernel with appropriate parameters.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef linear_kernel(x_ptr, y_ptr, bias_ptr, out_ptr, N: tl.constexpr, block_size: tl.constexpr):\n # Get the index of the current thread\n pid = tl.program_id(0)\n block_start = pid * block_size\n offsets = block_start + tl.arange(0, block_size)\n mask = offsets < N\n\n # Load elements from global memory\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n bias = tl.load(bias_ptr)\n\n # Compute linear layer\n result = tl.sum(x * y, axis=0) + bias\n\n # Write result to global memory\n if pid == 0:\n tl.store(out_ptr, result)\n\ndef linear(x, y, bias):\n # Ensure x and y are 1D tensors\n if x.dim() != 1 or y.dim() != 1:\n raise ValueError(\"Both input tensors must be 1-dimensional\")\n \n if x.size(0) != y.size(0):\n raise ValueError(\"Input tensors must be of the same size\")\n\n N = next_power_of_2(x.size(0))\n block_size = 1024\n\n # Prepare output tensor\n out = torch.empty((), dtype=torch.float32, device=x.device)\n \n # Launch Triton kernel\n grid = (1,)\n\n linear_kernel[grid](x, y, bias, out, N, block_size)\n \n return out.item()\n", - "description_1": "Use triton language to implement a linear kernel function 'linear_kernel' with 6 parameters: x_ptr, y_ptr, bias_ptr, out_ptr, N, and block_size. The kernel computes the dot product of two 1D tensors x and y, adds a bias, and stores the result in out_ptr. The function 'linear' is a wrapper that prepares the input tensors, calculates the next power of 2 for the input size, and launches the kernel.", - "description_2": "Use triton language to compute the dot product of two 1D tensors with an added bias using a kernel function, and manage the kernel launch with a wrapper function.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n ):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98411\nx = torch.rand(size)\ny = torch.rand(size)\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(f'Triton result: {output_triton}')\nprint(f'Torch result: {output_torch}')\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel 'add_kernel' takes 5 parameters: pointers to the input vectors x and y, a pointer to the output vector, the number of elements in the vectors, and a block size as a compile-time constant. The function 'add' wraps this kernel, taking two PyTorch tensors as input, preallocating an output tensor, and launching the kernel with a 1D grid. The grid size is determined by the number of elements divided by the block size.", - "description_2": "Use triton language to create a kernel for element-wise addition of two vectors, and a wrapper function to execute this kernel on PyTorch tensors.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, N_CTX: tl.constexpr, fp8_v: tl.constexpr):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n else:\n lo, hi = 0, N_CTX\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n qk = tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_block_ptr)\n if fp8_v:\n p = p.to(tl.float8e5)\n else:\n p = p.to(tl.float16)\n acc = tl.dot(p, v, acc)\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.jit\ndef _attn_fwd(Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr):\n tl.static_assert(BLOCK_N <= HEAD_DIM)\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n\n Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0))\n v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)\n V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, HEAD_DIM), order=v_order)\n K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, shape=(HEAD_DIM, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(HEAD_DIM, BLOCK_N), order=(0, 1))\n O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0))\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)\n\n qk_scale = sm_scale\n qk_scale *= 1.44269504\n\n q = tl.load(Q_block_ptr)\n\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, BLOCK_M, HEAD_DIM, BLOCK_N, 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5)\n if STAGE & 2:\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, BLOCK_M, HEAD_DIM, BLOCK_N, 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5)\n\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n assert HEAD_DIM_K in {16, 32, 64, 128, 256}\n o = torch.empty_like(q)\n stage = 3 if causal else 1\n extra_kern_args = {}\n\n grid = lambda args: (triton.cdiv(q.shape[2], args[\"BLOCK_M\"]), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], N_CTX=q.shape[2], HEAD_DIM=HEAD_DIM_K, STAGE=stage, **extra_kern_args)\n\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.HEAD_DIM = HEAD_DIM_K\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, M = ctx.saved_tensors\n assert do.is_contiguous()\n assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n BATCH, N_HEAD, N_CTX = q.shape[:3]\n PRE_BLOCK = 128\n NUM_WARPS, NUM_STAGES = 4, 5\n BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32\n BLK_SLICE_FACTOR = 2\n RCP_LN2 = 1.4426950408889634\n arg_k = k\n arg_k = arg_k * (ctx.sm_scale * RCP_LN2)\n PRE_BLOCK = 128\n assert N_CTX % PRE_BLOCK == 0\n pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)\n delta = torch.empty_like(M)\n\n grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)\n\n return dq, dk, dv, None, None\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement forward and backward attention kernels for a transformer model. The forward kernel '_attn_fwd_inner' calculates attention scores using query, key, and value tensors, and the main '_attn_fwd' function applies these kernels over blocks of data in parallel. The backward pass is handled in a class-based autograd function '_attention' that calculates gradients with respect to inputs for backpropagation. Inputs for forward include query, key, value tensors and scaling factors, while backward handles gradients for these tensors.", - "description_2": "Use triton language to implement attention mechanism forward and backward kernels for transformer models, with efficient block processing and support for gradient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4)\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation\n )\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (matmul_kernel) with 15 parameters: pointers to matrices a, b, c; dimensions M, N, K; strides for a, b, c; block sizes BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K; group size GROUP_SIZE_M; and activation function ACTIVATION. The kernel computes the product of matrices A and B, storing the result in C, with optional leaky ReLU activation. The matmul function calls this kernel with 3 parameters: matrices a, b, and an optional activation function.", - "description_2": "Use triton language to create a matrix multiplication kernel with configurable block sizes and optional leaky ReLU activation, and a function to call this kernel for matrix multiplication.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Bias,\n Out,\n Lse,\n TMP,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_ob,\n stride_oh,\n stride_om,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = (\n Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n )\n k_ptrs = (\n K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n )\n v_ptrs = (\n V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n )\n if BIAS_TYPE == \"vector\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == \"matrix\":\n b_ptrs = (\n Bias\n + off_b * stride_bb\n + off_h * stride_bh\n + (offs_m[:, None] * stride_bm + offs_n[None, :])\n )\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(\n q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0\n )\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0,\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != \"none\":\n if BIAS_TYPE == \"vector\":\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(\n b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0\n ).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == \"matrix\":\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(\n b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((start_n + offs_n)[None, :] < seqlen_k),\n other=0.0,\n ).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0,\n )\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = (\n Out\n + off_b * stride_ob\n + off_h * stride_oh\n + (offs_m[:, None] * stride_om + offs_d[None, :])\n )\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(\n out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)\n )\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\n \"Last 2 dimensions of bias must be (1, seqlen_k)\" \" or (seqlen_q, seqlen_k)\"\n )\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q,\n k,\n v,\n bias,\n o,\n lse,\n tmp,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM,\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, softmax_scale\n\n", - "description_1": "Use triton language to implement a forward pass kernel for FlashAttention. The kernel _fwd_kernel has 34 parameters including Q, K, V for query, key, value matrices; Bias for the attention bias; Out for output; Lse and TMP for temporary storage; softmax_scale for scaling; various stride parameters to access the matrices, nheads for number of heads, seqlen_q and seqlen_k for sequence lengths, seqlen_q_rounded for rounded sequence length, headdim for head dimension, CACHE_KEY_SEQLEN_Q and CACHE_KEY_SEQLEN_K for cache keys, BIAS_TYPE and IS_CAUSAL for bias type and causality constant expressions, and BLOCK_M, BLOCK_N, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM as block and even constants. The function _flash_attn_forward calls _fwd_kernel with 26 parameters and has additional logic for setting up the data and bias.", - "description_2": "Use triton language to create a forward pass for FlashAttention in _fwd_kernel with specific parameters for matrix multiplication and attention computation, ensuring efficiency through the use of block and stride parameters, and apply this kernel in _flash_attn_forward to handle input tensors and manage biases, sequence lengths, and head dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n TMP,\n L,\n M,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n q = tl.load(q_ptrs)\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs)\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n@triton.jit\ndef _bwd_preprocess(\n Out,\n DO,\n L,\n NewDO,\n Delta,\n BLOCK_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n@triton.jit\ndef _bwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Out,\n DO,\n DQ,\n DK,\n DV,\n L,\n M,\n D,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n Z,\n H,\n N_CTX,\n num_block,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n q = tl.load(q_ptrs)\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n ds = p * dp * sm_scale\n dk += tl.dot(ds.to(q.dtype), q, trans_a=True)\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(k.dtype), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty(\n (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32\n )\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n tmp,\n L,\n m,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](\n o,\n do,\n l,\n do_scaled,\n delta,\n BLOCK_M=ctx.BLOCK,\n D_HEAD=ctx.BLOCK_DMODEL,\n )\n\n num_warps = 8\n _bwd_kernel[(ctx.grid[1],)](\n q,\n k,\n v,\n ctx.sm_scale,\n o,\n do_scaled,\n dq,\n dk,\n dv,\n l,\n m,\n delta,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK,\n BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL,\n num_warps=num_warps,\n num_stages=1,\n )\n return dq.to(q.dtype), dk, dv, None\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement fused attention using three main kernels: _fwd_kernel, _bwd_preprocess, and _bwd_kernel. The _fwd_kernel takes 22 parameters and computes the forward pass, involving matrix multiplications and accumulations for query, key, and value tensors with scale adjustments. The _bwd_preprocess kernel, with 6 parameters, prepares the gradients for backpropagation by normalizing the gradients and calculating a delta value. The _bwd_kernel, with 31 parameters, handles the computation of gradients for the inputs using the outputs from the forward pass, incorporating scaling factors and strides for addressing memory. These functions are wrapped by the _attention class, which provides the forward and backward methods for PyTorch's autograd functionality.", - "description_2": "Use triton language to create fused attention kernels that efficiently handle forward and backward passes using specialized GPU computations with triton.jit, leveraging memory management and compute optimizations for deep learning models.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n z_loss_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignore_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE\n SPLIT: tl.constexpr,\n PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)\n):\n row_idx = tl.program_id(0)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n sum_logits = 0.0 # For smoothing\n if not PRECOMPUTED_LSE:\n # Statistics for online softmax\n m_i = -float(\"inf\")\n l_i = 0.0\n for col_offset in range(0, n_cols, BLOCK_SIZE):\n cols = col_offset + tl.arange(0, BLOCK_SIZE)\n logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n if HAS_SMOOTHING:\n sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))\n m_i_new = tl.maximum(m_i, tl.max(logits))\n l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))\n m_i = m_i_new\n lse = tl.log(l_i) + m_i\n tl.store(lse_ptr + row_idx, lse)\n else:\n lse = tl.load(lse_ptr + row_idx)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx == ignore_index:\n loss = 0.0\n z_loss = 0.0\n else:\n label_idx -= class_start_idx\n if label_idx >= 0 and label_idx < n_cols:\n logits_label = tl.load(logits_ptr + label_idx) * logit_scale\n if HAS_SMOOTHING:\n loss = (\n (lse if not SPLIT else 0.0)\n - smoothing * sum_logits / total_classes\n - (1 - smoothing) * logits_label\n )\n else:\n loss = (lse if not SPLIT else 0.0) - logits_label\n else:\n # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss\n if HAS_SMOOTHING:\n loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n else:\n loss = 0.0\n if not SPLIT:\n z_loss = lse_square_scale * lse * lse\n loss += z_loss\n else:\n z_loss = 0.0\n tl.store(loss_ptr + row_idx, loss)\n if not SPLIT:\n tl.store(z_loss_ptr + row_idx, z_loss)\n\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignore_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx != ignore_index:\n dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n else:\n dloss = 0.0\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n lse = tl.load(lse_ptr + row_idx)\n probs = tl.exp(logits - lse)\n probs += 2.0 * lse_square_scale * lse * probs\n label_idx -= class_start_idx\n if HAS_SMOOTHING:\n smooth_positive = 1.0 - smoothing\n smooth_negative = smoothing / total_classes\n probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative\n else:\n probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)\n\n\nclass CrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n logits,\n labels,\n precomputed_lse=None,\n smoothing=0.0,\n logit_scale=1.0,\n lse_square_scale=0.0,\n ignore_index=-100,\n inplace_backward=False,\n process_group=None,\n ):\n if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:\n labels = F.pad(labels, (0, 1))[..., :-1]\n assert labels.data_ptr() % 16 == 0\n n_rows, n_cols = logits.shape\n assert labels.shape == (n_rows,)\n world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)\n total_classes = world_size * n_cols\n rank = 0 if process_group is None else torch.distributed.get_rank(process_group)\n class_start_idx = rank * n_cols\n use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0\n\n if logits.stride(-1) != 1:\n logits = logits.contiguous()\n MAX_BLOCK_SIZE = 16 * 1024\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)\n num_warps = (\n 4\n if BLOCK_SIZE < 2048\n else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))\n )\n losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)\n if use_precomputed_lse:\n assert precomputed_lse.shape == (n_rows,)\n lse = precomputed_lse.contiguous()\n else:\n lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)\n z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)\n with torch.cuda.device(logits.device.index):\n cross_entropy_fwd_kernel[(n_rows,)](\n losses, # data ptrs\n lse,\n z_losses,\n logits,\n labels,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignore_index,\n total_classes,\n class_start_idx,\n n_cols, # shapes\n logits.stride(0), # strides\n BLOCK_SIZE=BLOCK_SIZE, # constants\n SPLIT=world_size > 1,\n PRECOMPUTED_LSE=use_precomputed_lse,\n num_warps=num_warps,\n )\n\n if world_size > 1:\n if world_size > 1:\n lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)\n torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)\n handle_losses = torch.distributed.all_reduce(\n losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True\n )\n lse = torch.logsumexp(lse_allgather, dim=0)\n handle_losses.wait()\n losses += lse\n if lse_square_scale != 0.0:\n z_losses = lse_square_scale * lse.square()\n z_losses.masked_fill_(labels == ignore_index, 0.0)\n losses += z_losses\n else:\n z_losses = torch.zeros_like(losses)\n losses.masked_fill_(labels == ignore_index, 0.0)\n\n ctx.save_for_backward(logits, lse, labels)\n ctx.mark_non_differentiable(z_losses)\n ctx.smoothing = smoothing\n ctx.logit_scale = logit_scale\n ctx.lse_square_scale = lse_square_scale\n ctx.ignore_index = ignore_index\n ctx.total_classes = total_classes\n ctx.class_start_idx = class_start_idx\n ctx.inplace_backward = inplace_backward\n return losses, z_losses\n\n @staticmethod\n def backward(ctx, grad_losses, grad_z_losses):\n del grad_z_losses # z_losses are only for logging.\n\n logits, lse, labels = ctx.saved_tensors\n dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)\n n_rows, n_cols = logits.shape\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)\n num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)\n grid = lambda META: (n_rows, triton.cdiv(n_cols, META[\"BLOCK_SIZE\"])) # noqa\n with torch.cuda.device(logits.device.index):\n cross_entropy_bwd_kernel[grid](\n dlogits, # data ptrs\n grad_losses,\n logits,\n lse,\n labels,\n ctx.smoothing,\n ctx.logit_scale,\n ctx.lse_square_scale,\n ctx.ignore_index,\n ctx.total_classes,\n ctx.class_start_idx,\n n_cols, # shapes\n logits.stride(0), # strides\n dlogits.stride(0),\n grad_losses.stride(0),\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n )\n return dlogits, None, None, None, None, None, None, None, None, None\n\n\ndef cross_entropy_loss(\n logits: torch.Tensor,\n labels: torch.Tensor,\n precomputed_lse: Optional[torch.Tensor] = None,\n label_smoothing: float = 0.0,\n logit_scale: float = 1.0,\n lse_square_scale: float = 0.0,\n ignore_index=-100,\n inplace_backward: bool = False,\n process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Arguments:\n logits: (batch, vocab_size)\n labels: (batch,)\n label_smoothing: float\n logit_scale: float. Multiply logits by this scale before calculating the loss.\n lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.\n This is also referred to as \"z-loss\".\n ignore_index: int. If labels == ignore_index, the loss is set to 0.0.\n inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.\n This saves memory.\n process_group: if not None, we're doing Tensor Parallel: each process is responsible for\n one part of the vocab. The loss will be aggregated across processes.\n Returns:\n losses: (batch,), float\n z_losses: (batch,), float\n \"\"\"\n return CrossEntropyLoss.apply(\n logits,\n labels,\n precomputed_lse,\n label_smoothing,\n logit_scale,\n lse_square_scale,\n ignore_index,\n inplace_backward,\n process_group,\n )\n", - "description_1": "Use triton language to implement cross-entropy loss calculation with optional label smoothing and z-loss support. There are two main kernels: `cross_entropy_fwd_kernel` which computes the forward pass of the cross-entropy loss, and `cross_entropy_bwd_kernel` which computes the gradient (backward pass) for the logits. The forward kernel requires 18 parameters including pointers to the input/output tensors, constants related to the kernel execution (like block size), and options for smoothing. The backward kernel needs 18 parameters as well, focused on gradient computation with similar pointers and execution constants. An additional PyTorch function `cross_entropy_loss` manages the kernel execution, distributing the forward and backward passes across rows of the input logits.", - "description_2": "Use triton language to implement cross-entropy loss computation with both forward and backward kernel execution. Forward kernel computes loss with optional label smoothing, backward kernel calculates gradients for backpropagation. Integrate with PyTorch for tensor operations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\nfrom enum import Enum\nfrom typing import Optional\n\n_sqrt2pi = math.sqrt(2.0 / math.pi)\n_sqrt1_2 = math.sqrt(1.0 / 2)\n_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)\n\n\nclass Activation(str, Enum):\n SquaredReLU = \"squared_relu\"\n GeLU = \"gelu\"\n GeLUApprox = \"gelu_approx\"\n LeakyReLU = \"leaky_relu\"\n ReLU = \"relu\"\n\n\ndef get_triton_activation_kernel(activation: Optional[Activation]):\n return (\n {\n Activation.ReLU: relu,\n Activation.LeakyReLU: leaky_relu,\n Activation.GeLU: gelu,\n Activation.GeLUApprox: gelu_approx,\n Activation.SquaredReLU: squared_relu,\n }[activation]\n if activation\n else None\n )\n\n\ndef get_triton_activation_bwd_kernel(activation: Optional[Activation]):\n return (\n {\n Activation.ReLU: relu_grad,\n Activation.LeakyReLU: leaky_relu_grad,\n Activation.GeLU: gelu_grad,\n Activation.GeLUApprox: gelu_approx_grad,\n Activation.SquaredReLU: squared_relu_grad,\n }[activation]\n if activation\n else None\n )\n\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef cosh(x):\n exp_x = tl.exp(x)\n return (exp_x + 1.0 / exp_x) * 0.5\n\n\n# ReLU\n@triton.jit\ndef relu(x):\n \"\"\"\n ReLU_ activation function\n\n .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html\n \"\"\"\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n\n@triton.jit\ndef relu_grad(x):\n # ReLU is different from other activations\n # in that it does not require the input to retrospectively compute its gradient\n # here the input is the downstream gradient, and we return the upstream gradient directly\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n\n@triton.jit\ndef squared_relu(x):\n \"\"\"\n Squared ReLU activation, as proposed in the Primer_ paper.\n\n .. _Primer: https://arxiv.org/abs/2109.08668\n \"\"\"\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n\n# Leaky ReLU\n@triton.jit\ndef leaky_relu(x):\n \"\"\"\n LeakyReLU_ activation\n\n .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html\n \"\"\"\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n\n return tl.where(x >= 0, max_grad, min_grad)\n\n\n@triton.jit\ndef gelu(x):\n \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n\n\n@triton.jit\ndef gelu_grad(x):\n cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization\n return cdf + x * pdf\n\n\n@triton.jit\ndef gelu_approx(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit, with tanh approximation\n\n .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf\n \"\"\"\n return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))\n\n\n@triton.jit\ndef gelu_approx_grad(x):\n # CREDITS: Fast implementation proposed in\n # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (\n 1 + tanh_out\n )\n", - "description_1": "Use triton language to implement various activation functions and their gradients, including ReLU, Leaky ReLU, Squared ReLU, GELU, and GELU approximation. Each function takes a single parameter 'x', which is a tensor, and applies the respective activation or gradient computation.", - "description_2": "Use triton language to create activation functions and their gradients for neural networks, such as ReLU and GELU.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n ]\n + get_configs_io_bound(),\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n \"perf_model\": estimate_matmul_time,\n \"top_k\": 10,\n },\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_fwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n bias,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n stride_cm,\n stride_am,\n stride_ak,\n stride_bn,\n stride_bk,\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n A_ROWMAJOR: tl.constexpr,\n B_COLMAJOR: tl.constexpr,\n BIAS: tl.constexpr,\n SAVE_ACT_INPUT: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n if A_ROWMAJOR:\n A = A + (ram[:, None] * stride_am + rk[None, :])\n else:\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n if B_COLMAJOR:\n B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n else:\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n if A_ROWMAJOR:\n A += BLOCK_K\n else:\n A += BLOCK_K * stride_ak\n if B_COLMAJOR:\n B += BLOCK_K\n else:\n B += BLOCK_K * stride_bk\n\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n if SAVE_ACT_INPUT:\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n tl.store(act_in_ptrs, acc)\n\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n elif ACTIVATION == \"gelu_approx\":\n acc = gelu_approx(acc)\n elif ACTIVATION == \"squared_relu\":\n acc = squared_relu(acc)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc)\n\n\ndef triton_linear_act(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor] = None,\n activation: str = \"id\",\n save_act_input: bool = False,\n) -> torch.Tensor:\n\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = x.shape[:-1], x.shape[-1]\n batch_dim = batch_shape.numel()\n x_reshaped = x.reshape(batch_dim, n)\n\n if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n x_reshaped = x_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n bias = bias.contiguous() if bias is not None else None\n\n assert (\n x.dtype == weight.dtype\n ), f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert (\n x.dtype == bias.dtype\n ), f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert (\n x_reshaped.shape[1] == weight.shape[1]\n ), f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n\n assert (\n bias is None or bias.shape[0] == weight.shape[0]\n ), \"Incompatible dimensions in between weight and bias\"\n\n M, K = x_reshaped.shape\n N, K = weight.shape\n\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_input = torch.empty_like(output) if save_act_input else None\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n\n kernel_fwd[grid](\n output,\n act_input,\n x_reshaped,\n weight,\n bias if bias is not None else x,\n M,\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=output.stride(0),\n stride_am=x_reshaped.stride(0),\n stride_ak=x_reshaped.stride(1),\n stride_bk=weight.stride(1),\n stride_bn=weight.stride(0),\n BIAS=bias is not None,\n SAVE_ACT_INPUT=save_act_input,\n ACTIVATION=activation,\n A_ROWMAJOR=x_reshaped.stride(1) == 1,\n B_COLMAJOR=weight.stride(1) == 1,\n GROUP_M=8,\n )\n\n if not save_act_input:\n return output.reshape(*batch_shape, output.shape[-1])\n else:\n return (\n output.reshape(*batch_shape, output.shape[-1]),\n act_input.reshape(*batch_shape, act_input.shape[-1]),\n )\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n ]\n + get_configs_io_bound(),\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n \"perf_model\": estimate_matmul_time,\n \"top_k\": 10,\n },\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_bwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n stride_cm,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n if ACTIVATION != \"id\":\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n act_input = tl.load(act_in_ptrs).to(acc.dtype)\n if ACTIVATION == \"gelu\":\n acc *= gelu_grad(act_input)\n elif ACTIVATION == \"gelu_approx\":\n acc *= gelu_approx_grad(act_input)\n elif ACTIVATION == \"squared_relu\":\n acc *= squared_relu_grad(act_input)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc, mask=mask)\n\n\ndef triton_dgrad_act(\n grad_output: torch.Tensor,\n weight: torch.Tensor,\n activation: str = \"id\",\n act_input: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]\n batch_dim = batch_shape.numel()\n grad_output_reshaped = grad_output.reshape(batch_dim, n)\n\n if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:\n grad_output_reshaped = grad_output_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n\n assert (\n grad_output.dtype == weight.dtype\n ), f\"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}\"\n assert (\n grad_output_reshaped.shape[1] == weight.shape[0]\n ), f\"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}\"\n if activation != \"id\":\n assert act_input is not None, f\"act_input is required for activation {activation}\"\n\n M, K = grad_output_reshaped.shape\n K, N = weight.shape\n\n grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n\n kernel_bwd[grid](\n grad_input,\n act_input,\n grad_output_reshaped,\n weight,\n M,\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=grad_input.stride(0),\n stride_am=grad_output_reshaped.stride(0),\n stride_ak=grad_output_reshaped.stride(1),\n stride_bk=weight.stride(0),\n stride_bn=weight.stride(1),\n ACTIVATION=activation,\n GROUP_M=8,\n )\n\n return grad_input.reshape(*batch_shape, grad_input.shape[-1])\n", - "description_1": "Use triton language to implement two kernels and their calling functions for linear layer with optional activation and backpropagation. The first kernel `kernel_fwd` takes 35 parameters, including pointers to matrices (C, ACT_INPUT, A, B), bias, dimensions (M, N, K), cache keys, strides, and several meta-parameters for controlling matrix blocking, activation handling, and saving activation input. It computes an output matrix by performing matrix multiplication of A and B, optionally adds bias, applies activation, and can save activation input. The second kernel `kernel_bwd` takes 27 parameters, including pointers to matrices (C, ACT_INPUT, A, B), dimensions (M, N, K), cache keys, strides, and several meta-parameters similar to `kernel_fwd`, excluding bias and activation saving parameters. It computes the gradient for the backpropagation. Both kernels leverage Triton for autotuning configurations and heuristics optimization.", - "description_2": "Use triton language to implement linear layer computations with optional activation using `kernel_fwd`, and its gradient computation for backpropagation with `kernel_bwd`. Include proper handling of matrix dimensions, strides, and optional operations such as bias addition and activation function application.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef print_grid():\n # Get the process ID for each dimension\n x_pid = tl.program_id(0) # Process ID in the x dimension\n y_pid = tl.program_id(1) # Process ID in the y dimension\n z_pid = tl.program_id(2) # Process ID in the z dimension\n # Print the process IDs\n tl.device_print(\"x_pid: \", x_pid)\n tl.device_print(\"y_pid: \", y_pid)\n tl.device_print(\"z_pid: \", z_pid)\n\ndef grid(meta):\n \"\"\"\n Args: meta is the meta information that can be used to determine the grid\n \"\"\"\n return (4, 2)\n\n# Launch the kernel with the specified grid\nprint_grid[grid]()\n", - "description_1": "Use triton language to define a kernel 'print_grid' that prints the process IDs for x, y, and z dimensions using tl.device_print. The kernel is launched with a grid determined by the 'grid' function, which returns a tuple (4, 2) representing the grid dimensions.", - "description_2": "Use triton language to create a kernel that prints process IDs for three dimensions and launch it with a specified grid.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, z_ptr, size, block_size: tl.constexpr):\n # Get program ID\n pid = tl.program_id(0)\n # Calculate offsets for the current block\n offsets = tl.arange(0, block_size) + pid * block_size \n # Create a mask for valid offsets\n mask = offsets < size\n \n # Load x and y values from memory\n x = tl.load(x_ptr + offsets, mask)\n y = tl.load(y_ptr + offsets, mask)\n \n # Perform element-wise addition\n z = x + y\n \n # Store the result back to memory\n tl.store(z_ptr + offsets, z, mask)\n \n# Function to call the Triton kernel\ndef add(x, y):\n # Create an empty tensor for the result\n z = torch.empty_like(x, device='cuda')\n size = z.numel()\n \n # Define the grid size for the kernel launch\n def grid(meta):\n return (triton.cdiv(size, meta[\"block_size\"]),)\n \n # Launch the Triton kernel\n add_kernel[grid](x, y, z, size, 1024)\n \n return z\n\n# Example usage\nsize = 2 ** 16\nx = torch.randn(size, device=\"cuda\")\ny = torch.randn(size, device=\"cuda\")\na = add(x, y)\nb = x + y\nassert torch.allclose(a, b)\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel 'add_kernel' takes 5 parameters: x_ptr, y_ptr, z_ptr (pointers to input and output tensors), size (total number of elements), and block_size (size of each block). The function 'add' calls this kernel, preparing the output tensor and determining the grid size for execution.", - "description_2": "Use triton language to perform element-wise addition of two tensors on the GPU using a custom kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _reg_matmul(\n pid_n, type_id,\n start_off,\n input, other, output, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n EVEN_N: tl.constexpr,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr\n):\n offs_m = start_off + tl.arange(0, TILE_M)\n offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)\n offs_k = tl.arange(0, TILE_K)\n rn = tl.max_contiguous(tl.multiple_of(offs_n % N, TILE_N), TILE_N)\n other_ptrs = other + type_id * stride_other_b + \\\n (offs_k[:, None] * stride_other_k + rn[None, :] * stride_other_n)\n b = tl.load(other_ptrs)\n\n # [M, K] x [K, N] -> [M, N]\n input_ptrs = input + (offs_m[:, None] * stride_input_m + offs_k[None, :] * stride_input_k)\n output_ptrs = output + stride_output_m * offs_m[:, None] + stride_output_n * offs_n[None, :]\n for _ in range(0, BLOCK_SIZE):\n a = tl.load(input_ptrs)\n acc = tl.dot(a, b, out_dtype=out_dtype).to(output.dtype.element_ty)\n if EVEN_N:\n tl.store(output_ptrs, acc)\n else:\n mask_n = offs_n[None, :] < N\n tl.store(output_ptrs, acc, mask=mask_n)\n input_ptrs += TILE_M * stride_input_m\n output_ptrs += TILE_M * stride_output_m\n\n\n@triton.jit\ndef _general_matmul(\n pid_n,\n start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype: tl.constexpr,\n MASK_M: tl.constexpr,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_K: tl.constexpr\n):\n offs_m = start_off + tl.arange(0, TILE_M)\n offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)\n offs_k = tl.arange(0, TILE_K)\n rn = tl.max_contiguous(tl.multiple_of(offs_n % N, TILE_N), TILE_N)\n\n # [M, K] x [K, N] -> [M, N]\n input_ptrs = input + (offs_m[:, None] * stride_input_m + offs_k[None, :] * stride_input_k)\n other_ptrs = other + \\\n (offs_k[:, None] * stride_other_k + rn[None, :] * stride_other_n)\n\n acc = tl.zeros((TILE_M, TILE_N), dtype=out_dtype)\n mask_m = offs_m[:, None] < end_off if MASK_M else True\n\n k_iter = K // TILE_K if EVEN_K else tl.cdiv(K, TILE_K)\n for k in range(0, k_iter):\n if EVEN_K:\n if MASK_M:\n a = tl.load(input_ptrs, mask=mask_m, other=0.0)\n b = tl.load(other_ptrs)\n else:\n a = tl.load(input_ptrs)\n b = tl.load(other_ptrs)\n else:\n if MASK_M:\n a = tl.load(input_ptrs, mask=mask_m & (offs_k[None, :] + k * TILE_K < K), other=0.0)\n b = tl.load(other_ptrs, mask=(offs_k[:, None] + k * TILE_K < K), other=0.0)\n else:\n a = tl.load(input_ptrs, mask=(offs_k[None, :] + k * TILE_K < K), other=0.0)\n b = tl.load(other_ptrs, mask=(offs_k[:, None] + k * TILE_K < K), other=0.0)\n acc += tl.dot(a, b, out_dtype=out_dtype)\n input_ptrs += TILE_K * stride_input_k\n other_ptrs += TILE_K * stride_other_k\n\n acc = acc.to(output.dtype.element_ty)\n c_ptrs = output + stride_output_m * \\\n offs_m[:, None] + stride_output_n * offs_n[None, :]\n if EVEN_N:\n if MASK_M:\n tl.store(c_ptrs, acc, mask=mask_m)\n else:\n tl.store(c_ptrs, acc)\n else:\n mask_n = offs_n[None, :] < N\n if MASK_M:\n tl.store(c_ptrs, acc, mask=mask_m & mask_n)\n else:\n tl.store(c_ptrs, acc, mask_n)\n\n\n@triton.jit\ndef _prefetch_matmul(\n pid_n, start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype: tl.constexpr,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_K: tl.constexpr,\n BLOCK_SIZE: tl.constexpr\n):\n offs_m = start_off + tl.arange(0, TILE_M)\n offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)\n offs_k = tl.arange(0, TILE_K)\n rn = tl.max_contiguous(tl.multiple_of(offs_n % N, TILE_N), TILE_N)\n\n # [M, K] x [K, N] -> [M, N]\n input_ptrs = input + (offs_m[:, None] * stride_input_m + offs_k[None, :] * stride_input_k)\n other_ptrs = other + \\\n (offs_k[:, None] * stride_other_k + rn[None, :] * stride_other_n)\n output_ptrs = output + stride_output_m * offs_m[:, None] + stride_output_n * offs_n[None, :]\n original_input_ptrs = input_ptrs\n original_other_ptrs = other_ptrs\n\n acc = tl.zeros((TILE_M, TILE_N), dtype=out_dtype)\n mask_n = offs_n[None, :] < N\n\n k_iters = K // TILE_K if EVEN_K else tl.cdiv(K, TILE_K)\n for k in range(0, k_iters * BLOCK_SIZE):\n i = k % k_iters\n if EVEN_K:\n a = tl.load(input_ptrs)\n b = tl.load(other_ptrs)\n else:\n a = tl.load(input_ptrs, mask=offs_k[None, :] + i * TILE_K < K, other=0.0)\n b = tl.load(other_ptrs, mask=offs_k[:, None] + i * TILE_K < K, other=0.0)\n acc += tl.dot(a, b, out_dtype=out_dtype)\n if i == k_iters - 1:\n if EVEN_N:\n tl.store(output_ptrs, acc.to(output.dtype.element_ty))\n else:\n tl.store(output_ptrs, acc.to(output.dtype.element_ty), mask_n)\n output_ptrs += TILE_M * stride_output_m\n if i == k_iters - 1:\n acc = tl.zeros((TILE_M, TILE_N), dtype=out_dtype)\n original_input_ptrs += TILE_M * stride_input_m\n input_ptrs = original_input_ptrs\n other_ptrs = original_other_ptrs\n else:\n input_ptrs += TILE_K * stride_input_k\n other_ptrs += TILE_K * stride_other_k\n\n\n@triton.jit\ndef _dynamic_matmul(\n pid_k, pid_n, next_id,\n input, grad_output, grad_other, grad_other_tiles,\n stride_input_m, stride_input_k,\n stride_grad_output_m, stride_grad_output_n,\n stride_grad_other_b, stride_grad_other_k, stride_grad_other_n,\n K, N, M, length,\n out_dtype: tl.constexpr,\n BLOCK_LENGTH: tl.constexpr,\n TILE_K: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_K: tl.constexpr,\n EVEN_M: tl.constexpr,\n DETERMINISTIC: tl.constexpr\n):\n offs_k = pid_k * TILE_K + tl.arange(0, TILE_K)\n offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)\n offs_m = tl.arange(0, TILE_M)\n acc = tl.zeros((TILE_K, TILE_N), dtype=out_dtype)\n mask_k = offs_k[:, None] < K if not EVEN_K else True\n mask_n = offs_n[None, :] < N if not EVEN_N else True\n\n # [M, K] -> [K, M]\n input_ptrs = input + (offs_m[None, :] * stride_input_m + offs_k[:, None] * stride_input_k)\n # [M, N]\n grad_output_ptrs = grad_output + (offs_m[:, None] * stride_grad_output_m + offs_n[None, :] * stride_grad_output_n)\n\n m_iter = length // TILE_M if EVEN_M else tl.cdiv(length, TILE_M)\n for m in range(0, m_iter):\n if EVEN_K:\n if EVEN_M:\n a = tl.load(input_ptrs)\n else:\n a = tl.load(input_ptrs, mask=(offs_m[None, :] + m * TILE_M < length), other=0.0)\n else:\n if EVEN_M:\n a = tl.load(input_ptrs, mask=mask_k, other=0.0)\n else:\n a = tl.load(input_ptrs, mask=mask_k & (offs_m[None, :] + m * TILE_M < length), other=0.0)\n if EVEN_N:\n if EVEN_M:\n b = tl.load(grad_output_ptrs)\n else:\n b = tl.load(grad_output_ptrs, mask=(offs_m[:, None] + m * TILE_M < length), other=0.0)\n else:\n if EVEN_M:\n b = tl.load(grad_output_ptrs, mask=mask_n)\n else:\n b = tl.load(grad_output_ptrs, mask=mask_n & (offs_m[:, None] + m * TILE_M < length), other=0.0)\n\n acc += tl.dot(a, b, out_dtype=out_dtype)\n input_ptrs += TILE_M * stride_input_m\n grad_output_ptrs += TILE_M * stride_grad_output_m\n\n acc = acc.to(grad_other.dtype.element_ty)\n\n if DETERMINISTIC:\n if M <= BLOCK_LENGTH:\n c_ptrs = grad_other + \\\n stride_grad_other_k * offs_k[:, None] + stride_grad_other_n * offs_n[None, :]\n if EVEN_N and EVEN_K:\n tl.store(c_ptrs, acc)\n else:\n c_mask = mask_k & mask_n\n tl.store(c_ptrs, acc, mask=c_mask)\n else:\n c_ptrs = grad_other_tiles + \\\n next_id * stride_grad_other_b + stride_grad_other_k * offs_k[:, None] + stride_grad_other_n * offs_n[None, :]\n if EVEN_N and EVEN_K:\n tl.store(c_ptrs, acc)\n else:\n c_mask = mask_k & mask_n\n tl.store(c_ptrs, acc, mask=c_mask)\n else:\n c_ptrs = grad_other + \\\n stride_grad_other_k * offs_k[:, None] + stride_grad_other_n * offs_n[None, :]\n if M <= BLOCK_LENGTH:\n if EVEN_N and EVEN_K:\n tl.store(c_ptrs, acc)\n else:\n c_mask = mask_k & mask_n\n tl.store(c_ptrs, acc, mask=c_mask)\n else:\n if EVEN_N and EVEN_K:\n tl.atomic_add(c_ptrs, acc)\n else:\n c_mask = mask_k & mask_n\n tl.atomic_add(c_ptrs, acc, mask=c_mask)\n", - "description_1": "Use triton language to define multiple matrix multiplication kernels with varying levels of optimization and functionality, including block-based, general, prefetch, and dynamic behavior for both regular and gradient calculations. These functions take a range of parameters, from pointers to input, output, and intermediate matrices, to kernel size and tiling configurations, to ensure efficient parallel computations for matrix operations.", - "description_2": "Use triton language to implement and execute efficient block-based and dynamic matrix multiplication operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport functools\nfrom triton.runtime import driver\n\n@triton.jit(noinline=True)\ndef _dispatch(\n pid_n,\n start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype: tl.constexpr,\n MASK_M: tl.constexpr,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n EVEN_N: tl.constexpr,\n DYNAMIC_TILING: tl.constexpr\n):\n TILE_M_16: tl.constexpr = 16\n TILE_M_32: tl.constexpr = 32\n TILE_M_64: tl.constexpr = 64\n\n if end_off - start_off <= TILE_M_16 and DYNAMIC_TILING:\n _general_matmul(\n pid_n,\n start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n MASK_M=True,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n TILE_M=TILE_M_16,\n TILE_N=TILE_N,\n TILE_K=TILE_K\n )\n elif end_off - start_off <= TILE_M_32 and DYNAMIC_TILING:\n _general_matmul(\n pid_n,\n start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n MASK_M=True,\n TILE_M=TILE_M_32,\n TILE_N=TILE_N,\n TILE_K=TILE_K\n )\n elif end_off - start_off <= TILE_M_64 and DYNAMIC_TILING:\n _general_matmul(\n pid_n,\n start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n MASK_M=True,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n TILE_M=TILE_M_64,\n TILE_N=TILE_N,\n TILE_K=TILE_K\n )\n else:\n _general_matmul(\n pid_n,\n start_off, end_off,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n MASK_M=MASK_M,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n TILE_M=TILE_M,\n TILE_N=TILE_N,\n TILE_K=TILE_K\n )\n\n\n@triton.jit\ndef _noncontiguous_block(\n input_tiles,\n next_id, next_next_id, pid_n,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n NUM_TILES: tl.constexpr,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n EVEN_N: tl.constexpr\n):\n for _ in range(0, BLOCK_SIZE):\n if next_id < NUM_TILES and next_id != -1:\n start_off = tl.load(input_tiles + 5 * next_id + 2)\n end_off = tl.load(input_tiles + 5 * next_id + 3)\n length = end_off - start_off\n\n if length > 0:\n type_id = tl.load(input_tiles + 5 * next_id + 1)\n for i in range(0, tl.cdiv(length, TILE_M)):\n _dispatch(\n pid_n,\n start_off + i * TILE_M, end_off,\n input, other + type_id * stride_other_b, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n MASK_M=True,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n TILE_M=TILE_M,\n TILE_N=TILE_N,\n TILE_K=TILE_K,\n DYNAMIC_TILING=True,\n )\n next_id = next_next_id\n next_next_id += 1\n\n\n@triton.jit\ndef _contiguous_block(\n input_tiles,\n next_id, pid_n,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n EVEN_N: tl.constexpr,\n EQUAL_K: tl.constexpr,\n):\n start_off = tl.load(input_tiles + 5 * next_id + 2)\n type_id = tl.load(input_tiles + 5 * next_id + 1)\n if EQUAL_K:\n _reg_matmul(\n pid_n, type_id,\n start_off,\n input, other, output, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n BLOCK_SIZE=BLOCK_SIZE,\n TILE_M=TILE_M,\n TILE_N=TILE_N,\n TILE_K=TILE_K,\n EVEN_N=EVEN_N,\n )\n else:\n for i in range(0, BLOCK_SIZE):\n _general_matmul(\n pid_n,\n start_off + i * TILE_M,\n start_off + (i + 1) * TILE_M,\n input, other + type_id * stride_other_b, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n MASK_M=True,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n TILE_M=TILE_M,\n TILE_N=TILE_N,\n TILE_K=TILE_K\n )\n\n\ndef _early_config_prune(configs: triton.Config, named_args: dict, is_weight: bool, **kwargs):\n if not GlobalConfig.with_autotune:\n return [configs[0]]\n pruned_configs = []\n element_size = named_args['input'].element_size()\n N = named_args['N']\n K = named_args['K']\n TILE_SIZE_M = kwargs['TILE_SIZE_M']\n BLOCK_SIZE = kwargs['BLOCK_SIZE']\n device = torch.cuda.current_device()\n min_tile_size_n = min([config.kwargs['TILE_SIZE_N'] for config in configs])\n min_tile_size_k = min([config.kwargs['TILE_SIZE_K'] for config in configs])\n max_shared_memory = driver.active.utils.get_device_properties(device)[\"max_shared_mem\"]\n for config in configs:\n kw = config.kwargs\n TILE_SIZE_N = kw['TILE_SIZE_N']\n TILE_SIZE_K = kw['TILE_SIZE_K']\n if is_weight:\n if ((TILE_SIZE_K > K and TILE_SIZE_K != min_tile_size_k) or (TILE_SIZE_N > N and TILE_SIZE_N != min_tile_size_n)):\n continue\n required_shared_memory = (TILE_SIZE_K + TILE_SIZE_N) * TILE_SIZE_M * config.num_stages * element_size\n if required_shared_memory > max_shared_memory:\n continue\n if TILE_SIZE_K >= 256 and TILE_SIZE_N >= 256 and config.num_warps == 4:\n continue\n if config.num_stages - 1 > BLOCK_SIZE:\n continue\n else:\n if TILE_SIZE_N > N and TILE_SIZE_N != min_tile_size_n:\n continue\n required_shared_memory = (TILE_SIZE_M + TILE_SIZE_N) * TILE_SIZE_K * config.num_stages * element_size\n if required_shared_memory > max_shared_memory:\n continue\n if TILE_SIZE_N >= 256 and TILE_SIZE_K >= 256 and config.num_warps == 4:\n continue\n if TILE_SIZE_K != K and (TILE_SIZE_K * (config.num_stages - 1) > K or TILE_SIZE_K * (config.num_stages + 1) < K):\n continue\n if TILE_SIZE_K == K and K >= 128:\n continue\n pruned_configs.append(config)\n if len(pruned_configs) == 0:\n pruned_configs.append(configs[0])\n if is_debug():\n print(f\"Number of configs pruned from {len(configs)} to {len(pruned_configs)}, is_weight={is_weight}\")\n return pruned_configs\n\n\n@triton.autotune(\n configs=_generate_configs(),\n key=['N', 'K', 'stddev_tile_size_m', 'avg_tile_size_m'],\n prune_configs_by={\n 'early_config_prune': functools.partial(_early_config_prune, is_weight=False),\n },\n rep=10,\n use_cuda_graph=True,\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % args['TILE_SIZE_K'] == 0,\n 'EVEN_N': lambda args: args['N'] % args['TILE_SIZE_N'] == 0,\n 'EQUAL_K': lambda args: args['K'] == args['TILE_SIZE_K']\n})\n@triton.jit\ndef segment_matmul_kernel(\n input, input_tiles, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n stddev_tile_size_m,\n avg_tile_size_m,\n out_dtype: tl.constexpr,\n NUM_TILES: tl.constexpr,\n NUM_BLOCKS: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n TILE_SIZE_M: tl.constexpr,\n EVEN_K: tl.constexpr,\n EVEN_N: tl.constexpr,\n EQUAL_K: tl.constexpr,\n TILE_SIZE_N: tl.constexpr,\n TILE_SIZE_K: tl.constexpr\n):\n TILE_N: tl.constexpr = TILE_SIZE_N\n TILE_K: tl.constexpr = TILE_SIZE_K\n TILE_M: tl.constexpr = TILE_SIZE_M\n\n GROUP_M: tl.constexpr = 4\n\n pid = tl.program_id(axis=0)\n grid_n = tl.cdiv(N, TILE_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(NUM_BLOCKS - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n next_id = pid_m\n next_next_id = tl.load(input_tiles + 5 * next_id + 4)\n if next_next_id == 0:\n _contiguous_block(\n input_tiles,\n next_id, pid_n,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n BLOCK_SIZE=BLOCK_SIZE,\n TILE_M=TILE_M,\n TILE_N=TILE_N,\n TILE_K=TILE_K,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N,\n EQUAL_K=EQUAL_K,\n )\n else:\n _noncontiguous_block(\n input_tiles,\n next_id, next_next_id, pid_n,\n input, other, output,\n K, N,\n stride_input_m, stride_input_k,\n stride_other_b, stride_other_k, stride_other_n,\n stride_output_m, stride_output_n,\n out_dtype=out_dtype,\n BLOCK_SIZE=BLOCK_SIZE,\n NUM_TILES=NUM_TILES,\n TILE_M=TILE_M,\n TILE_N=TILE_N,\n TILE_K=TILE_K,\n EVEN_K=EVEN_K,\n EVEN_N=EVEN_N)\n\n\ndef segment_matmul_forward(input: torch.Tensor, other: torch.Tensor,\n input_tiles: torch.Tensor, input_slices: torch.Tensor,\n output: torch.Tensor = None,\n num_blocks: Optional[int] = None, block_size: int = 1,\n tile_size: int = 64, out_dtype: Optional[torch.dtype] = None,\n avg_tile_size: Optional[float] = None, stddev_tile_size: Optional[float] = None, **kwargs):\n assert input.size(1) == other.size(1)\n assert input_tiles.device == input_slices.device == input.device == other.device\n assert input.dim() == 2\n assert other.dim() == 3\n M: int = input.size(0)\n K: int = input.size(1)\n N: int = other.size(2)\n num_tiles = input_tiles.size(0)\n num_blocks = num_blocks or num_tiles\n if output is None:\n output = torch.empty(M, N, dtype=input.dtype, device=input.device)\n\n def grid(meta):\n return (num_blocks * triton.cdiv(N, meta['TILE_SIZE_N']),)\n\n out_dtype = torch_dtype_to_triton_dtype(out_dtype or input.dtype)\n segment_matmul_kernel[grid](\n input, input_tiles, other, output,\n K, N,\n input.stride(0), input.stride(1),\n other.stride(0), other.stride(1), other.stride(2),\n output.stride(0), output.stride(1),\n binning(stddev_tile_size, 32),\n binning(avg_tile_size, 16),\n NUM_TILES=num_tiles,\n NUM_BLOCKS=num_blocks,\n BLOCK_SIZE=block_size,\n out_dtype=out_dtype,\n TILE_SIZE_M=tile_size,\n )\n return output\n", - "description_1": "Use triton language to define and compile multiple Triton kernels. These kernels perform operations on tiled matrix multiplication, handling both contiguous and non-contiguous blocks. The main operations include dispatching tiles for multiplication, managing matrix strides, and handling edge cases for performance optimization. Input, output, and intermediate matrices are managed using specific tensor strides, and kernel performance is optimized through autotuning configurations such as tile size and number of warps.", - "description_2": "Use triton language to implement a matrix multiplication kernel optimized for different tile sizes and input configurations. Utilize autotuning to select the best configuration based on performance heuristics.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel that computes (t * m) % P and stores the result in y_ptr\n@triton.jit\ndef add_kernel(\n y_ptr, # Pointer to output array in GPU memory\n n, # Size of the block (unused in kernel)\n t, # Multiplier for m\n m, # Large constant to be multiplied by t\n P, # Modulus\n BLOCK_SIZE: tl.constexpr # Size of each block\n):\n offsets = tl.arange(0, BLOCK_SIZE)\n tl.store(y_ptr + offsets, (t * m) % P)\n\n# Constants\nBLOCK_SIZE = 128\nP = 2038074743\nm = 4096 * 4096\n\n# Output tensor\ny = torch.zeros((BLOCK_SIZE,), device='cuda', dtype=torch.long)\n\n# Calculate (t * m) % P for t=1023\nt = 1023\nprint('Python: {} % {} = {}'.format(t * m, P, (t * m) % P))\nadd_kernel[(1,)](y, BLOCK_SIZE, t, m, P, BLOCK_SIZE)\nprint('Triton: {}'.format(y[0].item()))\n\n# Calculate (t * m) % P for t=3\nt = 3\nprint('Python: {} % {} = {}'.format(t * m, P, (t * m) % P))\nadd_kernel[(1,)](y, BLOCK_SIZE, t, m, P, BLOCK_SIZE)\nprint('Triton: {}'.format(y[0].item()))\n", - "description_1": "Use triton language to define a kernel that computes the modulo operation (t * m) % P for a given set of parameters and stores the result in a GPU tensor. The kernel function takes a pointer to the output tensor (y_ptr), block size (n), a multiplier (t), a large constant (m), a modulus (P), and a block size constant (BLOCK_SIZE). It computes the result of (t * m) % P using Triton's range of offsets and stores it in the provided tensor memory. The kernel is launched with a single block configuration to compute the result for different values of t.", - "description_2": "Use triton language to compute (t * m) % P and store the result in a GPU tensor by creating a kernel that executes the operation in a single block.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef dot_kernel(x_ptr, y_ptr, z_ptr, BLOCK_SIZE: tl.constexpr):\n # Triton kernel for matrix multiplication\n r = tl.program_id(0) * BLOCK_SIZE\n c = tl.program_id(1) * BLOCK_SIZE\n b = tl.program_id(2)\n bid = b * 4 * BLOCK_SIZE * BLOCK_SIZE\n x_val = tl.load(\n x_ptr\n + bid\n + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE\n + tl.arange(0, BLOCK_SIZE)[None, :]\n )\n y_val = tl.load(\n y_ptr\n + bid\n + tl.arange(0, BLOCK_SIZE)[:, None] * 2 * BLOCK_SIZE\n + tl.arange(0, BLOCK_SIZE)[None, :]\n + c\n )\n z = tl.dot(x_val, y_val)\n x_val = tl.load(\n x_ptr\n + bid\n + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE\n + tl.arange(0, BLOCK_SIZE)[None, :]\n + BLOCK_SIZE\n )\n y_val = tl.load(\n y_ptr\n + bid\n + (BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE\n + tl.arange(0, BLOCK_SIZE)[None, :]\n + c\n )\n z = z + tl.dot(x_val, y_val)\n tl.store(\n z_ptr\n + (b * (2 * BLOCK_SIZE) * (2 * BLOCK_SIZE - 10))\n + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * (2 * BLOCK_SIZE - 10)\n + tl.arange(0, BLOCK_SIZE)[None, :]\n + c,\n z,\n mask=tl.arange(0, BLOCK_SIZE)[None, :] + c < 2 * BLOCK_SIZE - 10,\n )\n\ndef perform_dot(device, BLOCK_SIZE):\n # Function to set up matrices and invoke the Triton kernel\n x = torch.randn((2 * BLOCK_SIZE, 2 * BLOCK_SIZE), device=device)\n y = torch.randn((2 * BLOCK_SIZE, 2 * BLOCK_SIZE), device=device)\n z = torch.zeros((2 * BLOCK_SIZE, 2 * BLOCK_SIZE - 10), device=device)\n dot_kernel[(2, 2)](x, y, z, BLOCK_SIZE)\n return x, y, z\n", - "description_1": "Use triton language to perform matrix multiplication with a kernel called dot_kernel. This kernel has 4 parameters: x_ptr (input matrix 1), y_ptr (input matrix 2), z_ptr (output matrix), and BLOCK_SIZE (the size of each block in the matrix multiplication). The perform_dot function prepares data and launches this kernel.", - "description_2": "Use triton language to implement a matrix multiplication kernel and execute it using torch tensors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef sum_kernel(\n x_ptr,\n y_ptr,\n STRIDE: tl.constexpr,\n CHANNEL_SIZE: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n # Load a block of data from x_ptr\n x_val = tl.load(\n x_ptr\n + tl.arange(0, BLOCK_SIZE)[:, None] * STRIDE\n + tl.arange(0, CHANNEL_SIZE)[None, :]\n )\n # Sum the loaded data along axis 1\n x_sum = tl.sum(x_val, axis=1)\n # Store the result in y_ptr\n tl.store(y_ptr + tl.arange(0, BLOCK_SIZE), x_sum)\n\ndef perform_sum(device, BLOCK_SIZE, CHANNEL_SIZE):\n # Initialize input and output tensors\n x = torch.ones((BLOCK_SIZE, CHANNEL_SIZE), device=device, dtype=torch.long)\n y = torch.zeros((BLOCK_SIZE), device=device, dtype=torch.long)\n # Launch the Triton kernel\n sum_kernel[(1,)](x, y, CHANNEL_SIZE, CHANNEL_SIZE, BLOCK_SIZE)\n return x, y\n", - "description_1": "Use triton language to define a kernel 'sum_kernel' that takes 5 parameters: x_ptr (pointer to input tensor), y_ptr (pointer to output tensor), STRIDE (constant stride value), CHANNEL_SIZE (constant channel size), and BLOCK_SIZE (constant block size). The kernel loads a block of data from x_ptr, sums it along axis 1, and stores the result in y_ptr. The 'perform_sum' function initializes input and output tensors and launches the 'sum_kernel' with the specified parameters.", - "description_2": "Use triton language to create a kernel that sums blocks of data from an input tensor and stores the results in an output tensor.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = tl.zeros(x.shape, dtype=x.dtype)\n output = output + x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor, access_size: int, BLOCK_SIZE: int = 1024):\n output = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(access_size, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, access_size, BLOCK_SIZE=BLOCK_SIZE)\n return output, grid\n\ndef perform_vec_add(device, size, access_size=None):\n torch.manual_seed(0)\n x = torch.rand(size, device=device)\n y = torch.rand(size, device=device)\n access_size = size if access_size is None else access_size\n output, _ = add(x, y, access_size=access_size)\n return x, y, output\n", - "description_1": "Use triton language to create a kernel 'add_kernel' with parameters: (x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE). The kernel adds elements of two input vectors and writes the result to an output vector using block-based processing. A mask guards against out-of-bounds memory accesses. The 'add' function launches this kernel on a 1D grid, determining the number of parallel instances based on input size and block size.", - "description_2": "Use triton language to develop a vector addition kernel with masked memory access. Implement the kernel launch using a grid size derived from input dimensions.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef compute_attn_lp_loss_kernel(\n q, q_stride_n, q_stride_h, q_stride_t, q_stride_hdim,\n k, k_stride_n, k_stride_h, k_stride_t, k_stride_hdim,\n p: float,\n H: int, TDST: int, TSRC: int, HDIM: int,\n HDIM_MAX: tl.constexpr,\n KV_BLOCK_SIZE: tl.constexpr, Q_BLOCK_SIZE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n attend_lengths, attend_lengths_stride_n, attend_lengths_stride_t,\n l, l_stride_n, l_stride_h, l_stride_t,\n m, m_stride_n, m_stride_h, m_stride_t,\n output, output_stride_n, output_stride_h, output_stride_t,\n):\n # Triton kernel for computing attention Lp loss\n batch_idx = tl.program_id(1)\n n_idx = batch_idx // H\n h_idx = batch_idx % H\n q_begin = tl.program_id(0) * Q_BLOCK_SIZE\n q_idx = tl.arange(0, Q_BLOCK_SIZE)\n kv_idx = tl.arange(0, KV_BLOCK_SIZE)\n d_idx = tl.arange(0, HDIM_MAX)\n\n q_chunk = tl.load(\n q +\n n_idx * q_stride_n +\n h_idx * q_stride_h +\n (q_begin + q_idx)[:, None] * q_stride_t +\n d_idx[None, :] * q_stride_hdim,\n mask=(\n (q_begin + q_idx < TDST)[:, None] &\n (d_idx < HDIM)[None, :]\n ),\n other=0\n ) # [q_blk, hd]\n\n attend_lengths_chunk = None\n if attend_lengths is not None:\n attend_lengths_chunk = tl.load(\n attend_lengths +\n n_idx * attend_lengths_stride_n +\n (q_begin + q_idx) * attend_lengths_stride_t,\n mask=(q_begin + q_idx < TDST),\n other=0\n ) # [q_blk]\n\n for kv_begin in range(0, TSRC, KV_BLOCK_SIZE):\n k_chunk = tl.load(\n k +\n n_idx * k_stride_n +\n h_idx * k_stride_h +\n (kv_begin + kv_idx)[None, :] * k_stride_t +\n d_idx[:, None] * k_stride_hdim,\n mask=(\n (kv_begin + kv_idx < TSRC)[None, :] &\n (d_idx < HDIM)[:, None]\n ),\n other=0\n ) # [hd, kv_blk]\n output_chunk = tl.load(\n output +\n n_idx * output_stride_n +\n h_idx * output_stride_h +\n (q_begin + q_idx)[:, None] * output_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[:, None]\n ),\n other=0\n ) # [q_blk, 1]\n l_chunk = tl.load(\n l +\n n_idx * l_stride_n +\n h_idx * l_stride_h +\n (q_begin + q_idx)[:, None] * l_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[:, None]\n ),\n other=0\n ) # [q_blk, 1]\n m_chunk = tl.load(\n m +\n n_idx * m_stride_n +\n h_idx * m_stride_h +\n (q_begin + q_idx)[:, None] * m_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[:, None]\n ),\n other=-1e9\n ) # [q_blk, 1]\n attn_scores = tl.dot(q_chunk.to(tl.float16), k_chunk.to(tl.float16)).to(tl.float32) # [q_blk, kv_blk]\n if IS_CAUSAL:\n attn_scores = tl.where(\n (kv_begin + kv_idx)[None, :] > (q_begin + q_idx)[:, None],\n -1e9,\n attn_scores\n )\n if attend_lengths is not None:\n attn_scores = tl.where(\n (kv_begin + kv_idx)[None, :] >= attend_lengths_chunk[:, None],\n -1e9,\n attn_scores\n )\n m_tilde = tl.max(attn_scores, axis=1)[:, None] # [q_blk, 1]\n P_tilde = tl.exp(attn_scores - m_tilde) # [q_blk, kv_blk]\n l_tilde = tl.sum(P_tilde, axis=1)[:, None] # [q_blk, 1]\n m_new = tl.maximum(m_chunk, m_tilde) # [q_blk, 1]\n l_new = (\n tl.exp(m_chunk - m_new) * l_chunk +\n tl.exp(m_tilde - m_new) * l_tilde\n ) # [q_blk, 1]\n\n loss_new = tl.exp(tl.log(l_new) * -p) * (\n tl.exp(p * (tl.log(l_chunk) + m_chunk - m_new)) * output_chunk +\n tl.exp(p * (m_tilde - m_new)) * tl.sum(tl.exp((attn_scores - m_tilde) * p), axis=1)[:, None]\n ) # [q_blk, 1]\n tl.store(\n output +\n n_idx * output_stride_n +\n h_idx * output_stride_h +\n (q_begin + q_idx)[:, None] * output_stride_t,\n loss_new,\n mask=(q_begin + q_idx < TDST)[:, None]\n )\n tl.store(\n m +\n n_idx * m_stride_n +\n h_idx * m_stride_h +\n (q_begin + q_idx)[:, None] * m_stride_t,\n m_new,\n mask=(q_begin + q_idx < TDST)[:, None]\n )\n tl.store(\n l +\n n_idx * l_stride_n +\n h_idx * l_stride_h +\n (q_begin + q_idx)[:, None] * l_stride_t,\n l_new,\n mask=(q_begin + q_idx < TDST)[:, None]\n )\n\n\n@triton.jit\ndef compute_attn_lp_loss_kernel_backward(\n q, q_stride_n, q_stride_h, q_stride_t, q_stride_hdim,\n k, k_stride_n, k_stride_h, k_stride_t, k_stride_hdim,\n output, output_stride_n, output_stride_h, output_stride_t,\n grad_output, grad_output_stride_n, grad_output_stride_h, grad_output_stride_t,\n p: float,\n H: int, TDST: int, TSRC: int, HDIM: int,\n HDIM_MAX: tl.constexpr,\n KV_BLOCK_SIZE: tl.constexpr, Q_BLOCK_SIZE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n attend_lengths, attend_lengths_stride_n, attend_lengths_stride_t,\n l, l_stride_n, l_stride_h, l_stride_t,\n m, m_stride_n, m_stride_h, m_stride_t,\n grad_q, grad_q_stride_n, grad_q_stride_h, grad_q_stride_t, grad_q_stride_hdim,\n grad_k, grad_k_stride_n, grad_k_stride_h, grad_k_stride_t, grad_k_stride_hdim,\n):\n # Triton kernel for computing the backward pass of attention Lp loss\n batch_idx = tl.program_id(1)\n n_idx = batch_idx // H\n h_idx = batch_idx % H\n q_begin = tl.program_id(0) * Q_BLOCK_SIZE\n q_idx = tl.arange(0, Q_BLOCK_SIZE)\n kv_idx = tl.arange(0, KV_BLOCK_SIZE)\n d_idx = tl.arange(0, HDIM_MAX)\n\n q_chunk = tl.load(\n q +\n n_idx * q_stride_n +\n h_idx * q_stride_h +\n (q_begin + q_idx)[None, :] * q_stride_t +\n d_idx[:, None] * q_stride_hdim,\n mask=(\n (q_begin + q_idx < TDST)[None, :] &\n (d_idx < HDIM)[:, None]\n ),\n other=0\n ).to(tl.float32) # [hd, q_blk]\n output_chunk = tl.load(\n output +\n n_idx * output_stride_n +\n h_idx * output_stride_h +\n (q_begin + q_idx)[None, :] * output_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[None, :]\n ),\n other=0\n ).to(tl.float32) # [1, q_blk]\n grad_output_chunk = tl.load(\n grad_output +\n n_idx * grad_output_stride_n +\n h_idx * grad_output_stride_h +\n (q_begin + q_idx)[None, :] * grad_output_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[None, :]\n ),\n other=0\n ).to(tl.float32) # [1, q_blk]\n l_chunk = tl.load(\n l +\n n_idx * l_stride_n +\n h_idx * l_stride_h +\n (q_begin + q_idx)[None, :] * l_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[None, :]\n ),\n other=0\n ).to(tl.float32) # [1, q_blk]\n m_chunk = tl.load(\n m +\n n_idx * m_stride_n +\n h_idx * m_stride_h +\n (q_begin + q_idx)[None, :] * m_stride_t,\n mask=(\n (q_begin + q_idx < TDST)[None, :]\n ),\n other=-1e9\n ).to(tl.float32) # [1, q_blk]\n\n attend_lengths_chunk = None\n if attend_lengths is not None:\n attend_lengths_chunk = tl.load(\n attend_lengths +\n n_idx * attend_lengths_stride_n +\n (q_begin + q_idx) * attend_lengths_stride_t,\n mask=(q_begin + q_idx < TDST),\n other=0\n )\n\n for kv_begin in range(0, TSRC, KV_BLOCK_SIZE):\n k_chunk = tl.load(\n k +\n n_idx * k_stride_n +\n h_idx * k_stride_h +\n (kv_begin + kv_idx)[:, None] * k_stride_t +\n d_idx[None, :] * k_stride_hdim,\n mask=(\n (kv_begin + kv_idx < TSRC)[:, None] &\n (d_idx < HDIM)[None, :]\n ),\n other=0\n ).to(tl.float32) # [kv_blk, hd]\n\n attn_scores = tl.dot(k_chunk, q_chunk).to(tl.float32) # [kv_blk, q_blk]\n logP = attn_scores - m_chunk - tl.log(l_chunk) # [kv_blk, q_blk]\n grad_P = grad_output_chunk * p * tl.exp(logP * (p-1)) # [kv_blk, q_blk]\n\n D = grad_output_chunk * p * output_chunk # [1, q_blk]\n grad_S = tl.exp(logP) * (grad_P - D) # [kv_blk, q_blk]\n if IS_CAUSAL:\n grad_S = tl.where(\n (kv_begin + kv_idx)[:, None] > (q_begin + q_idx)[None, :],\n 0.0,\n grad_S\n )\n if attend_lengths is not None:\n grad_S = tl.where(\n (kv_begin + kv_idx)[:, None] >= attend_lengths_chunk[None, :],\n 0.0,\n grad_S\n )\n\n grad_q_new = tl.dot(tl.trans(grad_S), k_chunk).to(tl.float32) # [q_blk, hd]\n tl.atomic_add(\n grad_q +\n n_idx * grad_q_stride_n +\n h_idx * grad_q_stride_h +\n (q_begin + q_idx)[:, None] * grad_q_stride_t +\n d_idx[None, :] * grad_q_stride_hdim,\n grad_q_new,\n mask=(\n (q_begin + q_idx < TDST)[:, None] &\n (d_idx < HDIM)[None, :]\n )\n )\n grad_k_chunk = tl.dot(q_chunk, tl.trans(grad_S)).to(tl.float32) # [hd, kv_blk]\n tl.atomic_add(\n grad_k +\n n_idx * grad_k_stride_n +\n h_idx * grad_k_stride_h +\n (kv_begin + kv_idx)[None, :] * grad_k_stride_t +\n d_idx[:, None] * grad_k_stride_hdim,\n grad_k_chunk,\n mask=(\n (kv_begin + kv_idx < TSRC)[None, :] &\n (d_idx < HDIM)[:, None]\n ),\n )\n\n\nclass AttnLpLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, # noqa\n q, k, N, H, TDST, TSRC, HDIM, p, is_causal, attend_lengths,\n KV_BLOCK_SIZE, Q_BLOCK_SIZE):\n # Forward pass for attention Lp loss\n assert q.ndim == 4\n assert k.ndim == 4\n assert attend_lengths.ndim == 2 if attend_lengths is not None else True\n l = torch.full((N, H, TDST), 0.0, device=q.device) # [bsz, num_heads, q_len]\n m = torch.full((N, H, TDST), -1e9, device=q.device) # [bsz, num_heads, q_len]\n result = torch.zeros((N, H, TDST), device=q.device)\n\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(q.device)\n compute_attn_lp_loss_kernel[(triton.cdiv(TDST, Q_BLOCK_SIZE), N * H)](\n q, *q.stride(),\n k, *k.stride(),\n p,\n H, TDST, TSRC, HDIM,\n triton.next_power_of_2(HDIM),\n KV_BLOCK_SIZE, Q_BLOCK_SIZE,\n is_causal,\n attend_lengths, *(attend_lengths.stride() if attend_lengths is not None else (None, None)),\n l, *l.stride(),\n m, *m.stride(),\n result, *result.stride(),\n )\n torch.cuda.set_device(orig_device)\n\n if attend_lengths is not None:\n ctx.save_for_backward(q, k, l, m, result, attend_lengths)\n else:\n ctx.save_for_backward(q, k, l, m, result)\n ctx.has_attend_lengths = attend_lengths is not None\n ctx.N, ctx.H, ctx.TDST, ctx.TSRC, ctx.HDIM = N, H, TDST, TSRC, HDIM\n ctx.p, ctx.is_causal = p, is_causal\n ctx.KV_BLOCK_SIZE, ctx.Q_BLOCK_SIZE = KV_BLOCK_SIZE, Q_BLOCK_SIZE\n\n result = result ** (1/p)\n return result\n\n @staticmethod\n def backward(ctx, grad_output: torch.Tensor): # noqa\n # Backward pass for attention Lp loss\n if ctx.has_attend_lengths:\n q, k, l, m, result, attend_lengths = ctx.saved_tensors\n else:\n q, k, l, m, result = ctx.saved_tensors\n attend_lengths = None\n N, H, TDST, TSRC, HDIM = ctx.N, ctx.H, ctx.TDST, ctx.TSRC, ctx.HDIM\n p, is_causal = ctx.p, ctx.is_causal\n KV_BLOCK_SIZE, Q_BLOCK_SIZE = ctx.KV_BLOCK_SIZE, ctx.Q_BLOCK_SIZE\n\n grad_output *= ((1/p) * result**(1/p - 1))\n\n grad_q = torch.full((N, H, TDST, HDIM), 0.0, device=q.device)\n grad_k = torch.full((N, H, TSRC, HDIM), 0.0, device=q.device)\n\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(q.device)\n compute_attn_lp_loss_kernel_backward[(triton.cdiv(TDST, Q_BLOCK_SIZE), N * H)](\n q, *q.stride(),\n k, *k.stride(),\n result, *result.stride(),\n grad_output, *grad_output.stride(),\n p,\n H, TDST, TSRC, HDIM,\n triton.next_power_of_2(HDIM),\n KV_BLOCK_SIZE, Q_BLOCK_SIZE,\n is_causal,\n attend_lengths, *(attend_lengths.stride() if attend_lengths is not None else (None, None)),\n l, *l.stride(),\n m, *m.stride(),\n grad_q, *grad_q.stride(),\n grad_k, *grad_k.stride(),\n )\n torch.cuda.set_device(orig_device)\n\n return (\n grad_q, grad_k,\n None, None, None, None, None, None, None, None,\n None, None\n )\n\n\ndef compute_attn_lp_loss_triton(q, k, p, is_causal=True, attend_lengths=None, do_average=True,\n KV_BLOCK_SIZE=64, Q_BLOCK_SIZE=64):\n # Wrapper function to compute attention Lp loss using Triton\n assert q.ndim == 4 and k.ndim == 4\n N, H, TDST, TSRC, HDIM = q.shape[0], q.shape[1], q.shape[2], k.shape[2], q.shape[3]\n result = AttnLpLoss.apply(\n q, k, N, H, TDST, TSRC, HDIM, p, is_causal, attend_lengths, KV_BLOCK_SIZE, Q_BLOCK_SIZE)\n if do_average:\n result = result.mean(dim=-1) # [bsz, num_heads]\n return result\n", - "description_1": "Use triton language to implement a kernel for computing attention Lp loss and its backward pass. The kernel takes in query and key tensors, strides, a float parameter p, dimensions H, TDST, TSRC, HDIM, and several constexpr parameters. It computes the attention scores, applies causal masking if needed, and calculates the Lp loss. The backward kernel computes gradients for the query and key tensors.", - "description_2": "Use triton language to create a kernel for attention Lp loss computation and its gradient calculation. The kernel handles causal masking and computes the loss and gradients based on input tensors and parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\nfrom torch.autograd import Function\nfrom typing import Union\n\n@triton.jit\ndef _sdbmm_compute(\n INDICES, stride_indices_n, stride_indices_bdst, stride_indices_bk,\n KS, stride_ks_n, stride_ks_bdst, \n PROBS, stride_probs_n, stride_probs_tdst, stride_probs_k,\n VALUES, stride_values_n, stride_values_tsrc, stride_values_hid,\n CONTEXT, stride_context_n, stride_context_tdst, stride_context_hid,\n KV_REPEAT_INTERLEAVE, N, TSRC, TDST, HID, K, BK, BSRC, BDST,\n stride_values_vllm_num_blocks,\n stride_values_vllm_num_kv_heads,\n stride_values_vllm_head_size,\n stride_values_vllm_block_size,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE,\n VLLM_BLOCK_SIZE,\n BLOCK_TABLES,\n stride_block_tables_num_seqs,\n stride_block_tables_max_num_blocks_per_seq,\n VALUE_CACHE_METHOD: tl.constexpr,\n BLOCK_SIZE_Q: tl.constexpr,\n BLOCK_SIZE_Q_PADDED: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_K_PADDED: tl.constexpr,\n BLOCK_HID: tl.constexpr,\n):\n idx_n = tl.program_id(0)\n idx_block_q = tl.arange(0, BLOCK_SIZE_Q_PADDED)\n mask_block_q = idx_block_q < BLOCK_SIZE_Q\n idx_block_k = tl.arange(0, BLOCK_SIZE_K_PADDED)\n mask_block_k = idx_block_k < BLOCK_SIZE_K\n idx_bdst = tl.program_id(1)\n idx_tdst = idx_bdst * BLOCK_SIZE_Q + idx_block_q\n mask_tdst = (idx_tdst < TDST) & mask_block_q\n pid_hid = tl.program_id(2)\n idx_hid = tl.arange(0, BLOCK_HID) + pid_hid * BLOCK_HID\n mask_hid = idx_hid < HID\n n_bk = tl.load(\n KS +\\\n idx_n * stride_ks_n+\\\n idx_bdst * stride_ks_bdst,\n )\n scores = tl.zeros((BLOCK_SIZE_Q_PADDED, BLOCK_HID), dtype=tl.float32)\n for idx_bk in range(BK):\n mask_bk = idx_bk < n_bk\n _idx_tsrc = tl.load(\n INDICES +\\\n idx_n * stride_indices_n +\\\n idx_bdst * stride_indices_bdst +\\\n idx_bk * stride_indices_bk,\n mask = mask_bk,\n ).to(tl.int64)\n idx_tsrc = _idx_tsrc + idx_block_k\n mask_tsrc = (idx_tsrc < TSRC) & mask_block_k & mask_bk\n idx_prob_k = (idx_bk * BLOCK_SIZE_K + idx_block_k)\n mask_prob_k = (idx_prob_k < K) & mask_block_k & mask_bk\n atten_probs = tl.load(\n PROBS +\\\n idx_n * stride_probs_n +\\\n idx_tdst[:, None] * stride_probs_tdst +\\\n idx_prob_k[None, :] * stride_probs_k,\n mask = \\\n mask_tdst[:, None] &\\\n mask_prob_k[None, :] &\\\n ((idx_tdst[:, None] + TSRC - TDST) >= idx_tsrc[None, :]) & \\\n mask_bk,\n other = 0,\n )\n if VALUE_CACHE_METHOD == 'cont':\n value = tl.load(\n VALUES +\\\n (idx_n // KV_REPEAT_INTERLEAVE).to(tl.int64) * stride_values_n +\\\n idx_tsrc[:, None].to(tl.int64) * stride_values_tsrc +\\\n idx_hid[None, :].to(tl.int64) * stride_values_hid,\n mask = mask_tsrc[:, None] & mask_hid[None, :] & mask_bk,\n other = 0,\n )\n elif VALUE_CACHE_METHOD == 'vllm':\n idx_batch = (idx_n // KV_REPEAT_INTERLEAVE) // VLLM_NUM_KV_HEADS\n idx_head = (idx_n // KV_REPEAT_INTERLEAVE) % VLLM_NUM_KV_HEADS\n idx_block = tl.load(\n BLOCK_TABLES +\\\n idx_batch * stride_block_tables_num_seqs +\\\n (idx_tsrc // VLLM_BLOCK_SIZE) * stride_block_tables_max_num_blocks_per_seq,\n mask = mask_tsrc & mask_bk,\n other = 0\n ).to(tl.int64)\n mask_block = (idx_tsrc // VLLM_BLOCK_SIZE) < tl.cdiv(TSRC, VLLM_BLOCK_SIZE)\n offset_block = idx_tsrc - ((idx_tsrc // VLLM_BLOCK_SIZE) * VLLM_BLOCK_SIZE)\n value = tl.load(\n VALUES +\\\n idx_block[:, None] * stride_values_vllm_num_blocks+\\\n idx_head * stride_values_vllm_num_kv_heads+\\\n idx_hid[None, :].to(tl.int64) * stride_values_vllm_head_size +\\\n offset_block[:, None] * stride_values_vllm_block_size,\n mask = mask_tsrc[:, None] & mask_hid[None, :] & mask_bk & mask_block[:, None],\n other = 0\n )\n else:\n raise Exception()\n if value.dtype == tl.uint8:\n value = value.to(tl.float8e5, bitcast=True).to(atten_probs.dtype)\n scores_mini = tl.dot(atten_probs, value)\n scores += scores_mini.to(scores.dtype)\n tl.store(\n CONTEXT +\\\n idx_n * stride_context_n +\\\n idx_tdst[:, None] * stride_context_tdst +\\\n idx_hid[None, :] * stride_context_hid,\n mask = mask_tdst[:, None] & mask_hid[None, :],\n value = scores\n )\n\nclass SparseAttentionAutoGradFn(Function):\n @staticmethod\n def forward(\n ctx, \n values: Union[Tensor, \"PagedValueCacheVllmCompat\"],\n indices: Tensor,\n ks: Tensor,\n probs: Tensor,\n KV_REPEAT_INTERLEAVE: int,\n BLOCK_SIZE_Q: int,\n BLOCK_SIZE_K: int,\n ):\n ctx.save_for_backward(values, indices, ks, probs)\n ctx.BLOCK_SIZE_Q = BLOCK_SIZE_Q\n ctx.BLOCK_SIZE_K = BLOCK_SIZE_K\n N, BDST, BK = indices.shape\n _N, TDST, K = probs.shape\n __N, TSRC, HID = values.shape\n assert N == _N\n assert N == (__N * KV_REPEAT_INTERLEAVE)\n assert ks.shape == (N, BDST)\n BSRC = triton.cdiv(TSRC, BLOCK_SIZE_K)\n context_dtype = values.dtype\n if context_dtype not in [torch.float16, torch.bfloat16, torch.float32]:\n context_dtype = probs.dtype\n assert context_dtype in [torch.float16, torch.bfloat16, torch.float32]\n context = torch.zeros((N, TDST, HID), dtype=context_dtype, device=values.device)\n BLOCK_SIZE_Q_PADDED = next_multiple_of(BLOCK_SIZE_Q, 16)\n BLOCK_SIZE_K_PADDED = next_multiple_of(BLOCK_SIZE_K, 16)\n BLOCK_HID = triton.next_power_of_2(HID)\n if isinstance(values, Tensor):\n VALUE_CACHE_METHOD = 'cont'\n block_tables = values\n block_tables_strides = (0, 0)\n VLLM_NUM_BLOCKS =\\\n VLLM_NUM_KV_HEADS =\\\n VLLM_HEAD_SIZE =\\\n VLLM_BLOCK_SIZE = 0\n vllm_values_strides = (0, 0, 0, 0)\n elif isinstance(values, PagedValueCacheVllmCompat):\n VALUE_CACHE_METHOD = 'vllm'\n block_tables = values.block_table\n block_tables_strides = block_tables.stride()\n (\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE,\n VLLM_BLOCK_SIZE\n ) = values.value_cache.shape\n vllm_values_strides = values.value_cache.stride()\n else:\n raise Exception()\n grid = (N, BDST, triton.cdiv(HID, BLOCK_HID))\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(indices.device)\n _sdbmm_compute[grid](\n indices, *indices.stride(),\n ks, *ks.stride(),\n probs, *probs.stride(),\n values, *values.stride(),\n context, *context.stride(),\n KV_REPEAT_INTERLEAVE, N, TSRC, TDST, HID, K, BK, BSRC, BDST,\n *vllm_values_strides,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE,\n VLLM_BLOCK_SIZE,\n block_tables,\n *block_tables_strides,\n VALUE_CACHE_METHOD,\n BLOCK_SIZE_Q,\n BLOCK_SIZE_Q_PADDED,\n BLOCK_SIZE_K,\n BLOCK_SIZE_K_PADDED,\n BLOCK_HID,\n num_warps=BLOCK_HID//32,\n )\n torch.cuda.set_device(orig_device)\n return context\n \n @staticmethod\n def backward(ctx, grad_context):\n ENABLED_VALUES = True\n ENABLED_PROBS = True\n values, indices, ks, probs = ctx.saved_tensors\n BLOCK_SIZE_Q = ctx.BLOCK_SIZE_Q\n BLOCK_SIZE_K = ctx.BLOCK_SIZE_K\n grad_values = grad_probs = None\n N, T_SRC, HID = values.shape\n _, B_DST, BK = indices.shape\n _, T_DST, K = probs.shape\n assert ks.shape == (N, B_DST)\n assert probs.shape == (N, T_DST, K)\n assert indices.shape[0] == N\n if ctx.needs_input_grad[0]:\n grid = (N, B_DST, BK)\n BLOCK_HID = triton.next_power_of_2(HID)\n grad_values = torch.zeros(\n (N, T_SRC, HID), \n device=values.device, \n dtype=torch.float32,\n )\n if ENABLED_VALUES:\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(indices.device)\n _sdbmm_compute_bwd_values[grid](\n probs, probs.stride(0), probs.stride(1), probs.stride(2),\n indices, indices.stride(0), indices.stride(1), indices.stride(2),\n grad_context, grad_context.stride(0), grad_context.stride(1), grad_context.stride(2),\n grad_values, grad_values.stride(0), grad_values.stride(1), grad_values.stride(2),\n N, T_DST, T_SRC, HID, BK, K,\n BLOCK_SIZE_Q,\n next_multiple_of(BLOCK_SIZE_Q, 16),\n BLOCK_SIZE_K,\n next_multiple_of(BLOCK_SIZE_K, 16),\n BLOCK_HID,\n )\n torch.cuda.set_device(orig_device)\n grad_values = grad_values.to(values.dtype)\n if ctx.needs_input_grad[3]:\n grid = (N, triton.cdiv(T_DST, BLOCK_SIZE_Q), BK)\n BLOCK_HID = triton.next_power_of_2(HID)\n grad_probs = torch.zeros(\n (N, T_DST, K),\n device=probs.device,\n dtype=probs.dtype,\n )\n if ENABLED_PROBS:\n _sdbmm_compute_bwd_probs[grid](\n indices, indices.stride(0), indices.stride(1), indices.stride(2),\n values, values.stride(0), values.stride(1), values.stride(2), \n grad_context, grad_context.stride(0), grad_context.stride(1), grad_context.stride(2),\n grad_probs, grad_probs.stride(0), grad_probs.stride(1), grad_probs.stride(2),\n N, T_DST, T_SRC, HID, BK, K,\n BLOCK_SIZE_Q,\n next_multiple_of(BLOCK_SIZE_Q, 16),\n BLOCK_SIZE_K,\n next_multiple_of(BLOCK_SIZE_K, 16),\n BLOCK_HID,\n )\n return (\n grad_values, \n None, \n None, \n grad_probs, \n None,\n None,\n None,\n )\n\ndef sparse_attention(\n values: Tensor,\n indices: Tensor,\n ks: Tensor,\n probs: Tensor,\n KV_REPEAT_INTERLEAVE: int,\n BLOCK_SIZE_Q: int,\n BLOCK_SIZE_K: int,\n):\n context = SparseAttentionAutoGradFn.apply(\n values, indices, ks, probs, \n KV_REPEAT_INTERLEAVE, BLOCK_SIZE_Q, BLOCK_SIZE_K,\n )\n return context\n", - "description_1": "Use triton language to implement sparse matrix multiplication with a custom kernel function. Implement an autograd function in PyTorch for sparse attention that utilizes this custom kernel, supporting both forward and backward operations.", - "description_2": "Implement custom Triton kernels for efficient sparse matrix multiplication and integrate these kernels into PyTorch autograd for automatic differentiation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch import Tensor\nimport math\nfrom typing import Optional, Union\n\n@triton.jit\ndef _calc_prob_return_context_acc_compute(\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n V, stride_v_n, stride_v_tsrc, stride_v_hid, \n CONTEXT_LENGTH, \n queries,\n queries_grouped,\n idx_n,\n idx_tsrc,\n mask_tsrc,\n idx_hid,\n mask_hid,\n idx_tdst,\n mask_tdst,\n context_length,\n acc,\n l_i,\n m_i,\n KV_REPEAT_INTERLEAVE,\n IS_CAUSAL,\n TDST,\n TSRC,\n HID,\n CACHE_METHOD,\n VLLM_NUM_KV_HEADS,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n stride_k_vllm_num_blocks,\n stride_k_vllm_num_kv_heads,\n stride_k_vllm_head_size_x,\n stride_k_vllm_block_size,\n stride_k_vllm_x,\n stride_v_vllm_num_blocks,\n stride_v_vllm_num_kv_heads,\n stride_v_vllm_head_size,\n stride_v_vllm_block_size,\n BLOCK_TABLES,\n stride_block_tables_num_seqs,\n stride_block_tables_max_num_blocks_per_seq,\n ROPE_METHOD,\n ROPE_COS,\n stride_rope_cos_idx, \n stride_rope_cos_hid,\n ROPE_SIN,\n stride_rope_sin_idx, \n stride_rope_sin_hid,\n POSITION_IDS,\n stride_position_ids_n,\n stride_position_ids_tdst,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n RETURN_SCORES,\n OUT_SCORES, \n stride_out_scores_n, \n stride_out_scores_tdst, \n stride_out_scores_k,\n idx_out_k,\n mask_out_k,\n):\n # Triton kernel for attention computation\n # Implementation details omitted for brevity.\n\n@triton.autotune(\n configs=[\n triton.Config(kwargs={}, num_warps=16, num_stages=1),\n triton.Config(kwargs={}, num_warps=8, num_stages=1),\n triton.Config(kwargs={}, num_warps=2, num_stages=1),\n triton.Config(kwargs={}, num_warps=1, num_stages=1),\n ],\n key=['BLOCK_HID', 'BLOCK_BK'],\n warmup=3,\n rep=50,\n)\n@triton.jit\ndef _calc_prob_return_context_compute(\n Q, stride_q_n, stride_q_tdst, stride_q_hid,\n Q_GROUPED,\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n V, stride_v_n, stride_v_tsrc, stride_v_hid,\n ATTEN_MASK, stride_atten_mask_n, stride_atten_mask_tsrc,\n INDICES, stride_indices_n, stride_indices_bdst, stride_indices_bk,\n KS, stride_ks_n, stride_ks_bdst,\n CONTEXT, stride_context_n, stride_context_tdst, stride_context_hid,\n KV_REPEAT_INTERLEAVE, N, TDST, TSRC, HID: tl.constexpr, BDST, BSRC, BK,\n stride_k_vllm_num_blocks, \n stride_k_vllm_num_kv_heads, \n stride_k_vllm_head_size_x, \n stride_k_vllm_block_size, \n stride_k_vllm_x,\n stride_v_vllm_num_blocks,\n stride_v_vllm_num_kv_heads,\n stride_v_vllm_head_size,\n stride_v_vllm_block_size,\n BLOCK_TABLES,\n stride_block_tables_num_seqs,\n stride_block_tables_max_num_blocks_per_seq,\n CONTEXT_LENGTH,\n stride_context_length_num_seqs,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE: tl.constexpr,\n VLLM_X: tl.constexpr,\n VLLM_HEAD_SIZE,\n USING_SLIDING_WINDOW: tl.constexpr,\n SLIDING_WINDOW_SIZE: tl.constexpr,\n SLIDING_WINDOW_MASK,\n stride_sliding_window_mask_n,\n stride_sliding_window_mask_bdst,\n stride_sliding_window_mask_tsrc,\n ROPE_METHOD: tl.constexpr,\n ROPE_COS, stride_rope_cos_idx, stride_rope_cos_hid,\n ROPE_SIN, stride_rope_sin_idx, stride_rope_sin_hid,\n POSITION_IDS, stride_position_ids_n, stride_position_ids_tdst,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n CACHE_METHOD: tl.constexpr,\n BLOCK_SIZE_Q,\n BLOCK_SIZE_Q_PADDED: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n BLOCK_HID: tl.constexpr,\n BLOCK_BK: tl.constexpr,\n NUM_SINK: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n RETURN_SCORES: tl.constexpr,\n OUT_SCORES, stride_out_scores_n, stride_out_scores_tdst, stride_out_scores_k,\n):\n # Triton kernel for batched attention computation\n # Implementation details omitted for brevity.\n\ndef calc_prob_return_context(\n queries: Tensor, \n keys: Union[Tensor, \"PagedKeyCacheVllmCompat\"], \n values: Union[Tensor, \"PagedValueCacheVllmCompat\"], \n attention_mask: Optional[Tensor],\n indices: Tensor, ks: Tensor,\n KV_REPEAT_INTERLEAVE: int,\n BLOCK_SIZE_Q: int,\n BLOCK_SIZE_K: int,\n IS_CAUSAL: bool,\n USING_SLIDING_WINDOW: bool,\n SLIDING_WINDOW_SIZE: int,\n ROPE_METHOD: str = 'none',\n ROPE_COS: Optional[Tensor] = None,\n ROPE_SIN: Optional[Tensor] = None,\n POSITION_IDS: Optional[Tensor] = None,\n SELF_EXTEND_SCALE: int = 1,\n SELF_EXTEND_WINDOW: int = 1,\n RETURN_SCORES: bool = False,\n NUM_SINK: Optional[int] = None,\n):\n \"\"\"\n Python function calling Triton kernels.\n Computes attention using the custom kernels.\n \"\"\"\n N, TDST, HID = queries.shape\n _N, TSRC, HID = keys.shape\n assert keys.shape == values.shape\n assert attention_mask is None or attention_mask.shape == (N, TDST)\n \n BSRC = triton.cdiv(TSRC, BLOCK_SIZE_K)\n BDST = triton.cdiv(TDST, BLOCK_SIZE_Q)\n _, _, BK = indices.shape\n assert ks.shape == (N, BDST), f'{ks.shape}'\n \n BLOCK_BK = triton.cdiv(64 if queries.dtype == torch.float32 else 128, BLOCK_SIZE_K)\n if HID >= 256:\n BLOCK_BK = BLOCK_BK // math.ceil(HID / 128)\n BLOCK_HID = triton.next_power_of_2(HID)\n BLOCK_SIZE_Q_PADDED = next_multiple_of(BLOCK_SIZE_Q, 16)\n \n if ROPE_METHOD == 'self_extend':\n q_scale = 1 / math.sqrt(HID)\n \n queries_neighbor = apply_rotary_pos_emb(\n queries / q_scale, \n None, \n ROPE_COS, \n ROPE_SIN, \n POSITION_IDS,\n )[0] * q_scale\n queries_grouped = apply_rotary_pos_emb(\n queries / q_scale, \n None, \n ROPE_COS, \n ROPE_SIN, \n POSITION_IDS // SELF_EXTEND_SCALE + SELF_EXTEND_WINDOW - SELF_EXTEND_WINDOW // SELF_EXTEND_SCALE,\n )[0] * q_scale\n queries = queries_neighbor\n assert queries.stride() == queries_grouped.stride()\n else:\n queries_grouped = None\n \n assert values.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]\n context = torch.zeros(\n (N, TDST, HID),\n dtype=queries.dtype,\n device=queries.device,\n )\n \n if isinstance(keys, Tensor) and isinstance(values, Tensor):\n CACHE_METHOD = 'cont'\n \n VLLM_NUM_BLOCKS =\\\n VLLM_NUM_KV_HEADS =\\\n VLLM_HEAD_SIZE_X =\\\n VLLM_BLOCK_SIZE =\\\n VLLM_X =\\\n VLLM_HEAD_SIZE = 0\n \n vllm_keys_strides = (0, 0, 0, 0, 0)\n vllm_values_strides = (0, 0, 0, 0)\n \n block_tables = keys\n block_tables_strides = (0, 0)\n \n context_length = None\n context_length_strides = (0, )\n elif isinstance(keys, PagedKeyCacheVllmCompat) and isinstance(values, PagedValueCacheVllmCompat):\n CACHE_METHOD = 'vllm'\n \n (\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS, \n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n ) = keys.key_cache.shape\n VLLM_HEAD_SIZE = VLLM_HEAD_SIZE_X * VLLM_X\n \n block_tables = keys.block_table\n block_tables_strides = block_tables.stride()\n assert len(block_tables_strides) == 2\n \n context_length = keys.context_length\n context_length_strides = context_length.stride()\n assert len(context_length_strides) == 1\n \n vllm_keys_strides = keys.key_cache.stride()\n assert len(vllm_keys_strides) == 5\n \n vllm_values_strides = values.value_cache.stride()\n assert len(vllm_values_strides) == 4\n else:\n raise Exception(\"not supported\")\n \n if USING_SLIDING_WINDOW:\n sliding_window_mask = torch.zeros(\n (N, BDST, SLIDING_WINDOW_SIZE), \n dtype=torch.bool, \n device=queries.device\n )\n sliding_window_mask_strides = sliding_window_mask.stride()\n else:\n sliding_window_mask = None\n sliding_window_mask_strides = (0, 0, 0)\n assert len(sliding_window_mask_strides) == 3\n \n assert ROPE_METHOD in ['none', 'self_extend']\n if ROPE_METHOD in ['self_extend']:\n assert ROPE_SIN is not None\n assert POSITION_IDS is not None\n assert ROPE_COS.ndim == 2\n assert ROPE_SIN.ndim == 2\n assert POSITION_IDS.ndim == 2\n assert POSITION_IDS.shape == (N, TDST), POSITION_IDS.shape\n rope_cos_stride = ROPE_COS.stride()\n rope_sin_stride = ROPE_SIN.stride()\n position_ids_stride = POSITION_IDS.stride()\n else:\n rope_cos_stride = (0, 0)\n rope_sin_stride = (0, 0)\n position_ids_stride = (0, 0)\n \n NUM_SINK = triton.cdiv(32, BLOCK_SIZE_K) if NUM_SINK is None else NUM_SINK\n assert isinstance(NUM_SINK, int)\n \n if RETURN_SCORES:\n if USING_SLIDING_WINDOW:\n output_scores = torch.full(\n (\n N, TDST, \n indices.shape[-1] * BLOCK_SIZE_K + NUM_SINK * BLOCK_SIZE_K + SLIDING_WINDOW_SIZE\n ),\n fill_value=-32000.0,\n dtype=queries.dtype,\n device=queries.device,\n )\n else: \n output_scores = torch.full(\n (N, TDST, indices.shape[-1] * BLOCK_SIZE_K),\n fill_value=-32000.0,\n dtype=queries.dtype,\n device=queries.device,\n )\n output_scores_stride = output_scores.stride()\n else:\n output_scores = None\n output_scores_stride = (0, 0, 0)\n \n grid = (N * BDST, )\n \n assert attention_mask is None, \"attention mask is not supported yet\"\n assert queries.ndim == 3\n assert keys.ndim == 3\n assert values.ndim == 3\n assert attention_mask is None or attention_mask.ndim == 3\n assert indices.ndim == 3\n assert ks.ndim == 2\n assert context.ndim == 3\n\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(queries.device)\n _calc_prob_return_context_compute[grid](\n queries, *queries.stride(),\n queries_grouped,\n keys, *keys.stride(),\n values, *values.stride(),\n attention_mask, *((0, 0) if attention_mask is None else attention_mask.stride()),\n indices, *indices.stride(),\n ks, *ks.stride(),\n context, *context.stride(),\n KV_REPEAT_INTERLEAVE, \n N, \n TDST, \n TSRC, \n HID, \n BDST, \n BSRC, \n BK,\n *vllm_keys_strides,\n *vllm_values_strides,\n block_tables,\n *block_tables_strides,\n context_length,\n *context_length_strides,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n VLLM_HEAD_SIZE,\n USING_SLIDING_WINDOW,\n SLIDING_WINDOW_SIZE,\n sliding_window_mask,\n *sliding_window_mask_strides,\n ROPE_METHOD,\n ROPE_COS, *rope_cos_stride,\n ROPE_SIN, *rope_sin_stride,\n POSITION_IDS, *position_ids_stride,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n CACHE_METHOD,\n BLOCK_SIZE_Q,\n BLOCK_SIZE_Q_PADDED, \n BLOCK_SIZE_K,\n BLOCK_HID,\n BLOCK_BK,\n NUM_SINK,\n IS_CAUSAL,\n RETURN_SCORES,\n output_scores, *output_scores_stride\n )\n torch.cuda.set_device(orig_device)\n \n if RETURN_SCORES:\n return context, output_scores\n return context\n", - "description_1": "Use triton language to implement a custom kernel for flash attention computation with variable cache methods, and the ability to return intermediate scores if required. The kernel processes inputs such as queries, keys, and values, possibly using a sliding window approach and optional rotary position embedding.", - "description_2": "Use triton language to create optimized attention kernels for neural networks with configurable parameters, capable of handling context lengths and various cache strategies.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Function\n\n@triton.jit\ndef _calc_score_compute(\n QUERIES, stride_queries_n, stride_queries_tdst, stride_queries_hid,\n KEYS, stride_keys_n, stride_keys_tsrc, stride_keys_hid,\n ATTEN_MASK, stride_atten_mask_n, stride_atten_mask_tsrc,\n INDICES, stride_indices_n, stride_indices_bdst, stride_indices_bk,\n KS, stride_ks_n, stride_ks_bdst,\n SCORES, stride_scores_n, stride_scores_tdst, stride_scores_k,\n KV_REPEAT_INTERLEAVE, N, TDST, TSRC, HID, BK, K, BDST, BSRC, IS_CAUSAL,\n stride_keys_vllm_num_bocks, stride_keys_vllm_num_kv_heads, stride_keys_vllm_head_size_x,\n stride_keys_vllm_block_size, stride_keys_vllm_x,\n VLLM_NUM_BLOCKS, VLLM_NUM_KV_HEADS, VLLM_HEAD_SIZE_X, VLLM_BLOCK_SIZE, VLLM_X, VLLM_HEAD_SIZE,\n BLOCK_TABLES, stride_block_tables_num_seqs, stride_block_tables_max_num_blocks_per_seq,\n KEY_CACHE_METHOD: tl.constexpr, BLOCK_BK: tl.constexpr, BLOCK_SIZE_Q: tl.constexpr,\n BLOCK_SIZE_Q_PADDED: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K_PADDED: tl.constexpr,\n BLOCK_HID: tl.constexpr,\n):\n idx_n = tl.program_id(0).to(tl.int64)\n idx_bdst = tl.program_id(1).to(tl.int64)\n pid_bk = tl.program_id(2).to(tl.int64)\n \n ks = tl.load(\n KS +\n idx_n * stride_ks_n +\n idx_bdst * stride_ks_bdst,\n )\n \n idx_bk = tl.arange(0, BLOCK_BK) + pid_bk * BLOCK_BK\n mask_bk = idx_bk < ks\n \n idx_block_q = tl.arange(0, BLOCK_SIZE_Q_PADDED)\n mask_block_q = idx_block_q < BLOCK_SIZE_Q\n idx_block_k = tl.arange(0, BLOCK_SIZE_K_PADDED)\n mask_block_k = idx_block_k < BLOCK_SIZE_K\n \n idx_tsrc = tl.load(\n INDICES +\n idx_n * stride_indices_n +\n idx_bdst * stride_indices_bdst +\n idx_bk * stride_indices_bk,\n mask=mask_bk,\n )\n idx_tsrc = idx_tsrc[:, None] + idx_block_k[None, :]\n mask_tsrc = (idx_tsrc < TSRC) & mask_block_k[None, :] & mask_bk[:, None]\n \n if ATTEN_MASK is not None:\n key_mask = tl.load(\n ATTEN_MASK +\n idx_n * stride_atten_mask_n +\n idx_tsrc * stride_atten_mask_tsrc,\n mask=mask_tsrc,\n other=False,\n ).to(tl.int1)\n mask_tsrc = mask_tsrc & key_mask\n \n idx_tdst = idx_bdst * BLOCK_SIZE_Q + idx_block_q\n mask_tdst = (idx_tdst < TDST) & mask_block_q\n if ATTEN_MASK is not None:\n query_mask = tl.load(\n ATTEN_MASK +\n idx_n * stride_atten_mask_n +\n (idx_tdst + TSRC - TDST) * stride_atten_mask_tsrc,\n mask=mask_tdst,\n other=False,\n ).to(tl.int1)\n mask_tdst = mask_tdst & query_mask\n \n scores = tl.zeros((BLOCK_SIZE_Q_PADDED, BLOCK_BK, BLOCK_SIZE_K_PADDED), dtype=tl.float32)\n for pid_hid in range(tl.cdiv(HID, BLOCK_HID)):\n idx_hid = (tl.arange(0, BLOCK_HID) + pid_hid * BLOCK_HID).to(tl.int64)\n mask_hid = idx_hid < HID\n \n queries = tl.load(\n QUERIES +\n idx_n * stride_queries_n +\n idx_tdst[:, None] * stride_queries_tdst +\n idx_hid[None, :] * stride_queries_hid,\n mask=mask_tdst[:, None] & mask_hid[None, :],\n other=0\n )\n \n if KEY_CACHE_METHOD == 'cont':\n keys = tl.load(\n KEYS +\n (idx_n // KV_REPEAT_INTERLEAVE) * stride_keys_n +\n idx_tsrc[None, :, :] * stride_keys_tsrc +\n idx_hid[:, None, None] * stride_keys_hid,\n mask=mask_tsrc[None, :, :] & mask_hid[:, None, None],\n other=0\n )\n elif KEY_CACHE_METHOD == 'vllm':\n idx_batch = ((idx_n // KV_REPEAT_INTERLEAVE) // VLLM_NUM_KV_HEADS).to(tl.int64)\n idx_head = ((idx_n // KV_REPEAT_INTERLEAVE) % VLLM_NUM_KV_HEADS).to(tl.int64)\n idx_block = tl.load(\n BLOCK_TABLES +\n idx_batch * stride_block_tables_num_seqs +\n (idx_tsrc // VLLM_BLOCK_SIZE) * stride_block_tables_max_num_blocks_per_seq,\n mask=mask_tsrc,\n ).to(tl.int64)\n offset_block = (idx_tsrc - ((idx_tsrc // VLLM_BLOCK_SIZE) * VLLM_BLOCK_SIZE)).to(tl.int64)\n \n keys = tl.load(\n KEYS +\n idx_block[None, :, :] * stride_keys_vllm_num_bocks +\n idx_head * stride_keys_vllm_num_kv_heads +\n (idx_hid[:, None, None] // VLLM_X) * stride_keys_vllm_head_size_x +\n offset_block[None, :, :] * stride_keys_vllm_block_size +\n (idx_hid[:, None, None] % VLLM_X) * stride_keys_vllm_x,\n mask=mask_tsrc[None, :, :] & mask_hid[:, None, None],\n other=0,\n )\n else:\n raise Exception()\n keys = tl.reshape(keys, (BLOCK_HID, BLOCK_BK * BLOCK_SIZE_K_PADDED))\n \n if keys.dtype == tl.uint8:\n keys = keys.to(tl.float8e5, bitcast=True).to(queries.dtype)\n scores_mini = tl.dot(queries, keys)\n scores_mini = tl.reshape(scores_mini, (BLOCK_SIZE_Q_PADDED, BLOCK_BK, BLOCK_SIZE_K_PADDED))\n \n scores += scores_mini.to(scores.dtype)\n \n idx_scorek = (idx_bk[:, None] * BLOCK_SIZE_K + idx_block_k[None, :])\n mask_scorek = (idx_scorek < K) & mask_block_k[None, :] & mask_bk[:, None]\n \n scores_mask = (\n (mask_tdst[:, None, None] & mask_tsrc[None, :, :]) &\n mask_scorek[None, :] &\n True\n )\n \n if IS_CAUSAL:\n scores_mask = scores_mask & ((idx_tdst[:, None, None] + (TSRC - TDST)) >= idx_tsrc[None, :, :])\n \n tl.store(\n SCORES +\n idx_n * stride_scores_n +\n idx_tdst[:, None, None] * stride_scores_tdst +\n idx_scorek[None, :, :] * stride_scores_k,\n mask=scores_mask,\n value=scores,\n )\n\n@triton.jit\ndef _calc_score_compute_bwd_queries(\n KS, stride_ks_n, stride_ks_bdst,\n INDICES, stride_indices_n, stride_indices_bdst, stride_indices_bk,\n KEYS, stride_keys_n, stride_keys_tsrc, stride_keys_hid,\n GRAD_SCORES, stride_grad_scores_n, stride_grad_scores_tdst, stride_grad_scores_k,\n GRAD_QUERIES, stride_grad_queries_n, stride_grad_queries_tdst, stride_grad_queries_hid,\n N, TDST, TSRC, HID, BLOCK_K, K,\n BLOCK_SIZE_Q: tl.constexpr, BLOCK_SIZE_Q_PADDED: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_K_PADDED: tl.constexpr, BLOCK_HID: tl.constexpr, IS_CAUSAL: tl.constexpr,\n):\n idx_n = tl.program_id(0)\n idx_query_block = tl.program_id(1)\n\n idx_block_q = tl.arange(0, BLOCK_SIZE_Q_PADDED)\n idx_block_k = tl.arange(0, BLOCK_SIZE_K_PADDED)\n idx_hid = tl.arange(0, BLOCK_HID)\n\n scalar_ks = tl.load(\n KS +\n idx_n.to(tl.int64) * stride_ks_n +\n idx_query_block.to(tl.int64) * stride_ks_bdst\n )\n\n accumulator = tl.zeros((BLOCK_SIZE_Q_PADDED, BLOCK_HID,), dtype=tl.float32)\n for idx_key_block in range(scalar_ks):\n idx_key_start = tl.load(\n INDICES +\n idx_n.to(tl.int64) * stride_indices_n +\n idx_query_block.to(tl.int64) * stride_indices_bdst +\n idx_key_block.to(tl.int64) * stride_indices_bk,\n )\n\n if IS_CAUSAL:\n causal_mask = ((idx_key_start + idx_block_k)[None, :] <= (idx_query_block * BLOCK_SIZE_Q + idx_block_q)[:, None])\n else:\n causal_mask = True\n\n grad_score = tl.load(\n GRAD_SCORES +\n idx_n.to(tl.int64) * stride_grad_scores_n +\n (idx_query_block * BLOCK_SIZE_Q + idx_block_q)[:, None].to(tl.int64) * stride_grad_scores_tdst +\n (idx_key_block * BLOCK_SIZE_K + idx_block_k)[None, :].to(tl.int64) * stride_grad_scores_k,\n mask=((idx_query_block * BLOCK_SIZE_Q + idx_block_q)[:, None] < TDST) &\n (idx_block_q[:, None] < BLOCK_SIZE_Q) &\n ((idx_key_block * BLOCK_SIZE_K + idx_block_k)[None, :] < K) &\n (idx_block_k[None, :] < BLOCK_SIZE_K) &\n causal_mask,\n other=0,\n )\n\n key = tl.load(\n KEYS +\n idx_n.to(tl.int64) * stride_keys_n +\n (idx_key_start + idx_block_k)[:, None].to(tl.int64) * stride_keys_tsrc +\n idx_hid[None, :].to(tl.int64) * stride_keys_hid,\n mask=((idx_key_start + idx_block_k)[:, None] < TSRC) &\n (idx_block_k[:, None] < BLOCK_SIZE_K) &\n (idx_hid[None, :] < HID),\n other=0,\n )\n\n accumulator += tl.dot(grad_score, key).to(accumulator.dtype)\n\n tl.store(\n GRAD_QUERIES +\n idx_n.to(tl.int64) * stride_grad_queries_n +\n (idx_query_block * BLOCK_SIZE_Q + idx_block_q)[:, None].to(tl.int64) * stride_grad_queries_tdst +\n idx_hid[None, :].to(tl.int64) * stride_grad_queries_hid,\n mask=((idx_query_block * BLOCK_SIZE_Q + idx_block_q)[:, None] < TDST) &\n (idx_block_q[:, None] < BLOCK_SIZE_Q) &\n (idx_hid[None, :] < HID),\n value=accumulator\n )\n\n\n@triton.jit\ndef _calc_score_compute_bwd_keys(\n ks, stride_ks_n, stride_ks_bdst,\n indices, stride_indices_n, stride_indices_bdst, stride_indices_bk,\n queries, stride_queries_n, stride_queries_tdst, stride_queries_hid,\n grad_scores, stride_grad_scores_n, stride_grad_scores_tdst, stride_grad_scores_k,\n grad_keys, stride_grad_keys_n, stride_grad_keys_tsrc, stride_grad_keys_hid,\n N, TDST, TSRC, HID, BK, K,\n BLOCK_SIZE_Q: tl.constexpr, BLOCK_SIZE_Q_PADDED: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_K_PADDED: tl.constexpr, BLOCK_HID: tl.constexpr,\n):\n idx_n = tl.program_id(0)\n idx_bdst = tl.program_id(1)\n idx_bk = tl.program_id(2)\n \n scalar_ks = tl.load(\n ks +\n idx_n * stride_ks_n +\n idx_bdst * stride_ks_bdst,\n )\n if idx_bk >= scalar_ks: return\n \n idx_hid = tl.arange(0, BLOCK_HID)\n mask_hid = (idx_hid < HID)\n \n idx_block_q = tl.arange(0, BLOCK_SIZE_Q_PADDED)\n mask_block_q = idx_block_q < BLOCK_SIZE_Q\n idx_block_k = tl.arange(0, BLOCK_SIZE_K_PADDED)\n mask_block_k = idx_block_k < BLOCK_SIZE_K\n \n idx_tdst = idx_bdst * BLOCK_SIZE_Q + idx_block_q\n mask_tdst = (idx_tdst < TDST) & mask_block_q\n \n idx_k = idx_bk * BLOCK_SIZE_K + idx_block_k\n mask_k = (idx_k < K) & mask_block_k\n \n grad_score = tl.load(\n grad_scores +\n idx_n * stride_grad_scores_n +\n idx_tdst[None, :] * stride_grad_scores_tdst +\n idx_k[:, None] * stride_grad_scores_k,\n mask=mask_tdst[None, :] & mask_k[:, None],\n other=0\n )\n query = tl.load(\n queries +\n idx_n * stride_queries_n +\n idx_tdst[:, None] * stride_queries_tdst +\n idx_hid[None, :] * stride_queries_hid,\n mask=mask_tdst[:, None] & mask_hid[None, :],\n other=0,\n )\n scores = tl.dot(grad_score, query)\n \n idx_tsrc = tl.load(\n indices +\n idx_n * stride_indices_n +\n idx_bdst * stride_indices_bdst +\n idx_bk * stride_indices_bk,\n )\n idx_tsrc = idx_tsrc + idx_block_k\n mask_tsrc = (idx_tsrc < TSRC) & mask_block_k\n tl.atomic_add(\n grad_keys +\n idx_n * stride_grad_keys_n +\n idx_tsrc[:, None] * stride_grad_keys_tsrc +\n idx_hid[None, :] * stride_grad_keys_hid,\n val=scores,\n mask=mask_tsrc[:, None] & mask_hid[None, :]\n )\n\nclass CalcScoreAutoGradFn(Function):\n @staticmethod\n def forward(\n ctx, \n queries: Tensor, keys: Tensor, attention_mask: Tensor,\n indices: Tensor, ks: Tensor,\n KV_REPEAT_INTERLEAVE: int,\n BLOCK_SIZE_Q: int,\n BLOCK_SIZE_K: int,\n IS_CAUSAL: bool\n ):\n ctx.save_for_backward(queries, keys, indices, ks)\n ctx.BLOCK_SIZE_Q = BLOCK_SIZE_Q\n ctx.BLOCK_SIZE_K = BLOCK_SIZE_K\n ctx.IS_CAUSAL = IS_CAUSAL\n \n N, TDST, HID = queries.shape\n _N, TSRC, _ = keys.shape\n _, _, BK = indices.shape\n \n BDST = triton.cdiv(TDST, BLOCK_SIZE_Q)\n BSRC = triton.cdiv(TSRC, BLOCK_SIZE_K)\n \n assert keys.shape == (_N, TSRC, HID)\n assert indices.shape == (N, BDST, BK)\n assert ks.shape == (N, BDST)\n \n K = BK * BLOCK_SIZE_K\n scores = torch.full(\n (N, TDST, K), \n torch.finfo(queries.dtype).min,\n device=queries.device, \n dtype=queries.dtype\n )\n \n BLOCK_SIZE_Q_PADDED = next_multiple_of(BLOCK_SIZE_Q, 16)\n BLOCK_SIZE_K_PADDED = next_multiple_of(BLOCK_SIZE_K, 1)\n BLOCK_BK = next_multiple_of(128 // BLOCK_SIZE_K_PADDED, 1)\n BLOCK_HID = 32\n \n if isinstance(keys, Tensor):\n KEY_CACHE_METHOD = 'cont'\n \n VLLM_NUM_BLOCKS =\\\n VLLM_NUM_KV_HEADS =\\\n VLLM_HEAD_SIZE_X =\\\n VLLM_BLOCK_SIZE =\\\n VLLM_X =\\\n VLLM_HEAD_SIZE = 0\n \n vllm_keys_strides = (0, 0, 0, 0, 0)\n \n block_tables = keys\n block_tables_strides = (0, 0)\n else:\n KEY_CACHE_METHOD = 'vllm'\n \n (\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS, \n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n ) = keys.key_cache.shape\n VLLM_HEAD_SIZE = VLLM_HEAD_SIZE_X * VLLM_X\n \n block_tables = keys.block_table\n block_tables_strides = block_tables.stride()\n assert len(block_tables_strides) == 2\n \n vllm_keys_strides = keys.key_cache.stride()\n assert len(vllm_keys_strides) == 5 \n \n grid = (N, BDST, triton.cdiv(BK, BLOCK_BK))\n \n with timer(\"_calc_score_compute\"):\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(queries.device)\n _calc_score_compute[grid](\n queries, *queries.stride(),\n keys, *keys.stride(),\n attention_mask, *(attention_mask.stride() if attention_mask is not None else (0, 0)),\n indices, *indices.stride(),\n ks, *ks.stride(),\n scores, *scores.stride(),\n KV_REPEAT_INTERLEAVE, \n N, \n TDST, \n TSRC, \n HID, \n BK, \n K, \n BDST, \n BSRC, \n IS_CAUSAL,\n *vllm_keys_strides,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n VLLM_HEAD_SIZE,\n block_tables, *block_tables_strides,\n KEY_CACHE_METHOD,\n BLOCK_BK,\n BLOCK_SIZE_Q,\n BLOCK_SIZE_Q_PADDED,\n BLOCK_SIZE_K,\n BLOCK_SIZE_K_PADDED,\n BLOCK_HID,\n num_warps=4,\n num_stages=2,\n enable_warp_specialization=False,\n )\n torch.cuda.set_device(orig_device)\n \n return scores\n\n @staticmethod\n def backward(ctx, grad_scores):\n ENABLED = True\n \n queries, keys, indices, ks = ctx.saved_tensors\n BLOCK_SIZE_Q = ctx.BLOCK_SIZE_Q\n BLOCK_SIZE_K = ctx.BLOCK_SIZE_K\n grad_queries = grad_keys = None\n \n N, T_DST, HID = queries.shape\n _, T_SRC, _HID = keys.shape\n assert HID == _HID\n _, _, BK = indices.shape\n _, _, K = grad_scores.shape\n\n if ctx.needs_input_grad[0]:\n grid = (N, triton.cdiv(T_DST, BLOCK_SIZE_Q))\n BLOCK_HID = triton.next_power_of_2(HID)\n\n grad_queries = torch.zeros_like(queries)\n\n if ENABLED:\n _calc_score_compute_bwd_queries[grid](\n ks, ks.stride(0), ks.stride(1),\n indices, indices.stride(0), indices.stride(1), indices.stride(2), \n keys, keys.stride(0), keys.stride(1), keys.stride(2),\n grad_scores, grad_scores.stride(0), grad_scores.stride(1), grad_scores.stride(2),\n grad_queries, grad_queries.stride(0), grad_queries.stride(1), grad_queries.stride(2),\n N, T_DST, T_SRC, HID, BK, K,\n BLOCK_SIZE_Q,\n next_multiple_of(BLOCK_SIZE_Q, 16),\n BLOCK_SIZE_K,\n next_multiple_of(BLOCK_SIZE_K, 16),\n BLOCK_HID,\n ctx.IS_CAUSAL,\n )\n \n if ctx.needs_input_grad[1]:\n grid = (N, triton.cdiv(T_DST, BLOCK_SIZE_Q), BK)\n BLOCK_HID = triton.next_power_of_2(HID)\n \n grad_keys = torch.zeros_like(keys, dtype=torch.float32)\n \n if ENABLED:\n _calc_score_compute_bwd_keys[grid](\n ks, ks.stride(0), ks.stride(1),\n indices, indices.stride(0), indices.stride(1), indices.stride(2), \n queries, queries.stride(0), queries.stride(1), queries.stride(2),\n grad_scores, grad_scores.stride(0), grad_scores.stride(1), grad_scores.stride(2),\n grad_keys, grad_keys.stride(0), grad_keys.stride(1), grad_keys.stride(2),\n N, T_DST, T_SRC, HID, BK, K,\n BLOCK_SIZE_Q,\n next_multiple_of(BLOCK_SIZE_Q, 16),\n BLOCK_SIZE_K,\n next_multiple_of(BLOCK_SIZE_K, 16),\n BLOCK_HID,\n )\n\n grad_keys = grad_keys.to(keys.dtype)\n \n return (\n grad_queries, \n grad_keys, \n None,\n None, \n None, \n None,\n None,\n None,\n None,\n )\n\ndef calc_score_return_prob(\n queries: Tensor, keys: Tensor, attention_mask: Tensor,\n indices: Tensor, ks: Tensor,\n KV_REPEAT_INTERLEAVE: int,\n BLOCK_SIZE_Q: int,\n BLOCK_SIZE_K: int,\n IS_CAUSAL: bool,\n):\n scores = CalcScoreAutoGradFn.apply(\n queries, keys, attention_mask,\n indices, ks,\n KV_REPEAT_INTERLEAVE, BLOCK_SIZE_Q, BLOCK_SIZE_K, IS_CAUSAL\n ) # type: Tensor\n \n with timer(\"calc_score_return_prob.softmax\"):\n probs = scores.softmax(-1).to(scores.dtype)\n \n assert probs.dtype == queries.dtype\n \n N, TDST, K = scores.shape\n if attention_mask is not None:\n _, TSRC = attention_mask.shape\n if probs.requires_grad:\n probs = probs * attention_mask[:, TSRC-TDST:, None]\n else:\n probs.masked_fill_(~attention_mask[:, TSRC-TDST:, None], 0)\n \n assert scores.dtype == queries.dtype\n assert probs.dtype == queries.dtype\n \n return scores, probs\n", - "description_1": "Use triton language to implement a kernel for calculating scores for attention mechanisms. The kernel involves multiple triton.jit functions to handle forward and backward passes for query and key inputs. It operates on input matrices (queries, keys, attention mask), indices, and ks with several block and constant parameters.", - "description_2": "Use triton language to compute attention scores with triton.jit functions for forward and backward operations, processing input matrices and handling gradients for queries and keys.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch import Tensor\nfrom typing import Optional, Union, List\n\n@triton.jit\ndef _triton_kth_ascending(\n scores: tl.tensor, \n k: tl.tensor,\n BLOCK_SCORES: tl.constexpr,\n METHOD: tl.constexpr = 'sort',\n) -> tl.tensor:\n if METHOD == 'sort':\n sorted_score = tl.sort(scores)\n sorted_score_mask = tl.arange(0, BLOCK_SCORES) < k\n kth_ascending_value = tl.max(tl.where(sorted_score_mask, sorted_score, -32000.0))\n elif METHOD == 'search':\n kth_ascending_value = tl.min(scores)\n step_scale = tl.abs(kth_ascending_value)\n step_size = 0.5\n for i in range(5):\n smaller_count = tl.sum((scores < kth_ascending_value).to(tl.int32))\n if smaller_count > k:\n kth_ascending_value -= step_scale * step_size\n else:\n kth_ascending_value += step_scale * step_size\n step_size *= 0.8\n tl.debug_barrier()\n else:\n raise Exception()\n return kth_ascending_value\n\n@triton.jit\ndef _masking_iteration_topk(\n QUERIES, stride_queries_n, stride_queries_tdst, stride_queries_hid, \n QUERIES_GROUPED_ROPE,\n KEYS, stride_keys_n, stride_keys_tsrc, stride_keys_hid, \n MASK, stride_mask_n, stride_mask_bdst, stride_mask_src_grid, stride_mask_k,\n TMASK, stride_tmask_n, stride_tmask_bdst, stride_tmask_src_grid, stride_tmask_k,\n ATTEN_MASK, stride_atten_mask_n, stride_atten_mask_tsrc,\n SPARQ_INDICES, stride_sparq_indices_n, stride_sparq_indices_bdst, stride_sparq_indices_hid, \n BLOCK_TABLES, stride_block_tables_num_seqs, stride_block_tables_max_num_blocks_per_seq,\n SCORES, stride_scores_n, stride_scores_bdst, stride_scores_k, \n CONTEXT_LENGTH, \n idx_n,\n idx_bdst,\n idx_src_grid,\n idx_iteration,\n idx_block_q,\n mask_w,\n mask_block_q,\n k_old_mask,\n k_new, \n w_old,\n w_new,\n t_src,\n context_length,\n loc_idx_start_vec,\n loc_idx_start_origin,\n num_pixels_vec,\n num_pixels_scalar,\n dup_pixels_vec,\n dup_pixels_first,\n IS_CAUSAL,\n USING_SCORE_CACHE: tl.constexpr,\n N_ITERATION,\n T_DST,\n T_SRC,\n KEY_CACHE_METHOD,\n KV_REPEAT_INTERLEAVE, \n REDUCE_METHOD,\n SAMPLING_METHOD,\n GRID_SRC_STRIDE,\n GRID_K_STRIDE,\n USING_SLIDING_WINDOW,\n SLIDING_WINDOW_SIZE,\n HID, \n SPARQ, \n SPARQ_HID,\n BLOCK_MAX_DUP,\n BLOCK_SIZE_Q,\n BLOCK_SIZE_Q_PADDED,\n BLOCK_SIZE_K,\n BLOCK_MASK_K,\n BLOCK_MASK_K_PADDED,\n BLOCK_TMASK_K,\n BLOCK_TMASK_K_PADDED,\n BLOCK_HID, \n VLLM_NUM_KV_HEADS, \n VLLM_BLOCK_SIZE,\n VLLM_X, \n stride_keys_vllm_num_blocks, \n stride_keys_vllm_num_kv_heads, \n stride_keys_vllm_head_size_x, \n stride_keys_vllm_block_size, \n stride_keys_vllm_x, \n ROPE_METHOD,\n ROPE_COS, stride_rope_cos_idx, stride_rope_cos_hid,\n ROPE_SIN, stride_rope_sin_idx, stride_rope_sin_hid,\n POSITION_IDS, stride_position_ids_n, stride_position_ids_tdst,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n):\n # Code logic with appropriate triton operations...\n pass\n\n@triton.jit\ndef _masking_iteration_compute(\n QUERIES, stride_queries_n, stride_queries_tdst, stride_queries_hid,\n QUERIES_GROUPED_ROPE,\n KEYS, stride_keys_n, stride_keys_tsrc, stride_keys_hid,\n ATTEN_MASK, stride_atten_mask_n, stride_atten_mask_tsrc,\n SPARQ_INDICES, stride_sparq_indices_n, stride_sparq_indices_bdst, stride_sparq_indices_hid,\n MASK, stride_mask_n, stride_mask_bdst, stride_mask_src_grid, stride_mask_k,\n TMASK, stride_tmask_n, stride_tmask_bdst, stride_tmask_src_grid, stride_tmask_k,\n WS, stride_ws_n, stride_ws_bdst,\n KS, stride_ks_n, stride_ks_bdst,\n WS_OUT, stride_ws_out_n, stride_ws_out_bdst,\n KS_OUT, stride_ks_out_n, stride_ks_out_bdst, stride_ks_out_src_grid,\n TSRCS, stride_tsrcs_n, stride_tsrcs_bdst,\n SCORES, stride_scores_n, stride_scores_bdst, stride_scores_k,\n SCALE_UP: tl.constexpr, \n N_PATCHES: tl.constexpr, \n MASK_K: tl.constexpr, \n TMASK_K: tl.constexpr, \n IS_CAUSAL: tl.constexpr,\n KV_REPEAT_INTERLEAVE: int,\n N: int, \n T_DST: int, \n T_SRC: int, \n B_DST: int, \n B_SRC: int, \n HID: tl.constexpr, \n SPARQ_HID: tl.constexpr,\n SPARQ_HID_HALF: tl.constexpr,\n N_COMPLETED: int,\n N_ITERATION: int,\n stride_keys_vllm_num_blcoks, \n stride_keys_vllm_num_kv_heads,\n stride_keys_vllm_head_size_x,\n stride_keys_vllm_block_size,\n stride_keys_vllm_x,\n VLLM_NUM_BLOCKS: int, \n VLLM_NUM_KV_HEADS: int,\n VLLM_HEAD_SIZE_X: int,\n VLLM_BLOCK_SIZE: tl.constexpr,\n VLLM_X: int, \n VLLM_HEAD_SIZE: int,\n BLOCK_TABLES, \n stride_block_tables_num_seqs, \n stride_block_tables_max_num_blocks_per_seq,\n CONTEXT_LENGTH,\n stride_context_length_num_seqs,\n ROPE_METHOD: tl.constexpr,\n ROPE_COS, stride_rope_cos_idx, stride_rope_cos_hid,\n ROPE_SIN, stride_rope_sin_idx, stride_rope_sin_hid,\n POSITION_IDS, stride_position_ids_n, stride_position_ids_tdst,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n MAX_KS, stride_max_ks_n, stride_max_ks_bdst,\n SELECTED_MAX_KS: tl.constexpr,\n USING_SCORE_CACHE: tl.constexpr,\n KEY_CACHE_METHOD: tl.constexpr,\n SPARQ: tl.constexpr,\n REDUCE_METHOD: tl.constexpr,\n BLOCK_MASK_K: tl.constexpr, \n BLOCK_MASK_K_PADDED: tl.constexpr,\n BLOCK_TMASK_K: tl.constexpr, \n BLOCK_TMASK_K_PADDED: tl.constexpr,\n BLOCK_MASK_K_HALF: tl.constexpr, \n BLOCK_MASK_K_HALF_PADDED: tl.constexpr,\n BLOCK_TMASK_K_HALF: tl.constexpr, \n BLOCK_TMASK_K_HALF_PADDED: tl.constexpr,\n BLOCK_MAX_DUP: tl.constexpr,\n BLOCK_HID: tl.constexpr,\n BLOCK_SIZE_Q: tl.constexpr,\n BLOCK_SIZE_Q_PADDED: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_K_PADDED: tl.constexpr,\n REDUCE_STRDIE: tl.constexpr,\n SAMPLING_METHOD: tl.constexpr,\n GRID_SRC_STRIDE: tl.constexpr,\n GRID_K_STRIDE: tl.constexpr,\n USING_SLIDING_WINDOW: tl.constexpr,\n SLIDING_WINDOW_SIZE: tl.constexpr,\n):\n # Code logic with appropriate triton operations...\n pass\n\ndef masking_iteration(\n queries: Tensor, keys: Union[Tensor, \"PagedKeyCacheVllmCompat\"], attention_mask: Tensor,\n mask: Tensor, t_mask: Tensor, sparq_indices, sparq_indices_strides,\n ws: Tensor, ks: Tensor, t_srcs: Tensor, \n scale_up: float, n_patches: int, mask_k: int, is_causal: bool,\n i_iteration: int, n_iteration: int,\n ROPE_METHOD: str,\n ROPE_COS: Optional[Tensor],\n ROPE_SIN: Optional[Tensor],\n POSITION_IDS: Optional[Tensor],\n SELF_EXTEND_SCALE: int,\n SELF_EXTEND_WINDOW: int,\n maximum_ks: Optional[Tensor],\n maximum_ks_config: Optional[List[int]],\n KV_REPEAT_INTERLEAVE: int,\n N: int, \n T_DST: int, \n T_SRC: int, \n B_DST: int, \n B_SRC: int, \n HID: int, \n SPARQ: bool, \n SPARQ_HID: int,\n N_COMPLETED: int,\n BLOCK_SIZE_Q: int, \n BLOCK_SIZE_K: int, \n REDUCE_METHOD: str,\n REDUCE_STRIDE: int,\n SAMPLING_METHOD: str,\n GRID_SRC_STRIDE: int,\n GRID_K_STRIDE: int,\n USING_SLIDING_WINDOW: bool,\n SLIDING_WINDOW_SIZE: int,\n DEBUG: bool = False,\n):\n if DEBUG:\n print(\n 'masking_iteration', \n queries.shape, queries.data_ptr(), \n keys.shape, keys.data_ptr(), \n mask.shape, mask.data_ptr(),\n t_mask.shape, t_mask.data_ptr(),\n ws.shape, ws.data_ptr(),\n ks.shape, ks.data_ptr(),\n t_srcs.shape, t_srcs.data_ptr(),\n N, T_DST, T_SRC, B_DST, B_SRC, HID,\n BLOCK_SIZE_Q,\n BLOCK_SIZE_K,\n REDUCE_METHOD,\n GRID_SRC_STRIDE, \n GRID_K_STRIDE,\n )\n\n if ROPE_METHOD == 'self_extend':\n q_scale = 1 / math.sqrt(HID)\n queries_neighbor = apply_rotary_pos_emb(\n queries / q_scale, \n None, \n ROPE_COS, \n ROPE_SIN, \n POSITION_IDS\n )[0] * q_scale\n queries_grouped = apply_rotary_pos_emb(\n queries / q_scale, \n None, \n ROPE_COS, \n ROPE_SIN, \n POSITION_IDS // SELF_EXTEND_SCALE + SELF_EXTEND_WINDOW - SELF_EXTEND_WINDOW // SELF_EXTEND_SCALE\n )[0] * q_scale\n queries = queries_neighbor\n else:\n queries_grouped = None\n\n BLOCK_MASK_K = triton.next_power_of_2(mask.shape[-1])\n BLOCK_TMASK_K = triton.next_power_of_2(t_mask.shape[-1])\n\n BLOCK_HID = triton.next_power_of_2(HID)\n if SPARQ:\n BLOCK_HID = triton.next_power_of_2(max(16, SPARQ_HID))\n\n if isinstance(keys, Tensor):\n KEY_CACHE_METHOD = 'cont'\n stride_keys_vllm = (0, 0, 0, 0, 0)\n VLLM_NUM_BLOCKS = 0\n VLLM_NUM_KV_HEADS = 0\n VLLM_HEAD_SIZE_X = 0\n VLLM_BLOCK_SIZE = 0\n VLLM_X = 0\n VLLM_HEAD_SIZE = 0\n block_tables = keys\n block_tables_stride = (0, 0)\n context_length = None\n context_length_stride = (0,)\n elif isinstance(keys, PagedKeyCacheVllmCompat):\n KEY_CACHE_METHOD = 'vllm'\n stride_keys_vllm = keys.key_cache.stride()\n (\n VLLM_NUM_BLOCKS, \n VLLM_NUM_KV_HEADS, \n VLLM_HEAD_SIZE_X, \n VLLM_BLOCK_SIZE, \n VLLM_X\n ) = keys.key_cache.shape\n VLLM_HEAD_SIZE = VLLM_HEAD_SIZE_X * VLLM_X\n block_tables = keys.block_table\n block_tables_stride = block_tables.stride()\n context_length = keys.context_length\n context_length_stride = context_length.stride()\n else:\n raise Exception()\n\n USING_SCORE_CACHE = False\n if USING_SCORE_CACHE:\n scores = torch.full_like(mask, 32000.0, dtype=torch.float16)\n else:\n scores = None\n\n ws_out = torch.empty_like(ws)\n ks_out = torch.empty(\n (N, B_DST, GRID_SRC_STRIDE),\n dtype=torch.int64,\n device=queries.device,\n )\n\n if ROPE_METHOD in ['self_extend']:\n rope_cos_stride = ROPE_COS.stride()\n rope_sin_stride = ROPE_SIN.stride()\n position_ids_stride = POSITION_IDS.stride()\n else:\n rope_cos_stride = (0, 0)\n rope_sin_stride = (0, 0)\n position_ids_stride = (0, 0)\n \n grid = (GRID_SRC_STRIDE, B_DST - N_COMPLETED, N)\n\n assert REDUCE_METHOD in ['max', 'sum', 'first']\n\n assert queries.ndim == 3\n assert keys.ndim == 3\n if attention_mask is not None:\n assert attention_mask.ndim == 2\n assert mask.ndim == 4\n assert t_mask.ndim == 4\n assert ws.ndim == 2\n assert ws_out.ndim == 2\n assert ks.ndim == 2\n assert ks_out.ndim == 3\n assert t_srcs.ndim == 2\n\n if maximum_ks is not None:\n maximum_ks_stride = maximum_ks.stride()\n maximum_ks_config = list([math.ceil(x / (BLOCK_SIZE_K * GRID_SRC_STRIDE)) for x in maximum_ks_config])\n else:\n maximum_ks_stride = (0, 0)\n\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(queries.device)\n if maximum_ks is not None:\n calculated_maximum_ks_config = []\n for max_k in maximum_ks_config:\n calculated_maximum_ks_config.append((\n max_k,\n max(maximum_ks_config) // max_k,\n ))\n \n for selected_max_k, scale in calculated_maximum_ks_config:\n _BLOCK_MASK_K = BLOCK_MASK_K // scale\n _BLOCK_TMASK_K = BLOCK_TMASK_K // scale\n \n _masking_iteration_compute[grid](\n queries, *queries.stride(),\n queries_grouped,\n keys, *keys.stride(),\n attention_mask, *(attention_mask.stride() if attention_mask is not None else (0, 0)),\n sparq_indices, *sparq_indices_strides,\n mask, *mask.stride(),\n t_mask, *t_mask.stride(),\n ws, *ws.stride(),\n ks, *ks.stride(),\n ws_out, *ws_out.stride(),\n ks_out, *ks_out.stride(),\n t_srcs, *t_srcs.stride(),\n scores, *(scores.stride() if scores is not None else (0, 0, 0)),\n float(scale_up), \n int(triton.cdiv(n_patches, GRID_K_STRIDE)) // scale, \n int(mask.shape[-1]) // scale, \n int(t_mask.shape[-1]) // scale, \n is_causal,\n KV_REPEAT_INTERLEAVE, \n N, \n T_DST, \n T_SRC, \n int(B_DST), \n int(B_SRC), \n HID, \n SPARQ_HID, \n SPARQ_HID // 2 if SPARQ_HID > 16 else SPARQ_HID,\n N_COMPLETED,\n min(n_iteration, int(os.getenv('HIP_DEBUG_LIMIT_N_ITER', '99999999'))),\n *stride_keys_vllm,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n VLLM_HEAD_SIZE,\n block_tables, *block_tables_stride,\n context_length, *context_length_stride,\n ROPE_METHOD,\n ROPE_COS, *rope_cos_stride,\n ROPE_SIN, *rope_sin_stride,\n POSITION_IDS, *position_ids_stride,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n maximum_ks, *maximum_ks_stride,\n selected_max_k,\n USING_SCORE_CACHE,\n KEY_CACHE_METHOD,\n SPARQ,\n REDUCE_METHOD,\n _BLOCK_MASK_K,\n next_multiple_of(_BLOCK_MASK_K),\n _BLOCK_TMASK_K,\n next_multiple_of(_BLOCK_TMASK_K),\n _BLOCK_MASK_K // 2,\n next_multiple_of(_BLOCK_MASK_K // 2),\n _BLOCK_TMASK_K // 2,\n next_multiple_of(_BLOCK_TMASK_K // 2),\n triton.next_power_of_2(math.ceil(scale_up)),\n int(BLOCK_HID),\n int(BLOCK_SIZE_Q),\n next_multiple_of(triton.cdiv(BLOCK_SIZE_Q, REDUCE_STRIDE), 16),\n int(BLOCK_SIZE_K),\n next_multiple_of(BLOCK_SIZE_K, 1),\n REDUCE_STRIDE,\n SAMPLING_METHOD,\n GRID_SRC_STRIDE,\n GRID_K_STRIDE,\n USING_SLIDING_WINDOW,\n SLIDING_WINDOW_SIZE,\n num_warps=8,\n num_stages=2,\n )\n else:\n _masking_iteration_compute[grid](\n queries, *queries.stride(),\n queries_grouped,\n keys, *keys.stride(),\n attention_mask, *(attention_mask.stride() if attention_mask is not None else (0, 0)),\n sparq_indices, *sparq_indices_strides,\n mask, *mask.stride(),\n t_mask, *t_mask.stride(),\n ws, *ws.stride(),\n ks, *ks.stride(),\n ws_out, *ws_out.stride(),\n ks_out, *ks_out.stride(),\n t_srcs, *t_srcs.stride(),\n scores, *(scores.stride() if scores is not None else (0, 0, 0)),\n float(scale_up), \n int(triton.cdiv(n_patches, GRID_K_STRIDE)), \n int(mask.shape[-1]), \n int(t_mask.shape[-1]), \n is_causal,\n KV_REPEAT_INTERLEAVE, \n N, \n T_DST, \n T_SRC, \n int(B_DST), \n int(B_SRC), \n HID, \n SPARQ_HID, \n SPARQ_HID // 2 if SPARQ_HID > 16 else SPARQ_HID,\n N_COMPLETED,\n min(n_iteration, int(os.getenv('HIP_DEBUG_LIMIT_N_ITER', '99999999'))),\n *stride_keys_vllm,\n VLLM_NUM_BLOCKS,\n VLLM_NUM_KV_HEADS,\n VLLM_HEAD_SIZE_X,\n VLLM_BLOCK_SIZE,\n VLLM_X,\n VLLM_HEAD_SIZE,\n block_tables, *block_tables_stride,\n context_length, *context_length_stride,\n ROPE_METHOD,\n ROPE_COS, *rope_cos_stride,\n ROPE_SIN, *rope_sin_stride,\n POSITION_IDS, *position_ids_stride,\n SELF_EXTEND_SCALE,\n SELF_EXTEND_WINDOW,\n maximum_ks, *maximum_ks_stride,\n 0,\n USING_SCORE_CACHE,\n KEY_CACHE_METHOD,\n SPARQ,\n REDUCE_METHOD,\n BLOCK_MASK_K,\n next_multiple_of(BLOCK_MASK_K),\n BLOCK_TMASK_K,\n next_multiple_of(BLOCK_TMASK_K),\n BLOCK_MASK_K // 2,\n next_multiple_of(BLOCK_MASK_K // 2),\n BLOCK_TMASK_K // 2,\n next_multiple_of(BLOCK_TMASK_K // 2),\n triton.next_power_of_2(math.ceil(scale_up)),\n int(BLOCK_HID),\n int(BLOCK_SIZE_Q),\n next_multiple_of(triton.cdiv(BLOCK_SIZE_Q, REDUCE_STRIDE), 16),\n int(BLOCK_SIZE_K),\n next_multiple_of(BLOCK_SIZE_K, 1),\n REDUCE_STRIDE,\n SAMPLING_METHOD,\n GRID_SRC_STRIDE,\n GRID_K_STRIDE,\n USING_SLIDING_WINDOW,\n SLIDING_WINDOW_SIZE,\n num_warps=8,\n num_stages=2,\n )\n torch.cuda.set_device(orig_device)\n \n ks_out = ks_out.sum(-1)\n \n if GRID_SRC_STRIDE > 1:\n mask = mask.flatten(-2, -1)\n mask = mask.sort(dim=-1).values\n else:\n mask = mask.flatten(-2, -1)\n \n return mask, ws_out, ks_out\n", - "description_1": "Use triton language to implement kernels for computing ascending k-th values in a sorted list (_triton_kth_ascending) and for iterating through masked operations to efficiently compute scoring (_masking_iteration_topk) as well as iterate through the computations for a given context (_masking_iteration_compute).", - "description_2": "Use triton language to implement efficient top-k masking iteration and context-based computation for neural network operations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch import Tensor\n\n@triton.jit\ndef _safe_indices_compute(\n MASK, stride_mask_n, stride_mask_tdst, stride_mask_k,\n WS, stride_ws_n, stride_ws_tdst, stride_ws_k,\n INDICES, stride_indices_n, stride_indices_tdst, stride_indices_k,\n N, TDST, K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n ALLOW_COLLISION: tl.constexpr,\n BLOCK_N_TDST: tl.constexpr,\n BLOCK_K: tl.constexpr,\n COLLISION_METHOD: tl.constexpr = 'biased',\n):\n if not ALLOW_COLLISION:\n pids = tl.program_id(0) * BLOCK_N_TDST + tl.arange(0, BLOCK_N_TDST)\n idx_n = pids // TDST\n mask_n = idx_n < N\n idx_tdst = pids % TDST\n mask_tdst = idx_tdst < TDST\n mask = mask_n & mask_tdst\n \n if COLLISION_METHOD == 'biased':\n last_col = tl.zeros((BLOCK_N_TDST, ), dtype=tl.int64) - 1\n for _idx_k in range(K):\n mask_vec = tl.load(\n MASK +\\\n idx_n * stride_mask_n +\\\n idx_tdst * stride_mask_tdst +\\\n _idx_k * stride_mask_k,\n mask = mask,\n other = 0\n )\n ws_vec = tl.load(\n WS +\\\n idx_n * stride_ws_n +\\\n idx_tdst * stride_ws_tdst +\\\n _idx_k * stride_ws_k,\n mask = mask,\n other = 0\n )\n indices_float = mask_vec * ws_vec\n col = tl.math.ceil(indices_float / BLOCK_SIZE_K).to(tl.int32)\n col = tl.maximum(last_col + 1, col)\n last_col = col\n col = col * BLOCK_SIZE_K\n tl.store(\n INDICES +\\\n idx_n * stride_indices_n +\\\n idx_tdst * stride_indices_tdst +\\\n _idx_k * stride_indices_k,\n value = col,\n mask = mask\n )\n\ndef safe_indices(mask: Tensor, ws, block_size_k, allow_collision=False):\n N, TDST, K = mask.shape\n ws = ws.unsqueeze(-1).expand(N, TDST, K)\n indices = torch.empty((N, TDST, K), dtype=torch.int32, device=mask.device)\n BLOCK_N_TDST = 32\n BLOCK_K = 128\n\n if not allow_collision:\n grid = (triton.cdiv(N*TDST, BLOCK_N_TDST), )\n else:\n grid = (triton.cdiv(K, BLOCK_K), triton.cdiv(N*TDST, BLOCK_N_TDST), )\n\n orig_device = torch.cuda.current_device()\n torch.cuda.set_device(mask.device)\n _safe_indices_compute[grid](\n mask, *mask.stride(),\n ws, *ws.stride(),\n indices, *indices.stride(),\n N, TDST, K, block_size_k,\n allow_collision,\n BLOCK_N_TDST,\n BLOCK_K,\n num_warps=4 if allow_collision else 1,\n )\n torch.cuda.set_device(orig_device)\n return indices\n", - "description_1": "Use triton language to implement a kernel function '_safe_indices_compute' that calculates safe indices for tensors, avoiding collision if specified. It takes in multiple tensor strides and dimensions as parameters and outputs the computed indices. The wrapper function 'safe_indices' sets up tensor dimensions, grids for kernel execution, and invokes the kernel function with the appropriate configuration.", - "description_2": "Use triton language to create a kernel that computes safe indices with optional collision handling and a Python function to invoke this kernel, setting up necessary configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Function\n\n@triton.jit\ndef _triton_kth_large(\n scores: tl.tensor, k: tl.tensor,\n BLOCK_SCORES: tl.constexpr,\n) -> tl.tensor:\n sorted_score = tl.sort(scores)\n sorted_score_mask = tl.arange(0, BLOCK_SCORES) < k\n return tl.max(sorted_score * sorted_score_mask + (-32000.0) * (~sorted_score_mask))\n\n@triton.jit\ndef _masking_iteration_compute(\n queries, stride_queries_n, stride_queries_tdst, stride_queries_hid,\n keys, stride_keys_n, stride_keys_tsrc, stride_keys_hid,\n mask, stride_mask_n, stride_mask_tdst, stride_mask_k,\n tmask, stride_tmask_n, stride_tmask_tdst, stride_tmask_k,\n scores_out, stride_scores_out_n, stride_scores_out_tdst, stride_scores_out_k,\n ws, stride_ws_n, stride_ws_tdst,\n ks, stride_ks_n, stride_ks_tdst,\n tsrcs, stride_tsrcs_n, stride_tsrcs_tdst,\n scale_up: float, n_patches: int, mask_k: int,\n N, T_DST, T_SRC, HID,\n GROUP_N,\n GROUP_TDST,\n BLOCK_MASK_K: tl.constexpr, \n BLOCK_TMASK_K: tl.constexpr, \n BLOCK_MAX_DUP: tl.constexpr,\n BLOCK_HID: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n for _idx_n in range(GROUP_N):\n idx_n = _idx_n + GROUP_N * pid_n\n if idx_n < N:\n pid_tdst = tl.program_id(1)\n for _idx_tdst in range(GROUP_TDST):\n idx_tdst = pid_tdst * GROUP_TDST + _idx_tdst\n if idx_tdst < T_DST:\n w_old = tl.load(\n ws + \\\n idx_n * stride_ws_n + \\\n idx_tdst * stride_ws_tdst,\n )\n t_src = tl.load(\n tsrcs + \\\n idx_n * stride_tsrcs_n + \\\n idx_tdst * stride_tsrcs_tdst,\n )\n w_new = tl.minimum(\n tl.math.round(w_old.to(tl.float32) * scale_up.to(tl.float32)).to(tl.float32), \n t_src\n ).to(tl.int64)\n if w_old != w_new:\n k_old = tl.load(\n ks + \\\n idx_n * stride_ks_n +\\\n idx_tdst * stride_ks_tdst,\n ).to(tl.int64)\n k_new = tl.maximum(\n n_patches, \n (\n tl.minimum(\n mask_k.to(tl.float32) / t_src.to(tl.float32), \n 1.0\n ) * w_new.to(tl.float32)\n ).to(tl.int64)\n )\n k_new = tl.minimum(t_src, tl.maximum(n_patches, k_new))\n k_old_range = tl.arange(0, BLOCK_MASK_K)\n k_old_mask = k_old_range < k_old\n loc_vec = tl.load(\n mask +\\\n idx_n * stride_mask_n +\\\n idx_tdst * stride_mask_tdst +\\\n k_old_range * stride_mask_k,\n mask = k_old_mask,\n other = 0\n )\n loc_idx_start_vec = (loc_vec * w_old).to(tl.int64)\n loc_idx_end_vec = loc_idx_start_vec + 1\n loc_idx_start_vec = (loc_idx_start_vec.to(tl.float32) / w_old.to(tl.float32) * w_new.to(tl.float32)).to(tl.int64)\n loc_idx_end_vec = (loc_idx_end_vec.to(tl.float32) / w_old.to(tl.float32) * w_new.to(tl.float32)).to(tl.int64)\n dup_pixels_vec = loc_idx_end_vec - loc_idx_start_vec\n dup_pixels_vec = dup_pixels_vec * k_old_mask\n num_pixels_vec = tl.cumsum(dup_pixels_vec)\n dup_pixels_first = tl.min(num_pixels_vec)\n num_pixels_scalar = tl.max(num_pixels_vec)\n dup_pixels_range = tl.arange(0, BLOCK_MAX_DUP)\n dup_pixels_mask = (dup_pixels_range[None, :] <= dup_pixels_vec[:, None]) & k_old_mask[:, None]\n tl.store(\n tmask + \\\n idx_n * stride_tmask_n +\\\n idx_tdst * stride_tmask_tdst +\\\n ((num_pixels_vec - dup_pixels_first)[:, None] + dup_pixels_range[None, :]) * stride_tmask_k,\n mask=dup_pixels_mask,\n value=(\n (loc_idx_start_vec[:, None] + tl.arange(0, BLOCK_MAX_DUP)[None, :]).to(tl.float32) / w_new.to(tl.float32)\n )\n )\n if k_new < num_pixels_scalar and True:\n scores = tl.zeros((BLOCK_TMASK_K,), dtype=tl.float32)\n for _idx_hid in range(tl.cdiv(HID, BLOCK_HID)):\n hid_range = tl.arange(0, BLOCK_HID) + _idx_hid * BLOCK_HID\n hid_mask = hid_range < HID\n vec_q = tl.load(\n queries +\\\n idx_n * stride_queries_n +\\\n idx_tdst * stride_queries_tdst +\\\n (hid_range[None, :] + tl.arange(0, 16)[:, None]) * stride_queries_hid,\n mask = (hid_mask[None, :] & (tl.arange(0, 16)[:, None] < 1)),\n other = 0,\n )\n num_pixels_range = tl.arange(0, BLOCK_TMASK_K)\n num_pixels_mask = num_pixels_range < num_pixels_scalar\n loc_k_vec = tl.load(\n tmask +\\\n idx_n * stride_tmask_n +\\\n idx_tdst * stride_tmask_tdst +\\\n num_pixels_range * stride_tmask_k,\n mask = num_pixels_mask,\n other = 0,\n )\n loc_k_vec = (loc_k_vec.to(tl.float32) * t_src.to(tl.float32)).to(tl.int64)\n vec_k_mask = num_pixels_mask[None, :] & hid_mask[:, None]\n vec_k = tl.load(\n keys +\\\n idx_n * stride_keys_n +\\\n loc_k_vec[None, :] * stride_keys_tsrc + \\\n hid_range[:, None] * stride_keys_hid,\n mask = vec_k_mask,\n other = 0,\n )\n scores_partial = -tl.dot(vec_q, vec_k, allow_tf32=True)\n scores_partial = tl.sum(scores_partial, axis=0)\n scores_partial = scores_partial + (~num_pixels_mask) * 32000.0\n scores += scores_partial.to(scores.dtype)\n masked_scores = scores\n scores_kth_large = _triton_kth_large(masked_scores, k_new, BLOCK_TMASK_K)\n topk_mask = masked_scores <= scores_kth_large\n topk_mask_cumsum = tl.cumsum(topk_mask.to(tl.int64))\n topk_range = tl.minimum((topk_mask_cumsum - 1) * topk_mask, k_new - 1)\n temp_range = tl.arange(0, BLOCK_TMASK_K)\n temp_mask = temp_range < num_pixels_scalar\n temp = tl.load(\n tmask +\\\n idx_n * stride_tmask_n +\\\n idx_tdst * stride_tmask_tdst +\\\n temp_range * stride_tmask_k,\n mask=temp_mask,\n other=0\n )\n tl.store(\n mask +\\\n idx_n * stride_mask_n +\\\n idx_tdst * stride_mask_tdst +\\\n topk_range * stride_mask_k,\n mask=topk_mask & temp_mask,\n value=temp,\n )\n else:\n temp1_range = tl.arange(0, BLOCK_MASK_K)\n temp1_mask = temp1_range < num_pixels_scalar\n temp1 = tl.load(\n tmask +\\\n idx_n * stride_tmask_n +\\\n idx_tdst * stride_tmask_tdst +\\\n temp1_range * stride_tmask_k,\n mask=temp1_mask,\n )\n tl.store(\n mask +\\\n idx_n * stride_mask_n +\\\n idx_tdst * stride_mask_tdst +\\\n temp1_range * stride_mask_k,\n mask=temp1_mask,\n value=temp1,\n )\n tl.store(\n ws +\\\n idx_n * stride_ws_n +\\\n idx_tdst * stride_ws_tdst,\n value = w_new\n )\n tl.store(\n ks +\\\n idx_n * stride_ks_n +\\\n idx_tdst * stride_ks_tdst,\n value = tl.minimum(k_new, num_pixels_scalar)\n )\n\ndef masking_iteration(\n queries: Tensor, keys: Tensor, mask: Tensor, t_mask: Tensor, scores: Tensor, \n ws: Tensor, ks: Tensor, t_srcs: Tensor, \n scale_up: float, n_patches: int, mask_k: int, \n N: int, T_DST: int, T_SRC: int, HID: int,\n):\n GROUP_N = 1\n GROUP_TDST = 4\n BLOCK_HID = 16\n grid = (triton.cdiv(N, GROUP_N), triton.cdiv(T_DST, GROUP_TDST))\n \n _masking_iteration_compute[grid](\n queries, queries.stride(0), queries.stride(1), queries.stride(2),\n keys, keys.stride(0), keys.stride(1), keys.stride(2),\n mask, mask.stride(0), mask.stride(1), mask.stride(2),\n t_mask, t_mask.stride(0), t_mask.stride(1), t_mask.stride(2),\n scores, scores.stride(0), scores.stride(1), scores.stride(2),\n ws, ws.stride(0), ws.stride(1),\n ks, ks.stride(0), ks.stride(1),\n t_srcs, t_srcs.stride(0), t_srcs.stride(1),\n float(scale_up), int(n_patches), int(mask_k),\n N, T_DST, T_SRC, HID,\n GROUP_N,\n GROUP_TDST,\n triton.next_power_of_2(mask.shape[-1]),\n triton.next_power_of_2(t_mask.shape[-1]),\n triton.next_power_of_2(math.ceil(scale_up)),\n BLOCK_HID,\n num_warps=4,\n num_stages=1,\n enable_warp_specialization=True,\n )\n", - "description_1": "Use triton language to perform efficient computations on tensors. The provided Triton kernels implement operations like sorting and masking of large tensors using parallel programming models. This includes finding the k-th largest value in a tensor and iteratively computing masked attention matrices. Each function is carefully constructed to handle specific data dimensions and parallel workloads efficiently, with the _masking_iteration_compute function focusing on masking and updating query scores based on input matrices and operational parameters.", - "description_2": "Use triton language to implement parallel masking iteration and kth largest value computation on tensors efficiently.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch import Tensor\nfrom typing import Optional\n\n@triton.jit\ndef masking_iteration_draft_cuda_initialize(\n # in\n INDICES_SEED, \n stride_indices_seed_b, \n stride_indices_seed_bdst, \n stride_indices_seed_bk,\n KS_SEED,\n stride_ks_seed_b,\n stride_ks_seed_bdst,\n POS, stride_pos_tdst,\n \n # out\n INDICES, stride_indices_b, stride_indices_bdst, stride_indices_bk,\n KS, stride_ks_b, stride_ks_bdst,\n GROUP_SIZE, stride_group_size_b, stride_group_size_bdst, stride_group_size_bk,\n \n # temp\n T_GROUP_SIZE, stride_t_group_size_b, stride_t_group_size_bdst,\n \n # param\n mask_k: int,\n block_size_q: tl.constexpr,\n block_size_k: tl.constexpr,\n \n sliding_window_size: int,\n \n G, MAX_TDST, MAX_TSRC, \n \n BLOCK_MASK_BLOCK_K: tl.constexpr,\n):\n idx_b = tl.program_id(0)\n idx_bdst = tl.program_id(1)\n idx_group = tl.program_id(2)\n idx_tdst = tl.arange(0, block_size_q) + idx_bdst * block_size_q\n mask_tdst = idx_tdst < MAX_TDST\n \n mask_block_k = tl.cdiv(mask_k, block_size_k)\n pos_tdst = tl.load(\n POS +\\\n idx_tdst * stride_pos_tdst,\n mask=mask_tdst,\n )\n TSRC = tl.max(pos_tdst)\n TSRC = tl.maximum(0, TSRC - sliding_window_size)\n BSRC = tl.cdiv(TSRC, block_size_k)\n MAX_BSRC = tl.cdiv(MAX_TSRC, block_size_k)\n \n if TSRC <= mask_k:\n idx_bk = tl.arange(0, BLOCK_MASK_BLOCK_K)\n mask_bk = idx_bk < BSRC\n tl.store(\n INDICES +\\\n idx_b * stride_indices_b +\\\n idx_bdst * stride_indices_bdst +\\\n (idx_group * BSRC + idx_bk) * stride_indices_bk,\n value = idx_group * MAX_BSRC + idx_bk,\n mask = mask_bk,\n )\n \n if idx_group == 0:\n tl.store(\n KS +\\\n idx_b * stride_ks_b +\\\n idx_bdst * stride_ks_bdst,\n value = BSRC * G\n )\n else:\n idx_bk = tl.arange(0, BLOCK_MASK_BLOCK_K)\n mask_bk = idx_bk < mask_block_k\n \n ks = 0\n if KS_SEED is not None:\n ks = tl.load(\n KS_SEED +\\\n idx_b * stride_ks_seed_b +\\\n idx_bdst * stride_ks_seed_bdst,\n )\n \n indices = (MAX_BSRC * idx_group + (BSRC / mask_block_k * idx_bk)).to(tl.int32)\n group_sizes = tl.minimum(\n BSRC, \n (\n BSRC / mask_block_k * (idx_bk + 1).to(tl.int32) -\\\n (BSRC / mask_block_k * idx_bk).to(tl.int32)\n )\n ).to(tl.int32)\n if INDICES_SEED is not None:\n if ks == (mask_block_k * G):\n indices = tl.load(\n INDICES_SEED +\\\n idx_b * stride_indices_seed_b +\\\n idx_bdst * stride_indices_seed_bdst +\\\n (idx_group * mask_block_k + idx_bk) * stride_indices_seed_bk,\n mask=mask_bk,\n other=idx_group * MAX_BSRC,\n )\n indices_next = tl.load(\n INDICES_SEED +\\\n idx_b * stride_indices_seed_b +\\\n idx_bdst * stride_indices_seed_bdst +\\\n (idx_group * mask_block_k + idx_bk + 1) * stride_indices_seed_bk,\n mask=(\n mask_bk &\n ((idx_group * mask_block_k + idx_bk + 1) < (BLOCK_MASK_BLOCK_K * G))\n ),\n other=G * MAX_BSRC,\n )\n indices_group_id = indices // MAX_BSRC\n indices_next_group_id = indices_next // MAX_BSRC\n group_sizes = tl.where(\n indices_group_id == indices_next_group_id,\n indices_next - indices,\n indices_group_id * MAX_BSRC + BSRC - indices,\n ).to(tl.int32)\n \n tl.store(\n INDICES +\\\n idx_b * stride_indices_b +\\\n idx_bdst * stride_indices_bdst +\\\n (idx_group * mask_block_k + idx_bk) * stride_indices_bk,\n value=indices,\n mask=mask_bk,\n )\n tl.store(\n GROUP_SIZE +\\\n idx_b * stride_group_size_b +\\\n idx_bdst * stride_group_size_bdst +\\\n (idx_group * mask_block_k + idx_bk) * stride_group_size_bk,\n value=group_sizes,\n mask=mask_bk,\n )\n \n tl.atomic_max(\n T_GROUP_SIZE +\\\n idx_b * stride_t_group_size_b +\\\n idx_bdst * stride_t_group_size_bdst,\n # val = tl.max(group_sizes)\n val = tl.minimum(tl.max(group_sizes), tl.cdiv(BSRC, mask_block_k))\n )\n tl.atomic_add(\n KS +\\\n idx_b * stride_ks_b +\\\n idx_bdst * stride_ks_bdst,\n val = mask_block_k\n )\n\ndef masking_iteration_draft( \n q: Tensor,\n k: Tensor,\n position_ids: Tensor,\n mask_k: int,\n block_size_q: int,\n block_stride_q: int,\n block_size_k: int,\n block_size_k_group: int,\n sliding_window_size: int,\n sink_token_size: int,\n using_extend: bool,\n rope_cos: Optional[Tensor],\n rope_sin: Optional[Tensor],\n self_extend_neighboor_window: int,\n self_extend_group_size: int,\n topk_head_group_size: int,\n sample_method: str,\n branch_method: str,\n score_head_group_size: int,\n sparq_ind: Optional[Tensor],\n \n indices_seed: Optional[Tensor] = None,\n ks_seed: Optional[Tensor] = None,\n scores_seed: Optional[Tensor] = None,\n group_size_seed: Optional[Tensor] = None,\n):\n assert q.device == k.device\n assert isinstance(q, Tensor)\n assert isinstance(k, Tensor)\n \n if rope_cos is not None:\n assert rope_cos.ndim == 2\n assert rope_cos.shape[-1] == q.shape[-1]\n assert isinstance(rope_cos, Tensor)\n \n if rope_sin is not None:\n assert rope_sin.ndim == 2\n assert rope_sin.shape[-1] == q.shape[-1]\n assert isinstance(rope_sin, Tensor)\n assert isinstance(rope_sin, Tensor)\n \n N, TDST, HID = q.shape\n _, TSRC, _ = k.shape\n BDST = (TDST + block_size_q - 1) // block_size_q\n BSRC = (TSRC + block_size_k - 1) // block_size_k\n \n assert (N % topk_head_group_size) == 0, 'batch * n_head should divisible by head group size'\n \n # split batch-head dim into head groups\n q = q.view(N // topk_head_group_size, topk_head_group_size, TDST, HID)\n k = k.view(N // topk_head_group_size, topk_head_group_size, TSRC, HID)\n \n B, G, TDST, HID = q.shape\n _, _, TSRC, _ = k.shape\n mask_block_k = (mask_k + block_size_k - 1) // block_size_k\n \n assert block_size_k_group == 1\n if block_size_k_group > 1:\n k_group = k.view(B, G, (TSRC + block_size_k_group - 1) // block_size_k_group, block_size_k_group, HID)\n k_group_min = torch.min(k_group, dim=-2)\n k_group_max = torch.max(k_group, dim=-2)\n k = torch.concat([k_group_min, k_group_max], dim=-1)\n del block_size_k_group\n \n indices = torch.full(\n (\n B,\n (TDST + block_size_q - 1) // block_size_q, \n # head group is merged as single sequence\n G * mask_block_k,\n ), \n fill_value=(BSRC + block_size_k + block_size_q) * G, \n dtype=torch.int32, \n device=q.device\n )\n \n ks = torch.zeros((\n B, \n (TDST + block_size_q - 1) // block_size_q,\n ), dtype=torch.int32, device=q.device)\n \n group_sizes = torch.empty_like(indices)\n t_group_sizes = torch.empty((B, BDST), dtype=torch.float32, device=q.device)\n \n if sparq_ind is None:\n using_sparq = False\n sparq_hid = 0\n else:\n using_sparq = True\n sparq_hid = sparq_ind.shape[-1]\n assert sparq_ind.ndim == 4\n \n assert len(q.stride()) == 4\n assert len(k.stride()) == 4\n assert len(indices.stride()) == 3\n assert len(ks.stride()) == 2\n assert len(group_sizes.stride()) == 3\n assert len(t_group_sizes.stride()) == 2\n if indices_seed is not None:\n assert len(indices_seed.stride()) == 3\n assert len(ks_seed.stride()) == 2\n assert indices_seed.shape == indices.shape\n assert ks_seed.shape == ks.shape\n indices_seed = indices_seed // block_size_k\n if rope_cos is not None:\n assert len(rope_cos.stride()) == 2\n assert len(rope_sin.stride()) == 2\n \n assert sample_method in ['first', 'last', 'random', 'oracle', 'center']\n assert position_ids.ndim == 1\n \n # launch kernels\n BLOCK_MASK_BLOCK_K = triton.next_power_of_2(mask_block_k)\n grid = (B, BDST, G)\n masking_iteration_draft_cuda_initialize[grid](\n indices_seed, *(indices_seed.stride() if indices_seed is not None else (0, 0, 0)),\n ks_seed, *(ks_seed.stride() if ks_seed is not None else (0, 0)),\n position_ids, *position_ids.stride(),\n \n indices, *indices.stride(),\n ks, *ks.stride(),\n group_sizes, *group_sizes.stride(),\n \n t_group_sizes, *t_group_sizes.stride(),\n \n mask_k,\n block_size_q, \n block_size_k, \n \n sliding_window_size,\n \n G, TDST, TSRC, \n \n BLOCK_MASK_BLOCK_K,\n \n # num_warps=min(max((BLOCK_MASK_BLOCK_K + 32 - 1) // 32, 1), 32),\n num_warps=1,\n num_stages=1,\n )\n \n return indices, ks, None, None, None\n", - "description_1": "Use triton language to initialize data structures for block-wise sequence processing, adjusting for masking, block sizes, and constraints on sequence length.", - "description_2": "Use triton language to efficiently initialize indices and group sizes in sequence data for masked block-wise operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef load_tokens(\n ptr, stride_ptr_n, stride_ptr_t, stride_ptr_hid, \n idx_n, idx_t, mask_t, HID: tl.constexpr\n):\n return tl.load(\n ptr +\\\n idx_n * stride_ptr_n +\\\n idx_t[:, None] * stride_ptr_t +\\\n tl.arange(0, HID)[None, :] * stride_ptr_hid,\n mask = mask_t[:, None]\n )\n\n@triton.jit\ndef attention_norm_cuda(\n Q, stride_q_n, stride_q_tdst, stride_q_hid,\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n \n NORM, stride_norm_n, stride_norm_tdst,\n \n TDST, TSRC,\n \n HID: tl.constexpr,\n BLOCK_SIZE_Q: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n idx_n = tl.program_id(0)\n idx_bdst = tl.program_id(1)\n idx_tdst = tl.arange(0, BLOCK_SIZE_Q) + idx_bdst * BLOCK_SIZE_Q\n mask_tdst = idx_tdst < TDST\n \n q = load_tokens(\n Q, stride_q_n, stride_q_tdst, stride_q_hid, \n idx_n, idx_tdst, mask_tdst, HID\n )\n \n score_max = tl.full((BLOCK_SIZE_Q, ), dtype=tl.float32, value=float('-inf'))\n for i_tsrc in range(0, TSRC, BLOCK_SIZE_K):\n idx_tsrc = i_tsrc + tl.arange(0, BLOCK_SIZE_K)\n mask_tsrc = idx_tsrc < TSRC\n \n k = load_tokens(\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n idx_n, idx_tsrc, mask_tsrc, HID,\n )\n \n qk = tl.dot(\n q, k.trans(1, 0),\n allow_tf32=True\n ).to(tl.float32)\n \n qk = tl.where(\n idx_tsrc[None, :] <= idx_tdst[:, None],\n qk, float('-inf')\n )\n \n score_max = tl.maximum(\n score_max,\n tl.max(qk, axis=-1)\n )\n \n exp_score_sum = tl.zeros((BLOCK_SIZE_Q, ), dtype=tl.float32)\n for i_tsrc in range(0, TSRC, BLOCK_SIZE_K):\n idx_tsrc = i_tsrc + tl.arange(0, BLOCK_SIZE_K)\n mask_tsrc = idx_tsrc < TSRC\n \n k = load_tokens(\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n idx_n, idx_tsrc, mask_tsrc, HID,\n )\n \n qk = tl.dot(\n q, k.trans(1, 0),\n allow_tf32=True\n ).to(tl.float32)\n \n qk = tl.where(\n idx_tsrc[None, :] <= idx_tdst[:, None],\n qk, float('-inf')\n )\n \n qk = qk - score_max[:, None]\n qk = tl.exp(qk)\n exp_score_sum += tl.sum(qk, axis=-1)\n \n norm_sum = tl.zeros((BLOCK_SIZE_Q, ), dtype=tl.float64)\n for i_tsrc in range(0, TSRC, BLOCK_SIZE_K):\n idx_tsrc = i_tsrc + tl.arange(0, BLOCK_SIZE_K)\n mask_tsrc = idx_tsrc < TSRC\n \n k = load_tokens(\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n idx_n, idx_tsrc, mask_tsrc, HID,\n )\n \n qk = tl.dot(\n q, k.trans(1, 0),\n allow_tf32=True\n ).to(tl.float32)\n \n qk = tl.where(\n idx_tsrc[None, :] <= idx_tdst[:, None],\n qk, float('-inf')\n )\n \n qk = qk - score_max[:, None]\n prob = tl.exp(qk) / tl.maximum(exp_score_sum[:, None], 1e-20)\n norm_sum += tl.sum(prob * prob, axis=-1)\n \n norm = tl.sqrt(norm_sum)\n \n tl.store(\n NORM +\\\n idx_n * stride_norm_n +\\\n idx_tdst * stride_norm_tdst,\n value=norm,\n mask=mask_tdst,\n )\n\ndef attention_norm(\n q: torch.Tensor,\n k: torch.Tensor,\n):\n \"\"\"\n q: fp*[N, TDST, HID]\n k: fp*[N, TSRC, HID]\n \n # return\n norm: fp32[N, TDST]\n \"\"\"\n assert q.ndim == 3\n assert q.shape == k.shape\n \n N, TDST, HID = q.shape\n _, TSRC, _ = k.shape\n \n norm = torch.zeros((N, TDST), dtype=torch.float32, device=q.device)\n \n BLOCK_SIZE_Q = 32\n BLOCK_SIZE_K = 64\n \n grid = (N, triton.cdiv(TDST, BLOCK_SIZE_Q))\n \n pre_device = torch.get_default_device()\n torch.set_default_device(q.device)\n attention_norm_cuda[grid](\n q, *q.stride(),\n k, *k.stride(),\n norm, *norm.stride(),\n \n TDST, TSRC,\n \n q.shape[-1],\n BLOCK_SIZE_Q, \n BLOCK_SIZE_K,\n \n num_warps=4,\n num_stages=2,\n )\n torch.set_default_device(pre_device)\n \n return norm\n", - "description_1": "Use triton language to define a kernel function `attention_norm_cuda` that computes the attention normalization of two input matrices, Q and K, where each element in Q and K is accessed using specific strides. The kernel includes computations for maximum scores, exponential sum of scores, and normalization sum in a block-wise manner. It also involves another kernel `load_tokens` to load matrix elements with masking. The results are stored in an output matrix NORM. The function `attention_norm` serves as a Python wrapper to set up parameters and launch the kernel on specific grid dimensions.", - "description_2": "Use triton language to implement attention normalization for given input matrices by loading data, computing dot products, applying softmax normalization, and storing results efficiently on GPUs.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, Lse, TMP, softmax_scale,\n stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k,\n seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Kernel code for forward pass in flash attention.\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == \"vector\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == \"matrix\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != \"none\":\n if BIAS_TYPE == \"vector\":\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == \"matrix\":\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out, DO, Delta, stride_ob, stride_oh, stride_om,\n stride_dob, stride_doh, stride_dom, nheads,\n seqlen_q, seqlen_q_rounded, headdim,\n BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n):\n # Kernel code for preprocessing in backward pass.\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n o = tl.load(\n Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],\n mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0,\n ).to(tl.float32)\n do = tl.load(\n DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],\n mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0,\n ).to(tl.float32)\n delta = tl.sum(o * do, axis=1)\n tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale,\n stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm,\n stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm,\n stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Kernel code for the backward pass in flash attention.\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n Q += off_b * stride_qb + off_h * stride_qh\n K += off_b * stride_kb + off_h * stride_kh\n V += off_b * stride_vb + off_h * stride_vh\n DO += off_b * stride_dob + off_h * stride_doh\n DQ += off_b * stride_dqb + off_h * stride_dqh\n DK += off_b * stride_dkb + off_h * stride_dkh\n DV += off_b * stride_dvb + off_h * stride_dvh\n if BIAS_TYPE != \"none\":\n Bias += off_b * stride_bb + off_h * stride_bh\n D += off_hb * seqlen_q_rounded\n LSE += off_hb * seqlen_q_rounded\n if not SEQUENCE_PARALLEL:\n num_block_n = tl.cdiv(seqlen_k, BLOCK_N)\n for start_n in range(0, num_block_n):\n _bwd_kernel_one_col_block(\n start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale,\n stride_qm, stride_kn, stride_vn, stride_bm, stride_dom,\n stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE,\n IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M,\n EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,\n )\n else:\n start_n = tl.program_id(0)\n _bwd_kernel_one_col_block(\n start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale,\n stride_qm, stride_kn, stride_vn, stride_bm, stride_dom,\n stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE,\n IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M,\n EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,\n )\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\")\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q, k, v, bias, o, lse, tmp, softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n o.stride(0), o.stride(2), o.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1,\n )\n return o, lse, softmax_scale\n\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n if do.stride(-1) != 1:\n do = do.contiguous()\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert d <= 128\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n assert lse.shape == (batch, nheads, seqlen_q_rounded)\n assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1\n assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n dq_accum = torch.empty_like(q, dtype=torch.float32)\n delta = torch.empty_like(lse)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _bwd_preprocess_do_o_dot[grid](\n o, do, delta,\n o.stride(0), o.stride(2), o.stride(1),\n do.stride(0), do.stride(2), do.stride(1),\n nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,\n )\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n assert bias.stride(-1) == 1\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\")\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n grid = lambda META: (triton.cdiv(seqlen_k, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1, batch * nheads)\n _bwd_kernel[grid](\n q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n do.stride(0), do.stride(2), do.stride(1),\n dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),\n dk.stride(0), dk.stride(2), dk.stride(1),\n dv.stride(0), dv.stride(2), dv.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n )\n dq.copy_(dq_accum)\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):\n q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]\n o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)\n ctx.save_for_backward(q, k, v, o, lse, bias)\n ctx.causal = causal\n return o, lse\n\n @staticmethod\n def backward(ctx, do, dlse_use_needed=None):\n q, k, v, o, lse, bias = ctx.saved_tensors\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n _flash_attn_backward(\n do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale,\n )\n return dq, dk, dv, None, None, None\n\n\nflash_attn_func = FlashAttnFunc.apply\n", - "description_1": "Use triton language to implement FlashAttention kernels for forward and backward pass. The forward kernel processes inputs Q, K, V, and an optional bias to produce attention outputs and log-sum-exp values. The backward kernel computes gradients with respect to Q, K, V, and optional bias, given the gradient of the output. The kernels are designed to handle different bias types, sequence lengths, and support for causal attention. The function FlashAttnFunc ties these components together to provide autograd support, taking inputs q, k, v, bias, causal, and softmax_scale, returning output and log-sum-exp, and computing appropriate gradients during the backward pass.", - "description_2": "Use triton language to implement efficient FlashAttention kernels with support for different sequence lengths, head dimensions, and optional biases for both forward and backward passes in neural network attention modules.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, \n Out,\n sqz, sqh, sqm, sqd, \n skz, skh, skn, skd, \n svz, svh, svn, svd, \n soz, soh, som, sod, \n L, M,\n Z, H, N_CTX_Q, N_CTX_KV, \n BLOCK: tl.constexpr, \n BLOCK_DMODEL: tl.constexpr, \n N_PREFIX_Q: tl.constexpr,\n):\n start_m = tl.program_id(0) \n off_hz = tl.program_id(1)\n\n BLOCK_M: tl.constexpr = BLOCK\n BLOCK_N: tl.constexpr = BLOCK\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m_real = (start_m + N_PREFIX_Q) * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m_real += tl.where(tl.arange(0, BLOCK_M) == BLOCK_M - 1, -1, 0)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n offs_q = off_hz * sqh + offs_m[:, None] * sqm + offs_d[None, :]\n offs_k = off_hz * skh + offs_n[None, :] * skn + offs_d[:, None] * skd\n offs_v = off_hz * svh + offs_n[:, None] * svn + offs_d[None, :]\n\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n q_vals = tl.load(Q + offs_q, mask=offs_m[:, None] < N_CTX_Q, other=0) \n\n for start_n in range(0, (N_PREFIX_Q + start_m)):\n k_vals = tl.load(K + offs_k, mask=offs_n[None, :] < N_CTX_KV, other=0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=q_vals.dtype)\n qk += tl.dot(q_vals, k_vals, allow_tf32=False)\n qk *= sm_scale\n qk = tl.where(offs_m_real[:,None] >= offs_n[None,:], qk, float(\"-inf\"))\n landmark_qk = tl.max(tl.where(tl.arange(0, BLOCK_N)[None, :] == BLOCK_N - 1, qk, float(\"-inf\")), 1)\n normal_qk = tl.where(tl.arange(0, BLOCK_N)[None, :] == BLOCK_N - 1, float(\"-inf\"), qk)\n normal_m = tl.max(normal_qk, 1)\n normal_p = tl.exp(normal_qk - normal_m[:, None])\n normal_denom = tl.sum(normal_p, 1)\n\n m_curr = tl.maximum(landmark_qk, m_prev)\n m_curr_ = m_curr\n l_prev *= tl.exp(m_prev - m_curr_)\n landmark_p = tl.exp(landmark_qk - m_curr_)\n l_curr = landmark_p + l_prev \n l_rcp = 1. / l_curr\n landmark_p *= l_rcp\n\n acc *= (l_prev * l_rcp)[:, None]\n v_vals = tl.load(V + offs_v, mask=offs_n[:, None] < N_CTX_KV, other=0)\n acc += tl.dot((landmark_p[:, None] * normal_p / normal_denom[:, None]).to(Q.dtype.element_ty), v_vals, allow_tf32=False) \n\n l_prev = l_curr\n m_prev = m_curr\n\n offs_n += BLOCK_N\n offs_k += BLOCK_N * skn\n offs_v += BLOCK_N * svn\n\n k_vals = tl.load(K + offs_k, mask=offs_n[None, :] < N_CTX_KV, other=0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=q_vals.dtype)\n qk += tl.dot(q_vals, k_vals, allow_tf32=False)\n qk *= sm_scale\n qk = tl.where(offs_m_real[:,None] >= offs_n[None,:], qk, float(\"-inf\"))\n\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n m_curr_ = m_curr\n\n l_prev *= tl.exp(m_prev - m_curr_)\n p = tl.exp(qk - m_curr_[:, None])\n l_curr = tl.sum(p, 1) + l_prev \n\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n p = p.to(Q.dtype.element_ty)\n v_vals = tl.load(V + offs_v, mask=offs_n[:, None] < N_CTX_KV, other=0)\n acc += tl.dot(p, v_vals, allow_tf32=False) \n\n l_prev = l_curr\n m_prev = m_curr\n\n offs_L = off_hz * N_CTX_Q + offs_m\n offs_M = off_hz * N_CTX_Q + offs_m\n tl.store(L + offs_L, l_prev, mask=offs_m < N_CTX_Q)\n tl.store(M + offs_M, m_prev, mask=offs_m < N_CTX_Q)\n offs_o = off_hz * soh + offs_m[:, None] * som + offs_d[None, :]\n tl.store(Out + offs_o, acc, mask=offs_m[:, None] < N_CTX_Q)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, soz, soh, som, sod,\n DO, L, slzh, slm,\n NewDO, Delta, N_CTX_Q,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n\n off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_d = tl.arange(0, D_HEAD)\n off_o = off_hz * soh + off_m[:, None] * som + off_d[None, :] * sod\n off_l = off_hz * slzh + off_m * slm\n o = tl.load(Out + off_o).to(tl.float32)\n do = tl.load(DO + off_o).to(tl.float32)\n denom = tl.load(L + off_l).to(tl.float32)\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n tl.store(NewDO + off_o, do)\n tl.store(Delta + off_l, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n sqz, sqh, sqm, sqd,\n skz, skh, skn, skd,\n svz, svh, svn, svd,\n Z, H, N_CTX_Q, N_CTX_KV,\n BLOCK: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n N_PREFIX_Q: tl.constexpr,\n):\n off_hz = tl.program_id(0).to(tl.int64)\n off_z = off_hz // H\n off_h = off_hz % H\n\n BLOCK_M: tl.constexpr = BLOCK\n BLOCK_N: tl.constexpr = BLOCK\n\n Q += off_z * sqz + off_h * sqh\n K += off_z * skz + off_h * skh\n V += off_z * svz + off_h * svh\n DO += off_z * sqz + off_h * sqh\n DQ += off_z * sqz + off_h * sqh\n DK += off_z * skz + off_h * skh\n DV += off_z * svz + off_h * svh\n\n offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)\n \n D_ptrs = D + off_hz * N_CTX_Q \n m_ptrs = M + off_hz * N_CTX_Q \n\n for start_n in range(0, N_CTX_KV, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N).to(tl.int64)\n offs_n = start_n + tl.arange(0, BLOCK_N).to(tl.int64)\n k_ptrs = K + (offs_n[:, None] * skn + offs_d[None, :] * skd)\n v_ptrs = V + (offs_n[:, None] * svn + offs_d[None, :] * svd)\n\n dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)\n\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n\n if start_n < N_PREFIX_Q * BLOCK_M:\n start_q_index = 0\n elif N_CTX_Q <= start_n - N_PREFIX_Q * BLOCK_M:\n start_q_index = start_n - N_PREFIX_Q * BLOCK_M\n else:\n first_start_m = start_n - N_PREFIX_Q * BLOCK_M\n first_start_m = tl.multiple_of(first_start_m, BLOCK_M)\n offs_m = (first_start_m + tl.arange(0, BLOCK_M))\n offs_m_real = offs_m + N_PREFIX_Q * BLOCK_M \n offs_m_real += tl.where(tl.arange(0, BLOCK_M) == BLOCK_M - 1, -1, 0) \n\n q_ptrs = Q + (offs_m[:, None] * sqm + offs_d[None, :] * sqd)\n do_ptrs = DO + (offs_m[:, None] * sqm + offs_d[None, :] * sqd)\n dq_ptrs = DQ + (offs_m[:, None] * sqm + offs_d[None, :] * sqd)\n \n q = tl.load(q_ptrs) \n qk = tl.dot(q, tl.trans(k), allow_tf32=False)\n qk = tl.where(offs_m_real[:,None] >= (offs_n[None,:]), qk, float(\"-inf\"))\n\n m = tl.load(m_ptrs + offs_m) \n m_ = m \n\n last_p = tl.exp(qk * sm_scale - m_[:, None])\n\n do = tl.load(do_ptrs) \n dv += tl.dot(tl.trans(last_p.to(Q.dtype.element_ty)), do, allow_tf32=False)\n\n Di = tl.load(D_ptrs + offs_m) \n last_dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n last_dp += tl.dot(do, tl.trans(v), allow_tf32=False)\n ds = last_p * last_dp * sm_scale\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q, allow_tf32=False)\n\n dq = tl.load(dq_ptrs) \n dq += tl.dot(ds.to(Q.dtype.element_ty), k, allow_tf32=False)\n tl.store(dq_ptrs, dq) \n start_q_index = first_start_m + BLOCK_M\n\n for start_m in range(start_q_index, N_CTX_Q, BLOCK_M):\n start_m = tl.multiple_of(start_m, BLOCK_M).to(tl.int64)\n offs_m = (start_m + tl.arange(0, BLOCK_M)).to(tl.int64)\n\n q_ptrs = Q + (offs_m[:, None] * sqm + offs_d[None, :] * sqd)\n do_ptrs = DO + (offs_m[:, None] * sqm + offs_d[None, :] * sqd)\n dq_ptrs = DQ + (offs_m[:, None] * sqm + offs_d[None, :] * sqd)\n \n q = tl.load(q_ptrs) \n qk = tl.dot(q, tl.trans(k), allow_tf32=False)\n qk *= sm_scale\n\n landmark_qk = tl.max(tl.where(tl.arange(0, BLOCK_N)[None, :] == BLOCK_N - 1, qk, float(\"-inf\")), 1)\n normal_qk = tl.where(tl.arange(0, BLOCK_N)[None, :] == BLOCK_N - 1, float(\"-inf\"), qk)\n\n m = tl.load(m_ptrs + offs_m)\n m_ = m \n\n p = tl.exp(landmark_qk - m_) \n\n do = tl.load(do_ptrs)\n\n normal_m = tl.max(normal_qk, 1)\n normal_p = tl.exp(normal_qk - normal_m[:, None])\n normal_p_normalized = normal_p / tl.sum(normal_p, 1)[:, None]\n normal_kv = tl.dot(normal_p_normalized.to(Q.dtype.element_ty), v, allow_tf32=False)\n\n normal_D = tl.sum(do * normal_kv, 1)\n\n dv += tl.dot(tl.trans((p[:, None] * normal_p_normalized).to(Q.dtype.element_ty)), do, allow_tf32=False)\n\n Di = tl.load(D_ptrs + offs_m)\n dp = tl.zeros([BLOCK_M], dtype=tl.float32) - Di\n dp += normal_D \n landmark_ds = p * dp\n normal_dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - normal_D[:, None]\n normal_dp += tl.dot(do, tl.trans(v), allow_tf32=False)\n normal_ds = p[:, None] * normal_p_normalized * normal_dp \n ds = tl.where(tl.arange(0, BLOCK_N)[None, :] == BLOCK_N - 1, landmark_ds[:, None], normal_ds)\n ds *= sm_scale\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q, allow_tf32=False)\n\n dq = tl.load(dq_ptrs)\n dq += tl.dot(ds.to(Q.dtype.element_ty), k, allow_tf32=False)\n tl.store(dq_ptrs, dq)\n \n dv_ptrs = DV + (offs_n[:, None] * svn + offs_d[None, :] * svd)\n dk_ptrs = DK + (offs_n[:, None] * skn + offs_d[None, :] * skd)\n tl.store(dv_ptrs, dv) \n tl.store(dk_ptrs, dk) \n\n\nclass FusedLandmarkAttention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, n_prefix_q, sm_scale, block_size):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n\n batch, nheads, seqlen_q, d = q.shape\n _, _, seqlen_k, _ = k.shape\n assert k.shape == (batch, nheads, seqlen_k, d)\n assert v.shape == (batch, nheads, seqlen_k, d)\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n assert q.is_cuda and k.is_cuda and v.is_cuda\n \n BLOCK = block_size\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if d <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n L, m,\n q.shape[0], q.shape[1], q.shape[2], k.shape[2],\n BLOCK=BLOCK, BLOCK_DMODEL=d,\n N_PREFIX_Q=n_prefix_q,\n num_warps=num_warps, num_stages=2\n )\n\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = d\n ctx.N_PREFIX_Q = n_prefix_q\n ctx.BLOCK = BLOCK\n return o\n\n @staticmethod\n def backward(ctx, do):\n BLOCK = ctx.BLOCK\n q, k, v, o, l, m = ctx.saved_tensors\n assert q.shape[2] % BLOCK == 0, \"Backward supported only for full blocks\"\n assert k.shape[2] % BLOCK == 0, \"Backward supported only for full blocks\"\n\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0], ctx.grid[1])](\n o, o.stride(0), o.stride(1), o.stride(2), o.stride(3), do, l, l.stride(0), l.stride(1),\n do_scaled, delta, q.shape[2],\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2], k.shape[2],\n BLOCK=BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, \n N_PREFIX_Q=ctx.N_PREFIX_Q,\n num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None, None, None\n\ndef fused_landmark_attention(q, k, v, is_mem, sm_scale=None, block_size=64):\n expected_is_mem = torch.arange(0, is_mem.shape[-1], device=is_mem.device) % block_size == (block_size - 1)\n assert (is_mem == expected_is_mem).all()\n\n n_history_kv = k.shape[-2] - q.shape[-2]\n assert n_history_kv % block_size == 0\n n_history_blocks = n_history_kv // block_size\n\n if sm_scale is None:\n sm_scale = 1.0 / math.sqrt(q.size(-1))\n\n return FusedLandmarkAttention.apply(q, k, v, n_history_blocks, sm_scale, block_size)\n", - "description_1": "Use triton language to implement fused landmark self-attention operation, consisting of three kernels: forward kernel for computing attention scores, backward preprocessing for scaling output, and backward kernel for gradient computations. Forward kernel takes 21 input parameters (Q, K, V matrices; output and intermediate result tensors; scaling factor, etc.). Backward preprocessing takes 10 input parameters. Backward kernel takes 19 input parameters. Calling function 'fused_landmark_attention' takes five parameters.", - "description_2": "Use triton language to implement and invoke a fused self-attention operation with backward gradient computation. Ensure compatibility with CUDA devices and apply block-wise computations for scalability.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\nfrom torch.autograd import Function\n\n@triton.jit\ndef load_rotary_embedded_vector(\n QK, stride_qk_n, stride_qk_t, stride_qk_hid,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n idx_n,\n idx_t_qk,\n idx_t_rope,\n HID,\n BLOCK_HID,\n):\n idx_hid = tl.arange(0, BLOCK_HID).to(tl.int64)\n mask_hid = idx_hid < HID\n \n idx_hid_rot = ((idx_hid + HID // 2) % HID).to(tl.int64)\n mask_hid_rot = mask_hid\n \n vec = tl.load(\n QK +\\\n idx_n.to(tl.int64) * stride_qk_n +\\\n idx_t_qk.to(tl.int64) * stride_qk_t +\\\n idx_hid.to(tl.int64) * stride_qk_hid,\n mask = mask_hid,\n other = 0,\n )\n \n vec_rot = tl.load(\n QK +\\\n idx_n.to(tl.int64) * stride_qk_n +\\\n idx_t_qk.to(tl.int64) * stride_qk_t +\\\n idx_hid_rot.to(tl.int64) * stride_qk_hid,\n mask = mask_hid_rot,\n other = 0,\n )\n vec_rot = tl.where(idx_hid < HID // 2, -vec_rot, vec_rot)\n \n cos = tl.load(\n COS +\\\n idx_t_rope.to(tl.int64) * stride_cos_t +\\\n idx_hid.to(tl.int64) * stride_cos_hid,\n mask=mask_hid,\n other=0,\n )\n sin = tl.load(\n SIN +\\\n idx_t_rope.to(tl.int64) * stride_sin_t +\\\n idx_hid.to(tl.int64) * stride_sin_hid,\n mask=mask_hid,\n other=0,\n )\n \n vec_rope = ((vec.to(tl.float32) * cos) + (vec_rot.to(tl.float32) * sin)).to(vec.dtype)\n \n return vec_rope, vec, vec_rot, cos, sin\n\n@triton.jit\ndef grad_rotary_embedded_vector(\n grad_vec_rope, vec_origin, vec_rot, cos, sin,\n HID, BLOCK_HID,\n):\n grad_vec_origin = grad_vec_rope * cos\n idx_vec_origin_hid = tl.arange(0, BLOCK_HID)\n \n grad_vec_rot = grad_vec_rope * sin\n grad_vec_rot = tl.where(idx_vec_origin_hid < HID // 2, -grad_vec_rot, grad_vec_rot)\n idx_vec_rot_hid = (idx_vec_origin_hid + HID // 2) % HID\n \n return grad_vec_origin, idx_vec_origin_hid, grad_vec_rot, idx_vec_rot_hid\n\n@triton.jit\ndef _attention_scores_compute(\n Q, stride_q_n, stride_q_tdst, stride_q_hid,\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n INDICES, stride_indices_d, stride_indices_z,\n VALUES, stride_values_z,\n N, TDST, TSRC, HID,\n NUM_SINK,\n WINDOW_SIZE,\n BLOCK_HID: tl.constexpr,\n):\n idx_n = tl.program_id(0).to(tl.int64)\n idx_tdst = tl.program_id(1).to(tl.int64)\n idx_k = tl.program_id(2).to(tl.int64)\n \n tdst = idx_tdst + TSRC - TDST\n \n if idx_k < NUM_SINK:\n idx_tsrc = idx_k\n else:\n window_offset = idx_k - NUM_SINK\n t_tsrc = tdst - WINDOW_SIZE + 1 + window_offset\n idx_tsrc = tl.maximum(idx_k, t_tsrc)\n \n key, _, _, _, _ = load_rotary_embedded_vector(\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n idx_n, idx_tsrc, idx_k,\n HID, BLOCK_HID,\n )\n \n query, _, _, _, _ = load_rotary_embedded_vector(\n Q, stride_q_n, stride_q_tdst, stride_q_hid,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n idx_n, idx_tdst, tl.minimum(tdst, WINDOW_SIZE + NUM_SINK - 1),\n HID, BLOCK_HID,\n )\n \n score = tl.sum(query.to(tl.float32) * key.to(tl.float32))\n score = score * (1 / tl.sqrt(HID.to(tl.float32)))\n score = tl.where(idx_tsrc <= tdst, score, float('-inf'))\n \n idx_z = idx_n.to(tl.int64) * TDST * (WINDOW_SIZE + NUM_SINK) + idx_tdst.to(tl.int64) * (WINDOW_SIZE + NUM_SINK) + idx_k.to(tl.int64)\n tl.store(\n VALUES +\\\n idx_z.to(tl.int64) * stride_values_z,\n value = score\n )\n zero = tl.zeros((1,), dtype=tl.int64)\n one = zero + 1\n tl.store(\n INDICES +\\\n zero * stride_indices_d +\\\n idx_z.to(tl.int64) * stride_indices_z,\n value = idx_n\n )\n tl.store(\n INDICES +\\\n one * stride_indices_d +\\\n idx_z.to(tl.int64) * stride_indices_z,\n value = idx_tdst\n )\n tl.store(\n INDICES +\\\n (one * 2) * stride_indices_d +\\\n idx_z.to(tl.int64) * stride_indices_z,\n value = idx_tsrc\n )\n\n@triton.jit\ndef _attention_score_backward_compute(\n GRAD_VALUES, stride_grad_values_z,\n Q, stride_q_n, stride_q_tdst, stride_q_hid,\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n INDICES, stride_indices_d, stride_indices_z,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n GRAD_Q, stride_grad_q_n, stride_grad_q_tdst, stride_grad_q_hid,\n GRAD_K, stride_grad_k_n, stride_grad_k_tsrc, stride_grad_k_hid,\n N, TDST, TSRC, HID, NNZ,\n NUM_SINK,\n WINDOW_SIZE,\n BLOCK_HID: tl.constexpr,\n):\n idx_z = tl.program_id(0)\n \n idx_n = tl.load(\n INDICES +\\\n 0 * stride_indices_d +\\\n idx_z * stride_indices_z\n ).to(tl.int64)\n idx_tdst = tl.load(\n INDICES +\\\n 1 * stride_indices_d +\\\n idx_z * stride_indices_z\n ).to(tl.int64)\n idx_tsrc = tl.load(\n INDICES +\\\n 2 * stride_indices_d +\\\n idx_z * stride_indices_z\n ).to(tl.int64)\n tdst = idx_tdst + TSRC - TDST\n \n idx_k = idx_z % (NUM_SINK + WINDOW_SIZE)\n \n key, key_origin, key_rot, cos_k, sin_k = load_rotary_embedded_vector(\n K, stride_k_n, stride_k_tsrc, stride_k_hid,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n idx_n, idx_tsrc, idx_k,\n HID, BLOCK_HID\n )\n \n query, query_origin, query_rot, cos_q, sin_q = load_rotary_embedded_vector(\n Q, stride_q_n, stride_q_tdst, stride_q_hid,\n COS, stride_cos_t, stride_cos_hid,\n SIN, stride_sin_t, stride_sin_hid,\n idx_n, idx_tdst, tl.minimum(tdst, WINDOW_SIZE + NUM_SINK - 1),\n HID, BLOCK_HID,\n )\n \n grad_score = tl.load(\n GRAD_VALUES +\\\n idx_z * stride_grad_values_z,\n )\n \n grad_score = tl.where(idx_tsrc <= tdst, grad_score, 0)\n grad_score = grad_score * (1 / tl.sqrt(HID.to(tl.float32)))\n \n grad_key = grad_score * query\n grad_query = grad_score * key\n \n grad_key_origin, idx_key_origin_hid, grad_key_rot, idx_key_rot_hid = grad_rotary_embedded_vector(\n grad_key, key_origin, key_rot, cos_k, sin_k,\n HID, BLOCK_HID\n )\n grad_query_origin, idx_query_origin_hid, grad_query_rot, idx_query_rot_hid = grad_rotary_embedded_vector(\n grad_query, query_origin, query_rot, cos_q, sin_q,\n HID, BLOCK_HID\n )\n \n mask_hid = tl.arange(0, BLOCK_HID) < HID\n \n tl.atomic_add(\n GRAD_K +\\\n idx_n * stride_grad_k_n +\\\n idx_tsrc * stride_grad_k_tsrc +\\\n idx_key_origin_hid * stride_grad_k_hid,\n mask = mask_hid,\n val = grad_key_origin\n )\n tl.atomic_add(\n GRAD_K +\\\n idx_n * stride_grad_k_n +\\\n idx_tsrc * stride_grad_k_tsrc +\\\n idx_key_rot_hid * stride_grad_k_hid,\n mask = mask_hid,\n val = grad_key_rot\n )\n \n tl.atomic_add(\n GRAD_Q +\\\n idx_n * stride_grad_q_n +\\\n idx_tdst * stride_grad_q_tdst +\\\n idx_query_origin_hid * stride_grad_q_hid,\n mask = mask_hid,\n val = grad_query_origin\n )\n tl.atomic_add(\n GRAD_Q +\\\n idx_n * stride_grad_q_n +\\\n idx_tdst * stride_grad_q_tdst +\\\n idx_query_rot_hid * stride_grad_q_hid,\n mask = mask_hid,\n val = grad_query_rot\n )\n\nclass AttentionScoreFunc(Function):\n @staticmethod\n def forward(\n ctx,\n q: Tensor, \n k: Tensor,\n cos: Tensor,\n sin: Tensor,\n num_sink: int,\n window_size: int,\n ):\n q = q.contiguous()\n k = k.contiguous()\n \n assert q.ndim == 3\n assert k.ndim == 3\n assert cos.ndim == 2, cos.shape\n assert sin.ndim == 2, sin.shape\n N, TDST, HID = q.shape\n _, TSRC, _ = k.shape\n assert k.shape == (N, TSRC, HID)\n assert cos.shape[-1] == HID\n assert sin.shape[-1] == HID\n \n device = q.device\n if q.requires_grad or k.requires_grad:\n dtype = torch.float32\n else:\n dtype = q.dtype\n \n nnz = N * TDST * (num_sink + window_size)\n indices = torch.zeros((3, nnz), dtype=torch.int64, device=device)\n values = torch.zeros((nnz,), dtype=dtype, device=device)\n \n BLOCK_HID = triton.next_power_of_2(HID)\n \n grid = (N, TDST, num_sink + window_size)\n \n _device = torch.cuda.current_device()\n torch.cuda.set_device(q.device)\n try:\n _attention_scores_compute[grid](\n q, *q.stride(),\n k, *k.stride(),\n cos, *cos.stride(),\n sin, *sin.stride(),\n \n indices, *indices.stride(),\n values, *values.stride(),\n \n N, TDST, TSRC, HID,\n num_sink,\n window_size,\n \n BLOCK_HID,\n \n num_warps=2,\n num_stages=1,\n )\n except RuntimeError as ex:\n raise Exception() from ex\n torch.cuda.set_device(_device)\n \n ctx.save_for_backward(\n q, k, cos, sin, indices\n )\n ctx.num_sink = num_sink\n ctx.window_size = window_size\n \n return indices, values\n\n @staticmethod\n def backward(\n ctx, \n grad_indices: Tensor, \n grad_values: Tensor\n ):\n q, k, cos, sin, indices = ctx.saved_tensors\n num_sink = ctx.num_sink\n window_size = ctx.window_size\n \n N, TDST, HID = q.shape\n _, TSRC, _ = k.shape\n _, NNZ = indices.shape\n \n assert q.ndim == 3\n assert k.ndim == 3\n assert cos.ndim == 2\n assert sin.ndim == 2\n assert indices.ndim == 2\n assert grad_values.ndim == 1\n \n grad_q = torch.zeros_like(q, dtype=torch.float32)\n grad_k = torch.zeros_like(k, dtype=torch.float32)\n \n BLOCK_HID = triton.next_power_of_2(HID)\n \n grid = (NNZ,)\n \n _device = torch.cuda.current_device()\n torch.cuda.set_device(q.device)\n _attention_score_backward_compute[grid](\n grad_values, *grad_values.stride(),\n q, *q.stride(),\n k, *k.stride(),\n indices, *indices.stride(),\n cos, *cos.stride(),\n sin, *sin.stride(),\n grad_q, *grad_q.stride(),\n grad_k, *grad_k.stride(),\n \n N, TDST, TSRC, HID, NNZ, \n num_sink,\n window_size,\n \n BLOCK_HID,\n \n num_warps=1,\n num_stages=1,\n )\n torch.cuda.set_device(_device)\n \n return (\n grad_q,\n grad_k,\n None,\n None,\n None,\n None,\n )\n\ndef attention_scores(\n q: Tensor, \n k: Tensor,\n cos: Tensor,\n sin: Tensor,\n num_sink: int = 4,\n window_size: int = 512,\n):\n N, TDST, HID = q.shape\n _, TSRC, _ = k.shape\n \n window_size = min(window_size, TSRC - num_sink)\n \n indices, values = AttentionScoreFunc.apply(\n q, k, cos, sin, num_sink, window_size,\n )\n \n values = values\\\n .view(-1, num_sink + window_size)\\\n .softmax(-1)\\\n .view(-1)\\\n .contiguous()\n \n probs = torch.sparse_coo_tensor(\n indices=indices,\n values=values,\n size=(N, TDST, TSRC),\n requires_grad=q.requires_grad,\n dtype=values.dtype,\n device=values.device,\n check_invariants=False,\n )\n \n return probs\n\n@triton.jit\ndef _sparse_attention_compute(\n INDICES, stride_indices_d, stride_indices_z,\n VALUES, stride_values_z,\n V, stride_v_n, stride_v_tsrc, stride_v_hid,\n CONTEXT, stride_context_n, stride_context_tdst, stride_context_hid,\n N, TDST, TSRC, HID, BK,\n NUM_SINK,\n WINDOW_SIZE,\n BLOCK_HID: tl.constexpr,\n BLOCK_K: tl.constexpr,\n):\n zero = tl.zeros((1, ), dtype=tl.int64)\n one = zero + 1\n two = zero + 2\n \n idx_n = tl.program_id(0).to(tl.int64)\n idx_tdst = tl.program_id(1).to(tl.int64)\n \n idx_hid = tl.arange(0, BLOCK_HID).to(tl.int64)\n mask_hid = idx_hid < HID\n \n acc = tl.zeros((BLOCK_HID, ), dtype=tl.float32)\n \n for idx_bk in range(BK):\n CACHE_SIZE = NUM_SINK + WINDOW_SIZE\n idx_k = idx_bk.to(tl.int64) * BLOCK_K + tl.arange(0, BLOCK_K).to(tl.int64)\n mask_k = idx_k < CACHE_SIZE\n \n idx_z = idx_n * TDST * CACHE_SIZE + idx_tdst * CACHE_SIZE + idx_k\n mask_z = mask_k\n \n idx_tsrc = tl.load(\n INDICES +\\\n two * stride_indices_d +\\\n idx_z * stride_indices_z,\n mask = mask_z,\n other = 0\n )\n mask_tsrc = mask_z\n \n score = tl.load(\n VALUES +\\\n idx_z * stride_values_z,\n mask = mask_z,\n other = 0,\n )\n \n value = tl.load(\n V +\\\n idx_n * stride_v_n +\\\n idx_tsrc[:, None] * stride_v_tsrc +\\\n idx_hid[None, :] * stride_v_hid,\n mask = mask_tsrc[:, None] & mask_hid[None, :],\n other = 0,\n )\n \n context = tl.sum(score[:, None] * value, axis=0)\n acc += context.to(tl.float32)\n \n tl.store(\n CONTEXT +\\\n idx_n * stride_context_n +\\\n idx_tdst * stride_context_tdst +\\\n idx_hid * stride_context_hid,\n mask = mask_hid,\n value = acc\n )\n\ndef sparse_attention(\n probs: Tensor, v: Tensor, num_sink: int, window_size: int,\n):\n N, TDST, TSRC = probs.shape\n _, _, HID = v.shape\n \n window_size = min(window_size, TSRC - num_sink)\n \n values = probs._values()\n indices = probs._indices()\n \n context = torch.zeros((N, TDST, HID), dtype=v.dtype, device=v.device)\n \n BLOCK_HID = triton.next_power_of_2(HID)\n BLOCK_K = 128\n \n grid = (N, TDST)\n \n assert indices.ndim == 2\n assert values.ndim == 1\n assert v.ndim == 3\n assert context.ndim == 3\n _device = torch.cuda.current_device()\n torch.cuda.set_device(v.device)\n _sparse_attention_compute[grid](\n indices, *indices.stride(),\n values, *values.stride(),\n v, *v.stride(),\n \n context, *context.stride(),\n \n N, TDST, TSRC, HID, triton.cdiv(num_sink + window_size, BLOCK_K),\n num_sink,\n window_size,\n \n BLOCK_HID,\n BLOCK_K,\n )\n torch.cuda.set_device(_device)\n \n return context\n\ndef sink_attention(\n q: Tensor,\n k: Tensor,\n v: Tensor,\n cos: Tensor,\n sin: Tensor,\n num_sink: int = 4,\n window_size: int = 512,\n BENCHMARK: bool = False,\n): \n if BENCHMARK:\n event_scores_start = torch.cuda.Event(enable_timing=True)\n event_scores_end = torch.cuda.Event(enable_timing=True)\n event_bmm_start = torch.cuda.Event(enable_timing=True)\n event_bmm_end = torch.cuda.Event(enable_timing=True)\n event_scores_start.record()\n \n _dtype = v.dtype\n \n probs = attention_scores(\n q, k, cos, sin,\n num_sink=num_sink,\n window_size=window_size,\n )\n \n if BENCHMARK:\n event_scores_end.record()\n event_bmm_start.record()\n \n try:\n if q.requires_grad or k.requires_grad or v.requires_grad:\n if v.dtype in [torch.bfloat16, torch.float16]:\n v = v.to(torch.float32)\n context = torch.bmm(probs, v)\n else:\n context = sparse_attention(probs, v, num_sink, window_size)\n except torch.cuda.OutOfMemoryError as ex:\n raise Exception() from ex\n \n if context.dtype != _dtype:\n context = context.to(_dtype)\n \n if BENCHMARK:\n event_bmm_end.record()\n \n torch.cuda.synchronize()\n elapsed_scores = event_scores_start.elapsed_time(event_scores_end)\n elapsed_bmm = event_bmm_start.elapsed_time(event_bmm_end)\n \n print(elapsed_scores, elapsed_bmm)\n \n return context\n", - "description_1": "Use triton language to implement a set of kernels for computing and backpropagating attention scores, with support for rotary embeddings. It involves kernels for forward and backward attention score computations, gradient calculations, and sparse matrix multiplication. The primary functions include load_rotary_embedded_vector for loading rotary embedded vectors, grad_rotary_embedded_vector for computing the gradients of rotary embedded vectors, and AttentionScoreFunc, which encapsulates both the forward and backward passes using the defined kernels. It also includes helper functions to compute sparse attention context.", - "description_2": "Use triton language to implement attention score computation and gradient backpropagation for rotary-embedded vectors, supporting both forward and backward passes, including sparse attention context computations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport triton.language.core as core\nfrom triton.language.standard import _log2, sum, zeros_like\n\n@triton.jit\ndef _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):\n n_outer: core.constexpr = x.numel >> n_dims\n shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]\n y = core.reshape(x, shape)\n mask = core.arange(0, 2)[None, :, None]\n left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)\n right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)\n left = core.reshape(left, x.shape)\n right = core.reshape(right, x.shape)\n y_idx = core.reshape(ids, shape)\n left_idx = core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)\n right_idx = core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape)\n left_idx = core.reshape(left_idx, x.shape)\n right_idx = core.reshape(right_idx, x.shape)\n idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)\n ileft = left.to(idtype, bitcast=True)\n iright = right.to(idtype, bitcast=True)\n ix = x.to(idtype, bitcast=True)\n cond = (left > right) ^ flip\n ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))\n new_ids = ids ^ core.where(cond, left_idx ^ right_idx, zeros_like(ids))\n return ret.to(x.dtype, bitcast=True), new_ids\n\n@triton.jit\ndef _bitonic_merge(x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):\n n_outer: core.constexpr = x.numel >> n_dims\n core.static_assert(stage <= n_dims)\n if order == 2:\n shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]\n flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)\n else:\n flip = order\n for i in core.static_range(stage):\n x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)\n return x, ids\n\n@triton.jit\ndef argsort(x, ids, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):\n _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim\n core.static_assert(_dim == len(x.shape) - 1, \"only minor dimension is currently supported\")\n n_dims: core.constexpr = _log2(x.shape[_dim])\n for i in core.static_range(1, n_dims + 1):\n x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)\n return x, ids\n", - "description_1": "Use triton language to implement bitonic sort kernels. _compare_and_swap takes 5 parameters: x (the array to sort), ids (indices array), flip (boolean array for flipping), i (current stage), n_dims (total dimensions). _bitonic_merge takes 5 parameters: x, ids, stage (current sorting stage), order (sorting order), and n_dims. argsort takes 4 parameters: x, ids, dim (dimension for sorting), and descending (sort order).", - "description_2": "Use triton language to implement bitonic sort, with kernels to handle compare-and-swap, merging stages, and final argsort operation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef blocksparse_flash_attn_varlen_fwd(\n q,\n k,\n v, # (#tokens, n_heads, head_size)\n cu_seqlens_k,\n cu_seqlens_q,\n sm_scale,\n sparse_layout,\n *,\n block_size=64,\n q_block_size=None,\n max_seqlen=None):\n # split q to blocks\n\n assert isinstance(sparse_layout, (list, tuple))\n\n _, n_heads, head_size = q.shape\n batch_size = cu_seqlens_k.size(0) - 1\n q_block_size = q_block_size or block_size\n\n assert q.dim() == k.dim() == v.dim() == 3\n assert q.size(1) % k.size(1) == 0\n assert q.size(2) == k.size(2)\n # TODO: allow k, v to have different head_size\n assert k.shape == v.shape\n assert cu_seqlens_k.dim() == 1\n\n q_k_ratio = q.size(1) // k.size(1)\n\n if cu_seqlens_q is None:\n if q.size(0) == batch_size: # decoding only\n cu_seqlens_q = torch.arange(\n 0,\n batch_size + 1,\n dtype=cu_seqlens_k.dtype,\n device=cu_seqlens_k.device,\n )\n elif q.size(0) == k.size(0):\n cu_seqlens_q = cu_seqlens_k\n else:\n raise ValueError(\"cu_seqlens_q must be specified\\\n if it mix of prefilling and decoding.\")\n else:\n assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)\n\n # switch to use cpu to avoid too many kernel launches when iterated over\n q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()\n k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()\n\n assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (\n \"length of q should either be 1 (decoding) or same as k (prefilling).\")\n\n if max_seqlen:\n assert k_lens.max() <= max_seqlen\n\n n_blocks = (q_lens + q_block_size - 1) // q_block_size\n\n q_batch_ids = torch.tensor(\n [i for i, n in enumerate(n_blocks) for _ in range(n)],\n dtype=cu_seqlens_q.dtype,\n device=cu_seqlens_q.device,\n )\n q_start_sids = torch.tensor(\n [i * q_block_size for n in n_blocks for i in range(n)],\n dtype=cu_seqlens_q.dtype,\n device=cu_seqlens_q.device,\n )\n\n out = q.new_empty(q.shape)\n cu_seqlens_q = cu_seqlens_q.contiguous()\n cu_seqlens_k = cu_seqlens_k.contiguous()\n\n layout_crow_indices, layout_col_indices = sparse_layout\n block_d = triton.next_power_of_2(head_size)\n\n decoding_only = (q_lens == 1).all().item()\n grid = (len(q_start_sids), n_heads, 1)\n\n _fwd_kernel_batch_inference[grid](\n q,\n k,\n v,\n out,\n sm_scale,\n cu_seqlens_q[:-1],\n cu_seqlens_q[1:],\n cu_seqlens_k[:-1],\n cu_seqlens_k[1:],\n q_batch_ids,\n q_start_sids,\n 0,\n *q.stride(),\n 0,\n *k.stride(),\n 0,\n *v.stride(),\n 0,\n *out.stride(),\n layout_crow_indices,\n layout_col_indices,\n *layout_crow_indices.stride(),\n *layout_col_indices.stride(),\n q_k_ratio,\n HAS_BATCH_DIM=False,\n D_HEAD=head_size,\n BLOCK_M=q_block_size,\n BLOCK_N=block_size,\n BLOCK_D=block_d,\n BLOCK_M_LOADING=(16 if decoding_only else\n q_block_size), # smaller for decoding\n EVEN_D=block_d == head_size,\n num_warps=1 if decoding_only else 4,\n num_stages=3)\n\n return out\n\n\n@triton.jit\ndef _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n LAST_K_BLOCK: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n BLOCK_N: tl.constexpr,\n D_HEAD: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +\n k_block_col_idx * layout_col_stride_m).to(tl.int32)\n start_n = k_block_id * BLOCK_N\n if LAST_K_BLOCK:\n if EVEN_D:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=offs_n[None, :] + start_n < k_seqlen,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=(offs_n[None, :] + start_n < k_seqlen) &\n (offs_d[:, None] < D_HEAD),\n )\n else:\n if EVEN_D:\n k = tl.load(k_ptrs + start_n * stride_kt)\n else:\n k = tl.load(k_ptrs + start_n * stride_kt,\n mask=offs_d[:, None] < D_HEAD)\n\n qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N\n if LAST_K_BLOCK | M_LT_N:\n qk += tl.where(\n offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),\n 0,\n float(\"-inf\"),\n )\n\n # flash-attn2\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n p = tl.math.exp2(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n # update m_i\n m_i = m_ij\n l_i = l_i * alpha + l_ij\n\n p = p.to(Q.dtype.element_ty)\n # update acc\n if LAST_K_BLOCK:\n if EVEN_D:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=offs_n[:, None] + start_n < k_seqlen,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=(offs_n[:, None] + start_n < k_seqlen) &\n (offs_d[None, :] < D_HEAD),\n )\n else:\n if EVEN_D:\n v = tl.load(v_ptrs + start_n * stride_vt)\n else:\n v = tl.load(v_ptrs + start_n * stride_vt,\n mask=offs_d[None, :] < D_HEAD)\n\n acc += tl.dot(p, v)\n\n return acc, l_i, m_i\n\n\n@triton.heuristics({\n \"M_LT_N\":\n lambda kwargs: kwargs[\"BLOCK_M\"] < kwargs[\"BLOCK_N\"],\n})\n@triton.jit\ndef _fwd_kernel_batch_inference(\n Q,\n K,\n V,\n Out,\n sm_scale,\n q_batch_starts,\n q_batch_ends,\n k_batch_starts,\n k_batch_ends,\n q_batch_ids,\n q_start_sids,\n stride_qb,\n stride_qt,\n stride_qh,\n stride_qd,\n stride_kb,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_vb,\n stride_vt,\n stride_vh,\n stride_vd,\n stride_ob,\n stride_ot,\n stride_oh,\n stride_od,\n layout_crow_ptr,\n layout_col_ptr,\n layout_crow_stride_h,\n layout_crow_stride_m,\n layout_col_stride_h,\n layout_col_stride_m,\n q_k_ratio,\n HAS_BATCH_DIM: tl.constexpr,\n D_HEAD: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n \"\"\"\n NOTATION:\n pid: position id\n sid: storage id\n sbid: storage block id\n pbid: position block id\n offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)\n\n TODO:\n Optimize grouped-attn\n \"\"\"\n off_zm = tl.program_id(0)\n off_h = tl.program_id(1)\n\n off_h_for_kv = off_h // q_k_ratio\n\n if HAS_BATCH_DIM:\n off_z = tl.program_id(2)\n Q += off_z * stride_qb\n K += off_z * stride_kb\n V += off_z * stride_vb\n Out += off_z * stride_ob\n start_m = off_zm\n q_start_sid = start_m * BLOCK_M # always 0 for decoding\n else:\n off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]\n q_start_sid = tl.load(q_start_sids + off_zm)\n start_m = q_start_sid // BLOCK_M # q_sbid\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_D)\n\n q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)\n q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start\n k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)\n k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start\n past_len = k_seqlen - q_seqlen\n\n Q += q_cu_start * stride_qt + off_h * stride_qh\n K += k_cu_start * stride_kt + off_h_for_kv * stride_kh\n V += k_cu_start * stride_vt + off_h_for_kv * stride_vh\n Out += q_cu_start * stride_ot + off_h * stride_oh\n\n q_pbid = (past_len + q_start_sid) // BLOCK_M\n\n if EVEN_D:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n other=0,\n )\n\n sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +\n q_pbid * layout_crow_stride_m)\n\n # load at once, with any Triton version that supports `tl.split`\n k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)\n k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)\n\n m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)\n\n k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd\n v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd\n\n sm_scale *= (\n 1.44269504 # 1/log2 as we use base2 for exponential and logarithm\n )\n\n for k_block_col_idx in range(k_block_start, k_block_end - 1):\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n False,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_end - 1,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n True,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n # flash-attn 2\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n\n # write output\n if EVEN_D:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n )\n", - "description_1": "Use triton language to implement a forward pass of a blocksparse flash attention mechanism. The main components include `blocksparse_flash_attn_varlen_fwd` for setting up parameters and launching the kernel, `_fwd_kernel_inner` for performing inner product and attention computation, and `_fwd_kernel_batch_inference` for iterating over batches and handling the attention operation. The operations involve tensor manipulations, scaling, and reduction over blocks of data.", - "description_2": "Use triton language to implement a blocksparse flash attention mechanism with functions for parameter setup and kernel invocation, and inner computations for attention.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SLIDING_WINDOW: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n alibi_slopes=None,\n sliding_window=None):\n # Function implementation\n pass\n", - "description_1": "Use triton language to implement forward kernels for context attention with optional alibi bias and sliding window mechanism. The kernels process input tensors Q, K, V, and their cached versions, along with batch location and sequence length information, to compute the output tensor. The kernels are parameterized by block sizes and strides for efficient memory access.", - "description_2": "Use triton language to implement context attention forward kernels with optional alibi bias and sliding window, processing input and cached tensors to compute output efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef cdiv_fn(x, y):\n return (x + y - 1) // y\n\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):\n ms = tl.arange(0, m)\n ns = tl.arange(0, n)\n return philox_offset + ms[:, None] * stride + ns[None, :]\n\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,\n stride).to(tl.uint32)\n # TODO: use tl.randint for better performance\n return tl.rand(philox_seed, rng_offsets)\n\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,\n stride)\n rng_keep = rng_output > dropout_p\n return rng_keep\n\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n if first and second:\n tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)\n elif first:\n tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)\n elif second:\n tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)\n else:\n tensor = tl.load(block_ptr)\n return tensor\n\n\n@triton.jit\ndef _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n actual_seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n OFFS_M: tl.constexpr,\n OFFS_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n MASK_STEPS: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n):\n for start_n in range(block_min, block_max, BLOCK_N):\n k = load_fn(\n K_block_ptr,\n PADDED_HEAD,\n MASK_STEPS and (n_extra_tokens != 0),\n \"zero\",\n )\n if PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MASK_STEPS:\n if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n boundary_m = tl.full([BLOCK_M],\n actual_seqlen_k,\n dtype=tl.int32)\n size_n = start_n + OFFS_N[None, :]\n mask = size_n < boundary_m[:, None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if IS_CAUSAL:\n causal_boundary = start_n + offs_n_causal\n causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n qk = tl.where(causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n if bias_ptr is not None:\n bias = load_fn(bias_ptr, False, MASK_STEPS\n and (n_extra_tokens != 0), \"zero\")\n qk += bias * 1.44269504089\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = (batch_philox_offset +\n start_m * BLOCK_M * actual_seqlen_k + start_n -\n BLOCK_N)\n keep = dropout_mask(\n philox_seed,\n philox_offset,\n dropout_p,\n BLOCK_M,\n BLOCK_N,\n actual_seqlen_k,\n )\n if RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n tl.where(keep, p,\n -p).to(encoded_softmax_block_ptr.type.element_ty),\n )\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n p.to(encoded_softmax_block_ptr.type.element_ty),\n )\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, BLOCK_N))\n return acc, l_i, m_i\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": True,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 64,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 32,\n \"BLOCK_N\": 32,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 16,\n \"BLOCK_N\": 16,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n ],\n key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],\n)\n@triton.jit\ndef attn_fwd(\n Q,\n K,\n V,\n bias,\n sm_scale,\n L,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n stride_bz,\n stride_bh,\n stride_bm,\n stride_bn,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n HQ: tl.constexpr,\n HK: tl.constexpr,\n ACTUAL_BLOCK_DMODEL: tl.constexpr,\n MAX_SEQLENS_Q: tl.constexpr,\n MAX_SEQLENS_K: tl.constexpr,\n VARLEN: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_h_q = tl.program_id(1)\n off_z = tl.program_id(2)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n if VARLEN:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M > seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n else:\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = MAX_SEQLENS_Q\n seqlen_k = MAX_SEQLENS_K\n\n n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n if IS_CAUSAL:\n n_blocks_seqlen = cdiv_fn(\n (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)\n n_blocks = min(n_blocks, n_blocks_seqlen)\n if n_blocks <= 0:\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n return\n\n GROUP_SIZE: tl.constexpr = HQ // HK\n off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q\n\n n_extra_tokens = 0\n if seqlen_k < BLOCK_N:\n n_extra_tokens = BLOCK_N - seqlen_k\n elif seqlen_k % BLOCK_N:\n n_extra_tokens = seqlen_k % BLOCK_N\n padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL\n\n q_offset = (off_z * stride_qz + off_h_q * stride_qh +\n cu_seqlens_q_start * stride_qm)\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_offset = (off_z * stride_kz + off_h_k * stride_kh +\n cu_seqlens_k_start * stride_kn)\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n v_offset = (off_z * stride_vz + off_h_k * stride_vh +\n cu_seqlens_k_start * stride_vk)\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n if BIAS_TYPE != 0:\n bias_ptr = tl.make_block_ptr(\n base=bias + off_h_q * stride_bh,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_bm, stride_bn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n bias_ptr = None\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base \\\n + (off_z * HQ + off_h_q) \\\n * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n encoded_softmax_block_ptr = 0\n m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504089\n q = load_fn(Q_block_ptr, True, padded_head, \"zero\")\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n padded_block_k = n_extra_tokens != 0\n is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n if IS_CAUSAL:\n masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n else:\n masked_blocks = padded_block_k\n masked_blocks = min(masked_blocks, n_blocks)\n n_full_blocks = n_blocks - masked_blocks\n block_min = 0\n block_max = n_blocks * BLOCK_N\n if n_full_blocks > 0:\n block_max = (n_blocks - masked_blocks) * BLOCK_N\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n 0,\n 0,\n 0,\n bias_ptr,\n False,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n False,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n block_min = block_max\n block_max = n_blocks * BLOCK_N\n\n tl.debug_barrier()\n if masked_blocks > 0:\n offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0\n K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, n_full_blocks))\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n True,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n end_m_idx = (start_m + 1) * BLOCK_M\n start_m_idx = start_m * BLOCK_M\n causal_start_idx = seqlen_q - seqlen_k\n acc = acc.to(Out.type.element_ty)\n if IS_CAUSAL:\n if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n out_mask_boundary = tl.full((BLOCK_DMODEL, ),\n causal_start_idx,\n dtype=tl.int32)\n mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n out_ptrs_mask = (mask_m_offsets[:, None] >=\n out_mask_boundary[None, :])\n z = 0.0\n acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n tl.store(O_block_ptr, acc, boundary_check=(0, 1))\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n q,\n k,\n v,\n o,\n cu_seqlens_q,\n cu_seqlens_k,\n max_seqlens_q,\n max_seqlens_k,\n causal=False,\n sm_scale=1.0,\n bias=None,\n ):\n if o is None:\n o = torch.empty_like(q, dtype=v.dtype)\n\n check_args(\n q,\n k,\n v,\n o,\n varlen=True,\n cu_seqlens_q=cu_seqlens_q,\n cu_seqlens_k=cu_seqlens_k,\n )\n if True:\n total_q, nheads_q, head_size = q.shape\n total_k, nheads_k, _ = k.shape\n batch = len(cu_seqlens_q) - 1\n q_strides = (0, q.stride(1), q.stride(0), q.stride(2))\n k_strides = (0, k.stride(1), k.stride(0), k.stride(2))\n v_strides = (0, v.stride(1), v.stride(0), v.stride(2))\n o_strides = (0, o.stride(1), o.stride(0), o.stride(2))\n else:\n batch, seqlen_q, nheads_q, head_size = q.shape\n _, seqlen_k, nheads_k, _ = k.shape\n q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))\n k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))\n v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))\n o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))\n\n unpadded_head_dims = {32, 64, 128, 256}\n if head_size not in unpadded_head_dims:\n padded_d_model = None\n for i in unpadded_head_dims:\n if i > head_size:\n padded_d_model = i\n break\n assert padded_d_model is not None\n else:\n padded_d_model = head_size\n\n grid = lambda META: (\n triton.cdiv(max_seqlens_q, META[\"BLOCK_M\"]),\n nheads_q,\n batch,\n )\n\n encoded_softmax = None\n\n philox_seed = 0x1BF52\n philox_offset = 0x1D4B42\n\n if bias is not None:\n bias_strides = (\n bias.stride(0),\n bias.stride(1),\n bias.stride(2),\n bias.stride(3),\n )\n else:\n bias_strides = (0, 0, 0, 0)\n\n attn_fwd[grid](\n q,\n k,\n v,\n bias,\n sm_scale,\n None,\n o,\n *q_strides,\n *k_strides,\n *v_strides,\n *o_strides,\n *bias_strides,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p=0.0,\n philox_seed=philox_seed,\n philox_offset_base=philox_offset,\n encoded_softmax=encoded_softmax,\n HQ=nheads_q,\n HK=nheads_k,\n ACTUAL_BLOCK_DMODEL=head_size,\n MAX_SEQLENS_Q=max_seqlens_q,\n MAX_SEQLENS_K=max_seqlens_k,\n IS_CAUSAL=causal,\n VARLEN=True,\n BLOCK_DMODEL=padded_d_model,\n BIAS_TYPE=0 if bias is None else 1,\n ENABLE_DROPOUT=False,\n RETURN_ENCODED_SOFTMAX=False,\n )\n\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = head_size\n ctx.causal = causal\n ctx.dropout_p = 0.0\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n ctx.encoded_softmax = encoded_softmax\n ctx.return_encoded_softmax = False\n return o, encoded_softmax\n\n\ntriton_attention = _attention.apply\n", - "description_1": "Use triton language to implement the Flash Attention v2 algorithm, defining kernels for calculating attention with support for causal masking, dropout, and configurable block sizes.", - "description_2": "Use triton language to create an optimized attention function with configurable parameters and block sizes for efficient computation.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test that we can convert uniform to exponential without div by 0.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel that converts uniform random numbers to exponential random numbers. The kernel '_uniform_to_exponential_kernel' takes three parameters: 'input' (a pointer to the input tensor), 'output' (a pointer to the output tensor), and 'n' (a compile-time constant representing the number of elements to process). The kernel uses Triton's parallel programming model to load elements from the input tensor, apply the '_uniform_to_exponential' transformation, and store the results in the output tensor. The function 'test_uniform_to_exponential' is a test function that verifies the kernel's functionality by checking that the output values are finite and greater than zero.", - "description_2": "Use triton language to create a kernel for transforming uniform random numbers to exponential random numbers and verify its correctness with a test function.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr,\n sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk,\n stride_bn, stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,\n compute_type: tl.constexpr, use_fp8: tl.constexpr,\n):\n \"\"\"\n Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.\n The kernel performs multiplication of a token by its corresponding expert matrix.\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n\n if use_fp8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n if use_fp8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n if use_fp8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any], compute_type: tl.dtype,\n use_fp8: bool) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if not use_fp8:\n assert A_scale is None\n assert B_scale is None\n else:\n A, A_scale = ops.scaled_fp8_quant(A, A_scale)\n assert B_scale is not None\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8=use_fp8,\n **config,\n )\n", - "description_1": "Use triton language to define a fused MoE kernel that implements fused computation for a Mixture of Experts using token and expert matrices. The kernel takes pointers to input matrices, scale pointers, token IDs, expert IDs, and matrix dimensions as input. It outputs computed blocks of matrix C by multiplying tokens with their respective expert matrices. The kernel is called with a function that takes care of grid settings and input validation.", - "description_2": "Use triton language to implement a Mixture of Experts kernel for multiplying input tokens with expert matrices and invoke the kernel with specific grid configurations and parameter checks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n \"\"\"Similar to torch.rand, but allows for seeds to be set per row.\"\"\"\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output tensor.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a random number generator that outputs a tensor with random float32 values in the range [0, 1). It involves a Triton kernel function '_seeded_uniform_triton' which accepts nine parameters. The first two are torch tensors for the output and seed, respectively. The next five parameters are integers representing strides and dimensions of the tensors. The last two parameters 'n_slices' and 'block_size' are Triton constant expressions to define block-level computations. The Triton function generates four random numbers at once using 'tl.rand4x', and stores them in slices of the tensor based on conditions defined by 'n_slices'. The main function 'seeded_uniform' sets up these parameters and calls the Triton kernel, determining the configuration based on input tensor dimensions and attributes.", - "description_2": "Use triton language to create a seeded random number generator for tensors, using customized per-row seeds and efficient block-level computations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS = 1e-6\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n col_offsets = tl.arange(0, block_size)\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n if modify_greedy_probs:\n if not uses_random_sampling:\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to implement two kernels: _uniform_to_exponential and _sample_triton. The _uniform_to_exponential kernel takes one parameter, uniform_noise, and converts uniform samples to exponential samples using the inversion method. The _sample_triton kernel takes 18 parameters: sample_indices_ptr, output_ptr, output_logprobs_ptr, output_modified_probs_ptr, probs_ptr, logprobs_ptr, seeds_ptr, uniform_noise_ptr, output_row_stride, probs_row_stride, uniform_noise_row_stride, uniform_noise_best_stride, n_samples, n_cols, n_best, block_size, modify_greedy_probs, save_logprobs, and save_modified_probs. It samples tokens from a probability distribution, optionally modifies the distribution for greedy sampling, and saves log probabilities and modified probabilities if specified.", - "description_2": "Use triton language to create a kernel that converts uniform noise to exponential noise. Use triton language to create a kernel that samples tokens from a probability distribution with optional modifications and logging.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel_flash_attn_v2(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n # -- update output accumulator --\n # scale p\n # scale acc\n acc_scale = alpha\n # acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n # -- update output accumulator --\n # scale p\n # scale acc\n acc_scale = alpha\n # acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n # acc /= l_i[:, None]\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n", - "description_1": "Use triton language to implement a forward kernel for flash attention. This kernel performs batched matrix multiplications and reductions to compute the attention scores and outputs. It requires 46 parameters: Q (query), K (key), V (value), K_cache, V_cache, B_Loc (location), sm_scale (scale for softmax), B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, Out (output), and various strides for accessing memory. It also includes num_queries_per_kv (integer) and three constexpr parameters BLOCK_M, BLOCK_DMODEL, BLOCK_N to define block sizes for computation.", - "description_2": "Use triton language to create a batched matrix multiplication and reduction kernel for flash attention with 46 input parameters including data pointers, memory strides, and block dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SLIDING_WINDOW: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len\n\n block_start_loc = BLOCK_M * start_m\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n dim_mask = tl.where(\n tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,\n 0).to(tl.int1)\n\n q = tl.load(Q + off_q,\n mask=dim_mask[None, :] &\n (offs_m[:, None] < cur_batch_query_len),\n other=0.0)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],\n dtype=tl.float32)\n\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=dim_mask[:, None] &\n ((start_n + offs_n[None, :]) < cur_batch_ctx_len),\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n if SLIDING_WINDOW > 0:\n qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -\n (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,\n -10000)\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(V_cache + off_v,\n mask=dim_mask[None, :] &\n ((start_n + offs_n[:, None]) < cur_batch_ctx_len),\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=dim_mask[:, None] &\n ((start_n + offs_n[None, :]) < cur_batch_query_len),\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n if SLIDING_WINDOW > 0:\n qk = tl.where(\n offs_m[:, None] -\n (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=dim_mask[None, :] &\n ((start_n + offs_n[:, None]) < cur_batch_query_len),\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=dim_mask[None, :] &\n (offs_m[:, None] < cur_batch_query_len))\n return\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n sliding_window=None):\n\n BLOCK = 128\n\n if q.dtype is torch.float32:\n BLOCK = BLOCK // 2\n\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n Lk_padded = triton.next_power_of_2(Lk)\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n if sliding_window is None or sliding_window <= 0:\n sliding_window = 0\n\n num_warps = 8 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n v_cache.shape[3],\n k_cache.shape[4],\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(4),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_DMODEL_PADDED=Lk_padded,\n BLOCK_N=BLOCK,\n SLIDING_WINDOW=sliding_window,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for context attention. The kernel takes 43 parameters: Q, K, V, K_cache, V_cache, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, Out, and various strides and constants. It computes the attention scores and updates the output tensor. The context_attention_fwd function wraps this kernel, taking 12 parameters: q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, b_seq_len, b_ctx_len, max_input_len, and sliding_window. It sets up the grid and block sizes, and calls the kernel.", - "description_2": "Use triton language to create a context attention forward kernel and a wrapper function. The kernel computes attention scores using 43 parameters, while the wrapper function sets up execution parameters and calls the kernel with 12 parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n dim_mask = tl.where(\n tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)\n\n q = tl.load(Q + off_q,\n mask=dim_mask[None, :] &\n (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),\n other=0.0)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)\n\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = 0\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=dim_mask[:, None] &\n ((start_n + offs_n[None, :]) < cur_batch_ctx_len),\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n\n acc_scale = alpha\n acc = acc * acc_scale[:, None]\n\n v = tl.load(V_cache + off_v,\n mask=dim_mask[None, :] &\n ((start_n + offs_n[:, None]) < cur_batch_ctx_len),\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = cur_batch_ctx_len\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=dim_mask[:, None] &\n ((start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len),\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, allow_tf32=False)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n\n acc_scale = alpha\n acc = acc * acc_scale[:, None]\n\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=dim_mask[None, :] &\n ((start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len),\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n\n l_i = l_i_new\n m_i = m_i_new\n\n acc = acc / l_i[:, None]\n\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=dim_mask[None, :] &\n (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))\n return\n\n@torch.inference_mode()\ndef context_attention_fwd_alibi(\n q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n alibi_slopes=None):\n\n BLOCK = 128 \n\n if q.dtype is torch.float32:\n BLOCK = BLOCK // 2\n\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n Lk_padded = triton.next_power_of_2(Lk)\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 8 if Lk <= 64 else 8\n _fwd_kernel_alibi[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n alibi_slopes,\n v_cache.shape[3],\n k_cache.shape[4],\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(4),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_DMODEL_PADDED=Lk_padded,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for attention with ALiBi (Attention with Linear Bias) scaling. The kernel (_fwd_kernel_alibi) takes 37 tensor arguments representing different inputs and memory cache along with several stride values and 4 constexpr parameters. The context_attention_fwd_alibi function prepares the inputs and launches the kernel with grid settings based on the input dimensions and types.", - "description_2": "Use triton language to implement and launch a Triton kernel for ALiBi attention computation with specific grid settings.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, Lse, TMP, softmax_scale,\n stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k,\n seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation for forward pass of FlashAttention\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out, DO, Delta, stride_ob, stride_oh, stride_om,\n stride_dob, stride_doh, stride_dom, nheads, seqlen_q,\n seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n):\n # Triton kernel for preprocessing in backward pass\n\n@triton.jit\ndef _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):\n # Triton kernel for storing gradients of K and V\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D,\n softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm,\n stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q,\n seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for backward pass processing one column block\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(\"DQ\")),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(\"DQ\")),\n ],\n key=[\"CACHE_KEY_SEQLEN_Q\", \"CACHE_KEY_SEQLEN_K\", \"BIAS_TYPE\", \"IS_CAUSAL\", \"BLOCK_HEADDIM\"],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale,\n stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm,\n stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm,\n stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel for backward pass of FlashAttention\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Function to call the forward Triton kernel\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n # Function to call the backward Triton kernel\n", - "description_1": "Use triton language to implement a FlashAttention mechanism with forward and backward passes. The forward kernel (_fwd_kernel) takes 28 parameters including Q, K, V matrices, bias, output, and other configurations. The backward kernel (_bwd_kernel) takes 42 parameters including gradients, input matrices, and configurations. The kernels handle both causal and non-causal attention, support attention bias, and optimize for different head dimensions and sequence lengths.", - "description_2": "Use triton language to implement a FlashAttention mechanism with forward and backward passes, supporting causal and non-causal attention, attention bias, and optimized for various head dimensions and sequence lengths.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Bias,\n Out,\n Lse,\n TMP,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_ob,\n stride_oh,\n stride_om,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Function implementation...\n pass\n\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out,\n DO,\n Delta,\n stride_ob,\n stride_oh,\n stride_om,\n stride_dob,\n stride_doh,\n stride_dom,\n nheads,\n seqlen_q,\n seqlen_q_rounded,\n headdim,\n BLOCK_M: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n):\n # Function implementation...\n pass\n\n\n@triton.jit\ndef _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):\n # Function implementation...\n pass\n\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n,\n Q,\n K,\n V,\n Bias,\n DO,\n DQ,\n DK,\n DV,\n LSE,\n D,\n softmax_scale,\n stride_qm,\n stride_kn,\n stride_vn,\n stride_bm,\n stride_dom,\n stride_dqm,\n stride_dkn,\n stride_dvn,\n seqlen_q,\n seqlen_k,\n headdim,\n ATOMIC_ADD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Function implementation...\n pass\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(\"DQ\")),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(\"DQ\")),\n ],\n key=[\"CACHE_KEY_SEQLEN_Q\", \"CACHE_KEY_SEQLEN_K\", \"BIAS_TYPE\", \"IS_CAUSAL\", \"BLOCK_HEADDIM\"],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q,\n K,\n V,\n Bias,\n DO,\n DQ,\n DK,\n DV,\n LSE,\n D,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_dob,\n stride_doh,\n stride_dom,\n stride_dqb,\n stride_dqh,\n stride_dqm,\n stride_dkb,\n stride_dkh,\n stride_dkn,\n stride_dvb,\n stride_dvh,\n stride_dvn,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n SEQUENCE_PARALLEL: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Function implementation...\n pass\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n (batch, seqlen_q, nheads, d) = q.shape\n (_, seqlen_k, _, _) = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\")\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q,\n k,\n v,\n bias,\n o,\n lse,\n tmp,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM,\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1\n )\n return (o, lse, softmax_scale)\n\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n if do.stride(-1) != 1:\n do = do.contiguous()\n (batch, seqlen_q, nheads, d) = q.shape\n (_, seqlen_k, _, _) = k.shape\n assert d <= 128\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n assert lse.shape == (batch, nheads, seqlen_q_rounded)\n assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1\n assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n dq_accum = torch.empty_like(q, dtype=torch.float32)\n delta = torch.empty_like(lse)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _bwd_preprocess_do_o_dot[grid](\n o,\n do,\n delta,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n do.stride(0),\n do.stride(2),\n do.stride(1),\n nheads,\n seqlen_q,\n seqlen_q_rounded,\n d,\n BLOCK_M=128,\n BLOCK_HEADDIM=BLOCK_HEADDIM,\n )\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n assert bias.stride(-1) == 1\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\")\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n grid = lambda META: (triton.cdiv(seqlen_k, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1, batch * nheads)\n _bwd_kernel[grid](\n q,\n k,\n v,\n bias,\n do,\n dq_accum,\n dk,\n dv,\n lse,\n delta,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n do.stride(0),\n do.stride(2),\n do.stride(1),\n dq_accum.stride(0),\n dq_accum.stride(2),\n dq_accum.stride(1),\n dk.stride(0),\n dk.stride(2),\n dk.stride(1),\n dv.stride(0),\n dv.stride(2),\n dv.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM\n )\n dq.copy_(dq_accum)\n", - "description_1": "Use triton language to implement forward and backward kernels for FlashAttention. The forward kernel (_fwd_kernel) computes the attention output using queries, keys, values, and an optional bias. The backward preprocessing kernel (_bwd_preprocess_do_o_dot) computes the delta for gradient updates, and the backward kernel (_bwd_kernel) computes the gradients with respect to queries, keys, and values. Parameters are: (Q, K, V, Bias, Out, etc.) for the forward pass, and (DO, Delta, etc.) for the backward pass, with various constants defining dimensions and strides.", - "description_2": "Use triton language to create attention kernels for FlashAttention, including forward pass computation and backward pass gradient computation, with configurable parameters and dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_recurrence(\n S,\n p1,\n p2,\n O,\n NUM_BLOCK,\n D_MODEL_K: tl.constexpr,\n D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n\n O = (\n O\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + D_MODEL_K * D_MODEL_V\n )\n\n p1 = (\n p1\n + offset_bh * NUM_BLOCK * D_MODEL_K\n + tl.arange(0, BLOCK_MODEL)\n + offset_d * BLOCK_MODEL\n + D_MODEL_K\n )\n\n p2 = (\n p2\n + offset_bh * NUM_BLOCK * D_MODEL_V\n + tl.arange(0, BLOCK_MODEL)\n + offset_s * BLOCK_MODEL\n + D_MODEL_V\n )\n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S)\n\n S += D_MODEL_K * D_MODEL_V\n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK - 2):\n p_k = tl.load(p1)\n p_v = tl.load(p2)\n S_i = tl.load(S)\n acc = acc * p_k[:, None] * p_v[None, :] + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n p1 += D_MODEL_K\n p2 += D_MODEL_V\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V\n\n\n@triton.jit\ndef _bwd_recurrence(\n S,\n p1,\n p2,\n DS,\n Dp1,\n Dp2,\n NUM_BLOCK,\n NUM_SPLIT_K,\n NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr,\n D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n )\n\n DS = (\n DS\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n )\n\n p1 = (\n p1\n + offset_bh * NUM_BLOCK * D_MODEL_K\n + tl.arange(0, BLOCK_MODEL)\n + offset_d * BLOCK_MODEL\n + (NUM_BLOCK - 2) * D_MODEL_K\n )\n\n p2 = (\n p2\n + offset_bh * NUM_BLOCK * D_MODEL_V\n + tl.arange(0, BLOCK_MODEL)\n + offset_s * BLOCK_MODEL\n + (NUM_BLOCK - 2) * D_MODEL_V\n )\n\n Dp1 = (\n Dp1\n + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V\n + offset_s * D_MODEL_K\n + tl.arange(0, BLOCK_MODEL)\n + offset_d * BLOCK_MODEL\n + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V\n )\n\n Dp2 = (\n Dp2\n + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K\n + offset_d * D_MODEL_V\n + tl.arange(0, BLOCK_MODEL)\n + offset_s * BLOCK_MODEL\n + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K\n )\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n\n for i in range(NUM_BLOCK - 1):\n p_key = tl.load(p1)\n p_value = tl.load(p2)\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n Dacc += DS_i\n dp_i = Dacc * S_i\n dp_key = tl.sum(dp_i * p_value[None, :], axis=1)\n tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty))\n dp_value = tl.sum(dp_i * p_key[:, None], axis=0)\n tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty))\n\n tl.store(S, Dacc.to(S.dtype.element_ty))\n\n Dacc *= p_key[:, None]\n Dacc *= p_value[None, :]\n\n S -= D_MODEL_K * D_MODEL_V\n DS -= D_MODEL_K * D_MODEL_V\n p1 -= D_MODEL_K\n p2 -= D_MODEL_V\n Dp1 -= D_MODEL_K * NUM_SPLIT_V\n Dp2 -= D_MODEL_V * NUM_SPLIT_K\n\n\nclass Chunk_memory_update_full(torch.autograd.Function):\n @staticmethod\n def forward(ctx, decay_key_last, decay_value_last, to_add):\n decay_key_last = decay_key_last.contiguous()\n decay_value_last = decay_value_last.contiguous()\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape\n output = torch.empty_like(to_add)\n BLOCK_MODEL = 32\n\n assert D_k % 32 == 0\n assert D_v % 32 == 0\n assert D_k == decay_key_last.shape[-1]\n assert D_v == decay_value_last.shape[-1]\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n ctx.grid = grid\n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add,\n decay_key_last,\n decay_value_last,\n output,\n D_MODEL_K=D_k,\n D_MODEL_V=D_v,\n NUM_BLOCK=N,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n output[:, :, 0] = 0\n ctx.save_for_backward(output, decay_key_last, decay_value_last)\n\n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, decay_key_last, decay_value_last = ctx.saved_tensors\n\n B, H, N, D_k, D_v = output.shape\n\n num_block = N\n\n BLOCK_MODEL = 32\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n\n D_p1 = torch.empty(\n B, H, N, D_v // BLOCK_MODEL, D_k, device=DO.device, dtype=torch.float32\n )\n D_p2 = torch.empty(\n B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32\n )\n\n _bwd_recurrence[grid](\n output,\n decay_key_last,\n decay_value_last,\n DO,\n D_p1,\n D_p2,\n NUM_BLOCK=num_block,\n NUM_SPLIT_K=D_k // BLOCK_MODEL,\n NUM_SPLIT_V=D_v // BLOCK_MODEL,\n D_MODEL_K=D_k,\n D_MODEL_V=D_v,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n output[:, :, -1] = 0\n D_p1[:, :, 0] = 0\n D_p1[:, :, -1] = 0\n D_p2[:, :, 0] = 0\n D_p2[:, :, -1] = 0\n\n return D_p1.sum(-2), D_p2.sum(-2), output\n", - "description_1": "Use triton language to implement two kernels, _fwd_recurrence and _bwd_recurrence. The _fwd_recurrence kernel takes in seven arguments: S, p1, p2, O, NUM_BLOCK, D_MODEL_K, and D_MODEL_V, where S is the input tensor, p1 and p2 are pointers to decay values, O is the output tensor, and NUM_BLOCK, D_MODEL_K, and D_MODEL_V are model dimensions. The kernel computes a forward recurrence relation, storing results in O. The _bwd_recurrence kernel takes in twelve arguments: S, p1, p2, DS, Dp1, Dp2, NUM_BLOCK, NUM_SPLIT_K, NUM_SPLIT_V, D_MODEL_K, D_MODEL_V, and BLOCK_MODEL. It computes a backward recurrence relation for gradients, storing results in DS, Dp1, and Dp2. The class Chunk_memory_update_full uses these kernels in its forward and backward methods, handling the input and output tensors and managing grid dimensions.", - "description_2": "Use triton language to create forward and backward kernels for a recurrence relation with specified input, output, and model dimensions. Integrate these kernels into an autograd function class for PyTorch.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S,\n O,\n NUM_BLOCK,\n D_MODEL_K: tl.constexpr,\n D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n\n O = (\n O\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + D_MODEL_K * D_MODEL_V\n )\n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S)\n\n S += D_MODEL_K * D_MODEL_V\n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK - 2):\n S_i = tl.load(S)\n acc = acc + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V\n\n@triton.jit\ndef _bwd_recurrence(\n S,\n DS,\n NUM_BLOCK,\n NUM_SPLIT_K,\n NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr,\n D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n )\n\n DS = (\n DS\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n )\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n\n for i in range(NUM_BLOCK - 1):\n DS_i = tl.load(DS)\n Dacc += DS_i\n tl.store(S, Dacc.to(S.dtype.element_ty))\n\n S -= D_MODEL_K * D_MODEL_V\n DS -= D_MODEL_K * D_MODEL_V\n\nclass Chunk_memory_update_no_decay(torch.autograd.Function):\n @staticmethod\n def forward(ctx, to_add):\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape\n output = torch.empty_like(to_add)\n BLOCK_MODEL = 32\n\n assert D_k % 32 == 0\n assert D_v % 32 == 0\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n ctx.grid = grid\n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add,\n output,\n D_MODEL_K=D_k,\n D_MODEL_V=D_v,\n NUM_BLOCK=N,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n output[:, :, 0] = 0\n ctx.save_for_backward(output)\n\n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n (output,) = ctx.saved_tensors\n\n B, H, N, D_k, D_v = output.shape\n\n num_block = N\n\n BLOCK_MODEL = 32\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n\n _bwd_recurrence[grid](\n output,\n DO,\n NUM_BLOCK=num_block,\n NUM_SPLIT_K=D_k // BLOCK_MODEL,\n NUM_SPLIT_V=D_v // BLOCK_MODEL,\n D_MODEL_K=D_k,\n D_MODEL_V=D_v,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n output[:, :, -1] = 0\n\n return output\n", - "description_1": "Use triton language to implement two kernels: _fwd_recurrence and _bwd_recurrence. The _fwd_recurrence kernel takes 6 parameters: S (input tensor), O (output tensor), NUM_BLOCK (number of blocks), D_MODEL_K (model dimension K), D_MODEL_V (model dimension V), and BLOCK_MODEL (block size). It performs a forward recurrence operation on the input tensor S and stores the result in the output tensor O. The _bwd_recurrence kernel takes 8 parameters: S (input tensor), DS (gradient tensor), NUM_BLOCK (number of blocks), NUM_SPLIT_K (number of splits in K dimension), NUM_SPLIT_V (number of splits in V dimension), D_MODEL_K (model dimension K), D_MODEL_V (model dimension V), and BLOCK_MODEL (block size). It performs a backward recurrence operation to compute gradients and updates the input tensor S.", - "description_2": "Use triton language to create a forward recurrence kernel with 6 parameters and a backward recurrence kernel with 8 parameters for tensor operations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S, \n p1, \n O, \n NUM_BLOCK, \n D_MODEL_K: tl.constexpr, \n D_MODEL_V: tl.constexpr, \n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S \n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V \n + offset_d * D_MODEL_V * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V \n + offset_s * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n\n O = (\n O \n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V \n + offset_d * D_MODEL_V * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V \n + offset_s * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[None, :] \n + D_MODEL_K * D_MODEL_V\n )\n\n p1 = (\n p1 \n + offset_bh * NUM_BLOCK * D_MODEL_K \n + tl.arange(0, BLOCK_MODEL) \n + offset_d * BLOCK_MODEL \n + D_MODEL_K\n )\n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S)\n\n S += D_MODEL_K * D_MODEL_V\n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK - 2):\n p_k = tl.load(p1)\n S_i = tl.load(S)\n acc = acc * p_k[:, None] + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n p1 += D_MODEL_K\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V\n\n@triton.jit\ndef _bwd_recurrence(\n S, \n p1, \n DS, \n Dp1, \n NUM_BLOCK, \n NUM_SPLIT_K, \n NUM_SPLIT_V, \n D_MODEL_K: tl.constexpr, \n D_MODEL_V: tl.constexpr, \n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S \n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V \n + offset_d * D_MODEL_V * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V \n + offset_s * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[None, :] \n + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n )\n\n DS = (\n DS \n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V \n + offset_d * D_MODEL_V * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V \n + offset_s * BLOCK_MODEL \n + tl.arange(0, BLOCK_MODEL)[None, :] \n + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n )\n\n p1 = (\n p1 \n + offset_bh * NUM_BLOCK * D_MODEL_K \n + tl.arange(0, BLOCK_MODEL) \n + offset_d * BLOCK_MODEL \n + (NUM_BLOCK - 2) * D_MODEL_K\n )\n\n Dp1 = (\n Dp1 \n + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V \n + offset_s * D_MODEL_K \n + tl.arange(0, BLOCK_MODEL) \n + offset_d * BLOCK_MODEL \n + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V\n )\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n\n for i in range(NUM_BLOCK - 1):\n p_key = tl.load(p1)\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n Dacc += DS_i\n dp_i = Dacc * S_i\n dp_key = tl.sum(dp_i, axis=1)\n tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty))\n tl.store(S, Dacc.to(S.dtype.element_ty))\n\n Dacc *= p_key[:, None]\n\n S -= D_MODEL_K * D_MODEL_V\n DS -= D_MODEL_K * D_MODEL_V\n p1 -= D_MODEL_K\n Dp1 -= D_MODEL_K * NUM_SPLIT_V\n\nclass Chunk_memory_update_only_gk(torch.autograd.Function):\n @staticmethod\n def forward(ctx, decay_key_last, to_add):\n decay_key_last = decay_key_last.contiguous()\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape\n output = torch.empty_like(to_add)\n BLOCK_MODEL = 16\n\n assert D_k % BLOCK_MODEL == 0\n assert D_v % BLOCK_MODEL == 0\n assert D_k == decay_key_last.shape[-1]\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n ctx.grid = grid\n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add, \n decay_key_last, \n output, \n D_MODEL_K=D_k, \n D_MODEL_V=D_v, \n NUM_BLOCK=N, \n BLOCK_MODEL=BLOCK_MODEL\n )\n\n output[:, :, 0] = 0\n ctx.save_for_backward(output, decay_key_last)\n\n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, decay_key_last = ctx.saved_tensors\n\n B, H, N, D_k, D_v = output.shape\n\n num_block = N\n\n BLOCK_MODEL = 16\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n\n D_p1 = torch.empty(\n B, H, N, D_v // BLOCK_MODEL, D_k, device=DO.device, dtype=torch.float32\n )\n\n _bwd_recurrence[grid](\n output, \n decay_key_last, \n DO, \n D_p1, \n NUM_BLOCK=num_block, \n NUM_SPLIT_K=D_k // BLOCK_MODEL, \n NUM_SPLIT_V=D_v // BLOCK_MODEL, \n D_MODEL_K=D_k, \n D_MODEL_V=D_v, \n BLOCK_MODEL=BLOCK_MODEL\n )\n\n output[:, :, -1] = 0\n D_p1[:, :, 0] = 0\n D_p1[:, :, -1] = 0\n\n return D_p1.sum(-2), output\n", - "description_1": "Use triton language to implement forward and backward recurrence kernels for memory update in a sequence processing task. The `_fwd_recurrence` kernel takes 7 arguments: (1) S: input tensor with model state, (2) p1: decay factors for the input, (3) O: output tensor, (4) NUM_BLOCK: number of blocks, (5) D_MODEL_K: size of model's key dimension, (6) D_MODEL_V: size of model's value dimension, and (7) BLOCK_MODEL: size of block model, which dictates the data distribution and memory management during computation. The `_bwd_recurrence` kernel takes 10 arguments: (1) S: input tensor with model state, (2) p1: decay factors, (3) DS: gradient of the state, (4) Dp1: gradient of the decay factors, (5) NUM_BLOCK, (6) NUM_SPLIT_K, and (7) NUM_SPLIT_V for dimension splits, and the same last three constant parameters (8) D_MODEL_K, (9) D_MODEL_V, (10) BLOCK_MODEL as in the forward kernel. This backward kernel computes the gradients necessary for a custom backward pass.", - "description_2": "Use triton language to design custom autograd function with both forward and backward kernels for optimizing memory usage in sequence models. The function facilitates block-based operations by splitting key-value dimensions across blocks, enabling efficient parallel computation on input sequence data.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S,\n p2,\n O,\n NUM_BLOCK,\n D_MODEL_K: tl.constexpr,\n D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n\n O = (\n O\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + D_MODEL_K * D_MODEL_V\n )\n\n p2 = (\n p2\n + offset_bh * NUM_BLOCK * D_MODEL_V\n + tl.arange(0, BLOCK_MODEL)\n + offset_s * BLOCK_MODEL\n + D_MODEL_V\n )\n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S)\n\n S += D_MODEL_K * D_MODEL_V\n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK - 2):\n p_v = tl.load(p2)\n S_i = tl.load(S)\n acc = acc * p_v[None, :] + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n p2 += D_MODEL_V\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V\n\n@triton.jit\ndef _bwd_recurrence(\n S,\n p2,\n DS,\n Dp2,\n NUM_BLOCK,\n NUM_SPLIT_K,\n NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr,\n D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2)\n\n S = (\n S\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n )\n\n DS = (\n DS\n + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V\n + offset_d * D_MODEL_V * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V\n + offset_s * BLOCK_MODEL\n + tl.arange(0, BLOCK_MODEL)[None, :]\n + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n )\n\n p2 = (\n p2\n + offset_bh * NUM_BLOCK * D_MODEL_V\n + tl.arange(0, BLOCK_MODEL)\n + offset_s * BLOCK_MODEL\n + (NUM_BLOCK - 2) * D_MODEL_V\n )\n\n Dp2 = (\n Dp2\n + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K\n + offset_d * D_MODEL_V\n + tl.arange(0, BLOCK_MODEL)\n + offset_s * BLOCK_MODEL\n + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K\n )\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n\n for i in range(NUM_BLOCK - 1):\n p_value = tl.load(p2)\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n Dacc += DS_i\n dp_i = Dacc * S_i\n dp_value = tl.sum(dp_i, axis=0)\n tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty))\n tl.store(S, Dacc.to(S.dtype.element_ty))\n Dacc *= p_value[None, :]\n S -= D_MODEL_K * D_MODEL_V\n DS -= D_MODEL_K * D_MODEL_V\n p2 -= D_MODEL_V\n Dp2 -= D_MODEL_V * NUM_SPLIT_K\n\nclass Chunk_memory_update_only_gv(torch.autograd.Function):\n @staticmethod\n def forward(ctx, decay_value_last, to_add):\n decay_value_last = decay_value_last.contiguous()\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape\n output = torch.empty_like(to_add)\n BLOCK_MODEL = 32\n\n assert D_k % 32 == 0\n assert D_v % 32 == 0\n assert D_v == decay_value_last.shape[-1]\n\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n ctx.grid = grid\n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add,\n decay_value_last,\n output,\n D_MODEL_K=D_k,\n D_MODEL_V=D_v,\n NUM_BLOCK=N,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n output[:, :, 0] = 0\n ctx.save_for_backward(output, decay_value_last)\n\n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n output, decay_value_last = ctx.saved_tensors\n B, H, N, D_k, D_v = output.shape\n num_block = N\n BLOCK_MODEL = 32\n grid = (B * H, D_k // BLOCK_MODEL, D_v // BLOCK_MODEL)\n D_p2 = torch.empty(\n B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32\n )\n\n _bwd_recurrence[grid](\n output,\n decay_value_last,\n DO,\n D_p2,\n NUM_BLOCK=num_block,\n NUM_SPLIT_K=D_k // BLOCK_MODEL,\n NUM_SPLIT_V=D_v // BLOCK_MODEL,\n D_MODEL_K=D_k,\n D_MODEL_V=D_v,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n output[:, :, -1] = 0\n D_p2[:, :, 0] = 0\n D_p2[:, :, -1] = 0\n\n return D_p2.sum(-2), output\n", - "description_1": "Use triton language to implement two kernels: _fwd_recurrence and _bwd_recurrence. The _fwd_recurrence kernel performs a forward recurrence operation with 7 parameters: S (source tensor), p2 (previous tensor), O (output tensor), NUM_BLOCK (number of blocks), D_MODEL_K (model dimension K), D_MODEL_V (model dimension V), BLOCK_MODEL (block size). The _bwd_recurrence kernel executes a backward recurrence operation with 10 parameters: S, p2, DS (source delta), Dp2 (delta p2), NUM_BLOCK, NUM_SPLIT_K (split in K dimension), NUM_SPLIT_V (split in V dimension), D_MODEL_K, D_MODEL_V, BLOCK_MODEL. Both kernels use grid mapping based on program ids for processing, perform arithmetic operations, load, and store tensor data using Triton primitives.", - "description_2": "Use triton language to perform forward and backward recurrence operations with grid mapping and data processing on tensors utilizing Triton primitives in kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_preprocess_cumsum_gk(\n Q, K, GK, GK_cumsum, Q_exp, K_reduce, GK_last_exp, NUM_CHUNK, L,\n normalizer, clamp_min, D_MODEL_K: tl.constexpr, CHUNK_SIZE: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n Q_ptr = (\n Q + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n Q_exp_ptr = (\n Q_exp + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n GK_ptr = (\n GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n GK_cumsum_ptr = (\n GK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n GK_last_exp_ptr = (\n GK_last_exp + offset_bh * NUM_CHUNK * D_MODEL_K + offset_c * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n cumsum = tl.zeros([D_MODEL_K], dtype=tl.float32)\n\n for _ in range(CHUNK_SIZE):\n gk = tl.load(GK_ptr).to(tl.float32)\n gk = tl.where(gk >= clamp_min, gk, clamp_min)\n\n cumsum += gk\n tl.store(GK_cumsum_ptr, cumsum.to(GK_cumsum_ptr.dtype.element_ty))\n\n cumsum_exp = tl.exp(cumsum)\n\n q = tl.load(Q_ptr)\n q_exp = q * cumsum_exp\n tl.store(Q_exp_ptr, q_exp)\n\n Q_ptr += D_MODEL_K\n Q_exp_ptr += D_MODEL_K\n GK_ptr += D_MODEL_K\n GK_cumsum_ptr += D_MODEL_K\n\n tl.store(GK_last_exp_ptr, tl.exp(cumsum).to(GK_last_exp_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n GK_cumsum_ptr = (\n GK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n K_ptr = (\n K + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n K_reduce_ptr = (\n K_reduce + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n for _ in range(CHUNK_SIZE):\n gk_cumsum = tl.load(GK_cumsum_ptr)\n k = tl.load(K_ptr)\n k_reduce = k * tl.exp(cumsum - gk_cumsum)\n tl.store(K_reduce_ptr, k_reduce.to(K_reduce_ptr.dtype.element_ty))\n\n K_ptr += D_MODEL_K\n GK_cumsum_ptr += D_MODEL_K\n K_reduce_ptr += D_MODEL_K\n\n\n@triton.jit\ndef _bwd_preprocess_cumsum_gk(\n Q, K, GK, GK_cumsum, DQ_exp, DK_reduce, DGK_last_exp, DGK_cumsum, DQ, DK, DGK,\n NUM_CHUNK, L, normalizer, clamp_min, D_MODEL_K: tl.constexpr, CHUNK_SIZE: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n Q_ptr = (\n Q + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n K_ptr = (\n K + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n GK_ptr = (\n GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n GK_cumsum_ptr = (\n GK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n DQ_ptr = (\n DQ + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n DK_ptr = (\n DK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n DQ_exp_ptr = (\n DQ_exp + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n DK_reduce_ptr = (\n DK_reduce + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n DGK_cumsum_ptr = (\n DGK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n DGK_ptr = (\n DGK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n D_GK_last_exp_ptr = (\n DGK_last_exp + offset_bh * NUM_CHUNK * D_MODEL_K + offset_c * D_MODEL_K\n + tl.arange(0, D_MODEL_K)\n )\n\n cumsum_gradient = tl.zeros([D_MODEL_K], dtype=tl.float32)\n grad_gk_last = tl.zeros([D_MODEL_K], dtype=tl.float32)\n\n gk_last = tl.load(GK_cumsum_ptr + (CHUNK_SIZE - 1) * D_MODEL_K).to(tl.float32)\n cumsum_gradient += tl.load(D_GK_last_exp_ptr) * tl.exp(gk_last)\n\n GK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n GK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n Q_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n K_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n\n DQ_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DK_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DGK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DQ_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DGK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n\n for idx in range(CHUNK_SIZE - 1, -1, -1):\n gk_cs = tl.load(GK_cumsum_ptr).to(tl.float32)\n k = tl.load(K_ptr).to(tl.float32)\n grad_k = tl.exp(gk_last - gk_cs) * tl.load(DK_reduce_ptr).to(tl.float32)\n tl.store(DK_ptr, grad_k.to(DK_ptr.dtype.element_ty))\n grad_k *= k\n cumsum_gradient -= grad_k\n grad_gk_last += grad_k\n\n q = tl.load(Q_ptr).to(tl.float32)\n grad_q = tl.exp(gk_cs) * tl.load(DQ_exp_ptr)\n tl.store(DQ_ptr, grad_q.to(DK_ptr.dtype.element_ty))\n cumsum_gradient += grad_q * q.to(tl.float32)\n\n cumsum_gradient += tl.load(DGK_cumsum_ptr).to(tl.float32)\n\n tl.store(DGK_ptr, cumsum_gradient.to(DGK_ptr.dtype.element_ty))\n\n Q_ptr -= D_MODEL_K\n DQ_exp_ptr -= D_MODEL_K\n K_ptr -= D_MODEL_K\n DK_reduce_ptr -= D_MODEL_K\n GK_cumsum_ptr -= D_MODEL_K\n DGK_cumsum_ptr -= D_MODEL_K\n DQ_ptr -= D_MODEL_K\n DK_ptr -= D_MODEL_K\n DGK_ptr -= D_MODEL_K\n\n DGK_ptr = (\n DGK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K) + (CHUNK_SIZE - 1) * D_MODEL_K\n )\n GK_ptr = (\n GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K\n + tl.arange(0, D_MODEL_K) + (CHUNK_SIZE - 1) * D_MODEL_K\n )\n\n grad_gk_last = grad_gk_last + 0.0\n for idx in range(CHUNK_SIZE - 1, -1, -1):\n dgk = tl.load(DGK_ptr).to(tl.float32)\n dgk += grad_gk_last\n\n gk = tl.load(GK_ptr).to(tl.float32)\n dgk = tl.where(gk >= clamp_min, (dgk), 0.0)\n\n tl.store(DGK_ptr, dgk.to(DGK_ptr.dtype.element_ty))\n DGK_ptr -= D_MODEL_K\n GK_ptr -= D_MODEL_K\n\n\nclass PreprocessCumSum_GK(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, gk, normalizer_gk=8, clamp_min=-3):\n q = q.contiguous()\n k = k.contiguous()\n gk = gk.contiguous()\n\n B, H, NUM_CHUNK, CHUNK_SIZE, D = q.shape\n\n D_k = k.shape[-1]\n\n grid = (B * H, NUM_CHUNK)\n ctx.grid = grid\n\n k_reduce = torch.empty_like(k)\n\n q_exp = torch.empty_like(q)\n\n gk_cumsum = torch.empty_like(gk)\n\n gk_last_exp = torch.empty_like(gk[:, :, :, 0], dtype=torch.float32)\n\n _fwd_preprocess_cumsum_gk[grid](\n q, k, gk, gk_cumsum, q_exp, k_reduce, gk_last_exp, CHUNK_SIZE=CHUNK_SIZE,\n NUM_CHUNK=NUM_CHUNK, L=CHUNK_SIZE * NUM_CHUNK, normalizer=normalizer_gk,\n clamp_min=clamp_min, D_MODEL_K=D_k, num_warps=8 if D_k >= 512 else 4,\n )\n\n ctx.grid = grid\n ctx.save_for_backward(q, k, gk, gk_cumsum)\n ctx.normalizer_gk = normalizer_gk\n ctx.clamp_min = clamp_min\n\n return gk_cumsum, k_reduce, q_exp, gk_last_exp\n\n @staticmethod\n def backward(ctx, dgk_cumsum, dk_reduce, dq_exp, dgk_last_exp):\n dgk_cumsum = dgk_cumsum.contiguous()\n dk_reduce = dk_reduce.contiguous()\n dq_exp = dq_exp.contiguous()\n dgk_last_exp = dgk_last_exp.contiguous()\n\n q, k, gk, gk_cumsum = ctx.saved_tensors\n grid = ctx.grid\n\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dgk = torch.empty_like(gk)\n\n B, H, NUM_CHUNK, CHUNK_SIZE, D_k = q.shape\n\n _bwd_preprocess_cumsum_gk[grid](\n q, k, gk, gk_cumsum, dq_exp, dk_reduce, dgk_last_exp, dgk_cumsum,\n dq, dk, dgk, CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK=NUM_CHUNK,\n L=CHUNK_SIZE * NUM_CHUNK, normalizer=ctx.normalizer_gk, clamp_min=ctx.clamp_min,\n D_MODEL_K=D_k, num_warps=8 if D_k >= 512 else 4,\n )\n\n return dq, dk, dgk, None, None, None\n", - "description_1": "Use triton language to create a forward and backward preprocessing kernel for cumulative sum with guard on key tensors. The forward kernel (_fwd_preprocess_cumsum_gk) takes in 12 parameters, processes the cumulative sum of the given input tensors, and outputs the transformed tensors while applying certain mathematical operations and storing them back. The backward kernel (_bwd_preprocess_cumsum_gk) takes in 16 parameters, computes gradients for the inputs, and adjusts the inputs based on computed gradients.", - "description_2": "Use triton language to implement forward and backward pass kernels for preprocessing cumulative sum operations in a neural network. Ensure to handle input tensors with CUDA support and manage memory operations effectively for high performance.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_preprocess_cumsum_gv(\n V,\n GV,\n GV_cumsum,\n GV_exp,\n V_reduce,\n GV_last_exp,\n NUM_CHUNK,\n L,\n normalizer,\n clamp_min,\n D_MODEL_V: tl.constexpr,\n CHUNK_SIZE: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n\n GV_ptr = (\n GV\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n GV_last_exp_ptr = (\n GV_last_exp\n + offset_bh * NUM_CHUNK * D_MODEL_V\n + offset_c * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n GV_cumsum_ptr = (\n GV_cumsum\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n GV_exp_ptr = (\n GV_exp\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n cumsum = tl.zeros([D_MODEL_V], dtype=tl.float32)\n\n for _ in range(CHUNK_SIZE):\n gv = tl.load(GV_ptr).to(tl.float32)\n gv = tl.where(gv >= clamp_min, gv, clamp_min)\n cumsum += gv\n\n tl.store(GV_cumsum_ptr, cumsum.to(GV_cumsum_ptr.dtype.element_ty))\n tl.store(GV_exp_ptr, tl.exp(cumsum).to(GV_cumsum_ptr.dtype.element_ty))\n\n GV_cumsum_ptr += D_MODEL_V\n GV_exp_ptr += D_MODEL_V\n GV_ptr += D_MODEL_V\n\n tl.store(GV_last_exp_ptr, tl.exp(cumsum).to(GV_last_exp_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n V_ptr = (\n V\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n GV_cumsum_ptr = (\n GV_cumsum\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n V_reduce_ptr = (\n V_reduce\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n for _ in range(CHUNK_SIZE):\n v = tl.load(V_ptr)\n gv = tl.load(GV_cumsum_ptr)\n v_reduce = v * tl.exp(cumsum - gv)\n tl.store(V_reduce_ptr, v_reduce.to(V_reduce_ptr.dtype.element_ty))\n\n V_ptr += D_MODEL_V\n V_reduce_ptr += D_MODEL_V\n GV_cumsum_ptr += D_MODEL_V\n\n@triton.jit\ndef _bwd_preprocess_cumsum_gv(\n V,\n GV,\n GV_cumsum,\n DGV_cumsum_exp,\n DV_reduce,\n DGV_last_exp,\n DGV_cumsum,\n DV,\n DGV,\n NUM_CHUNK,\n L,\n normalizer,\n clamp_min,\n D_MODEL_V: tl.constexpr,\n CHUNK_SIZE: tl.constexpr,\n):\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n V_ptr = (\n V\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n GV_ptr = (\n GV\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n GV_cumsum_ptr = (\n GV_cumsum\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n DV_ptr = (\n DV\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n DV_reduce_ptr = (\n DV_reduce\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n DGV_cumsum_ptr = (\n DGV_cumsum\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n DGV_cumsum_exp_ptr = (\n DGV_cumsum_exp\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n DGV_ptr = (\n DGV\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n D_GV_last_exp_ptr = (\n DGV_last_exp\n + offset_bh * NUM_CHUNK * D_MODEL_V\n + offset_c * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n )\n\n cumsum_gradient = tl.zeros([D_MODEL_V], dtype=tl.float32)\n grad_gv_last = tl.zeros([D_MODEL_V], dtype=tl.float32)\n\n gv_last = tl.load(GV_cumsum_ptr + (CHUNK_SIZE - 1) * D_MODEL_V)\n cumsum_gradient += tl.load(D_GV_last_exp_ptr) * tl.exp(gv_last).to(tl.float32)\n\n GV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n GV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n\n V_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n\n DV_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DGV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DGV_cumsum_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DGV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n\n for idx in range(CHUNK_SIZE - 1, -1, -1):\n gv_cs = tl.load(GV_cumsum_ptr).to(tl.float32)\n v = tl.load(V_ptr).to(tl.float32)\n grad_v = tl.exp(gv_last - gv_cs) * tl.load(DV_reduce_ptr).to(tl.float32)\n tl.store(DV_ptr, grad_v.to(DV_ptr.dtype.element_ty))\n grad_v *= v\n cumsum_gradient -= grad_v\n grad_gv_last += grad_v\n\n grad_v = tl.exp(gv_cs) * tl.load(DGV_cumsum_exp_ptr)\n cumsum_gradient += grad_v\n\n cumsum_gradient += tl.load(DGV_cumsum_ptr).to(tl.float32)\n\n tl.store(DGV_ptr, cumsum_gradient.to(DGV_ptr.dtype.element_ty))\n\n V_ptr -= D_MODEL_V\n DV_reduce_ptr -= D_MODEL_V\n GV_cumsum_ptr -= D_MODEL_V\n DGV_cumsum_ptr -= D_MODEL_V\n DV_ptr -= D_MODEL_V\n DGV_ptr -= D_MODEL_V\n DGV_cumsum_exp_ptr -= D_MODEL_V\n\n DGV_ptr = (\n DGV\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n + (CHUNK_SIZE - 1) * D_MODEL_V\n )\n GV_ptr = (\n GV\n + offset_bh * L * D_MODEL_V\n + offset_c * CHUNK_SIZE * D_MODEL_V\n + tl.arange(0, D_MODEL_V)\n + (CHUNK_SIZE - 1) * D_MODEL_V\n )\n\n grad_gv_last = grad_gv_last + 0.0\n\n for idx in range(CHUNK_SIZE - 1, -1, -1):\n dgv = tl.load(DGV_ptr).to(tl.float32)\n dgv += grad_gv_last\n gv = tl.load(GV_ptr).to(tl.float32)\n\n dgv = tl.where(gv >= clamp_min, dgv, 0.0)\n\n tl.store(DGV_ptr, dgv.to(DGV_ptr.dtype.element_ty))\n DGV_ptr -= D_MODEL_V\n GV_ptr -= D_MODEL_V\n\n\nclass PreprocessCumSum_GV(torch.autograd.Function):\n @staticmethod\n def forward(ctx, v, gv, normalizer_gv=8, clamp_min=-3):\n v = v.contiguous()\n gv = gv.contiguous()\n\n B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape\n\n grid = (B * H, NUM_CHUNK)\n ctx.grid = grid\n\n gv_cumsum = torch.empty_like(gv, dtype=torch.float32)\n gv_cumsum_exp = torch.empty_like(gv)\n v_reduce = torch.empty_like(v)\n gv_last_exp = torch.empty_like(gv[:, :, :, 0], dtype=torch.float32)\n _fwd_preprocess_cumsum_gv[grid](\n v,\n gv,\n gv_cumsum,\n gv_cumsum_exp,\n v_reduce,\n gv_last_exp,\n CHUNK_SIZE=CHUNK_SIZE,\n NUM_CHUNK=NUM_CHUNK,\n L=CHUNK_SIZE * NUM_CHUNK,\n normalizer=normalizer_gv,\n clamp_min=clamp_min,\n D_MODEL_V=D_v,\n num_warps=8 if D_v >= 512 else 4,\n )\n\n ctx.grid = grid\n ctx.save_for_backward(v, gv, gv_cumsum)\n ctx.normalizer_gv = normalizer_gv\n ctx.clamp_min = clamp_min\n\n return gv_cumsum, v_reduce, gv_cumsum_exp, gv_last_exp\n\n @staticmethod\n def backward(ctx, dgv_cumsum, dv_reduce, dgv_cumsum_exp, dgv_last_exp):\n\n dgv_cumsum = dgv_cumsum.contiguous()\n dv_reduce = dv_reduce.contiguous()\n dgv_cumsum_exp = dgv_cumsum_exp.contiguous()\n dgv_last_exp = dgv_last_exp.contiguous()\n v, gv, gv_cumsum = ctx.saved_tensors\n grid = ctx.grid\n\n B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape\n\n dv = torch.empty_like(v)\n dgv = torch.empty_like(gv)\n _bwd_preprocess_cumsum_gv[grid](\n v,\n gv,\n gv_cumsum,\n dgv_cumsum_exp,\n dv_reduce,\n dgv_last_exp,\n dgv_cumsum,\n dv,\n dgv,\n CHUNK_SIZE=CHUNK_SIZE,\n NUM_CHUNK=NUM_CHUNK,\n L=CHUNK_SIZE * NUM_CHUNK,\n normalizer=ctx.normalizer_gv,\n clamp_min=ctx.clamp_min,\n D_MODEL_V=D_v,\n num_warps=8 if D_v >= 512 else 4,\n )\n return dv, dgv, None, None, None\n", - "description_1": "Use triton language to implement two kernels: _fwd_preprocess_cumsum_gv and _bwd_preprocess_cumsum_gv. The _fwd_preprocess_cumsum_gv kernel takes 12 parameters: V, GV, GV_cumsum, GV_exp, V_reduce, GV_last_exp, NUM_CHUNK, L, normalizer, clamp_min, D_MODEL_V, and CHUNK_SIZE. It computes cumulative sums and exponentials of the input GV, storing results in GV_cumsum, GV_exp, and GV_last_exp. The _bwd_preprocess_cumsum_gv kernel takes 15 parameters: V, GV, GV_cumsum, DGV_cumsum_exp, DV_reduce, DGV_last_exp, DGV_cumsum, DV, DGV, NUM_CHUNK, L, normalizer, clamp_min, D_MODEL_V, and CHUNK_SIZE. It computes gradients for the forward pass, updating DV and DGV based on the input gradients and stored cumulative sums.", - "description_2": "Use triton language to create forward and backward kernels for cumulative sum and exponential operations with gradient computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_compute_A(\n Q,\n K,\n GK,\n A,\n stride_q1,\n stride_q2,\n stride_q3,\n stride_q4,\n stride_a1,\n stride_a2,\n stride_a3,\n stride_a4,\n Z,\n H,\n N_CTX,\n D,\n BLOCK_DMODEL_QK: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_k = tl.program_id(2)\n\n qk_offset = off_hz * stride_q2 + off_k * BLOCK_DMODEL_QK\n a_offset = (off_k * Z * H + off_hz) * stride_a2\n\n lo = 0\n hi = BLOCK_N\n\n Q_ptr = (\n Q\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n K_ptr = (\n K\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[:, None]\n + tl.arange(0, 16)[None, :] * stride_q4\n )\n\n GK_K_ptr = (\n GK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[:, None]\n + tl.arange(0, 16)[None, :] * stride_q4\n )\n\n GK_Q_ptr = (\n GK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n A_ptr = (\n A\n + a_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[None, :]\n + tl.arange(0, 16)[:, None] * stride_a4\n )\n\n for q_high in range(16, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_normalizer = tl.load(\n GK\n + qk_offset\n + start_m * stride_q3\n + q_high * stride_q4\n + tl.arange(0, BLOCK_DMODEL_QK)\n ).to(tl.float32)\n q_gk2 = tl.exp(q_gk - q_normalizer[None, :])\n q = q * q_gk2.to(q.dtype)\n\n # inter-chunk bf16\n for k_high in range(0, q_high, 16):\n k = tl.load(K_ptr + k_high * stride_q4)\n k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32)\n k_gk = tl.exp(q_normalizer[:, None] - k_gk)\n k = k * k_gk.to(k.dtype)\n qk = tl.dot(q, k, allow_tf32=False)\n tl.store(A_ptr + q_high * stride_a4 + k_high, qk.to(A_ptr.dtype.element_ty))\n\n ## intra chunk fp32\n for q_high in range(lo, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_normalizer = tl.load(\n GK\n + qk_offset\n + start_m * stride_q3\n + q_high * stride_q4\n + tl.arange(0, BLOCK_DMODEL_QK)\n ).to(tl.float32)\n q_gk2 = tl.exp(q_gk - q_normalizer[None, :])\n q = q * q_gk2\n q_gk3 = tl.exp(q_normalizer[None, :] - q_gk)\n k = tl.load(K_ptr + q_high * stride_q4)\n k = k * tl.trans(q_gk3)\n\n qk = tl.dot(q, k, allow_tf32=False)\n qk = tl.where(tl.arange(0, 16)[:, None] >= tl.arange(0, 16)[None, :], qk, 0.0)\n tl.store(A_ptr + q_high * stride_a4 + q_high, qk.to(A_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_kernel_dqk(\n Q,\n K,\n GK,\n DA,\n DQ,\n DK,\n DGK,\n stride_q1,\n stride_q2,\n stride_q3,\n stride_q4,\n stride_a1,\n stride_a2,\n stride_a3,\n stride_a4,\n Z,\n H,\n N_CTX,\n D,\n BLOCK_DMODEL_QK: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_k = tl.program_id(2)\n\n qk_offset = off_hz * stride_q2 + BLOCK_DMODEL_QK * off_k\n a_offset = off_hz * stride_a2\n\n lo = 0\n hi = BLOCK_N\n\n Q_ptr = (\n Q\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n DQ_ptr = (\n DQ\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n K_ptr = (\n K\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n GK_K_ptr = (\n GK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n GK_Q_ptr = (\n GK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n DA_ptr = (\n DA\n + a_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[None, :]\n + tl.arange(0, 16)[:, None] * stride_a4\n )\n\n # inter chunk dq. bf16\n for q_high in range(lo + 16, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n\n q_normalizer = tl.load(\n GK\n + qk_offset\n + (start_m * stride_q3)\n + q_high * stride_q4\n + tl.arange(0, BLOCK_DMODEL_QK)\n ).to(tl.float32)\n\n dq2 = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32)\n\n for k_high in range(0, q_high, 16):\n k = tl.load(K_ptr + k_high * stride_q4)\n k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32)\n dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(k.dtype)\n k_gk = tl.exp(q_normalizer[None, :] - k_gk)\n k = k * k_gk.to(k.dtype)\n dq2 += tl.dot(dqk, k, allow_tf32=False)\n\n dq2 = dq2.to(q.dtype)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_gk = tl.exp(q_gk - q_normalizer[None, :])\n dq = dq2 * q_gk.to(q.dtype)\n dq_gk = dq * q\n\n DQ_ptr = (\n DQ\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n + q_high * stride_q4\n )\n tl.store(DQ_ptr, dq.to(DQ_ptr.dtype.element_ty))\n\n DGK_Q_ptr = (\n DGK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n + q_high * stride_q4\n )\n tl.store(DGK_Q_ptr, dq_gk.to(DGK_Q_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n for k_high in range(lo, hi - 16, 16):\n k = tl.load(K_ptr + k_high * stride_q4)\n k_gk = tl.load(GK_K_ptr + k_high * stride_q4)\n dk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32)\n dgk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32)\n\n for q_high in range(k_high + 16, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_normalizer = tl.load(\n GK\n + qk_offset\n + (start_m * stride_q3)\n + q_high * stride_q4\n + tl.arange(0, BLOCK_DMODEL_QK)\n ).to(tl.float32)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_gk = tl.exp(q_gk - q_normalizer[None, :]).to(q.dtype)\n q = q * q_gk\n dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(q.dtype)\n\n k_gk2 = tl.exp(q_normalizer[None, :] - k_gk)\n\n dk2 = tl.dot(tl.trans(dqk), q, allow_tf32=False)\n dk += dk2 * k_gk2\n dgk -= dk2 * k * k_gk2\n\n DK_ptr = (\n DK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n + k_high * stride_q4\n )\n tl.store(DK_ptr, dk.to(DK_ptr.dtype.element_ty))\n\n DGK_K_ptr = (\n DGK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n + k_high * stride_q4\n )\n prev = tl.load(DGK_K_ptr)\n tl.store(DGK_K_ptr, (prev + dgk).to(DGK_K_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n DK_ptr = (\n DK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n DGK_K_ptr = (\n DGK\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n DQ_ptr = (\n DQ\n + qk_offset\n + (start_m) * stride_q3\n + tl.arange(0, BLOCK_DMODEL_QK)[None, :]\n + tl.arange(0, 16)[:, None] * stride_q4\n )\n\n ## intra chunk, fp32.\n for q_high in range(lo, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_normalizer = tl.load(\n GK\n + qk_offset\n + start_m * stride_q3\n + q_high * stride_q4\n + tl.arange(0, BLOCK_DMODEL_QK)\n ).to(tl.float32)\n q_gk2 = tl.exp(q_gk - q_normalizer[None, :])\n q2 = q * q_gk2\n q_gk3 = tl.exp(q_normalizer[None, :] - q_gk)\n\n k = tl.load(K_ptr + q_high * stride_q4)\n k2 = k * q_gk3\n\n dqk = tl.load(DA_ptr + q_high * stride_a4 + q_high)\n dqk = tl.where(tl.arange(0, 16)[:, None] >= tl.arange(0, 16)[None, :], dqk, 0.0)\n\n dk2 = tl.dot(tl.trans(dqk), q2, allow_tf32=False)\n dk = dk2 * q_gk3\n prev_dk = tl.load(DK_ptr + q_high * stride_q4)\n tl.store(\n DK_ptr + q_high * stride_q4, (dk + prev_dk).to(DK_ptr.dtype.element_ty)\n )\n\n dgk = -dk * k\n dq2 = tl.dot(dqk, k2, allow_tf32=False)\n dq = dq2 * q_gk2\n\n prev_dq = tl.load(DQ_ptr + q_high * stride_q4)\n tl.store(\n DQ_ptr + q_high * stride_q4, (dq + prev_dq).to(DQ_ptr.dtype.element_ty)\n )\n\n dgk += dq * q\n prev_dq_gk = tl.load(DGK_K_ptr + q_high * stride_q4)\n tl.store(\n DGK_K_ptr + q_high * stride_q4,\n (dgk + prev_dq_gk).to(DGK_K_ptr.dtype.element_ty),\n )\n\n\nclass FlashGRet(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, gk):\n q = q.contiguous()\n k = k.contiguous()\n gk = gk.contiguous()\n\n capability = torch.cuda.get_device_capability()\n if capability[0] < 8:\n raise RuntimeError(\n \"Flash attention currently only supported for compute capability >= 80\"\n )\n\n BLOCK_M = BLOCK_N = q.shape[-2]\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n if Lk > 128:\n assert Lk % 128 == 0\n\n BLOCK_DMODEL_QK = min(Lk, 128)\n ctx.BLOCK_DMODEL_QK = BLOCK_DMODEL_QK\n\n A = torch.zeros(\n max(1, Lk // 128),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n BLOCK_N,\n BLOCK_N,\n device=q.device,\n dtype=q.dtype,\n )\n\n grid = (q.shape[2], q.shape[0] * q.shape[1], max(1, Lk // 128))\n\n _fwd_kernel_compute_A[grid](\n q,\n k,\n gk,\n A,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n A.stride(1),\n A.stride(2),\n A.stride(3),\n A.stride(4),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n q.shape[3],\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL_QK=BLOCK_DMODEL_QK,\n BLOCK_M=BLOCK_M,\n num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4,\n num_stages=8,\n )\n\n ctx.save_for_backward(q, k, gk)\n ctx.grid = grid\n ctx.BLOCK_N = BLOCK_N\n ctx.BLOCK_N = BLOCK_N\n ctx.head = q.shape[1]\n return A.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, dA):\n dA = dA.contiguous()\n q, k, gk = ctx.saved_tensors\n\n dq = torch.zeros_like(q)\n dk = torch.zeros_like(k)\n dgk = torch.zeros_like(gk)\n\n BLOCK_N = ctx.BLOCK_N\n BLOCK_M = BLOCK_N\n Lq, Lk = q.shape[-1], k.shape[-1]\n\n _bwd_kernel_dqk[ctx.grid](\n q,\n k,\n gk,\n dA,\n dq,\n dk,\n dgk,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n dA.stride(0),\n dA.stride(1),\n dA.stride(2),\n dA.stride(3),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n q.shape[3],\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,\n BLOCK_M=BLOCK_M,\n num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4,\n num_stages=5,\n )\n\n return dq, dk, dgk, None\n", - "description_1": "Use triton language to implement forward and backward kernel functions for a customized attention mechanism. The forward kernel (_fwd_kernel_compute_A) computes the attention matrix A from input tensors Q, K, and GK using block sizes BLOCK_DMODEL_QK, BLOCK_M, and BLOCK_N for chunked processing. The backward kernel (_bwd_kernel_dqk) computes gradients DQ, DK, and DGK for the input tensors based on the backward pass of the attention mechanism.", - "description_2": "Use triton language to implement kernel functions for computing attention matrices and their gradients using block-level processing with customizable block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_compute_O(\n A,\n V,\n GV,\n O,\n stride_a1,\n stride_a2,\n stride_a3,\n stride_a4,\n stride_v1,\n stride_v2,\n stride_v3,\n stride_v4,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_DMODEL_V: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_v = tl.program_id(2)\n\n a_offset = off_hz * stride_a2\n v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V\n\n lo = 0\n hi = BLOCK_N\n\n V_ptr = (\n V\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n O_ptr = (\n O\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n GV_ptr = (\n GV\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n A_ptr = (\n A\n + a_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[None, :]\n + tl.arange(0, 16)[:, None] * stride_a4\n )\n\n for q_high in range(lo + 16, hi, 16):\n q_gv_normalizer = tl.load(\n GV\n + v_offset\n + (start_m) * stride_v3\n + q_high * stride_v4\n + tl.arange(0, BLOCK_DMODEL_V)\n ).to(tl.float32)\n acc = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32)\n\n for k_high in range(0, q_high, 16):\n qk = tl.load(A_ptr + q_high * stride_a4 + k_high)\n v = tl.load(V_ptr + k_high * stride_v4)\n k_gv = tl.load(GV_ptr + k_high * stride_v4)\n k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv)\n v = v * k_gv.to(v.dtype)\n output = tl.dot(qk.to(v.dtype), v, allow_tf32=False)\n acc += output\n\n tl.store(O_ptr + q_high * stride_v4, acc.to(O.dtype.element_ty))\n\n tl.store(\n O_ptr, tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32).to(O.dtype.element_ty)\n )\n\n tl.debug_barrier()\n\n for q_high in range(lo, hi, 16):\n q_gv_normalizer = tl.load(\n GV\n + v_offset\n + (start_m) * stride_v3\n + q_high * stride_v4\n + tl.arange(0, BLOCK_DMODEL_V)\n ).to(tl.float32)\n\n qk = tl.load(A_ptr + q_high * stride_a4 + q_high)\n v = tl.load(V_ptr + q_high * stride_v4)\n k_gv = tl.load(GV_ptr + q_high * stride_v4)\n k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv)\n\n v = v * k_gv2\n output = tl.dot(qk.to(tl.float32), v, allow_tf32=False)\n\n q_gv = tl.exp(k_gv - q_gv_normalizer[None, :])\n\n prev = tl.load(O_ptr + q_high * stride_v4)\n output += prev\n output = output * q_gv\n\n tl.store(O_ptr + q_high * stride_v4, output.to(O.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_kernel_dav(\n V,\n GV,\n A,\n O,\n DO,\n DA,\n DV,\n DGV,\n Z,\n H,\n stride_a1,\n stride_a2,\n stride_a3,\n stride_a4,\n stride_v1,\n stride_v2,\n stride_v3,\n stride_v4,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_DMODEL_V: tl.constexpr,\n):\n\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_v = tl.program_id(2)\n\n a_offset = off_hz * stride_a2\n da_offset = (off_v * Z * H + off_hz) * stride_a2\n v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V\n\n lo = 0\n hi = BLOCK_N\n\n DO_ptr = (\n DO\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n O_ptr = (\n O\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n DV_ptr = (\n DV\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n GV_ptr = (\n GV\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n DGV_ptr = (\n DGV\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n A_ptr = (\n A\n + a_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[None, :]\n + tl.arange(0, 16)[:, None] * stride_a4\n )\n\n DA_ptr = (\n DA\n + da_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[None, :]\n + tl.arange(0, 16)[:, None] * stride_a4\n )\n\n # pre-compute do*q_gv. in-place update\n for q_high in range(lo, hi, 16):\n do = tl.load(DO_ptr + q_high * stride_v4)\n o = tl.load(O_ptr + q_high * stride_v4)\n tl.store(DGV_ptr + q_high * stride_v4, (do * o))\n\n q_gv_normalizer = tl.load(\n GV\n + v_offset\n + (start_m) * stride_v3\n + q_high * stride_v4\n + tl.arange(0, BLOCK_DMODEL_V)\n ).to(tl.float32)\n q_gv = tl.load(GV_ptr + q_high * stride_v4)\n q_gv = tl.exp(q_gv - q_gv_normalizer[None, :])\n do = do * q_gv\n\n tl.store(DO_ptr + q_high * stride_v4, do.to(DO_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n V_ptr = (\n V\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[:, None]\n + tl.arange(0, 16)[None, :] * stride_v4\n )\n GV_ptr = (\n GV\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[:, None]\n + tl.arange(0, 16)[None, :] * stride_v4\n )\n\n for q_high in range(lo + 16, hi, 16):\n do = tl.load(DO_ptr + q_high * stride_v4)\n q_gv_normalizer = tl.load(\n GV\n + v_offset\n + (start_m) * stride_v3\n + q_high * stride_v4\n + tl.arange(0, BLOCK_DMODEL_V)\n ).to(tl.float32)\n\n for k_high in range(0, q_high, 16):\n v = tl.load(V_ptr + k_high * stride_v4)\n k_gv = tl.load(GV_ptr + k_high * stride_v4)\n k_gv = tl.exp(q_gv_normalizer[:, None] - k_gv)\n\n v2 = v * k_gv.to(v.dtype)\n dqk = tl.dot(do, v2, allow_tf32=False)\n tl.store(DA_ptr + q_high * stride_a4 + k_high, dqk.to(DA.dtype.element_ty))\n\n tl.debug_barrier()\n\n A_ptr = (\n A\n + a_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[:, None]\n + tl.arange(0, 16)[None, :] * stride_a4\n )\n\n V_ptr = (\n V\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n GV_ptr = (\n GV\n + v_offset\n + (start_m) * stride_v3\n + tl.arange(0, BLOCK_DMODEL_V)[None, :]\n + tl.arange(0, 16)[:, None] * stride_v4\n )\n\n for k_high in range(0, hi, 16):\n dv = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32)\n\n k_gv = tl.load(GV_ptr + k_high * stride_v4)\n\n for q_high in range(k_high + 16, BLOCK_N, 16):\n do = tl.load(DO_ptr + q_high * stride_v4)\n\n kq = tl.load(A_ptr + q_high * stride_a4 + k_high).to(do.dtype)\n\n q_gv_normalizer = tl.load(\n GV\n + v_offset\n + (start_m) * stride_v3\n + q_high * stride_v4\n + tl.arange(0, BLOCK_DMODEL_V)\n ).to(tl.float32)\n k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv)\n\n dv2 = tl.dot(kq, do, allow_tf32=False)\n dv += dv2 * k_gv2\n\n v = tl.load(V_ptr + k_high * stride_v4)\n tl.store(DV_ptr + k_high * stride_v4, dv.to(v.dtype))\n\n prev_dv = tl.load(DGV_ptr + k_high * stride_v4)\n tl.store(DGV_ptr + k_high * stride_v4, prev_dv - dv * v)\n\n tl.debug_barrier()\n\n A_ptr = (\n A\n + a_offset\n + (start_m) * stride_a3\n + tl.arange(0, 16)[:, None]\n + tl.arange(0, 16)[None, :] * stride_a4\n )\n\n for q_high in range(lo, hi, 16):\n do = tl.load(DO_ptr + q_high * stride_v4)\n\n q_gv_normalizer = tl.load(\n GV\n + v_offset\n + start_m * stride_v3\n + q_high * stride_v4\n + tl.arange(0, BLOCK_DMODEL_V)\n ).to(tl.float32)\n\n v = tl.load(V_ptr + q_high * stride_v4)\n k_gv = tl.load(GV_ptr + q_high * stride_v4)\n k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv)\n v2 = v * k_gv\n\n dqk = tl.dot(do.to(v2.dtype), tl.trans(v2), allow_tf32=False)\n dqk = tl.where(tl.arange(0, 16)[:, None] >= tl.arange(0, 16)[None, :], dqk, 0.0)\n tl.store(DA_ptr + q_high * stride_a4 + q_high, dqk.to(DA_ptr.dtype.element_ty))\n\n kq = tl.load(A_ptr + q_high * stride_a4 + q_high).to(do.dtype)\n dv2 = tl.dot(kq, do, allow_tf32=False)\n\n dv = dv2 * k_gv\n prev_dv = tl.load(DV_ptr + q_high * stride_v4)\n tl.store(DV_ptr + q_high * stride_v4, (prev_dv + dv).to(DV.dtype.element_ty))\n\n prev_gdv = tl.load(DGV_ptr + q_high * stride_v4)\n prev_gdv -= dv * v\n tl.store(DGV_ptr + q_high * stride_v4, prev_gdv.to(DGV.dtype.element_ty))\n\n\nclass FlashGRet_O(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A, v, gv, chunk_size=16):\n assert gv.dtype == torch.float32\n\n capability = torch.cuda.get_device_capability()\n if capability[0] < 8:\n raise RuntimeError(\n \"Flash attention currently only supported for compute capability >= 80\"\n )\n\n BLOCK_M = BLOCK_N = v.shape[-2]\n\n Lv = v.shape[-1]\n BLOCK_V = min(128, Lv)\n ctx.BLOCK_V = BLOCK_V\n\n assert v.shape[-1] % BLOCK_V == 0\n\n grid = (v.shape[2], v.shape[0] * v.shape[1], max(1, v.shape[-1] // BLOCK_V))\n\n o = torch.empty_like(v)\n\n _fwd_compute_O[grid](\n A,\n v,\n gv,\n o,\n A.stride(0),\n A.stride(1),\n A.stride(2),\n A.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n BLOCK_N=BLOCK_N,\n BLOCK_M=BLOCK_M,\n BLOCK_DMODEL_V=BLOCK_V,\n num_warps=8 if BLOCK_V == 128 else 4,\n num_stages=5,\n )\n\n ctx.save_for_backward(A, v, gv, o)\n ctx.grid = grid\n ctx.chunk_size = chunk_size\n return o\n\n @staticmethod\n def backward(ctx, do):\n do = do.contiguous()\n A, v, gv, o = ctx.saved_tensors\n BLOCK_V = ctx.BLOCK_V\n assert v.shape[-1] % BLOCK_V == 0\n\n dv = torch.zeros_like(v)\n dgv = torch.zeros_like(gv)\n\n BLOCK_M = BLOCK_N = v.shape[-2]\n\n grid = ctx.grid\n\n dA = torch.empty(\n v.shape[-1] // BLOCK_V if BLOCK_V == 128 else 1,\n A.shape[0],\n A.shape[1],\n A.shape[2],\n A.shape[3],\n A.shape[3],\n device=A.device,\n dtype=A.dtype,\n )\n\n _bwd_kernel_dav[grid](\n v,\n gv,\n A,\n o,\n do,\n dA,\n dv,\n dgv,\n v.shape[0],\n v.shape[1],\n A.stride(0),\n A.stride(1),\n A.stride(2),\n A.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n BLOCK_N=BLOCK_N,\n BLOCK_M=BLOCK_M,\n BLOCK_DMODEL_V=ctx.BLOCK_V,\n num_warps=8,\n num_stages=4,\n )\n\n return dA.sum(0).to(A), dv.to(v), dgv.to(gv), None\n", - "description_1": "Use triton language to implement two kernels for computing forward and backward passes for a custom operation on inputs A, V, GV, and O with specific strides and block configurations. The forward kernel computes a matrix multiplication and stores results in O, while the backward kernel computes gradients dA, dv, and dgv using inputs DO, DA, DV, and DGV.", - "description_2": "Use triton language to implement forward and backward kernels for a matrix operation, utilizing specific block sizes and memory access patterns for optimization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, g, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_bh % H\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _g = tl.load(p_g, mask=mask_bk, other=float(\"-inf\")).to(tl.float32)\n _g = tl.math.exp(_g)\n\n h = h * _g[None, :] + _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n p_g += DK\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_bh % H\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _g = tl.load(p_g, mask=mask_bk, other=0).to(tl.float32)\n _g = tl.exp(_g)\n\n h = h * _g[:, None] + _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += DK\n p_do += DV\n p_v += DV\n p_dq += DK\n p_g += DK\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_dk = (\n dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n )\n p_dv = (\n dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n )\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _g = tl.load(p_g, mask=mask_bk, other=0).to(tl.float32)\n _g = tl.exp(_g)\n\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n d_h = d_h * _g[:, None]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n p_g -= DK\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, g):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n g = g.contiguous()\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = 1\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_gla_fwd_kernel[grid](\n q,\n k,\n v,\n g,\n o,\n q.stride(1),\n q.stride(2),\n q.stride(3),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n batch_size,\n n_heads,\n seq_len,\n scale,\n DK=d_head_qk,\n DV=d_head_v,\n BK=BK,\n BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, g)\n return o.to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do):\n q, k, v, g = ctx.saved_tensors\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n g = g.contiguous()\n do = do.contiguous()\n\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = 1\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q,\n k,\n v,\n g,\n do,\n dq,\n dk,\n dv,\n q.stride(1),\n q.stride(2),\n q.stride(3),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n batch_size,\n n_heads,\n seq_len,\n scale,\n DK=d_head_qk,\n DV=d_head_v,\n BK=BK,\n BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n _dg = dq * q - dk * k\n _dg_cumsum = _dg.cumsum(-2)\n dg = _dg + _dg_cumsum[:, :, -1, None] - _dg_cumsum\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype)\n\nfused_recurrent_gla = FusedRecurrentGLAFunction.apply\n", - "description_1": "Use triton language to implement a fused recurrent gated linear attention (GLA) forward and backward kernel. The forward kernel takes 18 parameters: q, k, v, g, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BK, BV, DK, DV. The backward kernel takes 20 parameters: q, k, v, g, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BK, BV, DK, DV. The kernels perform operations on input tensors to compute the output and gradients for a recurrent GLA layer.", - "description_2": "Use triton language to create a fused recurrent GLA function with forward and backward passes, handling input tensors and computing outputs and gradients efficiently.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward pass of simple RMS normalization\n@triton.jit\ndef srms_norm_fw(X, Y, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n\n # Move to this row\n x_ptrs = X + row * stride + cols\n x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)\n\n x_zm = tl.where(mask, x, 0.0)\n\n x_var = tl.sum(x_zm * x_zm, axis=0) / N\n rstd = 1.0 / tl.sqrt(x_var + eps)\n\n # Normalize, optionally affine\n y = x_zm * rstd\n tl.store(V + row, rstd)\n\n y_ptrs = Y + row * stride + cols\n tl.store(y_ptrs, y, mask=mask)\n\n# Triton kernel for backward pass of simple RMS normalization\n@triton.jit\ndef srms_norm_bwd_dx_fused(\n DX, DY,\n X, V,\n stride, N,\n BLOCK_SIZE_N: tl.constexpr,\n):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n\n # offset data pointers to start at the row of interest\n x_ptrs = X + row * stride + cols\n dy_ptrs = DY + row * stride + cols\n\n # load data to SRAM\n x = tl.load(x_ptrs, mask=mask, other=0)\n dy = tl.load(dy_ptrs, mask=mask, other=0)\n rstd = tl.load(V + row)\n\n # compute dx\n xhat = x * rstd\n wdy = dy\n\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n mean1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - (xhat * mean1)) * rstd\n\n # write-back dx\n mask = cols < N # re-materialize the mask to save registers\n dx_ptrs = DX + row * stride + cols\n tl.store(dx_ptrs, dx, mask=mask)\n\nclass _SrmsNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, eps):\n if x.dtype == torch.float16:\n eps = max(eps, 1.6e-5)\n\n y = torch.empty_like(x)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n\n if not x_arg.is_contiguous() or not y.is_contiguous():\n x_arg = x_arg.contiguous()\n y = y.contiguous()\n\n num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)\n\n srms_norm_fw[(M,)](\n x_arg, y, rstd,\n x_arg.stride(0),\n N,\n eps,\n num_warps=num_warps,\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n )\n\n ctx.save_for_backward(x, rstd)\n ctx.BLOCK_SIZE_N = BLOCK_SIZE_N\n ctx.num_warps = num_warps\n\n return y.reshape_as(x)\n\n @staticmethod\n def backward(ctx, dy):\n x, rstd = ctx.saved_tensors\n x = x.reshape(-1, x.size(-1))\n M, N = x.size()\n\n GROUP_SIZE_M = 32\n if N <= 8192:\n GROUP_SIZE_M = 64\n if N <= 4096:\n GROUP_SIZE_M = 96\n if N <= 2048:\n GROUP_SIZE_M = 128\n if N <= 1024:\n GROUP_SIZE_M = 256\n\n if dy.dtype == torch.float32:\n GROUP_SIZE_M = GROUP_SIZE_M // 2\n\n dy = dy.contiguous()\n dx = torch.empty_like(dy)\n\n assert (\n dy.numel() == x.numel()\n ), \"Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm\"\n\n num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)\n\n srms_norm_bwd_dx_fused[(M,)](\n dx, dy, x,\n rstd,\n x.stride(0),\n N,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,\n num_warps=num_warps\n )\n\n dx = dx.reshape_as(dy)\n return dx, None, None\n", - "description_1": "Use triton language to implement a simple RMS normalization with two kernels: one for the forward pass and one for the backward pass. The forward kernel 'srms_norm_fw' takes 7 parameters: input tensor X, output tensor Y, tensor V for storing rstd, stride, dimension N, epsilon for numerical stability, and BLOCK_SIZE_N for block size. The backward kernel 'srms_norm_bwd_dx_fused' takes 7 parameters: output gradient DX, input gradient DY, input tensor X, tensor V for rstd, stride, dimension N, and BLOCK_SIZE_N for block size.", - "description_2": "Use triton language to create a simple RMS normalization with forward and backward kernels, handling input/output tensors, strides, dimensions, and block sizes.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _act_no_dim_fwd_triton(\n X,\n O,\n n: tl.constexpr,\n d: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_n = tl.program_id(0)\n off_block_d = tl.program_id(1)\n # compute offset\n offset_n = off_n * d\n offset_d = off_block_d * BLOCK\n # mask\n d_mask = (offset_d + tl.arange(0, BLOCK)) < d\n\n # compute\n x_block_ptr = X + offset_n + offset_d + tl.arange(0, BLOCK)\n o_block_ptr = O + offset_n + offset_d + tl.arange(0, BLOCK)\n x = tl.load(x_block_ptr, mask=d_mask, other=0).to(tl.float32)\n o = x\n\n if ACT == \"relu\":\n o = tl.where(x >= 0, x, 0)\n elif ACT == \"sigmoid\":\n o = tl.sigmoid(x)\n elif ACT == \"silu\":\n o = x * tl.sigmoid(x)\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=d_mask)\n\n@triton.jit\ndef _act_no_dim_bwd_triton(\n X,\n DO,\n DX,\n n: tl.constexpr,\n d: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_n = tl.program_id(0)\n off_block_d = tl.program_id(1)\n # compute offset\n offset_n = off_n * d\n offset_d = off_block_d * BLOCK\n # mask\n d_mask = (offset_d + tl.arange(0, BLOCK)) < d\n\n # compute\n x_block_ptr = X + offset_n + offset_d + tl.arange(0, BLOCK)\n do_block_ptr = DO + offset_n + offset_d + tl.arange(0, BLOCK)\n dx_block_ptr = DX + offset_n + offset_d + tl.arange(0, BLOCK)\n x = tl.load(x_block_ptr, mask=d_mask, other=0).to(tl.float32)\n do = tl.load(do_block_ptr, mask=d_mask, other=0).to(tl.float32)\n dx = do\n\n if ACT == \"relu\":\n dx = tl.where(x >= 0, do, 0)\n elif ACT == \"sigmoid\":\n sigmoid = tl.sigmoid(x)\n dx = do * sigmoid * (1 - sigmoid)\n elif ACT == \"silu\":\n sigmoid = tl.sigmoid(x)\n dx = do * sigmoid * (1 + x * (1 - sigmoid))\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=d_mask)\n\ndef act_no_dim_fwd_triton(x, act=\"none\"):\n if act == \"none\":\n return x\n\n shape = x.shape\n n = torch.prod(torch.tensor(shape[:-1])).item()\n d = x.shape[-1]\n o = torch.empty_like(x)\n\n def grid(meta):\n return (n, triton.cdiv(d, meta[\"BLOCK\"]))\n\n _act_no_dim_fwd_triton[grid](\n x,\n o,\n n,\n d,\n act,\n )\n\n return o\n\ndef act_no_dim_bwd_triton(x, do, act=\"none\"):\n if act == \"none\":\n return do\n\n shape = x.shape\n n = torch.prod(torch.tensor(shape[:-1])).item()\n d = x.shape[-1]\n\n dx = torch.empty_like(x)\n\n def grid(meta):\n return (n, triton.cdiv(d, meta[\"BLOCK\"]))\n\n _act_no_dim_bwd_triton[grid](\n x,\n do,\n dx,\n n,\n d,\n act,\n )\n\n return dx\n\ndef act_no_dim_triton(x, act=\"none\"):\n class ActNoDimTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, act=\"none\"):\n o = act_no_dim_fwd_triton(x, act)\n\n ctx.save_for_backward(x)\n ctx.act = act\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n x = ctx.saved_tensors[0]\n act = ctx.act\n\n dx = act_no_dim_bwd_triton(x, do, act)\n\n return dx, None\n\n return ActNoDimTriton.apply(x, act)\n", - "description_1": "Use triton language to implement forward and backward activation functions without considering dimensions, supporting 'none', 'relu', 'sigmoid', and 'silu' activations.", - "description_2": "Use triton language to create a forward and backward kernel for activation functions handling 'none', 'relu', 'sigmoid', and 'silu' cases.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _softmax_no_cache_fwd_triton(\n X,\n O,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_n = tl.program_id(0)\n # compute offset\n offset_n = off_n * d\n # mask\n d_mask = tl.arange(0, BLOCK) < d\n\n # compute\n x_block_ptr = X + offset_n + tl.arange(0, BLOCK)\n o_block_ptr = O + offset_n + tl.arange(0, BLOCK)\n x = tl.load(x_block_ptr, mask=d_mask, other=-float(\"inf\"))\n # for stable\n x_minus_max = x - tl.max(x, axis=0)\n # softmax\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n o = numerator / denominator\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=d_mask)\n\n@triton.jit\ndef _softmax_no_cache_bwd_triton(\n X,\n DO,\n DX,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_n = tl.program_id(0)\n # compute offset\n offset_n = off_n * d\n # mask\n d_mask = tl.arange(0, BLOCK) < d\n\n # compute\n x_block_ptr = X + offset_n + tl.arange(0, BLOCK)\n do_block_ptr = DO + offset_n + tl.arange(0, BLOCK)\n dx_block_ptr = DX + offset_n + tl.arange(0, BLOCK)\n\n x = tl.load(x_block_ptr, mask=d_mask, other=-float(\"inf\"))\n # for stable\n x_minus_max = x - tl.max(x, axis=0)\n # softmax\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n o = numerator / denominator\n\n do = tl.load(do_block_ptr, mask=d_mask, other=0)\n # scalar\n c = tl.sum(o * do, axis=0)\n dx = o * do - c * o\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=d_mask)\n\n\ndef softmax_no_cache_fwd_triton(x, dim=-1):\n if dim != -1:\n x = x.transpose(dim, -1).contiguous()\n\n shape = x.shape\n n = torch.prod(torch.tensor(shape[:-1])).item()\n d = x.shape[-1]\n BLOCK = triton.next_power_of_2(d)\n o = torch.empty_like(x)\n\n grid = (n,)\n _softmax_no_cache_fwd_triton[grid](\n x,\n o,\n n,\n d,\n BLOCK,\n )\n\n if dim != -1:\n o = o.transpose(dim, -1).contiguous()\n\n return o\n\n\ndef softmax_no_cache_bwd_triton(o, do, dim=-1):\n if dim != -1:\n do = do.transpose(dim, -1).contiguous()\n o = o.transpose(dim, -1).contiguous()\n\n shape = o.shape\n n = torch.prod(torch.tensor(shape[:-1])).item()\n d = o.shape[-1]\n BLOCK = triton.next_power_of_2(d)\n dx = torch.empty_like(o)\n\n grid = (n,)\n _softmax_no_cache_bwd_triton[grid](o, do, dx, n, d, BLOCK)\n\n if dim != -1:\n dx = dx.transpose(dim, -1).contiguous()\n o = o.transpose(dim, -1).contiguous()\n\n return dx\n", - "description_1": "Use triton language to implement a forward and backward kernel for computing the softmax operation along a specified dimension. The forward kernel (_softmax_no_cache_fwd_triton) takes inputs X (input tensor), O (output tensor), n (number of blocks), d (last dimension size), and BLOCK (block size), and performs softmax computation using Triton load/store operations with masking. The backward kernel (_softmax_no_cache_bwd_triton) takes inputs X, DO (gradient of output), DX (gradient of input), n, d, and BLOCK, and computes the gradient of the input from the softmax operation. Both kernels are then called from their respective wrapper functions softmax_no_cache_fwd_triton and softmax_no_cache_bwd_triton.", - "description_2": "Use triton language to create softmax computation kernels (_softmax_no_cache_fwd_triton and _softmax_no_cache_bwd_triton) for forward and backward passes, respectively. Integrate these kernels in wrapper functions that manage input/output transformation and kernel invocation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _softmax_fwd_triton(\n X,\n O,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_n = tl.program_id(0)\n # compute offset\n offset_n = off_n * d\n # mask\n d_mask = tl.arange(0, BLOCK) < d\n\n # compute\n x_block_ptr = X + offset_n + tl.arange(0, BLOCK)\n o_block_ptr = O + offset_n + tl.arange(0, BLOCK)\n x = tl.load(x_block_ptr, mask=d_mask, other=-float(\"inf\"))\n # for stable\n x_minus_max = x - tl.max(x, axis=0)\n # softmax\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n o = numerator / denominator\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=d_mask)\n\n\n@triton.jit\ndef _softmax_bwd_triton(\n O,\n DO,\n DX,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_n = tl.program_id(0)\n # compute offset\n offset_n = off_n * d\n # mask\n d_mask = tl.arange(0, BLOCK) < d\n\n # compute\n o_block_ptr = O + offset_n + tl.arange(0, BLOCK)\n do_block_ptr = DO + offset_n + tl.arange(0, BLOCK)\n dx_block_ptr = DX + offset_n + tl.arange(0, BLOCK)\n o = tl.load(o_block_ptr, mask=d_mask, other=0)\n do = tl.load(do_block_ptr, mask=d_mask, other=0)\n # scalar\n c = tl.sum(o * do, axis=0)\n dx = o * do - c * o\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=d_mask)\n\n\ndef softmax_fwd_triton(x, dim=-1):\n if dim != -1:\n x = x.transpose(dim, -1).contiguous()\n\n shape = x.shape\n n = torch.prod(torch.tensor(shape[:-1])).item()\n d = x.shape[-1]\n BLOCK = triton.next_power_of_2(d)\n o = torch.empty_like(x)\n\n grid = (n,)\n _softmax_fwd_triton[grid](\n x,\n o,\n n,\n d,\n BLOCK,\n )\n\n if dim != -1:\n o = o.transpose(dim, -1).contiguous()\n\n return o\n\n\ndef softmax_bwd_triton(o, do, dim=-1):\n if dim != -1:\n do = do.transpose(dim, -1).contiguous()\n o = o.transpose(dim, -1).contiguous()\n\n shape = o.shape\n n = torch.prod(torch.tensor(shape[:-1])).item()\n d = o.shape[-1]\n BLOCK = triton.next_power_of_2(d)\n dx = torch.empty_like(o)\n\n grid = (n,)\n _softmax_bwd_triton[grid](o, do, dx, n, d, BLOCK)\n\n if dim != -1:\n dx = dx.transpose(dim, -1).contiguous()\n o = o.transpose(dim, -1).contiguous()\n\n return dx\n", - "description_1": "Use triton language to implement a forward and backward softmax operation. The forward kernel '_softmax_fwd_triton' takes 5 parameters: X (input tensor), O (output tensor), n (number of elements in the batch), d (dimension size), and BLOCK (block size for parallelization). It computes the softmax of the input tensor X and stores the result in O. The backward kernel '_softmax_bwd_triton' takes 6 parameters: O (output from forward pass), DO (gradient of the output), DX (gradient of the input), n, d, and BLOCK. It computes the gradient of the input tensor based on the output and its gradient.", - "description_2": "Use triton language to create a softmax function with forward and backward passes, optimizing for parallel execution using block sizes.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for additive block recurrence forward pass\n@triton.jit\ndef _additive_block_recurrence_fwd(\n Q, K, V, G, O, S_INITIAL_STATE, DENOM_INITIAL_STATE, M_INITIAL_STATE,\n S_FINAL_STATE, DENOM_FINAL_STATE, M_FINAL_STATE,\n b: tl.constexpr, h: tl.constexpr, n: tl.constexpr, d: tl.constexpr,\n e: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_E: tl.constexpr,\n NUM_BLOCK_D: tl.constexpr, NUM_BLOCK_E: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, OUTPUT_FINAL_STATE: tl.constexpr\n):\n off_bh = tl.program_id(2)\n off_bh % h\n off_bh // h\n off_d, off_e = tl.program_id(0), tl.program_id(1)\n # compute offset\n off_qkg = off_bh * n * d\n off_v = off_bh * n * e\n off_o = (off_d * b * h + off_bh) * n * e\n off_d = off_d * BLOCK_D\n off_e = off_e * BLOCK_E\n off_s = off_bh * d * e\n off_denom_m = off_bh * d\n # mask\n mask_denom_m = (off_d + tl.arange(0, BLOCK_D) < d)[:, None]\n\n # get block ptr\n q_trans_block_ptr = tl.make_block_ptr(\n base=Q + off_qkg,\n shape=(d, n),\n strides=(1, d),\n offsets=(off_d, 0),\n block_shape=(BLOCK_D, 1),\n order=(0, 1),\n )\n k_trans_block_ptr = tl.make_block_ptr(\n base=K + off_qkg,\n shape=(d, n),\n strides=(1, d),\n offsets=(off_d, 0),\n block_shape=(BLOCK_D, 1),\n order=(0, 1),\n )\n v_block_ptr = tl.make_block_ptr(\n base=V + off_v,\n shape=(n, e),\n strides=(e, 1),\n offsets=(0, off_e),\n block_shape=(1, BLOCK_E),\n order=(1, 0),\n )\n g_trans_block_ptr = tl.make_block_ptr(\n base=G + off_qkg,\n shape=(d, n),\n strides=(1, d),\n offsets=(off_d, 0),\n block_shape=(BLOCK_D, 1),\n order=(0, 1),\n )\n o_block_ptr = tl.make_block_ptr(\n base=O + off_o,\n shape=(n, e),\n strides=(e, 1),\n offsets=(0, off_e),\n block_shape=(1, BLOCK_E),\n order=(1, 0),\n )\n\n if USE_INITIAL_STATE:\n s_block_ptr = tl.make_block_ptr(\n base=S_INITIAL_STATE + off_s,\n shape=(d, e),\n strides=(e, 1),\n offsets=(off_d, off_e),\n block_shape=(BLOCK_D, BLOCK_E),\n order=(1, 0),\n )\n denom_block_ptr = (\n DENOM_INITIAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n m_block_ptr = (\n M_INITIAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n\n s = tl.load(s_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n denom = tl.load(denom_block_ptr, mask=mask_denom_m).to(tl.float32)\n m = tl.load(m_block_ptr, mask=mask_denom_m).to(tl.float32)\n else:\n s = tl.zeros([BLOCK_D, BLOCK_E], dtype=tl.float32)\n denom = tl.zeros([BLOCK_D, 1], dtype=tl.float32)\n m = tl.zeros([BLOCK_D, 1], dtype=tl.float32) + (-1e5)\n\n for i in range(n):\n q_trans = tl.load(q_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n v = tl.load(v_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n g_trans = tl.load(g_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n\n m_ = tl.maximum(m, g_trans)\n g_trans = g_trans - m_\n lambda_ = tl.exp(m - m_)\n g_exp_trans = tl.exp(g_trans)\n k_bar_trans = g_exp_trans * k_trans\n s = lambda_ * s + k_bar_trans.to(v.dtype) * v\n denom = lambda_ * denom + g_exp_trans\n o = (q_trans) * (s / denom)\n o = tl.sum(o, axis=0)[None, :]\n\n m = m_\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n q_trans_block_ptr = tl.advance(q_trans_block_ptr, (0, 1))\n k_trans_block_ptr = tl.advance(k_trans_block_ptr, (0, 1))\n v_block_ptr = tl.advance(v_block_ptr, (1, 0))\n g_trans_block_ptr = tl.advance(g_trans_block_ptr, (0, 1))\n o_block_ptr = tl.advance(o_block_ptr, (1, 0))\n\n if OUTPUT_FINAL_STATE:\n s_final_block_ptr = tl.make_block_ptr(\n base=S_FINAL_STATE + off_s,\n shape=(d, e),\n strides=(e, 1),\n offsets=(off_d, off_e),\n block_shape=(BLOCK_D, BLOCK_E),\n order=(1, 0),\n )\n denom_final_block_ptr = (\n DENOM_FINAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n m_final_block_ptr = (\n M_FINAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n\n tl.store(\n s_final_block_ptr,\n s.to(s_final_block_ptr.dtype.element_ty),\n boundary_check=(0, 1),\n )\n\n tl.store(\n denom_final_block_ptr,\n denom.to(denom_final_block_ptr.dtype.element_ty),\n mask=mask_denom_m,\n )\n\n tl.store(\n m_final_block_ptr,\n m.to(m_final_block_ptr.dtype.element_ty),\n mask=mask_denom_m,\n )\n\n\nclass AdditiveRecurrenceFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, initial_state=None, output_final_state=None):\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n head_dim_d = max_power_of_2_divisor(d)\n head_dim_e = max_power_of_2_divisor(e)\n\n BLOCK_D, BLOCK_E = min(d, head_dim_d), min(e, head_dim_e)\n NUM_BLOCK_D, NUM_BLOCK_E = triton.cdiv(d, BLOCK_D), triton.cdiv(e, BLOCK_E)\n o = torch.empty(\n (NUM_BLOCK_D, b, h, n, e), dtype=q.dtype, device=torch.cuda.current_device()\n )\n\n if initial_state is not None:\n s_initial_state, denom_initial_state, m_initial_state = initial_state\n else:\n pass\n\n if output_final_state:\n s_final_state = torch.empty(\n (b, h, d, e), dtype=torch.float32, device=torch.cuda.current_device()\n )\n denom_final_state = torch.empty(\n (b, h, d, 1), dtype=torch.float32, device=torch.cuda.current_device()\n )\n m_final_state = torch.empty(\n (b, h, d, 1), dtype=torch.float32, device=torch.cuda.current_device()\n )\n else:\n s_final_state = None\n denom_final_state = None\n m_final_state = None\n\n USE_INITIAL_STATE = initial_state is not None\n OUTPUT_FINAL_STATE = output_final_state\n\n grid = (b * h, d)\n\n _additive_block_recurrence_fwd[grid](\n q, k, v, g, o, s_initial_state, denom_initial_state, m_initial_state,\n s_final_state, denom_final_state, m_final_state,\n b, h, n, d, e, BLOCK_D, BLOCK_E, NUM_BLOCK_D, NUM_BLOCK_E,\n USE_INITIAL_STATE, OUTPUT_FINAL_STATE\n )\n\n if OUTPUT_FINAL_STATE:\n final_state = (s_final_state, denom_final_state, m_final_state)\n else:\n final_state = None\n\n o = o.sum(0)\n\n ctx.save_for_backward(q, k, v, g)\n\n return o, final_state\n\n\ndef additive_rule_block_recurrence_triton(q, k, v, g, initial_state=None, output_final_state=False):\n o, final_state = AdditiveRecurrenceFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement an additive block recurrence forward pass kernel. The kernel takes 18 parameters: Q, K, V, G, O, S_INITIAL_STATE, DENOM_INITIAL_STATE, M_INITIAL_STATE, S_FINAL_STATE, DENOM_FINAL_STATE, M_FINAL_STATE, and 7 constexpr parameters (b, h, n, d, e, BLOCK_D, BLOCK_E, NUM_BLOCK_D, NUM_BLOCK_E, USE_INITIAL_STATE, OUTPUT_FINAL_STATE). It computes the forward pass of an additive block recurrence operation, optionally using initial states and outputting final states.", - "description_2": "Use triton language to implement a forward pass for additive block recurrence with optional initial and final states, using 18 parameters including tensors and constexpr values.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _additive_recurrence_fwd(\n Q, K, V, G, O, S_INITIAL_STATE, DENOM_INITIAL_STATE, M_INITIAL_STATE,\n S_FINAL_STATE, DENOM_FINAL_STATE, M_FINAL_STATE, b: tl.constexpr,\n h: tl.constexpr, n: tl.constexpr, d: tl.constexpr, e: tl.constexpr,\n BLOCK_D: tl.constexpr, BLOCK_E: tl.constexpr, NUM_BLOCK_D: tl.constexpr,\n NUM_BLOCK_E: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n OUTPUT_FINAL_STATE: tl.constexpr,\n):\n off_bh = tl.program_id(2)\n off_d, off_e = tl.program_id(0), tl.program_id(1)\n off_qkg = off_bh * n * d\n off_v = off_bh * n * e\n off_o = (off_d * b * h + off_bh) * n * e\n off_d = off_d * BLOCK_D\n off_e = off_e * BLOCK_E\n off_s = off_bh * d * e\n off_denom_m = off_bh * d\n mask_denom_m = (off_d + tl.arange(0, BLOCK_D) < d)[:, None]\n\n q_trans_block_ptr = tl.make_block_ptr(\n base=Q + off_qkg, shape=(d, n), strides=(1, d),\n offsets=(off_d, 0), block_shape=(BLOCK_D, 1), order=(0, 1),\n )\n k_trans_block_ptr = tl.make_block_ptr(\n base=K + off_qkg, shape=(d, n), strides=(1, d),\n offsets=(off_d, 0), block_shape=(BLOCK_D, 1), order=(0, 1),\n )\n v_block_ptr = tl.make_block_ptr(\n base=V + off_v, shape=(n, e), strides=(e, 1),\n offsets=(0, off_e), block_shape=(1, BLOCK_E), order=(1, 0),\n )\n g_trans_block_ptr = tl.make_block_ptr(\n base=G + off_qkg, shape=(d, n), strides=(1, d),\n offsets=(off_d, 0), block_shape=(BLOCK_D, 1), order=(0, 1),\n )\n o_block_ptr = tl.make_block_ptr(\n base=O + off_o, shape=(n, e), strides=(e, 1),\n offsets=(0, off_e), block_shape=(1, BLOCK_E), order=(1, 0),\n )\n\n if USE_INITIAL_STATE:\n s_block_ptr = tl.make_block_ptr(\n base=S_INITIAL_STATE + off_s, shape=(d, e), strides=(e, 1),\n offsets=(off_d, off_e), block_shape=(BLOCK_D, BLOCK_E), order=(1, 0),\n )\n denom_block_ptr = (\n DENOM_INITIAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n m_block_ptr = (\n M_INITIAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n s = tl.load(s_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n denom = tl.load(denom_block_ptr, mask=mask_denom_m).to(tl.float32)\n m = tl.load(m_block_ptr, mask=mask_denom_m).to(tl.float32)\n else:\n s = tl.zeros([BLOCK_D, BLOCK_E], dtype=tl.float32)\n denom = tl.zeros([BLOCK_D, 1], dtype=tl.float32)\n m = tl.zeros([BLOCK_D, 1], dtype=tl.float32) + (-1e5)\n\n for i in range(n):\n q_trans = tl.load(q_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n v = tl.load(v_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n g_trans = tl.load(g_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n\n m_ = tl.maximum(m, g_trans)\n g_trans = g_trans - m_\n lambda_ = tl.exp(m - m_)\n g_exp_trans = tl.exp(g_trans)\n k_bar_trans = g_exp_trans * k_trans\n s = lambda_ * s + k_bar_trans.to(v.dtype) * v\n denom = lambda_ * denom + g_exp_trans\n o = (q_trans / denom) * s\n o = tl.sum(o, axis=0)[None, :]\n\n m = m_\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), boundary_check=(0, 1))\n q_trans_block_ptr = tl.advance(q_trans_block_ptr, (0, 1))\n k_trans_block_ptr = tl.advance(k_trans_block_ptr, (0, 1))\n v_block_ptr = tl.advance(v_block_ptr, (1, 0))\n g_trans_block_ptr = tl.advance(g_trans_block_ptr, (0, 1))\n o_block_ptr = tl.advance(o_block_ptr, (1, 0))\n\n if OUTPUT_FINAL_STATE:\n s_final_block_ptr = tl.make_block_ptr(\n base=S_FINAL_STATE + off_s, shape=(d, e), strides=(e, 1),\n offsets=(off_d, off_e), block_shape=(BLOCK_D, BLOCK_E), order=(1, 0),\n )\n denom_final_block_ptr = (\n DENOM_FINAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n m_final_block_ptr = (\n M_FINAL_STATE + off_denom_m + off_d + tl.arange(0, BLOCK_D)[:, None]\n )\n tl.store(\n s_final_block_ptr, s.to(s_final_block_ptr.dtype.element_ty),\n boundary_check=(0, 1),\n )\n tl.store(\n denom_final_block_ptr, denom.to(denom_final_block_ptr.dtype.element_ty),\n mask=mask_denom_m,\n )\n tl.store(\n m_final_block_ptr, m.to(m_final_block_ptr.dtype.element_ty),\n mask=mask_denom_m,\n )\n\n\nclass AdditiveRecurrenceFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, initial_state=None, output_final_state=None):\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n head_dim_d = max_power_of_2_divisor(d)\n head_dim_e = max_power_of_2_divisor(e)\n\n BLOCK_D, BLOCK_E = min(d, head_dim_d), min(e, head_dim_e)\n NUM_BLOCK_D, NUM_BLOCK_E = triton.cdiv(d, BLOCK_D), triton.cdiv(e, BLOCK_E)\n o = torch.empty(\n (NUM_BLOCK_D, b, h, n, e), dtype=q.dtype, device=torch.cuda.current_device()\n )\n\n if initial_state is not None:\n s_initial_state, denom_initial_state, m_initial_state = initial_state\n else:\n s_initial_state = None\n denom_initial_state = None\n m_initial_state = None\n\n if output_final_state:\n s_final_state = torch.empty(\n (b, h, d, e), dtype=torch.float32, device=torch.cuda.current_device()\n )\n denom_final_state = torch.empty(\n (b, h, d, 1), dtype=torch.float32, device=torch.cuda.current_device()\n )\n m_final_state = torch.empty(\n (b, h, d, 1), dtype=torch.float32, device=torch.cuda.current_device()\n )\n else:\n s_final_state = None\n denom_final_state = None\n m_final_state = None\n\n USE_INITIAL_STATE = initial_state is not None\n OUTPUT_FINAL_STATE = output_final_state\n\n grid = (NUM_BLOCK_D, NUM_BLOCK_E, b * h)\n\n _additive_recurrence_fwd[grid](\n q, k, v, g, o,\n s_initial_state, denom_initial_state, m_initial_state,\n s_final_state, denom_final_state, m_final_state,\n b, h, n, d, e,\n BLOCK_D, BLOCK_E, NUM_BLOCK_D, NUM_BLOCK_E,\n USE_INITIAL_STATE, OUTPUT_FINAL_STATE,\n )\n\n if OUTPUT_FINAL_STATE:\n final_state = (s_final_state, denom_final_state, m_final_state)\n else:\n final_state = None\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, g)\n return o, final_state\n\n\ndef additive_rule_recurrence_triton(q, k, v, g, initial_state=None, output_final_state=False):\n o, final_state = AdditiveRecurrenceFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement an additive recurrence forward pass kernel. This kernel takes inputs Q, K, V, G, along with several state tensors and configuration constants, and computes an output tensor O. If specified, it uses and updates initial and final state tensors S, DENOM, and M. The kernel processes data in blocks determined by BLOCK_D and BLOCK_E and iterates over the feature dimension n.", - "description_2": "Use triton language to implement an autograd function in PyTorch that utilizes the additive recurrence kernel. This function orchestrates the input and output state handling and manages device memory allocation for intermediate results.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _additive_recurrence_fwd(\n Q,\n K,\n V,\n G,\n O,\n S0,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n NUM_BLOCK_D: tl.constexpr,\n NUM_BLOCK_E: tl.constexpr,\n):\n off_bh = tl.program_id(2)\n off_bh % h\n off_bh // h\n off_d, off_e = tl.program_id(0), tl.program_id(1)\n # compute offset\n off_qkg = off_bh * n * d\n off_v = off_bh * n * e\n off_o = (off_d * b * h + off_bh) * n * e\n off_d = off_d * BLOCK_D\n off_e = off_e * BLOCK_E\n\n # get block ptr\n q_trans_block_ptr = tl.make_block_ptr(\n base=Q + off_qkg,\n shape=(d, n),\n strides=(1, d),\n offsets=(\n off_d,\n 0,\n ),\n block_shape=(\n BLOCK_D,\n 1,\n ),\n order=(0, 1),\n )\n k_trans_block_ptr = tl.make_block_ptr(\n base=K + off_qkg,\n shape=(d, n),\n strides=(1, d),\n offsets=(off_d, 0),\n block_shape=(BLOCK_D, 1),\n order=(0, 1),\n )\n v_block_ptr = tl.make_block_ptr(\n base=V + off_v,\n shape=(n, e),\n strides=(e, 1),\n offsets=(0, off_e),\n block_shape=(1, BLOCK_E),\n order=(1, 0),\n )\n g_trans_block_ptr = tl.make_block_ptr(\n base=G + off_qkg,\n shape=(d, n),\n strides=(1, d),\n offsets=(\n off_d,\n 0,\n ),\n block_shape=(BLOCK_D, 1),\n order=(0, 1),\n )\n o_block_ptr = tl.make_block_ptr(\n base=O + off_o,\n shape=(n, e),\n strides=(e, 1),\n offsets=(0, off_e),\n block_shape=(1, BLOCK_E),\n order=(1, 0),\n )\n\n s = tl.zeros([BLOCK_D, BLOCK_E], dtype=tl.float32)\n denom = tl.zeros([BLOCK_D, 1], dtype=tl.float32)\n\n for i in range(n):\n # boundary check on feature dim\n q_trans = tl.load(q_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n v = tl.load(v_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n g_trans = tl.load(g_trans_block_ptr, boundary_check=(0, 1)).to(tl.float32)\n g_exp_trans = tl.exp(g_trans)\n\n k_bar_trans = g_exp_trans * k_trans\n # d 1, 1 e -> d e\n s += k_bar_trans.to(v.dtype) * v\n denom += g_exp_trans\n # d 1, d e -> d e\n o = (q_trans / denom) * (s)\n # d e -> 1 e\n o = tl.sum(o, axis=0)[None, :]\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n q_trans_block_ptr = tl.advance(q_trans_block_ptr, (0, 1))\n k_trans_block_ptr = tl.advance(k_trans_block_ptr, (0, 1))\n v_block_ptr = tl.advance(v_block_ptr, (1, 0))\n g_trans_block_ptr = tl.advance(g_trans_block_ptr, (0, 1))\n o_block_ptr = tl.advance(o_block_ptr, (1, 0))\n\n\nclass AdditiveRecurrenceFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, s=None):\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n # split over head dim to avoid shared memory not enough\n head_dim = max_power_of_2_divisor(d, e)\n BLOCK_D, BLOCK_E = min(d, head_dim), min(e, head_dim)\n NUM_BLOCK_D, NUM_BLOCK_E = triton.cdiv(d, BLOCK_D), triton.cdiv(e, BLOCK_E)\n o = torch.empty(\n (NUM_BLOCK_D, b, h, n, e), dtype=q.dtype, device=torch.cuda.current_device()\n )\n\n grid = (\n NUM_BLOCK_D,\n NUM_BLOCK_E,\n b * h,\n )\n _additive_recurrence_fwd[grid](\n q,\n k,\n v,\n g,\n o,\n s,\n b,\n h,\n n,\n d,\n e,\n BLOCK_D,\n BLOCK_E,\n NUM_BLOCK_D,\n NUM_BLOCK_E,\n )\n\n o = o.sum(0)\n\n ctx.save_for_backward(q, k, v, g, s)\n\n return o\n\n\ndef additive_rule_recurrence_triton(q, k, v, g, s=None, output_final_state=False):\n o = AdditiveRecurrenceFunction.apply(q, k, v, g, s)\n return o\n", - "description_1": "Use triton language to implement _additive_recurrence_fwd, a kernel for computing a forward pass of an additive recurrence relation. The kernel takes 14 tensor inputs/parameters: Q (query), K (key), V (value), G (additional gradient info), O (output), S0 (initial state), and 8 constexpr integers (b, h, n, d, e, BLOCK_D, BLOCK_E, NUM_BLOCK_D, NUM_BLOCK_E) defining the dimensions and block sizes of the computation. The kernel calculates offsets and advances through blocks of data, computing contributions to an output tensor by iterating over the feature dimension.", - "description_2": "Use triton language to define a function _additive_recurrence_fwd which processes matrices Q, K, V, G to produce an output O, handling blocks of data with specific dimensions and iterating over a feature dimension to update the result based on specific transformations and accumulations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel function performing additive recurrence forward operation.\n@triton.jit\ndef _additive_recurrence_fwd(\n Q, K, V, O, S,\n b: tl.constexpr, h: tl.constexpr, n: tl.constexpr,\n d: tl.constexpr, e: tl.constexpr,\n BLOCK_D: tl.constexpr, BLOCK_E: tl.constexpr,\n NUM_BLOCK_D: tl.constexpr, NUM_BLOCK_E: tl.constexpr,\n):\n off_bh = tl.program_id(2)\n off_bh % h\n off_bh // h\n off_d, off_e = tl.program_id(0), tl.program_id(1)\n # compute offset\n off_qk = off_bh * n * d\n off_v = off_bh * n * e\n off_o = (off_d * b * h + off_bh) * n * e\n off_d = off_d * BLOCK_D\n off_e = off_e * BLOCK_E\n\n # get block ptr\n q_block_ptr = Q + off_qk + off_d + tl.arange(0, BLOCK_D)\n k_block_ptr = K + off_qk + off_d + tl.arange(0, BLOCK_D)\n v_block_ptr = V + off_v + off_e + tl.arange(0, BLOCK_E)\n o_block_ptr = O + off_o + off_e + tl.arange(0, BLOCK_E)\n\n mask_d = (off_d + tl.arange(0, BLOCK_D)) < d\n mask_e = (off_e + tl.arange(0, BLOCK_E)) < e\n\n s = tl.zeros([BLOCK_D, BLOCK_E], dtype=tl.float32)\n\n for i in range(n):\n # boundary check on feature dim\n q = tl.load(q_block_ptr, mask=mask_d, other=0).to(tl.float32)\n k = tl.load(k_block_ptr, mask=mask_d, other=0).to(tl.float32)\n v = tl.load(v_block_ptr, mask=mask_e, other=0).to(tl.float32)\n\n # d 1, 1 e -> d e\n tl.static_print(\"aaa\", k[None, :], v[:, None])\n s += k[:, None] * v[None, :]\n # d 1, d e -> d e\n tl.static_print(\"aaa\", q[:, None], s)\n # d e -> e\n o = q[:, None] * s\n o = tl.sum(o, axis=0)\n tl.static_print(\"bbb\", o, o_block_ptr)\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=mask_e)\n\n q_block_ptr += BLOCK_D\n k_block_ptr += BLOCK_D\n v_block_ptr += BLOCK_E\n o_block_ptr += BLOCK_E\n", - "description_1": "Use triton language to implement a forward pass kernel for additive recurrence. The kernel takes 13 parameters: 5 pointers (Q, K, V, O, S) and 8 constants (b, h, n, d, e, BLOCK_D, BLOCK_E, NUM_BLOCK_D, NUM_BLOCK_E). The pointers represent input and output matrices, while the constants define dimensions and block sizes. The kernel performs operations on blocks of input data to compute an output using matrix multiplication and summation, with boundary checking.", - "description_2": "Use triton language to create a kernel for matrix operation using additive recurrence on input data blocks.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _additive_recurrence_fwd(\n Q,\n K,\n V,\n O,\n S,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n NUM_BLOCK_D: tl.constexpr,\n NUM_BLOCK_E: tl.constexpr,\n):\n off_bh = tl.program_id(2)\n off_bh % h\n off_bh // h\n off_d, off_e = tl.program_id(0), tl.program_id(1)\n # compute offset\n off_qk = off_bh * n * d\n off_v = off_bh * n * e\n off_o = (off_d * b * h + off_bh) * n * e\n off_d = off_d * BLOCK_D\n off_e = off_e * BLOCK_E\n\n # get block ptr\n q_trans_block_ptr = tl.make_block_ptr(\n base=Q + off_qk,\n shape=(d, n),\n strides=(1, d),\n offsets=(\n off_d,\n 0,\n ),\n block_shape=(\n BLOCK_D,\n 1,\n ),\n order=(0, 1),\n )\n k_trans_block_ptr = tl.make_block_ptr(\n base=K + off_qk,\n shape=(d, n),\n strides=(1, d),\n offsets=(off_d, 0),\n block_shape=(BLOCK_D, 1),\n order=(0, 1),\n )\n v_block_ptr = tl.make_block_ptr(\n base=V + off_v,\n shape=(n, e),\n strides=(e, 1),\n offsets=(0, off_e),\n block_shape=(1, BLOCK_E),\n order=(1, 0),\n )\n o_block_ptr = tl.make_block_ptr(\n base=O + off_o,\n shape=(n, e),\n strides=(e, 1),\n offsets=(0, off_e),\n block_shape=(1, BLOCK_E),\n order=(1, 0),\n )\n\n s = tl.zeros([BLOCK_D, BLOCK_E], dtype=tl.float32)\n denom = tl.zeros([BLOCK_D, 1], dtype=tl.float32)\n\n for i in range(n):\n # boundary check on feature dim\n q_trans = tl.load(q_trans_block_ptr, boundary_check=(0)).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr, boundary_check=(0)).to(tl.float32)\n v = tl.load(v_block_ptr, boundary_check=(1)).to(tl.float32)\n\n # d 1, 1 e -> d e\n s += k_trans.to(v.dtype) * v\n # d 1, d e -> d e\n o = q_trans * s\n # d e -> 1 e\n o = tl.sum(o, axis=0)[None, :]\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), boundary_check=(1))\n\n q_trans_block_ptr = tl.advance(q_trans_block_ptr, (0, 1))\n k_trans_block_ptr = tl.advance(k_trans_block_ptr, (0, 1))\n v_block_ptr = tl.advance(v_block_ptr, (1, 0))\n o_block_ptr = tl.advance(o_block_ptr, (1, 0))\n\n\nclass BaseRecurrenceFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, s=None):\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n # split over head dim to avoid shared memory not enough\n BLOCK_D, BLOCK_E = min(d, HEAD_DIM), min(e, HEAD_DIM)\n NUM_BLOCK_D, NUM_BLOCK_E = triton.cdiv(d, BLOCK_D), triton.cdiv(e, BLOCK_E)\n\n o = torch.empty(\n (NUM_BLOCK_D, b, h, n, e), dtype=q.dtype, device=torch.cuda.current_device()\n )\n\n grid = (\n NUM_BLOCK_D,\n NUM_BLOCK_E,\n b * h,\n )\n\n _additive_recurrence_fwd[grid](\n q,\n k,\n v,\n o,\n s,\n b,\n h,\n n,\n d,\n e,\n BLOCK_D,\n BLOCK_E,\n NUM_BLOCK_D,\n NUM_BLOCK_E,\n )\n\n o = o.sum(0)\n\n ctx.save_for_backward(q, k, v, s)\n\n return o\n\n\ndef base_rule_recurrence_triton(\n q,\n k,\n v,\n s=None,\n):\n o = BaseRecurrenceFunction.apply(q, k, v, s)\n return o\n", - "description_1": "Use triton language to implement a kernel `_additive_recurrence_fwd` which computes an additive recurrence given input tensors Q, K, V, O, and S, along with constants b, h, n, d, e, BLOCK_D, BLOCK_E, NUM_BLOCK_D, and NUM_BLOCK_E. The function `base_rule_recurrence_triton` applies this kernel using PyTorch Autograd framework.", - "description_2": "Use triton to perform additive recurrence with given tensors and parameters within a PyTorch Autograd Function.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _flao_non_causal_kv_triton(\n K,\n V,\n KV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n m: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK_D: tl.constexpr,\n NUM_BLOCK_D: tl.constexpr,\n BLOCK_NM: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_block_d = tl.program_id(1)\n off_block_e = tl.program_id(2)\n # compute offset\n offset_d = off_block_d * BLOCK_D\n offset_e = off_block_e * BLOCK_E\n off_bh * n * d + offset_d\n offset_k = off_bh * m * d + offset_d\n offset_v = off_bh * m * e + offset_e\n off_bh * n * e + offset_e\n off_block_d * b * h * n * e + off_bh * n * e + offset_e\n offset_kv = off_bh * d * e + offset_d * e + offset_e\n # mask\n d_mask = (offset_d + tl.arange(0, BLOCK_D)) < d\n e_mask = (offset_e + tl.arange(0, BLOCK_E)) < e\n\n # compute kv\n k_trans_block_ptr = (\n K\n + offset_k\n + tl.arange(0, BLOCK_NM)[None, :] * d\n + tl.arange(0, BLOCK_D)[:, None]\n )\n v_block_ptr = (\n V\n + offset_v\n + tl.arange(0, BLOCK_NM)[:, None] * e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n kv_block_ptr = (\n KV\n + offset_kv\n + tl.arange(0, BLOCK_D)[:, None] * e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n array = tl.arange(0, BLOCK_NM)\n NUM_BLOCK_M = tl.cdiv(m, BLOCK_NM)\n\n kv = tl.zeros([BLOCK_D, BLOCK_E], dtype=tl.float32)\n for i in range(0, NUM_BLOCK_M):\n mask = array < m\n k_trans = tl.load(\n k_trans_block_ptr, mask=mask[None, :] & d_mask[:, None], other=0\n ).to(tl.float32)\n v = tl.load(v_block_ptr, mask=mask[:, None] & e_mask[None, :], other=0).to(\n tl.float32\n )\n kv += tl.dot(k_trans, v)\n\n k_trans_block_ptr += BLOCK_NM * d\n v_block_ptr += BLOCK_NM * e\n array += BLOCK_NM\n\n tl.store(\n kv_block_ptr,\n kv.to(kv_block_ptr.dtype.element_ty),\n mask=d_mask[:, None] & e_mask[None, :],\n )\n\n\n@triton.jit\ndef _flao_non_causal_fwd_triton(\n Q,\n G,\n KV,\n O,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n m: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK_D: tl.constexpr,\n NUM_BLOCK_D: tl.constexpr,\n BLOCK_NM: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n NUM_BLOCK_N = tl.cdiv(n, BLOCK_NM)\n off_bhn = tl.program_id(0)\n off_bh = off_bhn // NUM_BLOCK_N\n off_n = off_bhn % NUM_BLOCK_N\n off_block_d = tl.program_id(1)\n off_block_e = tl.program_id(2)\n # compute offset\n offset_d = off_block_d * BLOCK_D\n offset_e = off_block_e * BLOCK_E\n offset_n = off_n * BLOCK_NM\n offset_q = off_bh * n * d + offset_n * d + offset_d\n offset_g = off_bh * n * e + offset_n * e + offset_e\n offset_o = off_block_d * b * h * n * e + off_bh * n * e + offset_n * e + offset_e\n offset_kv = off_bh * d * e + offset_d * e + offset_e\n # mask\n d_mask = (offset_d + tl.arange(0, BLOCK_D)) < d\n e_mask = (offset_e + tl.arange(0, BLOCK_E)) < e\n\n array = tl.arange(0, BLOCK_NM)\n\n # compute qkv\n q_block_ptr = (\n Q\n + offset_q\n + tl.arange(0, BLOCK_NM)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n g_block_ptr = (\n G\n + offset_g\n + tl.arange(0, BLOCK_NM)[:, None] * e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n kv_block_ptr = (\n KV\n + offset_kv\n + tl.arange(0, BLOCK_D)[:, None] * e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n o_block_ptr = (\n O\n + offset_o\n + tl.arange(0, BLOCK_NM)[:, None] * e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n\n array = offset_n + tl.arange(0, BLOCK_NM)\n NUM_BLOCK_N = tl.cdiv(n, BLOCK_NM)\n\n mask = (array < n)[:, None]\n q = tl.load(q_block_ptr, mask=mask & d_mask[None, :], other=0).to(tl.float32)\n kv = tl.load(kv_block_ptr, mask=d_mask[:, None] & e_mask[None, :], other=0).to(\n tl.float32\n )\n g = tl.load(g_block_ptr, mask=mask & e_mask[None, :], other=0).to(tl.float32)\n\n qkv = tl.dot(q, kv)\n o = g * qkv\n\n tl.store(\n o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=mask & e_mask[None, :]\n )\n\n\nclass FusedLinearAttentionOutputGateTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g):\n b, h, n, d = q.shape\n m = k.shape[-2]\n e = v.shape[-1]\n\n block_d = min(128, triton.next_power_of_2(d))\n num_block_d = triton.cdiv(d, block_d)\n kv = torch.empty(b, h, d, e, dtype=torch.float32, device=q.device)\n\n def grid(meta):\n return (b * h, num_block_d, triton.cdiv(e, meta[\"BLOCK_E\"]))\n\n # compute kv first\n _flao_non_causal_kv_triton[grid](\n k,\n v,\n kv,\n b,\n h,\n n,\n m,\n d,\n e,\n block_d,\n num_block_d,\n )\n\n o = torch.empty(num_block_d, b, h, n, e, dtype=q.dtype, device=q.device)\n\n def grid(meta):\n return (\n b * h * triton.cdiv(n, meta[\"BLOCK_NM\"]),\n num_block_d,\n triton.cdiv(e, meta[\"BLOCK_E\"]),\n )\n\n _flao_non_causal_fwd_triton[grid](\n q,\n g,\n kv,\n o,\n b,\n h,\n n,\n m,\n d,\n e,\n block_d,\n num_block_d,\n )\n\n o = o.sum(dim=0)\n\n ctx.save_for_backward(q, k, v, g, kv)\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, g, kv = ctx.saved_tensors\n\n qkv = torch.matmul(q, kv.to(q.dtype))\n\n dg = do * qkv\n dqkv = do * g\n dq = torch.einsum(\"... n e, ... d e -> ... n d\", dqkv, kv.to(q.dtype))\n dkv = torch.einsum(\"... n d, ... n e -> ... d e\", q, dqkv)\n dk = torch.einsum(\"... n e, ... d e -> ... n d\", v, dkv)\n dv = torch.einsum(\"... n d, ... d e -> ... n e\", k, dkv)\n\n return dq, dk, dv, dg\n\n\ndef flao_non_causal_triton(q, k, v, g):\n return FusedLinearAttentionOutputGateTriton.apply(q, k, v, g)\n", - "description_1": "Use triton language to implement two kernels: _flao_non_causal_kv_triton and _flao_non_causal_fwd_triton. The first kernel computes the product of K and V matrices and stores the result in KV. It takes 13 parameters: K, V, KV, b, h, n, m, d, e, BLOCK_D, NUM_BLOCK_D, BLOCK_NM, BLOCK_E. The second kernel computes the product of Q and KV matrices, applies a gate G, and stores the result in O. It takes 13 parameters: Q, G, KV, O, b, h, n, m, d, e, BLOCK_D, NUM_BLOCK_D, BLOCK_NM, BLOCK_E.", - "description_2": "Use triton language to implement a fused linear attention mechanism with output gating. The process involves two main steps: first, compute the key-value product using _flao_non_causal_kv_triton; second, compute the query-key-value product, apply gating, and store the result using _flao_non_causal_fwd_triton.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _grpe_recurrence_fwd(\n Q,\n K,\n V,\n M,\n O,\n S_INITIAL_STATE,\n S_FINAL_STATE,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n OUTPUT_FINAL_STATE: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n \"\"\"\n q: (1, d)\n k: (1, d)\n v: (1, BLOCK_E)\n m: (d, d)\n s: (d, BLOCK_E)\n \"\"\"\n off_bh = tl.program_id(0)\n off_e = tl.program_id(1)\n # compute offset\n off_qk = off_bh * n * d\n off_ov = off_bh * n * e\n off_m = off_bh * n * d * d\n off_e = off_e * BLOCK_E\n off_s = off_bh * d * e\n\n # compute block ptr\n q_trans_block_ptr = Q + off_qk + tl.arange(0, d)[:, None]\n k_trans_block_ptr = K + off_qk + tl.arange(0, d)[:, None]\n v_block_ptr = V + off_ov + off_e + tl.arange(0, BLOCK_E)[None, :]\n m_block_ptr = M + off_m + tl.arange(0, d)[:, None] * d + tl.arange(0, d)[None, :]\n o_block_ptr = O + off_ov + off_e + tl.arange(0, BLOCK_E)[None, :]\n\n mask = (off_e + tl.arange(0, BLOCK_E)[None, :]) < e\n\n if USE_INITIAL_STATE:\n s_block_ptr = (\n S_INITIAL_STATE\n + off_s\n + tl.arange(0, d)[:, None] * e\n + off_e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n\n s = tl.load(s_block_ptr, mask=mask, other=0).to(tl.float32)\n else:\n s = tl.zeros([d, BLOCK_E], dtype=tl.float32)\n\n for i in range(n):\n q_trans = tl.load(q_trans_block_ptr).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr).to(tl.float32)\n v = tl.load(v_block_ptr, mask=mask, other=0).to(tl.float32)\n m = tl.load(m_block_ptr).to(tl.float32)\n\n s = tl.dot(m, s) + k_trans.to(v.dtype) * v\n o = q_trans * s\n # d e -> 1 e\n o = tl.sum(o, axis=0)[None, :]\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=mask)\n\n q_trans_block_ptr += d\n k_trans_block_ptr += d\n v_block_ptr += e\n m_block_ptr += d * d\n o_block_ptr += e\n\n if OUTPUT_FINAL_STATE:\n s_final_block_ptr = (\n S_FINAL_STATE\n + off_s\n + tl.arange(0, d)[:, None] * e\n + off_e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n\n tl.store(\n s_final_block_ptr,\n s.to(s_final_block_ptr.dtype.element_ty),\n mask=mask,\n )\n\n\n@triton.jit\ndef _grpe_recurrence_bwd(\n Q,\n K,\n V,\n M,\n DO,\n DQ,\n DK,\n DV,\n DM,\n DS,\n S_INITIAL_STATE,\n S_FINAL_STATE,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n OUTPUT_FINAL_STATE: tl.constexpr,\n):\n \"\"\"\n q: (d, 1)\n k: (d, 1)\n v: (1, e)\n do: (1, e)\n dq: (d, 1)\n dk: (d, 1)\n dv: (1, e)\n m: (d, d)\n s: (d, e)\n \"\"\"\n off_bh = tl.program_id(0)\n # compute offset\n off_qk = off_bh * n * d\n off_ov = off_bh * n * e\n off_m = off_bh * n * d * d\n off_s = off_bh * d * e\n\n # compute block ptr\n # fwd\n q_trans_block_ptr = Q + off_qk + tl.arange(0, d)[:, None]\n k_trans_block_ptr = K + off_qk + tl.arange(0, d)[:, None]\n v_block_ptr = V + off_ov + tl.arange(0, e)[None, :]\n m_block_ptr = M + off_m + tl.arange(0, d)[:, None] * d + tl.arange(0, d)[None, :]\n # o_block_ptr = O + off_ov + tl.arange(0, e)[None, :]\n\n # bwd\n do_block_ptr = DO + off_ov + tl.arange(0, e)[None, :]\n dq_trans_block_ptr = DQ + off_qk + tl.arange(0, d)[:, None]\n dk_trans_block_ptr = DK + off_qk + tl.arange(0, d)[:, None]\n dv_block_ptr = DV + off_ov + tl.arange(0, e)[None, :]\n\n if USE_INITIAL_STATE:\n s_block_ptr = (\n S_INITIAL_STATE\n + off_s\n + tl.arange(0, d)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n s = tl.load(s_block_ptr).to(tl.float32)\n else:\n s = tl.zeros([d, e], dtype=tl.float32)\n\n for i in range(n):\n # q_trans = tl.load(q_trans_block_ptr).to(tl.float32)\n do = tl.load(do_block_ptr).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr).to(tl.float32)\n v = tl.load(v_block_ptr).to(tl.float32)\n m = tl.load(m_block_ptr).to(tl.float32)\n\n s = tl.dot(m, s) + k_trans.to(v.dtype) * v\n # o = q_trans * s\n dq_trans = s * do\n # d e -> d 1\n dq_trans = tl.sum(dq_trans, axis=1)[:, None]\n\n tl.store(dq_trans_block_ptr, dq_trans.to(dq_trans_block_ptr.dtype.element_ty))\n\n # q_trans_block_ptr += d\n k_trans_block_ptr += d\n v_block_ptr += e\n m_block_ptr += d * d\n do_block_ptr += e\n dq_trans_block_ptr += d\n\n ds = tl.zeros([d, e], dtype=tl.float32)\n do_block_ptr = DO + off_ov + n * e + tl.arange(0, e)[None, :]\n q_trans_block_ptr = Q + off_qk + n * d + tl.arange(0, d)[:, None]\n k_trans_block_ptr = K + off_qk + n * d + tl.arange(0, d)[:, None]\n v_block_ptr = V + off_ov + n * e + tl.arange(0, e)[None, :]\n m_trans_block_ptr = (\n M + off_m + n * d * d + tl.arange(0, d)[:, None] * d + tl.arange(0, d)[None, :]\n )\n\n dk_trans_block_ptr = DK + off_qk + n * d + tl.arange(0, d)[:, None]\n dv_block_ptr = DV + off_ov + n * e + tl.arange(0, e)[None, :]\n\n for i in range(n - 1, -1, -1):\n do_block_ptr -= e\n dq_trans_block_ptr -= d\n k_trans_block_ptr -= d\n v_block_ptr -= e\n m_block_ptr -= d * d\n\n dk_trans_block_ptr -= d\n dv_block_ptr -= e\n\n q_trans = tl.load(q_trans_block_ptr).to(tl.float32)\n do = tl.load(do_block_ptr).to(tl.float32)\n k_trans = tl.load(k_trans_block_ptr).to(tl.float32)\n v = tl.load(v_block_ptr).to(tl.float32)\n tl.load(m_trans_block_ptr).to(tl.float32)\n\n ds = tl.dot(m, ds) + q_trans.to(v.dtype) * do\n # o = q_trans * s\n dk_trans = ds * v\n # d e -> d 1\n dk_trans = tl.sum(dk_trans, axis=1)[:, None]\n\n dv = ds * k_trans\n dv = tl.sum(dv, axis=0)[None, :]\n\n tl.store(dk_trans_block_ptr, dk_trans.to(dk_trans_block_ptr.dtype.element_ty))\n tl.store(dv_block_ptr, dv.to(dv_block_ptr.dtype.element_ty))\n\n\nclass GrpeRecurrenceFunction(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx, q, k, v, alpha, beta, gamma, initial_state=None, output_final_state=None\n ):\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n # m = exp(alpha + beta * gamma * gamma ^ T)\n identity = torch.eye(d, device=torch.cuda.current_device())\n order_one_term = alpha.unsqueeze(-1) * identity\n order_two_term = (\n beta.unsqueeze(-1).unsqueeze(-1) * gamma.unsqueeze(-1) * gamma.unsqueeze(-2)\n )\n log_m = order_one_term + order_two_term\n m = torch.matrix_exp(log_m)\n\n o = torch.empty((b, h, n, e), dtype=q.dtype, device=torch.cuda.current_device())\n\n if initial_state is not None:\n s_initial_state = initial_state\n else:\n s_initial_state = None\n\n if output_final_state:\n s_final_state = torch.empty(\n (b, h, d, e), dtype=torch.float32, device=torch.cuda.current_device()\n )\n else:\n s_final_state = None\n\n USE_INITIAL_STATE = initial_state is not None\n OUTPUT_FINAL_STATE = output_final_state\n\n def grid(meta):\n return (b * h, triton.cdiv(e, meta[\"BLOCK_E\"]))\n\n _grpe_recurrence_fwd[grid](\n q,\n k,\n v,\n m,\n o,\n s_initial_state,\n s_final_state,\n b,\n h,\n n,\n d,\n e,\n USE_INITIAL_STATE,\n OUTPUT_FINAL_STATE,\n )\n\n if OUTPUT_FINAL_STATE:\n final_state = s_final_state\n else:\n final_state = None\n\n ctx.save_for_backward(q, k, v, alpha, beta, gamma, initial_state)\n ctx.output_final_state = output_final_state\n\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, ds):\n q, k, v, alpha, beta, gamma, initial_state = ctx.saved_tensors\n output_final_state = ctx.output_final_state\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n # m = exp(alpha + beta * gamma * gamma ^ T)\n identity = torch.eye(d, device=torch.cuda.current_device())\n order_one_term = alpha.unsqueeze(-1) * identity\n order_two_term = (\n beta.unsqueeze(-1).unsqueeze(-1) * gamma.unsqueeze(-1) * gamma.unsqueeze(-2)\n )\n log_m = order_one_term + order_two_term\n m = torch.matrix_exp(log_m)\n\n dq = torch.empty_like(q, dtype=q.dtype, device=torch.cuda.current_device())\n dk = torch.empty_like(k, dtype=q.dtype, device=torch.cuda.current_device())\n dv = torch.empty_like(v, dtype=q.dtype, device=torch.cuda.current_device())\n ds_ = (\n torch.empty((b, h, d, e), dtype=q.dtype, device=torch.cuda.current_device())\n if initial_state is not None\n else None\n )\n dalpha = torch.empty_like(\n alpha, dtype=q.dtype, device=torch.cuda.current_device()\n )\n dbeta = torch.empty_like(\n beta, dtype=q.dtype, device=torch.cuda.current_device()\n )\n dgamma = torch.empty_like(\n gamma, dtype=q.dtype, device=torch.cuda.current_device()\n )\n dm = torch.empty_like(m, dtype=q.dtype, device=torch.cuda.current_device())\n\n if initial_state is not None:\n s_initial_state = initial_state\n else:\n s_initial_state = None\n\n if output_final_state:\n s_final_state = torch.empty(\n (b, h, d, e), dtype=torch.float32, device=torch.cuda.current_device()\n )\n else:\n s_final_state = None\n\n USE_INITIAL_STATE = initial_state is not None\n OUTPUT_FINAL_STATE = output_final_state\n\n grid = (b * h,)\n\n _grpe_recurrence_bwd[grid](\n q,\n k,\n v,\n m,\n do,\n dq,\n dk,\n dv,\n dm,\n ds_,\n s_initial_state,\n s_final_state,\n b,\n h,\n n,\n d,\n e,\n USE_INITIAL_STATE,\n OUTPUT_FINAL_STATE,\n )\n\n return dq, dk, dv, dalpha, dbeta, dgamma, ds_, None\n\n\ndef grpe_recurrence_triton(\n q, k, v, alpha, beta, gamma, initial_state=None, output_final_state=False\n):\n o, final_state = GrpeRecurrenceFunction.apply(\n q, k, v, alpha, beta, gamma, initial_state, output_final_state\n )\n return o, final_state\n", - "description_1": "Use triton language to implement a forward and backward pass for a GRPE recurrence operation. The forward kernel '_grpe_recurrence_fwd' takes 15 parameters: Q, K, V, M, O, S_INITIAL_STATE, S_FINAL_STATE, and 8 constexpr parameters (b, h, n, d, e, USE_INITIAL_STATE, OUTPUT_FINAL_STATE, BLOCK_E). It computes the output O and optionally updates the final state S_FINAL_STATE. The backward kernel '_grpe_recurrence_bwd' takes 15 parameters: Q, K, V, M, DO, DQ, DK, DV, DM, DS, S_INITIAL_STATE, S_FINAL_STATE, and 6 constexpr parameters (b, h, n, d, e, USE_INITIAL_STATE, OUTPUT_FINAL_STATE). It computes the gradients DQ, DK, DV, DM, and optionally updates the final state S_FINAL_STATE.", - "description_2": "Use triton language to create a GRPE recurrence function with forward and backward passes, handling initial and final states, and computing necessary gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n generate_configs({\"BLOCK_D\": [16, 32, 64, 128], \"num_warps\": [2, 4, 8]}),\n key=[\"n\", \"d\"],\n)\n@triton.jit\ndef _logcumsumexp_block_parallel_compute(\n X,\n O,\n M,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_n = tl.program_id(1)\n off_d = tl.program_id(2)\n # compute offset\n off = off_b * n * d + off_n * BLOCK_N * d + off_d * BLOCK_D\n off_m = off_b * tl.cdiv(n, BLOCK_N) * d + off_n * d + off_d * BLOCK_D\n\n m = tl.full([BLOCK_D], float(\"-inf\"), dtype=tl.float32)\n o = tl.full([BLOCK_D], float(\"-inf\"), dtype=tl.float32)\n x_block_ptr = (\n X + off + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n o_block_ptr = (\n O + off + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n m_block_ptr = M + off_m + tl.arange(0, BLOCK_D)\n\n # get accumulation matrix, using this to compute cumsum\n # | 1 0 0 | | x1 | | x1 |\n # | 1 1 0 | | x2 | = | x1 + x2 | = cumsum({x1, x2, x3})\n # | 1 1 1 | | x3 | | x1 + x2 + x3 |\n index = tl.arange(0, BLOCK_N)\n acc_matrix = tl.where(index[:, None] >= index[None, :], 1.0, 0.0)\n feature_mask = off_d * BLOCK_D + tl.arange(0, BLOCK_D) < d\n\n mask = (off_n * BLOCK_N + tl.arange(0, BLOCK_N) < n)[:, None] and feature_mask[\n None, :\n ]\n\n # !!!!! important, we don't know which value for padding, this may cause bug in the future\n x = tl.load(\n x_block_ptr,\n mask=mask,\n ).to(tl.float32)\n\n # get the max value in the block\n m = tl.max(x, axis=0)\n\n # compute cumsum(exp(x - m)) using matrix production\n x_exp_stable = tl.exp(x - m)\n x_cumsum_exp = tl.dot(acc_matrix, x_exp_stable)\n\n o = tl.log(x_cumsum_exp)\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_ty), mask=mask)\n tl.store(m_block_ptr, m.to(o_block_ptr.dtype.element_ty), mask=feature_mask)\n\n\n@triton.autotune(\n generate_configs({\"BLOCK_D\": [16, 32, 64, 128], \"num_warps\": [2, 4, 8]}),\n key=[\"n\", \"d\"],\n)\n@triton.jit\ndef _logcumsumexp_block_parallel_reduce(\n X,\n O,\n M,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_d = tl.program_id(1)\n # compute offset\n off = off_b * n * d + off_d * BLOCK_D\n off_m = off_b * tl.cdiv(n, BLOCK_N) * d + off_d * BLOCK_D\n\n m = tl.full([BLOCK_D], float(\"-inf\"), dtype=tl.float32)\n o = tl.full([BLOCK_D], float(\"-inf\"), dtype=tl.float32)\n\n o_block_ptr = (\n O + off + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n m_block_ptr = M + off_m + tl.arange(0, BLOCK_D)\n\n feature_mask = off_d * BLOCK_D + tl.arange(0, BLOCK_D) < d\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n mask = (i * BLOCK_N + tl.arange(0, BLOCK_N) < n)[:, None] and feature_mask[\n None, :\n ]\n\n o_stage1 = tl.load(o_block_ptr, mask=mask).to(tl.float32)\n m_stage1 = tl.load(m_block_ptr, mask=feature_mask).to(tl.float32)\n\n # get the max value in the block\n # update cummax\n m_ = tl.maximum(m, m_stage1)\n\n o_ = tl.log(tl.exp(o + m - m_) + tl.exp(o_stage1 + m_stage1 - m_))\n m = m_\n # we whant the get o_[-1], however, triton doesn't support this,\n # since o_ is monotonically increasing on sequence dim,\n # we can use the max to get this\n o = tl.max(o_, 0)\n o_res = o_ + m\n\n tl.store(o_block_ptr, o_res.to(o_block_ptr.dtype.element_ty), mask=mask)\n\n o_block_ptr += BLOCK_N * d\n m_block_ptr += d\n\n\nclass LogCumSumExpBlockParallel(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, dim=-2):\n if dim >= 0:\n dim -= len(x.shape)\n\n if dim != -2:\n x = x.transpose(-2, dim).contiguous()\n\n b, n, d = x.shape\n o = torch.empty_like(x)\n BLOCK_N = 128\n m = torch.empty(\n b,\n triton.cdiv(n, BLOCK_N),\n d,\n dtype=x.dtype,\n device=torch.cuda.current_device(),\n )\n\n # parallel over batch, sequence and feature\n def grid(meta):\n return (b, triton.cdiv(n, BLOCK_N), triton.cdiv(d, meta[\"BLOCK_D\"]))\n\n _logcumsumexp_block_parallel_compute[grid](x, o, m, b, n, d, BLOCK_N)\n\n # reduce\n def grid(meta):\n return (b, triton.cdiv(d, meta[\"BLOCK_D\"]))\n\n _logcumsumexp_block_parallel_reduce[grid](x, o, m, b, n, d, BLOCK_N)\n\n if dim != -2:\n o = o.transpose(-2, dim).contiguous()\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n return None\n\n\ndef logcumsumexp_block_parallel_triton(x, dim=-2):\n return LogCumSumExpBlockParallel.apply(x, dim)\n", - "description_1": "Use triton language to implement two kernels: _logcumsumexp_block_parallel_compute and _logcumsumexp_block_parallel_reduce. The first kernel computes the cumulative sum of exponentials in a block-wise parallel manner, taking 8 parameters: X (input tensor), O (output tensor), M (max tensor), b (batch size), n (sequence length), d (feature dimension), BLOCK_N (block size for sequence), and BLOCK_D (block size for feature). The second kernel reduces the results across blocks, taking the same parameters. A PyTorch autograd function LogCumSumExpBlockParallel is used to apply these kernels, with a forward method that prepares the input and output tensors and calls the kernels, and a backward method that returns None.", - "description_2": "Use triton language to implement block-wise parallel computation and reduction of cumulative sum of exponentials for a given input tensor, utilizing two kernels and a PyTorch autograd function.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n generate_configs(\n {\"BLOCK_N\": [32, 64, 128], \"BLOCK_D\": [16, 32, 64, 128], \"num_warps\": [2, 4, 8]}\n ),\n key=[\"n\", \"d\"],\n)\n@triton.jit\ndef _logcumsumexp_block_recurrence_fwd(\n X,\n O,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_d = tl.program_id(1)\n # compute offset\n off = off_b * n * d + off_d * BLOCK_D\n\n m = tl.full([BLOCK_D], float(\"-inf\"), dtype=tl.float32)\n o = tl.full([BLOCK_D], float(\"-inf\"), dtype=tl.float32)\n x_block_ptr = (\n X + off + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n o_block_ptr = (\n O + off + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n\n # get accumulation matrix, using this to compute cumsum\n # | 1 0 0 | | x1 | | x1 |\n # | 1 1 0 | | x2 | = | x1 + x2 | = cumsum({x1, x2, x3})\n # | 1 1 1 | | x3 | | x1 + x2 + x3 |\n index = tl.arange(0, BLOCK_N)\n acc_matrix = tl.where(index[:, None] >= index[None, :], 1.0, 0.0)\n feature_mask = (off_d * BLOCK_D + tl.arange(0, BLOCK_D) < d)[None, :]\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n mask = (i * BLOCK_N + tl.arange(0, BLOCK_N) < n)[:, None] and feature_mask\n\n x = tl.load(x_block_ptr, mask=mask).to(tl.float32)\n\n # get the max value in the block\n m_ = tl.max(x, axis=0)\n # update cummax\n m_ = tl.maximum(m, m_)\n\n # compute cumsum(exp(x - m_)) using matrix production\n x_exp_stable = tl.exp(x - m_)\n x_cumsum_exp = tl.dot(acc_matrix, x_exp_stable)\n\n o_ = tl.log(tl.exp(o + m - m_) + x_cumsum_exp)\n m = m_\n # we whant the get o_[-1], however, triton doesn't support this,\n # since o_ is monotonically increasing on sequence dim,\n # we can use the max to get this\n o = tl.max(o_, 0)\n o_res = o_ + m\n\n tl.store(o_block_ptr, o_res.to(o_block_ptr.dtype.element_ty), mask=mask)\n\n x_block_ptr += BLOCK_N * d\n o_block_ptr += BLOCK_N * d\n\n\n@triton.autotune(\n generate_configs({\"BLOCK_N\": [32], \"BLOCK_D\": [16], \"num_warps\": [2]}),\n key=[\"n\", \"d\"],\n)\n@triton.jit\ndef _logcumsumexp_block_recurrence_bwd(\n X,\n O,\n DX,\n DO,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n \"\"\"\n ______\n 0 | |\n 1 | |\n ------\n 2 | |\n ——————\n 3 | ||\n 4 | |\n 5 | |\n ——————\n 6 | |\n 7 | |\n ------\n 8 | |\n ——————\n Assume the sequence length is 8, block size is 3, there are 3 blocks, and the index is 1, which is at block 0, we assume the index start from 0.\n The algorithm is as follows:\n 1. Compute the 2th block (mask the position whose index >= 8 with 0)\n 2. Compute the 1th block\n 3. Compute the 0th block (mask the position whose index < 1 with 0)\n \"\"\"\n off_b = tl.program_id(0)\n off_n = tl.program_id(1)\n off_d = tl.program_id(2)\n # compute offset\n off_x = off_b * n * d + off_n * d + off_d * BLOCK_D\n # start from the last block\n num_block = tl.cdiv(n, BLOCK_N)\n block_idx = off_n // BLOCK_N\n off_o = off_b * n * d + (num_block - 1) * BLOCK_N * d + off_d * BLOCK_D\n\n x_block_ptr = X + off_x + tl.arange(0, BLOCK_D)\n o_block_ptr = (\n O + off_o + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n dx_block_ptr = DX + off_x + tl.arange(0, BLOCK_D)\n do_block_ptr = (\n DO + off_o + tl.arange(0, BLOCK_N)[:, None] * d + tl.arange(0, BLOCK_D)[None, :]\n )\n\n # get rev accumulation matrix, using this to compute revcumsum\n # | 1 1 1 | | x1 | | x3 + x2 + x1 |\n # | 0 1 1 | | x2 | = | x3 + x2 | = revcumsum({x1, x2, x3})\n # | 0 0 1 | | x3 | | x3 |\n index = tl.arange(0, BLOCK_N)\n acc_matrix = tl.where(index[:, None] <= index[None, :], 1.0, 0.0)\n feature_mask = off_d * BLOCK_D + tl.arange(0, BLOCK_D) < d\n # feature_mask = tl.arange(0, BLOCK_D) < BLOCK_D\n\n # sequence mask\n # sequence_mask_front = ((block_idx * BLOCK_N + tl.arange(0, BLOCK_N)) >= off_n)[:, None]\n sequence_mask_front = ((block_idx * BLOCK_N + tl.arange(0, BLOCK_N)) >= off_n)[\n :, None\n ] and ((block_idx * BLOCK_N + tl.arange(0, BLOCK_N)) < n)[:, None]\n array = (num_block - 1) * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # use this mask to get first row of a matrix\n index_mask = (tl.arange(0, BLOCK_N) == 0)[:, None]\n\n x = tl.load(x_block_ptr, mask=feature_mask, other=0).to(tl.float32)\n dx = tl.zeros([BLOCK_D], dtype=tl.float32)\n # loop from last block to the first block\n # tl.device_print(\"aaa\", num_block - block_idx)\n\n # pdb.set_trace()\n # print(acc_matrix)\n m = num_block - block_idx\n for j in range(m):\n sequence_mask_end = (array < n)[:, None]\n # tl.static_print(\"aaa\", feature_mask[None, :], sequence_mask_front, sequence_mask_end)\n # if j == m - 1:\n # mask = feature_mask[None, :] and sequence_mask_front\n # else:\n # mask = feature_mask[None, :] and sequence_mask_end\n # mask = feature_mask[None, :] and sequence_mask_end\n mask = (feature_mask[None, :] and sequence_mask_front) and sequence_mask_end\n\n # tl.device_print(\"aaa\", mask)\n\n o = tl.load(o_block_ptr, mask=mask, other=0).to(tl.float32)\n do = tl.load(do_block_ptr, mask=mask, other=0).to(tl.float32)\n\n tmp = do * tl.exp(x - o)\n dx_arr = dx + tl.dot(acc_matrix, tmp)\n\n # we use this to get the first row of dx_arr,\n # since triton doesn't support index operation\n dx = tl.sum(tl.where(index_mask, dx_arr, 0), axis=0)\n\n array -= BLOCK_N\n o_block_ptr -= BLOCK_N * d\n do_block_ptr -= BLOCK_N * d\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=feature_mask)\n\n\nclass LogCumSumExpBlockRecurrence(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, dim=-2):\n x.dtype\n if dim >= 0:\n dim -= len(x.shape)\n\n if dim != -2:\n x = x.transpose(-2, dim).contiguous()\n\n x, ps, is_list = pack(x, \"* n d\")\n b, n, d = x.shape\n o = torch.empty_like(x)\n\n # parallel over batch and feature\n def grid(meta):\n return (b, triton.cdiv(d, meta[\"BLOCK_D\"]))\n\n _logcumsumexp_block_recurrence_fwd[grid](x, o, b, n, d)\n\n ctx.save_for_backward(x, o)\n ctx.dim = dim\n\n o = unpack(o, ps, \"* n d\", is_list)\n if dim != -2:\n o = o.transpose(-2, dim).contiguous()\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n x, o = ctx.saved_tensors\n dim = ctx.dim\n b, n, d = x.shape\n # print(x.shape)\n\n dx = torch.empty_like(x)\n\n if dim != -2:\n do = do.transpose(-2, dim).contiguous()\n\n do, ps, is_list = pack(do, \"* n d\")\n\n # parallel over batch, sequence and feature\n def grid(meta):\n return (b, n, triton.cdiv(d, meta[\"BLOCK_D\"]))\n\n _logcumsumexp_block_recurrence_bwd[grid](x, o, dx, do, b, n, d)\n\n dx = unpack(dx, ps, \"* n d\", is_list)\n if dim != -2:\n dx = dx.transpose(-2, dim).contiguous()\n\n return dx, None\n\n\ndef logcumsumexp_block_recurrence_triton(x, dim=-2):\n return LogCumSumExpBlockRecurrence.apply(x, dim)\n", - "description_1": "Use triton language to implement two kernels: _logcumsumexp_block_recurrence_fwd and _logcumsumexp_block_recurrence_bwd. The forward kernel computes the log cumulative sum of exponentials for a given input tensor X, storing the result in tensor O. It takes 7 parameters: X (input tensor), O (output tensor), b (batch size), n (sequence length), d (feature dimension), BLOCK_N (block size for sequence), and BLOCK_D (block size for feature). The backward kernel computes the gradient of the input tensor X with respect to the output tensor O, storing the result in tensor DX. It takes 8 parameters: X (input tensor), O (output tensor), DX (gradient of input), DO (gradient of output), b (batch size), n (sequence length), d (feature dimension), BLOCK_N (block size for sequence), and BLOCK_D (block size for feature).", - "description_2": "Use triton language to implement a forward kernel for log cumulative sum of exponentials and a backward kernel for computing gradients, both operating on block sizes for sequence and feature dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _logcumsumexp_recurrence_fwd(\n X,\n O,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_d = tl.program_id(1)\n # compute offset\n off = off_b * n * d + off_d * BLOCK\n\n m = tl.full([BLOCK], float(\"-inf\"), dtype=tl.float32)\n o = tl.full([BLOCK], float(\"-inf\"), dtype=tl.float32)\n x_block_ptr = X + off + tl.arange(0, BLOCK)\n o_block_ptr = O + off + tl.arange(0, BLOCK)\n mask = off_d * BLOCK + tl.arange(0, BLOCK) < d\n\n for i in range(n):\n x = tl.load(x_block_ptr, mask=mask).to(tl.float32)\n m_ = tl.maximum(x, m)\n\n o = tl.log(tl.exp(o + m - m_) + tl.exp(x - m_))\n m = m_\n o_res = o + m\n\n tl.store(o_block_ptr, o_res.to(o_block_ptr.dtype.element_ty), mask=mask)\n\n x_block_ptr += d\n o_block_ptr += d\n\n@triton.jit\ndef _logcumsumexp_recurrence_bwd(\n X,\n O,\n DX,\n DO,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_n = tl.program_id(1)\n off_d = tl.program_id(2)\n # compute offset\n off = off_b * n * d + off_n * d + off_d * BLOCK\n\n x_block_ptr = X + off + tl.arange(0, BLOCK)\n o_block_ptr = O + off + tl.arange(0, BLOCK)\n dx_block_ptr = DX + off + tl.arange(0, BLOCK)\n do_block_ptr = DO + off + tl.arange(0, BLOCK)\n mask = off_d * BLOCK + tl.arange(0, BLOCK) < d\n\n x = tl.load(x_block_ptr, mask=mask).to(tl.float32)\n dx = tl.zeros([BLOCK], dtype=tl.float32)\n for j in range(off_n, n):\n o = tl.load(o_block_ptr, mask=mask).to(tl.float32)\n do = tl.load(do_block_ptr, mask=mask).to(tl.float32)\n\n dx += do * tl.exp(x - o)\n\n o_block_ptr += d\n do_block_ptr += d\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=mask)\n\nclass LogCumSumExpRecurrence(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, dim=-2):\n if dim >= 0:\n dim -= len(x.shape)\n\n if dim != -2:\n x = x.transpose(-2, dim).contiguous()\n\n x, ps, is_list = pack(x, \"* n d\")\n b, n, d = x.shape\n o = torch.empty_like(x)\n\n # parallel over batch and feature\n def grid(meta):\n return (b, triton.cdiv(d, meta[\"BLOCK\"]))\n\n _logcumsumexp_recurrence_fwd[grid](x, o, b, n, d)\n\n ctx.save_for_backward(x, o)\n ctx.dim = dim\n\n o = unpack(o, ps, \"* n d\", is_list)\n if dim != -2:\n o = o.transpose(-2, dim).contiguous()\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n x, o = ctx.saved_tensors\n dim = ctx.dim\n b, n, d = x.shape\n\n dx = torch.empty_like(x)\n\n if dim != -2:\n do = do.transpose(-2, dim).contiguous()\n\n do, ps, is_list = pack(do, \"* n d\")\n\n # parallel over batch, sequence and feature\n def grid(meta):\n return (b, n, triton.cdiv(d, meta[\"BLOCK\"]))\n\n _logcumsumexp_recurrence_bwd[grid](x, o, dx, do, b, n, d)\n\n dx = unpack(dx, ps, \"* n d\", is_list)\n if dim != -2:\n dx = dx.transpose(-2, dim).contiguous()\n\n return dx, None\n\ndef logcumsumexp_recurrence_triton(x, dim=-2):\n return LogCumSumExpRecurrence.apply(x, dim)\n", - "description_1": "Use triton language to implement a forward and backward pass of a log cumulative sum exponential operation. The forward kernel '_logcumsumexp_recurrence_fwd' takes 5 parameters: X (input tensor), O (output tensor), b (batch size), n (sequence length), d (feature dimension), and BLOCK (block size). It computes the log cumulative sum exponential over the specified dimension. The backward kernel '_logcumsumexp_recurrence_bwd' takes 7 parameters: X (input tensor), O (output tensor from forward pass), DX (gradient of input), DO (gradient of output), b (batch size), n (sequence length), d (feature dimension), and BLOCK (block size). It computes the gradient of the input tensor based on the gradient of the output tensor.", - "description_2": "Use triton language to implement a log cumulative sum exponential operation with forward and backward passes, handling input and output tensors, and computing gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n generate_configs(\n {\n \"BLOCK_N\": [16, 32, 64, 128],\n \"BLOCK_D\": [16, 32, 64, 128],\n \"num_warps\": [2, 4, 8],\n }\n ),\n key=[\"n\", \"d\"],\n)\n@triton.jit\ndef _lrpe_cosine_1d_bp_fwd_triton(\n X,\n Theta,\n O,\n X_STAT1,\n X_STAT2,\n offset: tl.constexpr,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_d = tl.program_id(2)\n # compute offset\n offset_d = off_d * BLOCK_D\n offset_x = off_b * h * n * d + off_h * n * d + offset_d\n offset_theta = off_h * d + offset_d\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + offset_d\n # compute block ptr\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n # mask\n d_mask = (offset_d + tl.arange(0, BLOCK_D)) < d\n\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n # get stat\n if ACT != \"none\":\n if ACT == \"softmax\":\n x_max = tl.full([BLOCK_D], value, dtype=tl.float32)\n denominator = tl.full([BLOCK_D], 0, dtype=tl.float32)\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = (i * BLOCK_N + tl.arange(0, BLOCK_N)) < n\n x = tl.load(\n x_block_ptr, mask=n_mask[:, None] & d_mask[None, :], other=value\n )\n\n x_block_max = tl.max(x, axis=0)\n x_max_ = tl.where(x_block_max > x_max, x_block_max, x_max)\n # sum(exp(xi - a)) + exp(x - a) = exp(b - a) * sum(exp(xi - b)) + exp(x - b)\n x_exp = tl.exp(x - x_max_)\n lambda_ = tl.exp(x_max - x_max_)\n denominator = lambda_ * denominator + tl.sum(x_exp, axis=0)\n x_max = x_max_\n\n x_block_ptr += BLOCK_N * d\n\n # save\n x_stat1_block_ptr = (\n X_STAT1 + off_b * h * d + off_h * d + offset_d + tl.arange(0, BLOCK_D)\n )\n x_stat2_block_ptr = (\n X_STAT2 + off_b * h * d + off_h * d + offset_d + tl.arange(0, BLOCK_D)\n )\n\n tl.store(\n x_stat1_block_ptr,\n x_max.to(x_stat1_block_ptr.dtype.element_ty),\n mask=d_mask,\n )\n tl.store(\n x_stat2_block_ptr,\n denominator.to(x_stat2_block_ptr.dtype.element_ty),\n mask=d_mask,\n )\n\n # compute block ptr\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK_D)[None, :]\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n o_cos_block_ptr = (\n O\n + offset_o\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n o_sin_block_ptr = (\n O\n + offset_o\n + d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n array = tl.arange(0, BLOCK_N)\n theta_ = tl.load(theta_block_ptr, mask=d_mask[None, :], other=0).to(tl.float32)\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = array < n\n mask = n_mask[:, None] & d_mask[None, :]\n x = tl.load(x_block_ptr, mask=mask, other=0).to(tl.float32)\n\n if ACT != \"none\":\n if ACT == \"relu\":\n x = tl.where(x >= 0, x, 0)\n elif ACT == \"sigmoid\":\n x = tl.sigmoid(x)\n elif ACT == \"silu\":\n x = x * tl.sigmoid(x)\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n x = numerator / denominator\n\n theta = theta_ * (array[:, None] + offset)\n o_cos = x * tl.cos(theta)\n o_sin = x * tl.sin(theta)\n\n tl.store(o_cos_block_ptr, o_cos.to(o_cos_block_ptr.dtype.element_ty), mask=mask)\n tl.store(o_sin_block_ptr, o_sin.to(o_cos_block_ptr.dtype.element_ty), mask=mask)\n\n x_block_ptr += BLOCK_N * d\n array += BLOCK_N\n o_cos_block_ptr += BLOCK_N * 2 * d\n o_sin_block_ptr += BLOCK_N * 2 * d\n\n\n@triton.autotune(\n generate_configs(\n {\n \"BLOCK_N\": [16, 32, 64, 128],\n \"BLOCK_D\": [16, 32, 64, 128],\n \"num_warps\": [2, 4, 8],\n }\n ),\n key=[\"n\", \"d\"],\n)\n@triton.jit\ndef _lrpe_cosine_1d_bp_bwd_triton(\n X,\n Theta,\n DO,\n DX,\n X_STAT1,\n X_STAT2,\n offset: tl.constexpr,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_d = tl.program_id(2)\n # compute offset\n offset_d = off_d * BLOCK_D\n offset_x = off_b * h * n * d + off_h * n * d + offset_d\n offset_theta = off_h * d + offset_d\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + offset_d\n # compute block ptr\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK_D)[None, :]\n dx_block_ptr = (\n DX\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n do_cos_block_ptr = (\n DO\n + offset_o\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n do_sin_block_ptr = (\n DO\n + offset_o\n + d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n array = tl.arange(0, BLOCK_N)\n # mask\n d_mask = (offset_d + tl.arange(0, BLOCK_D)) < d\n\n theta_ = tl.load(theta_block_ptr, mask=d_mask[None, :], other=0).to(tl.float32)\n\n if ACT == \"softmax\": # compute c first\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n x_stat1_block_ptr = (\n X_STAT1 + off_b * h * d + off_h * d + offset_d + tl.arange(0, BLOCK_D)\n )\n x_stat2_block_ptr = (\n X_STAT2 + off_b * h * d + off_h * d + offset_d + tl.arange(0, BLOCK_D)\n )\n x_max = tl.load(x_stat1_block_ptr, mask=d_mask, other=0).to(tl.float32)\n denominator = tl.load(x_stat2_block_ptr, mask=d_mask, other=1).to(tl.float32)\n\n c = tl.zeros([BLOCK_D], dtype=tl.float32)\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = array < n\n mask = n_mask[:, None] & d_mask[None, :]\n\n do_cos = tl.load(do_cos_block_ptr, mask=mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=mask, other=0).to(tl.float32)\n\n theta = theta_ * (array[:, None] + offset)\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n x = tl.load(x_block_ptr, mask=mask, other=0).to(tl.float32)\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n o = numerator / denominator\n\n # scalar\n c += tl.sum(o * dx, axis=0)\n\n x_block_ptr += BLOCK_N * d\n array += BLOCK_N\n do_cos_block_ptr += BLOCK_N * 2 * d\n do_sin_block_ptr += BLOCK_N * 2 * d\n\n # reinit\n do_cos_block_ptr = (\n DO\n + offset_o\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n do_sin_block_ptr = (\n DO\n + offset_o\n + d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n array = tl.arange(0, BLOCK_N)\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = array < n\n mask = n_mask[:, None] & d_mask[None, :]\n\n do_cos = tl.load(do_cos_block_ptr, mask=mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=mask, other=0).to(tl.float32)\n\n theta = theta_ * (array[:, None] + offset)\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n if ACT != \"none\":\n x_block_ptr = (\n X\n + offset_x\n + i * BLOCK_N * d\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n x = tl.load(x_block_ptr, mask=mask, other=0).to(tl.float32)\n if ACT == \"relu\":\n dx = tl.where(x >= 0, dx, 0)\n elif ACT == \"sigmoid\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 - sigmoid)\n elif ACT == \"silu\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 + x * (1 - sigmoid))\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n o = numerator / denominator\n # scalar\n dx = o * dx - c * o\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=mask)\n\n dx_block_ptr += BLOCK_N * d\n array += BLOCK_N\n do_cos_block_ptr += BLOCK_N * 2 * d\n do_sin_block_ptr += BLOCK_N * 2 * d\n\n\ndef lrpe_cosine_1d_bp_fwd_triton(x, theta, offset=0, act=\"none\", dim=None, **kwargs):\n assert dim in [-2, None], \"dim must in [-2, None]\"\n\n b, h, n, d = x.shape\n o = torch.empty(b, h, n, 2 * d, dtype=x.dtype, device=x.device)\n x_stat1 = torch.empty(b, h, d, dtype=x.dtype, device=x.device)\n x_stat2 = torch.empty(b, h, d, dtype=x.dtype, device=x.device)\n\n def grid(meta):\n return (b, h, triton.cdiv(d, meta[\"BLOCK_D\"]))\n\n _lrpe_cosine_1d_bp_fwd_triton[grid](\n x, theta, o, x_stat1, x_stat2, offset, b, h, n, d, act\n )\n\n return o, x_stat1, x_stat2\n\n\ndef lrpe_cosine_1d_bp_bwd_triton(\n x, theta, do, x_stat1, x_stat2, offset=0, act=\"none\", dim=None, **kwargs\n):\n assert dim in [-2, None], \"dim must in [-2, None]\"\n\n b, h, n, d = x.shape\n dx = torch.empty_like(x)\n\n def grid(meta):\n return (b, h, triton.cdiv(d, meta[\"BLOCK_D\"]))\n\n _lrpe_cosine_1d_bp_bwd_triton[grid](\n x, theta, do, dx, x_stat1, x_stat2, offset, b, h, n, d, act\n )\n\n return dx\n", - "description_1": "Use triton language to implement two kernels: _lrpe_cosine_1d_bp_fwd_triton and _lrpe_cosine_1d_bp_bwd_triton. The forward kernel computes the cosine and sine transformations of input X with parameters Theta, storing results in O, and computes statistics X_STAT1 and X_STAT2 based on the activation function ACT. The backward kernel computes the gradient DX of the input X using the gradients DO of the output O, and the statistics X_STAT1 and X_STAT2. Both kernels use block sizes BLOCK_N and BLOCK_D for parallel processing.", - "description_2": "Use triton language to implement forward and backward kernels for cosine and sine transformations with activation functions, using block sizes for parallel processing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom xopes.utils import next_power_of_two\n\n@triton.jit\ndef _lrpe_cosine_1d_sp_fwd_triton(\n X,\n Theta,\n O,\n offset: tl.constexpr,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_n = tl.program_id(2)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d + off_n * d\n offset_theta = off_h * d\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + off_n * 2 * d\n # mask\n d_mask = tl.arange(0, BLOCK) < d\n\n x_block_ptr = X + offset_x + tl.arange(0, BLOCK)\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK)\n o_cos_block_ptr = O + offset_o + tl.arange(0, BLOCK)\n o_sin_block_ptr = O + offset_o + d + tl.arange(0, BLOCK)\n\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n x = tl.load(x_block_ptr, mask=d_mask, other=value).to(tl.float32)\n if ACT != \"none\":\n if ACT == \"relu\":\n x = tl.where(x >= 0, x, 0)\n elif ACT == \"sigmoid\":\n x = tl.sigmoid(x)\n elif ACT == \"silu\":\n x = x * tl.sigmoid(x)\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - tl.max(x, axis=0)\n # softmax\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n x = numerator / denominator\n\n theta = tl.load(theta_block_ptr, mask=d_mask, other=0).to(tl.float32) * (\n off_n + offset\n )\n o_cos = x * tl.cos(theta)\n o_sin = x * tl.sin(theta)\n\n tl.store(o_cos_block_ptr, o_cos.to(o_cos_block_ptr.dtype.element_ty), mask=d_mask)\n tl.store(o_sin_block_ptr, o_sin.to(o_cos_block_ptr.dtype.element_ty), mask=d_mask)\n\n@triton.jit\ndef _lrpe_cosine_1d_sp_bwd_triton(\n X,\n Theta,\n DO,\n DX,\n offset: tl.constexpr,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_n = tl.program_id(2)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d + off_n * d\n offset_theta = off_h * d\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + off_n * 2 * d\n # mask\n d_mask = tl.arange(0, BLOCK) < d\n\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK)\n dx_block_ptr = DX + offset_x + tl.arange(0, BLOCK)\n do_cos_block_ptr = DO + offset_o + tl.arange(0, BLOCK)\n do_sin_block_ptr = DO + offset_o + d + tl.arange(0, BLOCK)\n\n do_cos = tl.load(do_cos_block_ptr, mask=d_mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=d_mask, other=0).to(tl.float32)\n\n theta = tl.load(theta_block_ptr, mask=d_mask, other=0).to(tl.float32) * (\n off_n + offset\n )\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n if ACT != \"none\":\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n x_block_ptr = X + offset_x + tl.arange(0, BLOCK)\n x = tl.load(x_block_ptr, mask=d_mask, other=value).to(tl.float32)\n\n if ACT == \"relu\":\n dx = tl.where(x >= 0, dx, 0)\n elif ACT == \"sigmoid\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 - sigmoid)\n elif ACT == \"silu\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 + x * (1 - sigmoid))\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - tl.max(x, axis=0)\n # softmax\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n o = numerator / denominator\n\n # scalar\n c = tl.sum(o * dx, axis=0)\n dx = o * dx - c * o\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=d_mask)\n\ndef lrpe_cosine_1d_sp_fwd_triton(x, theta, offset=0, act=\"none\", dim=None, **kwargs):\n assert dim in [-1, None], \"dim must in [-1, None]\"\n\n b, h, n, d = x.shape\n o = torch.empty(b, h, n, 2 * d, dtype=x.dtype, device=x.device)\n BLOCK = next_power_of_two(d)\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_1d_sp_fwd_triton[grid](x, theta, o, offset, b, h, n, d, act, BLOCK)\n\n return o\n\ndef lrpe_cosine_1d_sp_bwd_triton(\n x, theta, do, offset=0, act=\"none\", dim=None, **kwargs\n):\n assert dim in [-1, None], \"dim must in [-1, None]\"\n\n b, h, n, d = x.shape\n dx = torch.empty_like(x)\n BLOCK = next_power_of_two(d)\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_1d_sp_bwd_triton[grid](\n x, theta, do, dx, offset, b, h, n, d, act, BLOCK\n )\n\n return dx\n\nclass LrpeCosine1dSpTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, theta, offset=0, act=\"none\", dim=None):\n o = lrpe_cosine_1d_sp_fwd_triton(x, theta, offset, act, dim)\n\n ctx.save_for_backward(x, theta)\n ctx.offset = offset\n ctx.act = act\n ctx.dim = dim\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n x, theta = ctx.saved_tensors\n offset = ctx.offset\n act = ctx.act\n dim = ctx.dim\n\n dx = lrpe_cosine_1d_sp_bwd_triton(x, theta, do, offset, act, dim)\n\n return dx, None, None, None, None\n\ndef lrpe_cosine_1d_sp_triton(x, theta, offset=0, act=\"none\", dim=None, **kwargs):\n # x: b, h, n, d\n # theta: h, d\n assert dim in [-1, None], \"dim must in [-1, None]\"\n return LrpeCosine1dSpTriton.apply(x, theta, offset, act, dim)\n", - "description_1": "Use triton language to implement a forward and backward pass of a 1D cosine function with optional activation functions. The forward kernel (_lrpe_cosine_1d_sp_fwd_triton) takes 8 parameters: X (input tensor), Theta (angle tensor), O (output tensor), offset (constant offset), b, h, n, d (dimensions of the input tensor), ACT (activation type), and BLOCK (block size). The backward kernel (_lrpe_cosine_1d_sp_bwd_triton) takes 9 parameters: X, Theta, DO (gradient of output), DX (gradient of input), offset, b, h, n, d, ACT, and BLOCK. The function lrpe_cosine_1d_sp_fwd_triton calls the forward kernel, and lrpe_cosine_1d_sp_bwd_triton calls the backward kernel. The class LrpeCosine1dSpTriton implements the autograd function for PyTorch, using these kernels for forward and backward passes.", - "description_2": "Use triton language to create a 1D cosine function with optional activation for forward and backward passes, integrated with PyTorch autograd.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom xopes.utils import ACT_SET, next_power_of_two\n\n\n@triton.autotune(\n generate_configs({\"BLOCK_N\": [16, 32, 64, 128], \"num_warps\": [2, 4, 8]}),\n key=[\"h\", \"n\", \"d\", \"m\"],\n)\n@triton.jit\ndef _lrpe_cosine_md_bp_fwd_triton(\n X,\n Theta,\n O,\n Shape,\n ThetaCache,\n X_STAT1,\n X_STAT2,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n l: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n m: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n BLOCK_L: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d\n offset_theta = off_h * e\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d\n offset_d = m * e\n offset_theta_cache = off_h * n * d + l * d\n\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n # get stat\n # for softmax act, we should compute max and denominator first\n if ACT == \"softmax\":\n # mask\n d_mask = tl.arange(0, BLOCK_D) < d\n\n x_block_ptr_ = (\n X\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n x_max = tl.full([BLOCK_D], value, dtype=tl.float32)\n denominator = tl.full([BLOCK_D], 0, dtype=tl.float32)\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = (i * BLOCK_N + tl.arange(0, BLOCK_N)) < n\n x_ = tl.load(\n x_block_ptr_, mask=n_mask[:, None] & d_mask[None, :], other=value\n )\n\n x_block_max = tl.max(x_, axis=0)\n x_max_ = tl.where(x_block_max > x_max, x_block_max, x_max)\n # sum(exp(xi - a)) + exp(x - a) = exp(b - a) * sum(exp(xi - b)) + exp(x - b)\n x_exp = tl.exp(x_ - x_max_)\n lambda_ = tl.exp(x_max - x_max_)\n denominator = lambda_ * denominator + tl.sum(x_exp, axis=0)\n x_max = x_max_\n\n x_block_ptr_ += BLOCK_N * d\n\n # save\n x_stat1_block_ptr = X_STAT1 + off_b * h * d + off_h * d + tl.arange(0, BLOCK_D)\n x_stat2_block_ptr = X_STAT2 + off_b * h * d + off_h * d + tl.arange(0, BLOCK_D)\n\n tl.store(\n x_stat1_block_ptr,\n x_max.to(x_stat1_block_ptr.dtype.element_ty),\n mask=d_mask,\n )\n tl.store(\n x_stat2_block_ptr,\n denominator.to(x_stat2_block_ptr.dtype.element_ty),\n mask=d_mask,\n )\n\n # compute the first l element\n if l > 0:\n offset_theta_cache_l = off_h * n * d\n theta_cache_block_ptr_l = (\n ThetaCache\n + offset_theta_cache_l\n + tl.arange(0, BLOCK_L)[:, None]\n + tl.arange(0, BLOCK_D)[None, :]\n )\n\n x_block_ptr_l = (\n X\n + offset_x\n + tl.arange(0, BLOCK_L)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n o_cos_block_ptr_l = (\n O\n + offset_o\n + tl.arange(0, BLOCK_L)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n o_sin_block_ptr_l = (\n O\n + offset_o\n + d\n + tl.arange(0, BLOCK_L)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n ld_mask = (tl.arange(0, BLOCK_L) < l)[:, None] & (\n tl.arange(0, BLOCK_D)[None, :] < d\n )\n x_l = tl.load(x_block_ptr_l, mask=ld_mask, other=0)\n\n if ACT != \"none\":\n if ACT == \"relu\":\n x_l = tl.where(x_l >= 0, x_l, 0)\n elif ACT == \"sigmoid\":\n x_l = tl.sigmoid(x_l)\n elif ACT == \"silu\":\n x_l = x_l * tl.sigmoid(x_l)\n elif ACT == \"softmax\":\n # for stable\n x_l_minus_max = x_l - x_max\n # softmax\n numerator_l = tl.exp(x_l_minus_max)\n x_l = numerator_l / denominator\n\n zero = tl.zeros([BLOCK_L, BLOCK_D], dtype=x_l.dtype)\n # save\n tl.store(\n o_cos_block_ptr_l, x_l.to(o_cos_block_ptr_l.dtype.element_ty), mask=ld_mask\n )\n tl.store(\n o_sin_block_ptr_l, zero.to(o_sin_block_ptr_l.dtype.element_ty), mask=ld_mask\n )\n tl.store(\n theta_cache_block_ptr_l,\n zero.to(theta_cache_block_ptr_l.dtype.element_ty),\n mask=ld_mask,\n )\n\n # compute from the last theta block\n x_block_ptr = (\n X\n + offset_x\n + l * d\n + offset_d\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_E)[None, :]\n )\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK_E)[None, :]\n o_cos_block_ptr = (\n O\n + offset_o\n + 2 * l * d\n + offset_d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_E)[None, :]\n )\n o_sin_block_ptr = (\n O\n + offset_o\n + 2 * l * d\n + offset_d\n + d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_E)[None, :]\n )\n # triton only support load block at least 16 elements, use this to get shape\n shape_mask = tl.arange(0, 16) < 1\n # mask\n e_mask = tl.arange(0, BLOCK_E) < e\n\n theta_ = tl.load(theta_block_ptr, mask=e_mask[None, :], other=0).to(tl.float32)\n array = tl.arange(0, BLOCK_N)\n theta_cache_block_ptr = (\n ThetaCache\n + offset_theta_cache\n + offset_d\n + tl.arange(0, BLOCK_N)[:, None] * e\n + tl.arange(0, BLOCK_E)[None, :]\n )\n\n for i in range(tl.cdiv(n - l, BLOCK_N)):\n n_mask = array < n - l # !!! important\n c = array[:, None]\n offset_d = m * e\n # triton only support load block at least 16 elements, use this to get shape\n shape_block_ptr = Shape + m + tl.arange(0, 16)\n if ACT == \"softmax\":\n x_max_block_ptr = (\n X_STAT1\n + off_b * h * d\n + off_h * d\n + offset_d\n + tl.arange(0, BLOCK_E)[None, :]\n )\n denominator_block_ptr = (\n X_STAT2\n + off_b * h * d\n + off_h * d\n + offset_d\n + tl.arange(0, BLOCK_E)[None, :]\n )\n\n for j in range(m):\n # update block ptr\n shape_block_ptr -= 1\n x_block_ptr -= e\n o_cos_block_ptr -= e\n o_sin_block_ptr -= e\n offset_d -= e\n theta_cache_block_ptr -= e\n\n de_mask = ((offset_d + tl.arange(0, BLOCK_E)) < d) & e_mask\n mask = n_mask[:, None] & de_mask[None, :]\n\n # compute dim\n dim = tl.sum(\n tl.load(shape_block_ptr, mask=shape_mask, other=0).to(tl.int32)\n )\n offset = c % dim\n c = c // dim\n\n x = tl.load(x_block_ptr, mask=mask, other=value).to(tl.float32)\n if ACT != \"none\":\n if ACT == \"relu\":\n x = tl.where(x >= 0, x, 0)\n elif ACT == \"sigmoid\":\n x = tl.sigmoid(x)\n elif ACT == \"silu\":\n x = x * tl.sigmoid(x)\n elif ACT == \"softmax\":\n x_max_block_ptr -= e\n denominator_block_ptr -= e\n x_max_ = tl.load(\n x_max_block_ptr, mask=de_mask[None, :], other=0\n ).to(tl.float32)\n denominator_ = tl.load(\n denominator_block_ptr, mask=de_mask[None, :], other=1\n ).to(tl.float32)\n # for stable\n x_minus_max_ = x - x_max_\n # softmax\n numerator_ = tl.exp(x_minus_max_)\n x = numerator_ / denominator_\n\n theta = theta_ * offset\n o_cos = x * tl.cos(theta)\n o_sin = x * tl.sin(theta)\n\n # save\n tl.store(\n o_cos_block_ptr, o_cos.to(o_cos_block_ptr.dtype.element_ty), mask=mask\n )\n tl.store(\n o_sin_block_ptr, o_sin.to(o_sin_block_ptr.dtype.element_ty), mask=mask\n )\n if i == 0:\n tl.store(\n theta_cache_block_ptr,\n theta.to(theta_cache_block_ptr.dtype.element_ty),\n mask=mask,\n )\n\n x_block_ptr += BLOCK_N * d + e * m\n array += BLOCK_N\n o_cos_block_ptr += BLOCK_N * 2 * d + e * m\n o_sin_block_ptr += BLOCK_N * 2 * d + e * m\n\n\n@triton.autotune(\n generate_configs({\"num_warps\": [2, 4, 8]}),\n key=[\"h\", \"n\", \"d\", \"m\"],\n)\n@triton.jit\ndef _lrpe_cosine_md_bp_bwd_triton(\n X,\n Theta,\n DO,\n DX,\n Shape,\n ThetaCache,\n X_STAT1,\n X_STAT2,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n l: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n m: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d\n offset_theta_cache = off_h * n * d\n # compute block ptr\n theta_block_ptr = (\n ThetaCache\n + offset_theta_cache\n + tl.arange(0, BLOCK_N) * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n dx_block_ptr = (\n DX\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n do_cos_block_ptr = (\n DO\n + offset_o\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n do_sin_block_ptr = (\n DO\n + offset_o\n + d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n array = tl.arange(0, BLOCK_N)\n # mask\n d_mask = tl.arange(0, BLOCK_D) < d\n\n if ACT == \"softmax\": # compute c first\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n x_stat1_block_ptr = X_STAT1 + off_b * h * d + off_h * d + tl.arange(0, BLOCK_D)\n x_stat2_block_ptr = X_STAT2 + off_b * h * d + off_h * d + tl.arange(0, BLOCK_D)\n x_max = tl.load(x_stat1_block_ptr, mask=d_mask, other=0).to(tl.float32)\n denominator = tl.load(x_stat2_block_ptr, mask=d_mask, other=1).to(tl.float32)\n\n c = tl.zeros([BLOCK_D], dtype=tl.float32)\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = array < n\n mask = n_mask[:, None] & d_mask[None, :]\n\n do_cos = tl.load(do_cos_block_ptr, mask=mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=mask, other=0).to(tl.float32)\n theta = tl.load(theta_block_ptr, mask=mask, other=0).to(tl.float32)\n\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n x = tl.load(x_block_ptr, mask=mask, other=0).to(tl.float32)\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n o = numerator / denominator\n\n # scalar\n c += tl.sum(o * dx, axis=0)\n\n x_block_ptr += BLOCK_N * d\n array += BLOCK_N\n do_cos_block_ptr += BLOCK_N * 2 * d\n do_sin_block_ptr += BLOCK_N * 2 * d\n theta_block_ptr += BLOCK_N * d\n\n # reinit\n do_cos_block_ptr = (\n DO\n + offset_o\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n do_sin_block_ptr = (\n DO\n + offset_o\n + d\n + tl.arange(0, BLOCK_N)[:, None] * 2 * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n array = tl.arange(0, BLOCK_N)\n theta_block_ptr = (\n ThetaCache\n + offset_theta_cache\n + tl.arange(0, BLOCK_N) * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n\n for i in range(tl.cdiv(n, BLOCK_N)):\n n_mask = array < n\n mask = n_mask[:, None] & d_mask[None, :]\n\n do_cos = tl.load(do_cos_block_ptr, mask=mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=mask, other=0).to(tl.float32)\n theta = tl.load(theta_block_ptr, mask=mask, other=0).to(tl.float32)\n\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n if ACT != \"none\":\n x_block_ptr = (\n X\n + offset_x\n + i * BLOCK_N * d\n + tl.arange(0, BLOCK_N)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n x = tl.load(x_block_ptr, mask=mask, other=0).to(tl.float32)\n if ACT == \"relu\":\n dx = tl.where(x >= 0, dx, 0)\n elif ACT == \"sigmoid\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 - sigmoid)\n elif ACT == \"silu\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 + x * (1 - sigmoid))\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n o = numerator / denominator\n # scalar\n dx = o * dx - c * o\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=mask)\n\n dx_block_ptr += BLOCK_N * d\n array += BLOCK_N\n do_cos_block_ptr += BLOCK_N * 2 * d\n do_sin_block_ptr += BLOCK_N * 2 * d\n theta_block_ptr += BLOCK_N * d\n\n\ndef lrpe_cosine_md_bp_fwd_triton(x, theta, shape, l=0, act=\"none\", dim=None):\n assert act in ACT_SET, f\"act: {act} not in {ACT_SET}\"\n assert dim in [-2, None], \"dim must in [-2, None]\"\n\n b, h, n, d = x.shape\n e = theta.shape[-1]\n m = len(shape)\n\n output_shape = list(x.shape)\n output_shape[-1] *= 2\n\n o = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n theta_cache = torch.empty((h, n, d), dtype=torch.float32, device=theta.device)\n x_stat1 = torch.empty(b, h, d, dtype=x.dtype, device=x.device)\n x_stat2 = torch.empty(b, h, d, dtype=x.dtype, device=x.device)\n\n BLOCK_D = next_power_of_two(d)\n BLOCK_E = next_power_of_two(e)\n BLOCK_L = next_power_of_two(l) if l > 0 else 0\n\n def grid(meta):\n return (b, h)\n\n _lrpe_cosine_md_bp_fwd_triton[grid](\n x,\n theta,\n o,\n shape,\n theta_cache,\n x_stat1,\n x_stat2,\n b,\n h,\n n,\n l,\n d,\n e,\n m,\n act,\n BLOCK_D,\n BLOCK_E,\n BLOCK_L,\n )\n\n return o, theta_cache, x_stat1, x_stat2\n\n\ndef lrpe_cosine_md_bp_bwd_triton(\n x,\n theta,\n do,\n shape,\n theta_cache,\n x_stat1,\n x_stat2,\n l=0,\n act=\"none\",\n dim=None,\n **kwargs,\n):\n assert act in ACT_SET, f\"act: {act} not in {ACT_SET}\"\n assert dim in [-2, None], \"dim must in [-2, None]\"\n\n b, h, n, d = x.shape\n e = theta.shape[-1]\n m = len(shape)\n\n dx = torch.empty_like(x)\n BLOCK_D = next_power_of_two(d)\n BLOCK_E = next_power_of_two(e)\n BLOCK_L = next_power_of_two(l) if l > 0 else 0\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_md_bp_bwd_triton[grid](\n x,\n theta,\n do,\n dx,\n shape,\n theta_cache,\n x_stat1,\n x_stat2,\n b,\n h,\n n,\n l,\n d,\n e,\n m,\n act,\n BLOCK_D,\n BLOCK_E,\n BLOCK_L,\n )\n\n return dx\n", - "description_1": "Use triton language to implement two functions: '_lrpe_cosine_md_bp_fwd_triton' and '_lrpe_cosine_md_bp_bwd_triton'. The '_lrpe_cosine_md_bp_fwd_triton' function takes 15 parameters including tensors and constants for performing element-wise computations with optional activation functions, across multidimensional data represented in a forward pass of a model. It computes cosine and sine outputs using triton parallel processing. The '_lrpe_cosine_md_bp_bwd_triton' function, with the same number of parameters, calculates gradients for a backward pass in the model, applying various element-wise activation functions and utilizing previously stored state tensors.", - "description_2": "Use triton language to create a forward kernel '_lrpe_cosine_md_bp_fwd_triton' to compute element-wise operations on multidimensional data with activation functions, and a backward kernel '_lrpe_cosine_md_bp_bwd_triton' for calculating gradients, both designed for efficient parallel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"num_warps\": 2}),\n triton.Config({\"num_warps\": 4}),\n triton.Config({\"num_warps\": 8}),\n ],\n key=[\"h\", \"n\", \"d\", \"m\"],\n)\n@triton.jit\ndef _lrpe_cosine_md_cache_fwd_triton(\n X,\n Theta,\n O,\n Shape,\n ThetaCache,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n l: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n m: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_n = tl.program_id(2)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d + off_n * d\n offset_theta = off_h * e\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + off_n * 2 * d\n offset_d = m * e\n offset_theta_cache = off_h * n * d + off_n * d\n\n # compute from the last theta block\n x_block_ptr = X + offset_x + offset_d + tl.arange(0, BLOCK_E)\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK_E)\n o_cos_block_ptr = O + offset_o + offset_d + tl.arange(0, BLOCK_E)\n o_sin_block_ptr = O + offset_o + offset_d + d + tl.arange(0, BLOCK_E)\n theta_cache_block_ptr = (\n ThetaCache + offset_theta_cache + offset_d + tl.arange(0, BLOCK_E)\n )\n # triton only support load block at least 16 elements, use this to get shape\n shape_block_ptr = Shape + m + tl.arange(0, 16)\n shape_mask = tl.arange(0, 16) < 1\n # mask\n e_mask = tl.arange(0, BLOCK_E) < e\n\n c = off_n - l\n offset = 0\n\n n_mask = c >= 0\n theta_ = tl.load(theta_block_ptr, mask=e_mask & n_mask[None], other=0).to(\n tl.float32\n )\n\n # for softmax act, we should compute max and denominator first\n if ACT == \"softmax\":\n x_block_ptr_ = X + offset_x + tl.arange(0, BLOCK_D)\n d_mask = tl.arange(0, BLOCK_D) < d\n x_ = tl.load(x_block_ptr_, mask=d_mask, other=-float(\"inf\")).to(tl.float32)\n x_max = tl.max(x_, axis=0)\n numerator_ = tl.exp(x_ - x_max)\n denominator = tl.sum(numerator_)\n\n for i in range(m):\n # update block ptr\n shape_block_ptr -= 1\n x_block_ptr -= e\n o_cos_block_ptr -= e\n o_sin_block_ptr -= e\n offset_d -= e\n theta_cache_block_ptr -= e\n mask = ((offset_d + tl.arange(0, BLOCK_E)) < d) & e_mask\n\n # compute dim\n dim = tl.sum(tl.load(shape_block_ptr, mask=shape_mask, other=0).to(tl.int32))\n offset = c % dim\n c = c // dim\n\n # compute\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n x = tl.load(x_block_ptr, mask=mask, other=value).to(tl.float32)\n if ACT != \"none\":\n if ACT == \"relu\":\n x = tl.where(x >= 0, x, 0)\n elif ACT == \"sigmoid\":\n x = tl.sigmoid(x)\n elif ACT == \"silu\":\n x = x * tl.sigmoid(x)\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n x = numerator / denominator\n\n theta = theta_ * offset\n o_cos = x * tl.cos(theta)\n o_sin = x * tl.sin(theta)\n\n # save\n tl.store(o_cos_block_ptr, o_cos.to(o_cos_block_ptr.dtype.element_ty), mask=mask)\n tl.store(o_sin_block_ptr, o_sin.to(o_sin_block_ptr.dtype.element_ty), mask=mask)\n tl.store(\n theta_cache_block_ptr,\n theta.to(theta_cache_block_ptr.dtype.element_ty),\n mask=mask,\n )\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"num_warps\": 2}),\n triton.Config({\"num_warps\": 4}),\n triton.Config({\"num_warps\": 8}),\n ],\n key=[\"h\", \"n\", \"d\", \"m\"],\n)\n@triton.jit\ndef _lrpe_cosine_md_cache_bwd_triton(\n X,\n Theta,\n DO,\n DX,\n Shape,\n ThetaCache,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n m: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_n = tl.program_id(2)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d + off_n * d\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + off_n * 2 * d\n offset_theta_cache = off_h * n * d + off_n * d\n\n # compute in parallel\n theta_cache_block_ptr = ThetaCache + offset_theta_cache + tl.arange(0, BLOCK_D)\n dx_block_ptr = DX + offset_x + tl.arange(0, BLOCK_D)\n do_cos_block_ptr = DO + offset_o + tl.arange(0, BLOCK_D)\n do_sin_block_ptr = DO + offset_o + d + tl.arange(0, BLOCK_D)\n # mask\n d_mask = tl.arange(0, BLOCK_D) < d\n\n # compute\n theta = tl.load(theta_cache_block_ptr, mask=d_mask, other=0).to(tl.float32)\n do_cos = tl.load(do_cos_block_ptr, mask=d_mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=d_mask, other=0).to(tl.float32)\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n if ACT != \"none\":\n x_block_ptr = X + offset_x + tl.arange(0, BLOCK_D)\n\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n x = tl.load(x_block_ptr, mask=d_mask, other=value).to(tl.float32)\n\n if ACT == \"relu\":\n dx = tl.where(x >= 0, dx, 0)\n elif ACT == \"sigmoid\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 - sigmoid)\n elif ACT == \"silu\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 + x * (1 - sigmoid))\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - tl.max(x, axis=0)\n # softmax\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n o = numerator / denominator\n\n # scalar\n c = tl.sum(o * dx, axis=0)\n dx = o * dx - c * o\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=d_mask)\n\n\ndef lrpe_cosine_md_cache_fwd_triton(\n x, theta, shape, l=0, act=\"none\", dim=None, **kwargs\n):\n b, h, n, d = x.shape\n e = theta.shape[-1]\n m = len(shape)\n\n output_shape = list(x.shape)\n output_shape[-1] *= 2\n\n o = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n theta_cache = torch.empty((h, n, d), dtype=torch.float32, device=theta.device)\n BLOCK_D = next_power_of_two(d)\n BLOCK_E = next_power_of_two(e)\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_md_cache_fwd_triton[grid](\n x, theta, o, shape, theta_cache, b, h, n, l, d, e, m, act, BLOCK_D, BLOCK_E\n )\n\n return o, theta_cache\n\n\ndef lrpe_cosine_md_cache_bwd_triton(\n x, theta, do, shape, theta_cache, l=0, act=\"none\", dim=None, **kwargs\n):\n b, h, n, d = x.shape\n e = theta.shape[-1]\n m = len(shape)\n\n dx = torch.empty_like(x)\n BLOCK_D = next_power_of_two(d)\n BLOCK_E = next_power_of_two(e)\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_md_cache_bwd_triton[grid](\n x, theta, do, dx, shape, theta_cache, b, h, n, d, e, m, act, BLOCK_D, BLOCK_E\n )\n\n return dx\n", - "description_1": "Use triton language to implement forward and backward kernels for a cosine-based multi-dimensional cache operation. The forward kernel takes 15 parameters: X (input tensor), Theta (angle tensor), O (output tensor), Shape (shape tensor), ThetaCache (cache tensor), and several compile-time constants including batch size (b), number of heads (h), sequence length (n), initial sequence length (l), feature dimension (d), angle dimension (e), number of shape dimensions (m), activation function (ACT), and block sizes (BLOCK_D, BLOCK_E). The backward kernel takes similar parameters with the addition of DO (gradient of output) and DX (gradient of input). The forward function computes cosine and sine transformations of the input tensor based on the angles and stores the results in the output tensor. The backward function computes the gradient of the input tensor based on the gradient of the output tensor and the cached angles.", - "description_2": "Use triton language to create kernels for computing cosine transformations and their gradients for multi-dimensional data, with support for various activation functions and efficient memory access patterns.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _lrpe_cosine_md_fwd_triton(\n X,\n Theta,\n O,\n Shape,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n l: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n m: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_n = tl.program_id(2)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d + off_n * d\n offset_theta = off_h * e\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + off_n * 2 * d\n offset_d = m * e\n\n # compute from the last theta block\n x_block_ptr = X + offset_x + offset_d + tl.arange(0, BLOCK_E)\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK_E)\n o_cos_block_ptr = O + offset_o + offset_d + tl.arange(0, BLOCK_E)\n o_sin_block_ptr = O + offset_o + offset_d + d + tl.arange(0, BLOCK_E)\n # triton only support load block at least 16 elements, use this to get shape\n shape_block_ptr = Shape + m + tl.arange(0, 16)\n shape_mask = tl.arange(0, 16) < 1\n # mask\n e_mask = tl.arange(0, BLOCK_E) < e\n\n c = off_n - l\n offset = 0\n\n n_mask = c >= 0\n theta_ = tl.load(theta_block_ptr, mask=e_mask & n_mask[None], other=0).to(\n tl.float32\n )\n # this is equivalent to:\n # if off_n >= l:\n # theta_ = tl.load(theta_block_ptr, mask=e_mask, other=0).to(tl.float32)\n # else:\n # # concat((x, 0)) = concat(x * cos(0), x * sin(0))\n # theta_ = tl.zeros((e,), dtype=tl.float32)\n\n # for softmax act, we should compute max and denominator first\n if ACT == \"softmax\":\n x_block_ptr_ = X + offset_x + tl.arange(0, BLOCK_D)\n d_mask = tl.arange(0, BLOCK_D) < d\n x_ = tl.load(x_block_ptr_, mask=d_mask, other=-float(\"inf\")).to(tl.float32)\n x_max = tl.max(x_, axis=0)\n numerator_ = tl.exp(x_ - x_max)\n denominator = tl.sum(numerator_)\n\n for i in range(m):\n # update block ptr\n shape_block_ptr -= 1\n x_block_ptr -= e\n o_cos_block_ptr -= e\n o_sin_block_ptr -= e\n offset_d -= e\n mask = ((offset_d + tl.arange(0, BLOCK_E)) < d) & e_mask\n\n # compute dim\n dim = tl.sum(tl.load(shape_block_ptr, mask=shape_mask, other=0).to(tl.int32))\n offset = c % dim\n c = c // dim\n\n # compute\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n x = tl.load(x_block_ptr, mask=mask, other=value).to(tl.float32)\n if ACT != \"none\":\n if ACT == \"relu\":\n x = tl.where(x >= 0, x, 0)\n elif ACT == \"sigmoid\":\n x = tl.sigmoid(x)\n elif ACT == \"silu\":\n x = x * tl.sigmoid(x)\n elif ACT == \"softmax\":\n # for stable\n x_minus_max = x - x_max\n # softmax\n numerator = tl.exp(x_minus_max)\n x = numerator / denominator\n\n theta = theta_ * offset\n o_cos = x * tl.cos(theta)\n o_sin = x * tl.sin(theta)\n\n # save\n tl.store(o_cos_block_ptr, o_cos.to(o_cos_block_ptr.dtype.element_ty), mask=mask)\n tl.store(o_sin_block_ptr, o_sin.to(o_sin_block_ptr.dtype.element_ty), mask=mask)\n\n\n@triton.jit\ndef _lrpe_cosine_md_bwd_triton(\n X,\n Theta,\n DO,\n DX,\n Shape,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n l: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n m: tl.constexpr,\n ACT: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_E: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_h = tl.program_id(1)\n off_n = tl.program_id(2)\n # compute offset\n offset_x = off_b * h * n * d + off_h * n * d + off_n * d\n offset_theta = off_h * e\n offset_o = off_b * h * n * 2 * d + off_h * n * 2 * d + off_n * 2 * d\n offset_d = m * e\n\n # compute from the last theta block\n theta_block_ptr = Theta + offset_theta + tl.arange(0, BLOCK_E)\n dx_block_ptr = DX + offset_x + offset_d + tl.arange(0, BLOCK_E)\n x_block_ptr = X + offset_x + offset_d + tl.arange(0, BLOCK_E)\n do_cos_block_ptr = DO + offset_o + offset_d + tl.arange(0, BLOCK_E)\n do_sin_block_ptr = DO + offset_o + offset_d + d + tl.arange(0, BLOCK_E)\n # triton only support load block at least 16 elements, use this to get shape\n shape_block_ptr = Shape + m + tl.arange(0, 16)\n shape_mask = tl.arange(0, 16) < 1\n # mask\n e_mask = tl.arange(0, BLOCK_E) < e\n\n c = off_n - l\n offset = 0\n\n n_mask = c >= 0\n theta_ = tl.load(theta_block_ptr, mask=e_mask & n_mask[None], other=0).to(\n tl.float32\n )\n # this is equivalent to:\n # if off_n >= l:\n # theta_ = tl.load(theta_block_ptr).to(tl.float32)\n # else:\n # # concat((x, 0)) = concat(x * cos(0), x * sin(0))\n # theta_ = tl.zeros((e,), dtype=tl.float32)\n\n for i in range(m):\n # update block ptr\n shape_block_ptr -= 1\n dx_block_ptr -= e\n x_block_ptr -= e\n do_cos_block_ptr -= e\n do_sin_block_ptr -= e\n offset_d -= e\n mask = ((offset_d + tl.arange(0, BLOCK_E)) < d) & e_mask\n\n # compute dim\n dim = tl.sum(tl.load(shape_block_ptr, mask=shape_mask, other=0).to(tl.int32))\n offset = c % dim\n c = c // dim\n\n # compute\n do_cos = tl.load(do_cos_block_ptr, mask=mask, other=0).to(tl.float32)\n do_sin = tl.load(do_sin_block_ptr, mask=mask, other=0).to(tl.float32)\n theta = theta_ * offset\n dx = do_cos * tl.cos(theta) + do_sin * tl.sin(theta)\n\n if ACT != \"none\":\n if ACT == \"softmax\":\n value = -float(\"inf\")\n else:\n value = 0\n\n x = tl.load(x_block_ptr, mask=mask, other=value).to(tl.float32)\n\n if ACT == \"relu\":\n dx = tl.where(x >= 0, dx, 0)\n elif ACT == \"sigmoid\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 - sigmoid)\n elif ACT == \"silu\":\n sigmoid = tl.sigmoid(x)\n dx = dx * sigmoid * (1 + x * (1 - sigmoid))\n\n tl.store(dx_block_ptr, dx.to(dx_block_ptr.dtype.element_ty), mask=mask)\n\n # for softmax, since s involves dx, we shoud compute again\n if ACT == \"softmax\":\n x_block_ptr_ = X + offset_x + tl.arange(0, BLOCK_D)\n d_mask = tl.arange(0, BLOCK_D) < d\n x_ = tl.load(x_block_ptr_, mask=d_mask, other=-float(\"inf\")).to(tl.float32)\n x_minus_max = x_ - tl.max(x_, axis=0)\n numerator = tl.exp(x_minus_max)\n denominator = tl.sum(numerator)\n o = numerator / denominator\n\n dx_block_ptr_ = DX + offset_x + tl.arange(0, BLOCK_D)\n dx_ = tl.load(dx_block_ptr_, mask=d_mask, other=0).to(tl.float32)\n\n # compute\n s = tl.sum(o * dx_, axis=0)\n dx_ = o * dx_ - s * o\n tl.store(dx_block_ptr_, dx_.to(dx_block_ptr_.dtype.element_ty), mask=d_mask)\n\n\ndef lrpe_cosine_md_fwd_triton(x, theta, shape, l=0, act=\"none\", dim=None):\n b, h, n, d = x.shape\n e = theta.shape[-1]\n m = len(shape)\n\n output_shape = list(x.shape)\n output_shape[-1] *= 2\n\n o = torch.empty(output_shape, dtype=x.dtype, device=x.device)\n BLOCK_D = next_power_of_two(d)\n BLOCK_E = next_power_of_two(e)\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_md_fwd_triton[grid](\n x, theta, o, shape, b, h, n, l, d, e, m, act, BLOCK_D, BLOCK_E\n )\n\n return o\n\n\ndef lrpe_cosine_md_bwd_triton(x, theta, do, shape, l=0, act=\"none\", dim=None, **kwargs):\n b, h, n, d = x.shape\n e = theta.shape[-1]\n m = len(shape)\n\n dx = torch.empty_like(x)\n BLOCK_D = next_power_of_two(d)\n BLOCK_E = next_power_of_two(e)\n\n def grid(meta):\n return (b, h, n)\n\n _lrpe_cosine_md_bwd_triton[grid](\n x, theta, do, dx, shape, b, h, n, l, d, e, m, act, BLOCK_D, BLOCK_E\n )\n\n return dx\n\n\ndef lrpe_cosine_md_triton(x, theta, shape, l=0, act=\"none\", dim=None, **kwargs):\n shape = torch.tensor(shape, dtype=torch.int32, device=x.device)\n return LrpeCosineMdTriton.apply(x, theta, shape, l, act, dim)\n\n\nclass LrpeCosineMdTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, theta, shape, l=0, act=\"none\", dim=None):\n o = lrpe_cosine_md_fwd_triton(x, theta, shape, l, act, dim)\n\n ctx.save_for_backward(x, theta, shape)\n ctx.l = l\n ctx.act = act\n ctx.dim = dim\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n x, theta, shape = ctx.saved_tensors\n l = ctx.l\n act = ctx.act\n dim = ctx.dim\n\n dx = lrpe_cosine_md_bwd_triton(x, theta, do, shape, l, act, dim)\n\n return dx, None, None, None, None, None\n", - "description_1": "Use triton language to implement two kernels and their respective Python wrapper functions for the forward and backward operations of a cosine-modulated function. The first kernel '_lrpe_cosine_md_fwd_triton' computes a forward operation, which takes 14 parameters: X (input tensor), Theta (modulation tensor), O (output tensor), Shape (shape tensor), and several block and configuration parameters. The second kernel '_lrpe_cosine_md_bwd_triton' computes a backward operation and takes 15 parameters: X, Theta, DO (derivative of output), DX (derivative of input), Shape, and several block and configuration parameters. Both kernels support various activation functions and perform computations in parallel across a grid specified by batch size b, head count h, and sequence length n. The results are used in the 'LrpeCosineMdTriton' class to perform autograd operations with Torch.", - "description_2": "Use triton language to implement forward and backward cosine modulation operations with parallel execution over grid dimensions specified by batch size, head count, and sequence length.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom xopes.utils import next_power_of_two\n\n@triton.jit\ndef _gumbel_multinomial_reduce_triton(\n Sample, # b k m\n Lse, # b m\n Sample_out, # b k\n seed,\n b: tl.constexpr,\n k: tl.constexpr,\n m: tl.constexpr, # num samples\n top_k: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_k = tl.program_id(1)\n # compute offset\n offset_b = off_b\n offset_k = off_k\n offset_sample = offset_b * k * m + offset_k * m\n offset_sample_out = offset_b * k + offset_k\n offset_lse = offset_b * m\n # mask\n m_mask = tl.arange(0, BLOCK_M) < m\n\n # 1, 1\n sample_out_block_ptr = Sample_out + offset_sample_out + tl.arange(0, 1)[:, None] * k\n # 1, m\n lse_ptr = Lse + offset_lse + tl.arange(0, BLOCK_M)[None, :]\n # for random\n # 1, 1\n rand_block_ptr = tl.zeros([1, 1], dtype=tl.float32)\n\n value = -float(\"inf\")\n\n logits = tl.load(lse_ptr, mask=m_mask[None, :], other=value)\n if top_k != -1:\n logits_ = tl.sort(logits, dim=1, descending=True)\n # triton doesn't support index, this is equivalent to logits_mask = logits >= logits_[:, top_k - 1]\n index = (\n tl.full([1, 1], 1, tl.int1) & (tl.arange(0, BLOCK_M) == top_k - 1)[None, :]\n )\n threshold = tl.sum(tl.where(index, logits_, 0))\n logits_mask = logits >= threshold\n logits = tl.where(logits_mask, logits, value)\n # use Gumbel Max to sample\n # sample from p1, ..., pk is equivalent to sample\n # argmax {log pi - log(-log(ui))} = argmax {logits - log(-log(ui))}, ui ~ U(0,1)\n # (1, 1)\n u = tl.rand(seed, rand_block_ptr)\n stat = logits - tl.log(-tl.log(u))\n # (1,)\n index = tl.argmax(stat, axis=1)\n\n # 1, 1\n sample_index_block_ptr = Sample + offset_sample + index[:, None]\n sample_out = tl.load(sample_index_block_ptr)\n tl.store(\n sample_out_block_ptr,\n sample_out.to(sample_out_block_ptr.dtype.element_ty),\n )\n\n\ndef gumbel_multinomial_reduce_triton(sample, lse, top_k=-1):\n \"\"\"\n sample: b k m\n lse: b m\n \"\"\"\n b, k, m = sample.shape\n\n def grid(meta):\n return (b, k)\n\n sample_out = torch.empty((b, k), dtype=torch.int32, device=sample.device)\n seed = 0\n BLOCK_M = next_power_of_two(m)\n\n _gumbel_multinomial_reduce_triton[grid](\n sample, lse, sample_out, seed, b, k, m, top_k, BLOCK_M\n )\n\n return sample_out.to(torch.int64)\n", - "description_1": "Use triton language to implement a kernel function '_gumbel_multinomial_reduce_triton' with 9 parameters: Sample (input tensor of shape b k m), Lse (input tensor of shape b m), Sample_out (output tensor of shape b k), seed (random seed), b (batch size as constexpr), k (number of categories as constexpr), m (number of samples as constexpr), top_k (top-k selection as constexpr), and BLOCK_M (block size as constexpr). The kernel performs Gumbel-Max sampling with optional top-k filtering. The function 'gumbel_multinomial_reduce_triton' is a wrapper that prepares inputs and calls the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a Gumbel-Max sampling kernel with optional top-k filtering, and a wrapper function to execute it with specified grid dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _online_multinomial_triton(\n X,\n W,\n Sample,\n Lse,\n Max_value,\n seed,\n load_lse: tl.constexpr,\n load_max_value: tl.constexpr,\n b: tl.constexpr,\n d: tl.constexpr,\n v: tl.constexpr,\n k: tl.constexpr, # num samples\n BLOCK_K: tl.constexpr,\n BLOCK_B: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_V: tl.constexpr,\n):\n off_b = tl.program_id(0)\n # compute offset\n offset_b = off_b * BLOCK_B\n offset_x = offset_b * d\n offset_sample = offset_b * k\n # mask\n b_mask = (offset_b + tl.arange(0, BLOCK_B)) < b\n k_mask = tl.arange(0, BLOCK_K) < k\n\n # BLOCK_B, k\n sample_block_ptr = (\n Sample\n + offset_sample\n + tl.arange(0, BLOCK_B)[:, None] * k\n + tl.arange(0, BLOCK_K)[None, :]\n )\n # for random\n # BLOCK_B, 1, k\n rand_block_ptr1 = tl.zeros([BLOCK_B, 1, BLOCK_K], dtype=tl.float32)\n # BLOCK_B, k\n rand_block_ptr2 = tl.zeros([BLOCK_B, BLOCK_K], dtype=tl.float32)\n\n value = -float(\"inf\")\n\n if load_lse:\n lse_ptr = Lse + offset_b + tl.arange(0, BLOCK_B)[:, None]\n lse = tl.load(lse_ptr, mask=b_mask, other=value)\n else:\n lse = tl.full([BLOCK_B, 1], value=value, dtype=tl.float32)\n\n if load_max_value:\n max_valuek_ptr = Max_value + offset_b + tl.arange(0, BLOCK_B)[:, None]\n max_value = tl.load(max_valuek_ptr, mask=b_mask, other=value)\n else:\n max_value = tl.full([BLOCK_B, 1], value=value, dtype=tl.float32)\n\n sample = tl.zeros([BLOCK_B, BLOCK_K], dtype=tl.int32)\n\n for i in range(tl.cdiv(v, BLOCK_V)):\n logits = tl.zeros([BLOCK_B, BLOCK_V], dtype=tl.float32)\n v_mask = (i * BLOCK_V + tl.arange(0, BLOCK_V)) < v\n\n # BLOCK_B, BLOCK_D\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_B)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n\n # BLOCK_D, BLOCK_V\n w_block_ptr = (\n W\n + tl.arange(0, BLOCK_D)[:, None] * v\n + i * BLOCK_V\n + tl.arange(0, BLOCK_V)[None, :]\n )\n\n for j in range(tl.cdiv(d, BLOCK_D)):\n d_mask = (j * BLOCK_D + tl.arange(0, BLOCK_D)) < d\n x = tl.load(x_block_ptr, mask=b_mask[:, None] * d_mask[None, :], other=0)\n w = tl.load(w_block_ptr, mask=d_mask[:, None] * v_mask[None, :], other=0)\n logits += tl.dot(x, w)\n\n x_block_ptr += BLOCK_D\n w_block_ptr += BLOCK_D * v\n\n logits = tl.where(v_mask[None, :], logits, value)\n\n # sample by multinomial\n max_value_curr = tl.max(logits, axis=1)[:, None]\n numerator = tl.exp(logits - max_value_curr)\n denominator = tl.sum(numerator, axis=1)[:, None]\n # lse(x) = lse(x - a) + a\n lse_curr = tl.log(denominator) + max_value_curr\n prob_curr = numerator / denominator\n # BLOCK_B, BLOCK_V\n prob_cum_curr = tl.cumsum(prob_curr, axis=1)\n # sample by uniform\n # BLOCK_B, 1, k\n p = tl.rand(seed, rand_block_ptr1)\n # find k, such that p1 + ... + p(k-1) < p <= p1 + ... + pk\n # e.g.\n # prob = [0.1, 0.2, 0.6, 0.1], p = 0.35 => k = 2\n # prob_cum = [0.1, 0.3, 0.9, 1.0]\n # upper = [0, 0, 1, 1]\n # (BLOCK_B, BLOCK_V, k)\n upper = (prob_cum_curr[:, :, None] >= p).to(tl.int32)\n # (BLOCK_B, k)\n sample_curr = i * BLOCK_V + tl.argmax(upper, axis=1)\n\n # sample by binomial\n # m = max(ma, mb)\n # lse(a, b) = log(exp(lse(a)) + exp(lse(b))) = log(exp(lse(a) - m) + exp(lse(b) - m)) + m\n max_value = tl.where(max_value > max_value_curr, max_value, max_value_curr)\n lse = tl.log(tl.exp(lse - max_value) + tl.exp(lse_curr - max_value)) + max_value\n # BLOCK_B, 1\n prob = tl.exp(lse_curr - lse)\n # x = 1: sample_curr\n # x = 0: sample\n # BLOCK_B, k\n index = tl.rand(seed, rand_block_ptr2) < prob\n sample = tl.where(\n index,\n sample_curr,\n sample,\n )\n\n tl.store(\n sample_block_ptr,\n sample.to(sample_block_ptr.dtype.element_ty),\n mask=k_mask[None, :],\n )\n\ndef online_multinomial_triton(x, W, num_samples, lse=None, max_value=None):\n \"\"\"\n x: b d\n W: d v\n lse: b\n max_value: b\n \"\"\"\n b, d = x.shape\n d, v = W.shape\n sample = torch.empty((b, num_samples), dtype=torch.int32, device=x.device)\n load_lse = lse is not None\n load_max_value = max_value is not None\n BLOCK_K = max(16, next_power_of_two(num_samples))\n seed = 0\n\n def grid(meta):\n return (triton.cdiv(b, meta[\"BLOCK_B\"]),)\n\n _online_multinomial_triton[grid](\n x,\n W,\n sample,\n lse,\n max_value,\n seed,\n load_lse,\n load_max_value,\n b,\n d,\n v,\n num_samples,\n BLOCK_K,\n )\n\n return sample.to(torch.int64)\n", - "description_1": "Use triton language to implement a multinomial sampling kernel. The kernel '_online_multinomial_triton' takes 17 parameters: X (input tensor), W (weight tensor), Sample (output tensor), Lse (log-sum-exp tensor), Max_value (max value tensor), seed (random seed), load_lse (boolean to load Lse), load_max_value (boolean to load Max_value), b (batch size), d (dimension size), v (vocab size), k (number of samples), BLOCK_K, BLOCK_B, BLOCK_D, BLOCK_V (block sizes for kernel execution). The function 'online_multinomial_triton' is a wrapper that prepares the input tensors and calls the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel for multinomial sampling with parameters for input, weights, output, and execution configuration, and a wrapper function to manage inputs and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_B\": 32, \"BLOCK_D\": 32, \"num_warps\": 2}),\n triton.Config({\"BLOCK_B\": 64, \"BLOCK_D\": 64, \"num_warps\": 4}),\n triton.Config({\"BLOCK_B\": 128, \"BLOCK_D\": 128, \"num_warps\": 8}),\n ],\n key=[\"b\", \"d\", \"v\", \"k\"],\n)\n@triton.jit\ndef _parallel_gumbel_multinomial_triton(\n X,\n W,\n Sample,\n Lse,\n Lse_cache,\n seed,\n load_lse: tl.constexpr,\n b: tl.constexpr,\n d: tl.constexpr,\n v: tl.constexpr,\n k: tl.constexpr, # num samples\n top_k: tl.constexpr,\n BLOCK_V: tl.constexpr,\n NUM_BLOCK_V: tl.constexpr,\n BLOCK_K: tl.constexpr,\n BLOCK_B: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_v = tl.program_id(1)\n # compute offset\n offset_b = off_b * BLOCK_B\n offset_x = offset_b * d\n offset_v = off_v * BLOCK_V\n offset_sample = offset_b * NUM_BLOCK_V * k + off_v * k\n offset_lse = offset_b * NUM_BLOCK_V + off_v\n # mask\n b_mask = (offset_b + tl.arange(0, BLOCK_B)) < b\n\n # 1, BLOCK_K\n sample_block_ptr = (\n Sample\n + offset_sample\n + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V * k\n + tl.arange(0, k)[None, :]\n )\n # BLOCK_B, BLOCK_D\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_B)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n\n # BLOCK_D, BLOCK_V\n w_block_ptr = (\n W\n + offset_v\n + tl.arange(0, BLOCK_D)[:, None] * v\n + tl.arange(0, BLOCK_V)[None, :]\n )\n lse_cache_ptr = (\n Lse_cache + offset_lse + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V\n )\n # for random\n # BLOCK_B, 1, k\n rand_block_ptr = tl.zeros([BLOCK_B, 1, k], dtype=tl.float32)\n\n value = -float(\"inf\")\n\n if load_lse:\n lse_ptr = Lse + offset_b + tl.arange(0, BLOCK_B)[:, None]\n lse = tl.load(lse_ptr, mask=b_mask, other=value)\n else:\n lse = tl.full([BLOCK_B, 1], value=value, dtype=tl.float32)\n\n logits = tl.zeros([BLOCK_B, BLOCK_V], dtype=tl.float32)\n v_mask = (offset_v + tl.arange(0, BLOCK_V)) < v\n\n for i in range(tl.cdiv(d, BLOCK_D)):\n d_mask = (i * BLOCK_D + tl.arange(0, BLOCK_D)) < d\n x = tl.load(x_block_ptr, mask=b_mask[:, None] & d_mask[None, :], other=0)\n w = tl.load(w_block_ptr, mask=d_mask[:, None] & v_mask[None, :], other=0)\n logits = tl.dot(x, w, logits)\n\n x_block_ptr += BLOCK_D\n w_block_ptr += BLOCK_D * v\n\n logits = tl.where(b_mask[:, None] & v_mask[None, :], logits, value)\n\n if top_k != -1:\n logits_ = tl.sort(logits, dim=1, descending=True)\n # triton doesn't support index, this is equivalent to logits_mask = logits >= logits_[:, top_k - 1]\n index = (\n tl.full([BLOCK_B, 1], 1, tl.int1)\n & (tl.arange(0, BLOCK_V) == top_k - 1)[None, :]\n )\n threshold = tl.sum(tl.where(index, logits_, 0))\n logits_mask = logits >= threshold\n logits = tl.where(logits_mask, logits, value)\n # use Gumbel Max to sample\n # sample from p1, ..., pk is equivalent to sample\n # argmax {log pi - log(-log(ui))} = argmax {logits - log(-log(ui))}, ui ~ U(0,1)\n # (BLOCK_B, 1, k)\n u = tl.rand(seed, rand_block_ptr)\n stat = logits[:, :, None] - tl.log(-tl.log(u))\n # (BLOCK_B, k)\n sample = offset_v + tl.argmax(stat, axis=1)\n\n # compute lse\n max_value = tl.max(logits, axis=1)[:, None]\n numerator = tl.exp(logits - max_value)\n denominator = tl.sum(numerator, axis=1)[:, None]\n # lse(x) = lse(x - a) + a\n lse = tl.log(denominator) + max_value\n\n tl.store(\n sample_block_ptr,\n sample.to(sample_block_ptr.dtype.element_ty),\n mask=b_mask[:, None],\n )\n tl.store(\n lse_cache_ptr, lse.to(lse_cache_ptr.dtype.element_ty), mask=b_mask[:, None]\n )\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"num_warps\": 2}),\n triton.Config({\"num_warps\": 4}),\n triton.Config({\"num_warps\": 8}),\n ],\n key=[\"b\", \"NUM_BLOCK_V\", \"k\"],\n)\n@triton.jit\ndef _parallel_gumbel_multinomial_reduce_triton(\n Sample,\n Lse_cache,\n Sample_out,\n Lse_out,\n seed,\n output_lse: tl.constexpr,\n b: tl.constexpr,\n d: tl.constexpr,\n v: tl.constexpr,\n k: tl.constexpr, # num samples\n top_k: tl.constexpr,\n BLOCK_V: tl.constexpr,\n NUM_BLOCK_V: tl.constexpr,\n NUM_BLOCK_V_PAD: tl.constexpr,\n BLOCK_K: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_k = tl.program_id(1)\n # compute offset\n offset_b = off_b\n offset_k = off_k\n offset_sample = offset_b * NUM_BLOCK_V * k + offset_k\n offset_sample_out = offset_b * k + offset_k\n offset_lse = offset_b * NUM_BLOCK_V\n # mask\n num_block_v_mask = tl.arange(0, NUM_BLOCK_V_PAD) < NUM_BLOCK_V\n\n # 1, 1\n sample_out_block_ptr = Sample_out + offset_sample_out + tl.arange(0, 1)[:, None] * k\n # 1, NUM_BLOCK_V\n lse_cache_ptr = Lse_cache + offset_lse + tl.arange(0, NUM_BLOCK_V_PAD)[None, :]\n # for random\n # 1, 1\n rand_block_ptr = tl.zeros([1, 1], dtype=tl.float32)\n\n value = -float(\"inf\")\n\n logits = tl.load(lse_cache_ptr, mask=num_block_v_mask[None, :], other=value)\n if top_k != -1:\n logits_ = tl.sort(logits, dim=1, descending=True)\n # triton doesn't support index, this is equivalent to logits_mask = logits >= logits_[:, top_k - 1]\n index = (\n tl.full([1, 1], 1, tl.int1)\n & (tl.arange(0, NUM_BLOCK_V_PAD) == top_k - 1)[None, :]\n )\n threshold = tl.sum(tl.where(index, logits_, 0))\n logits_mask = logits >= threshold\n logits = tl.where(logits_mask, logits, value)\n # use Gumbel Max to sample\n # sample from p1, ..., pk is equivalent to sample\n # argmax {log pi - log(-log(ui))} = argmax {logits - log(-log(ui))}, ui ~ U(0,1)\n # (1, 1)\n u = tl.rand(seed, rand_block_ptr)\n stat = logits - tl.log(-tl.log(u))\n # (1,)\n index = tl.argmax(stat, axis=1)\n\n # 1, 1\n sample_index_block_ptr = Sample + offset_sample + index[:, None] * k\n sample_out = tl.load(sample_index_block_ptr)\n tl.store(\n sample_out_block_ptr,\n sample_out.to(sample_out_block_ptr.dtype.element_ty),\n )\n\n if output_lse: # only save once\n if off_k == 0: # work around compiler bug\n lse_out_block_ptr = Lse_out + offset_b + tl.arange(0, 1)[:, None]\n max_value = tl.max(logits, axis=1)[:, None]\n numerator = tl.exp(logits - max_value)\n denominator = tl.sum(numerator, axis=1)[:, None]\n # lse(x) = lse(x - a) + a\n lse = tl.log(denominator) + max_value\n tl.store(lse_out_block_ptr, lse.to(lse_out_block_ptr.dtype.element_ty))\n\n\ndef parallel_gumbel_multinomial_triton(\n x, W, num_samples=1, lse=None, output_lse=False, top_k=-1\n):\n \"\"\"\n x: b d or b 1 d\n W: d v\n lse: b\n max_value: b\n \"\"\"\n assert top_k in [-1, 1], \"top_k should be -1 or 1\"\n b = x.shape[0]\n d = x.shape[1]\n d, v = W.shape\n x = x.contiguous()\n W = W.contiguous()\n\n # BLOCK_V = min(128, v)\n BLOCK_V = 128\n NUM_BLOCK_V = (v + BLOCK_V - 1) // BLOCK_V\n sample = torch.empty(\n (b, NUM_BLOCK_V, num_samples), dtype=torch.int32, device=x.device\n )\n lse_cache = torch.empty((b, NUM_BLOCK_V), dtype=torch.float32, device=x.device)\n\n load_lse = lse is not None\n BLOCK_K = max(16, next_power_of_two(num_samples))\n seed = 0\n\n def grid(meta):\n return (triton.cdiv(b, meta[\"BLOCK_B\"]), triton.cdiv(v, BLOCK_V))\n\n _parallel_gumbel_multinomial_triton[grid](\n x,\n W,\n sample,\n lse,\n lse_cache,\n seed,\n load_lse,\n b,\n d,\n v,\n num_samples,\n top_k,\n BLOCK_V,\n NUM_BLOCK_V,\n BLOCK_K,\n )\n\n def grid(meta):\n return (b, num_samples)\n\n sample_out = torch.empty((b, num_samples), dtype=torch.int32, device=x.device)\n lse_out = torch.empty((b, 1), dtype=torch.float32, device=x.device)\n\n NUM_BLOCK_V_PAD = next_power_of_two(NUM_BLOCK_V)\n _parallel_gumbel_multinomial_reduce_triton[grid](\n sample,\n lse_cache,\n sample_out,\n lse_out,\n seed,\n output_lse,\n b,\n d,\n v,\n num_samples,\n top_k,\n BLOCK_V,\n NUM_BLOCK_V,\n NUM_BLOCK_V_PAD,\n BLOCK_K,\n )\n\n return sample_out.to(torch.int64), lse_out\n", - "description_1": "Use triton language to create two kernels: (1) '_parallel_gumbel_multinomial_triton' which performs parallel sampling from a multinomial distribution using a Gumbel-max trick, taking in parameters like tensors X, W, Sample, and Lse, constants for block dimensions, and configuration settings like the number of samples (k), and (2) '_parallel_gumbel_multinomial_reduce_triton' which reduces the sampled values and possibly outputs the log-sum-exp of the distribution. The main function 'parallel_gumbel_multinomial_triton' calls these kernels to perform the sampling and reduction tasks.", - "description_2": "Use triton language to implement parallel multinomial sampling using the Gumbel-max trick with kernels for sampling and reducing sample results, employing autotuning for optimization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n {\n \"BLOCK_B\": [32, 64, 128],\n \"BLOCK_D\": [32, 64, 128],\n \"num_warps\": [2, 4, 8],\n },\n key=[\"b\", \"d\", \"v\", \"k\"],\n)\n@triton.jit\ndef _parallel_multinomial_triton(\n X,\n W,\n Sample,\n Lse,\n Max_value,\n Lse_cache,\n Max_value_cache,\n seed,\n load_lse: tl.constexpr,\n load_max_value: tl.constexpr,\n b: tl.constexpr,\n d: tl.constexpr,\n v: tl.constexpr,\n k: tl.constexpr,\n BLOCK_V: tl.constexpr,\n NUM_BLOCK_V: tl.constexpr,\n BLOCK_K: tl.constexpr,\n BLOCK_B: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_v = tl.program_id(1)\n offset_b = off_b * BLOCK_B\n offset_x = offset_b * d\n offset_v = off_v * BLOCK_V\n offset_sample = offset_b * NUM_BLOCK_V * k + off_v * k\n offset_lse_max_value = offset_b * NUM_BLOCK_V + off_v\n b_mask = (offset_b + tl.arange(0, BLOCK_B)) < b\n\n sample_block_ptr = (\n Sample\n + offset_sample\n + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V * k\n + tl.arange(0, k)[None, :]\n )\n x_block_ptr = (\n X\n + offset_x\n + tl.arange(0, BLOCK_B)[:, None] * d\n + tl.arange(0, BLOCK_D)[None, :]\n )\n w_block_ptr = (\n W\n + offset_v\n + tl.arange(0, BLOCK_D)[:, None] * v\n + tl.arange(0, BLOCK_V)[None, :]\n )\n lse_cache_ptr = (\n Lse_cache + offset_lse_max_value + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V\n )\n max_value_cache_ptr = (\n Max_value_cache\n + offset_lse_max_value\n + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V\n )\n\n rand_block_ptr = tl.zeros([BLOCK_B, 1, k], dtype=tl.float32)\n value = -float(\"inf\")\n\n if load_lse:\n lse_ptr = Lse + offset_b + tl.arange(0, BLOCK_B)[:, None]\n lse = tl.load(lse_ptr, mask=b_mask, other=value)\n else:\n lse = tl.full([BLOCK_B, 1], value=value, dtype=tl.float32)\n\n if load_max_value:\n max_valuek_ptr = Max_value + offset_b + tl.arange(0, BLOCK_B)[:, None]\n max_value = tl.load(max_valuek_ptr, mask=b_mask, other=value)\n else:\n max_value = tl.full([BLOCK_B, 1], value=value, dtype=tl.float32)\n\n logits = tl.zeros([BLOCK_B, BLOCK_V], dtype=tl.float32)\n v_mask = (offset_v + tl.arange(0, BLOCK_V)) < v\n\n for i in range(tl.cdiv(d, BLOCK_D)):\n d_mask = (i * BLOCK_D + tl.arange(0, BLOCK_D)) < d\n x = tl.load(x_block_ptr, mask=b_mask[:, None] * d_mask[None, :], other=0)\n w = tl.load(w_block_ptr, mask=d_mask[:, None] * v_mask[None, :], other=0)\n logits = tl.dot(x, w, logits)\n\n x_block_ptr += BLOCK_D\n w_block_ptr += BLOCK_D * v\n\n logits = tl.where(v_mask[None, :], logits, value)\n\n max_value_curr = tl.max(logits, axis=1)[:, None]\n numerator = tl.exp(logits - max_value_curr)\n denominator = tl.sum(numerator, axis=1)[:, None]\n lse_curr = tl.log(denominator) + max_value_curr\n prob_curr = numerator / denominator\n prob_cum_curr = tl.cumsum(prob_curr, axis=1)\n\n p = tl.rand(seed, rand_block_ptr)\n upper = (prob_cum_curr[:, :, None] >= p).to(tl.int32)\n sample = offset_v + tl.argmax(upper, axis=1)\n\n tl.store(\n sample_block_ptr,\n sample.to(sample_block_ptr.dtype.element_ty),\n mask=b_mask[:, None],\n )\n tl.store(\n lse_cache_ptr, lse_curr.to(lse_cache_ptr.dtype.element_ty), mask=b_mask[:, None]\n )\n tl.store(\n max_value_cache_ptr,\n max_value_curr.to(max_value_cache_ptr.dtype.element_ty),\n mask=b_mask[:, None],\n )\n\n\n@triton.autotune(\n {\n \"BLOCK_B\": [32, 64, 128],\n \"num_warps\": [2, 4, 8],\n },\n key=[\"b\", \"NUM_BLOCK_V\", \"k\"],\n)\n@triton.jit\ndef _parallel_multinomial_reduce_triton(\n Sample,\n Lse_cache,\n Max_value_cache,\n Sample_out,\n seed,\n load_lse: tl.constexpr,\n load_max_value: tl.constexpr,\n b: tl.constexpr,\n d: tl.constexpr,\n v: tl.constexpr,\n k: tl.constexpr,\n BLOCK_V: tl.constexpr,\n NUM_BLOCK_V: tl.constexpr,\n BLOCK_K: tl.constexpr,\n BLOCK_B: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_k = tl.program_id(1)\n offset_b = off_b * BLOCK_B\n offset_k = off_k\n offset_sample = offset_b * NUM_BLOCK_V * k + offset_k\n offset_sample_out = offset_b * k + offset_k\n offset_lse_max_value = offset_b * NUM_BLOCK_V\n b_mask = (offset_b + tl.arange(0, BLOCK_B)) < b\n\n sample_out_block_ptr = (\n Sample_out + offset_sample_out + tl.arange(0, BLOCK_B)[:, None] * k\n )\n lse_cache_ptr = (\n Lse_cache\n + offset_lse_max_value\n + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V\n + tl.arange(0, NUM_BLOCK_V)[None, :]\n )\n max_value_cache_ptr = (\n Max_value_cache\n + offset_lse_max_value\n + tl.arange(0, BLOCK_B)[:, None] * NUM_BLOCK_V\n + tl.arange(0, NUM_BLOCK_V)[None, :]\n )\n\n rand_block_ptr = tl.zeros([BLOCK_B, 1], dtype=tl.float32)\n value = -float(\"inf\")\n\n logits = tl.load(lse_cache_ptr, mask=b_mask[:, None], other=value)\n max_value = tl.load(max_value_cache_ptr, mask=b_mask[:, None], other=value)\n\n max_value_curr = tl.max(logits, axis=1)[:, None]\n numerator = tl.exp(logits - max_value_curr)\n denominator = tl.sum(numerator, axis=1)[:, None]\n prob_curr = numerator / denominator\n prob_cum_curr = tl.cumsum(prob_curr, axis=1)\n\n p = tl.rand(seed, rand_block_ptr)\n upper = (prob_cum_curr >= p).to(tl.int32)\n index = tl.argmax(upper, axis=1)\n\n sample_index_block_ptr = Sample + offset_sample + index[:, None] * k\n sample_out = tl.load(sample_index_block_ptr, mask=b_mask[:, None])\n\n tl.store(\n sample_out_block_ptr,\n sample_out.to(sample_out_block_ptr.dtype.element_ty),\n mask=b_mask[:, None],\n )\n\n\ndef parallel_multinomial_triton(x, W, num_samples, lse=None, max_value=None):\n b, d = x.shape\n d, v = W.shape\n\n BLOCK_V = 128\n NUM_BLOCK_V = (v + BLOCK_V - 1) // BLOCK_V\n sample = torch.empty(\n (b, NUM_BLOCK_V, num_samples), dtype=torch.int32, device=x.device\n )\n lse_cache = torch.empty((b, NUM_BLOCK_V), dtype=torch.float32, device=x.device)\n max_value_cache = torch.empty(\n (b, NUM_BLOCK_V), dtype=torch.float32, device=x.device\n )\n\n load_lse = lse is not None\n load_max_value = max_value is not None\n BLOCK_K = max(16, next_power_of_two(num_samples))\n seed = 0\n\n def grid(meta):\n return (triton.cdiv(b, meta[\"BLOCK_B\"]), triton.cdiv(v, BLOCK_V))\n\n _parallel_multinomial_triton[grid](\n x,\n W,\n sample,\n lse,\n max_value,\n lse_cache,\n max_value_cache,\n seed,\n load_lse,\n load_max_value,\n b,\n d,\n v,\n num_samples,\n BLOCK_V,\n NUM_BLOCK_V,\n BLOCK_K,\n )\n\n def grid(meta):\n return (triton.cdiv(b, meta[\"BLOCK_B\"]), num_samples)\n\n sample_out = torch.empty((b, num_samples), dtype=torch.int32, device=x.device)\n\n _parallel_multinomial_reduce_triton[grid](\n sample,\n lse_cache,\n max_value_cache,\n sample_out,\n seed,\n load_lse,\n load_max_value,\n b,\n d,\n v,\n num_samples,\n BLOCK_V,\n NUM_BLOCK_V,\n BLOCK_K,\n )\n\n return sample_out.to(torch.int64)\n", - "description_1": "Use triton language to implement two parallel multinomial sampling kernels. The first kernel '_parallel_multinomial_triton' samples using input tensors and stores intermediate results, while the second kernel '_parallel_multinomial_reduce_triton' reduces these intermediate results to generate final samples. Both kernels handle input tensors with a specified number of blocks and perform multinomial sampling with optional log-sum-exp calculations and maximum value caching for optimization.", - "description_2": "Use triton language to create two kernels that perform multinomial sampling in parallel, using configurable blocks and handling intermediate computations for optimized sampling.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n# Kernel for forward decay cumulative sum\n@triton.jit\ndef fwd_decay_cumsum(\n g, g_o, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n# Kernel to prepare qg and kg\n@triton.jit\ndef prepare_qg_kg(\n q, k, g, qg, kg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(\n g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK)\n )\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n# Kernel for backward decay global cumulative sum\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner, dq_inter, dk_inner, dk_inter, q, k, g, dg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = (\n dq_inner\n + i_bh * s_qk_h\n + i_k * BK\n + tl.arange(0, BK)\n + (i_c * BT + BT - 1) * DK\n )\n p_dk_inner = (\n dk_inner\n + i_bh * s_qk_h\n + i_k * BK\n + tl.arange(0, BK)\n + (i_c * BT + BT - 1) * DK\n )\n p_dq_inter = (\n dq_inter\n + i_bh * s_qk_h\n + i_k * BK\n + tl.arange(0, BK)\n + (i_c * BT + BT - 1) * DK\n )\n p_dk_inter = (\n dk_inter\n + i_bh * s_qk_h\n + i_k * BK\n + tl.arange(0, BK)\n + (i_c * BT + BT - 1) * DK\n )\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT - 1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT - 1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n", - "description_1": "Use triton language to implement three kernels: fwd_decay_cumsum, prepare_qg_kg, and bwd_decay_global_cumsum. The fwd_decay_cumsum kernel computes a forward decay cumulative sum with 12 parameters, including input tensors and dimensions. The prepare_qg_kg kernel prepares qg and kg tensors with 12 parameters, including input tensors and dimensions. The bwd_decay_global_cumsum kernel computes a backward decay global cumulative sum with 15 parameters, including input tensors and dimensions.", - "description_2": "Use triton language to implement kernels for forward and backward decay cumulative sums and tensor preparation with specified input tensors and dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)\n )\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)\n )\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)\n )\n p_o = tl.make_block_ptr(\n o + (i_bh + i_k * B * H) * s_vo_h,\n (T, DV),\n (s_vo_t, s_vo_d),\n (0, i_v * BV),\n (BT, BV),\n (1, 0),\n )\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(\n initial_state + i_bh * DK * DV,\n (DK, DV),\n (DV, 1),\n (i_k * BK, i_v * BV),\n (BK, BV),\n (1, 0),\n )\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n if CHECK and i == 0:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(\n b_k.to(b_v.dtype), b_v, allow_tf32=False\n )\n else:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(\n b_k.to(b_v.dtype), b_v, allow_tf32=False\n )\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_db += BT * DK\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(\n final_state + i_bh * DK * DV,\n (DK, DV),\n (DV, 1),\n (i_k * BK, i_v * BV),\n (BK, BV),\n (1, 0),\n )\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, initial_state, s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(\n initial_state + i_bh * DK * DV,\n (DV, DK),\n (1, DV),\n (i_v * BV, i_k * BK),\n (BV, BK),\n (0, 1),\n )\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h,\n (T, DK),\n (s_qk_t, s_qk_d),\n (i * BT, i_k * BK),\n (BT, BK),\n (1, 0),\n )\n p_db = (\n g\n + i_bh * s_qk_h\n + ((i + 1) * BT - 1) * s_qk_t\n + i_k * BK\n + tl.arange(0, BK)\n )\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h,\n (DV, T),\n (s_vo_d, s_vo_t),\n (i_v * BV, i * BT),\n (BV, BT),\n (0, 1),\n )\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h,\n (T, DV),\n (s_vo_t, s_vo_d),\n (i * BT, i_v * BV),\n (BT, BV),\n (1, 0),\n )\n p_dq = tl.make_block_ptr(\n dq + (i_bh + i_v * B * H) * s_qk_h,\n (T, DK),\n (s_qk_t, s_qk_d),\n (i * BT, i_k * BK),\n (BT, BK),\n (1, 0),\n )\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(\n b_v, b_k.to(b_v.dtype), allow_tf32=False\n )\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(\n b_v, b_k.to(b_v.dtype), allow_tf32=False\n )\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h,\n (DK, T),\n (s_qk_d, s_qk_t),\n (i_k * BK, T - i * BT),\n (BK, BT),\n (0, 1),\n )\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h,\n (T, DK),\n (s_qk_t, s_qk_d),\n (T - i * BT, i_k * BK),\n (BT, BK),\n (1, 0),\n )\n p_db = (\n g\n + i_bh * s_qk_h\n + (T - (i - 1) * BT - 1) * s_qk_t\n + i_k * BK\n + tl.arange(0, BK)\n )\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h,\n (T, DV),\n (s_vo_t, s_vo_d),\n (T - i * BT, i_v * BV),\n (BT, BV),\n (1, 0),\n )\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h,\n (T, DV),\n (s_vo_t, s_vo_d),\n (T - i * BT, i_v * BV),\n (BT, BV),\n (1, 0),\n )\n p_dk = tl.make_block_ptr(\n dk + (i_bh + i_v * B * H) * s_qk_h,\n (T, DK),\n (s_qk_t, s_qk_d),\n (T - i * BT, i_k * BK),\n (BT, BK),\n (1, 0),\n )\n p_dv = tl.make_block_ptr(\n dv + (i_bh + i_k * B * H) * s_vo_h,\n (T, DV),\n (s_vo_t, s_vo_d),\n (T - i * BT, i_v * BV),\n (BT, BV),\n (1, 0),\n )\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n if CHECK and i == 1:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(\n b_q.to(b_do.dtype), b_do, allow_tf32=False\n )\n else:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(\n b_q.to(b_do.dtype), b_do, allow_tf32=False\n )\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkGLAFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):\n ctx.g_dtype = g.dtype\n g_original = g\n g = torch.empty_like(g, dtype=torch.float32)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n\n BT = 16\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n num_stages = 1\n num_warps = 2\n\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n\n if output_final_state:\n final_state = q.new_empty(\n batch_size,\n n_heads,\n d_head_qk,\n d_head_v,\n dtype=torch.float,\n requires_grad=False,\n )\n else:\n final_state = None\n\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_gla_fwd_kernel[grid](\n q_g,\n k_g,\n v,\n g,\n o,\n initial_state,\n final_state,\n q.stride(1),\n q.stride(2),\n q.stride(3),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n batch_size,\n n_heads,\n seq_len,\n scale,\n BT=BT,\n DK=d_head_qk,\n DV=d_head_v,\n BK=BK,\n BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n o = o.sum(0)\n\n ctx.save_for_backward(q, k, v, g_original, initial_state)\n ctx.CHECK = CHECK\n return o.to(v), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, g_origin, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n g = torch.empty_like(g_origin, dtype=torch.float32)\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_gla_bwd_kernel[grid](\n q_g,\n k_g,\n v,\n g,\n do,\n dq,\n dk,\n dv,\n initial_state,\n q.stride(1),\n q.stride(2),\n q.stride(3),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n batch_size,\n n_heads,\n seq_len,\n scale,\n BT=BT,\n DK=d_head_qk,\n DV=d_head_v,\n BK=BK,\n BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n\n return dq.to(q), dk.to(k), dv.to(v), None, None, None, None\n\n\ndef fused_chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = q.shape[-2]\n o, final_state = FusedChunkGLAFunction.apply(\n q, k, v, g, scale, initial_state, output_final_state\n )\n o = o[..., :seq_len, :]\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunked Gated Linear Attention (GLA) mechanism with both forward and backward kernels. The forward kernel handles inputs (queries, keys, values, cumulative sums, and states) to compute an output tensor by performing block-wise operations and storing intermediate results. The backward kernel computes gradients with respect to the inputs using accumulated states. Both kernels are designed to be efficient using block pointers and constexpr parameters. The main entry point is the `fused_chunk_gla` function, which applies the kernels and handles padding of the input tensors.", - "description_2": "Use triton language to develop a forward and backward kernel for GLA operations that efficiently compute outputs and gradients using block operations and manage states, with the main function wrapping and invoking these kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _tpe_recurrence_fwd(\n X,\n B,\n LOG_LAMBDA,\n O,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_d = tl.program_id(1)\n # compute offset\n offset_x = off_b * n * d + off_d * BLOCK\n offset_b = off_d * BLOCK * e\n\n x_block_ptr = X + offset_x + tl.arange(0, BLOCK)\n b_block_ptr = B + offset_b + tl.arange(0, e)\n log_lambda_block_ptr = LOG_LAMBDA + tl.arange(0, e)\n o_block_ptr = O + offset_x + tl.arange(0, BLOCK)\n\n h = tl.zeros([BLOCK, e], dtype=tl.float32)\n b = tl.load(b_block_ptr).to(tl.float32)[None, :] # (1, e)\n lambda_ = tl.exp(tl.load(log_lambda_block_ptr).to(tl.float32))[None, :] # (1, e)\n\n for i in range(n):\n x = tl.load(x_block_ptr).to(tl.float32)[:, None] # (d, 1)\n h = lambda_ * h + b * x\n o = tl.sum(h, axis=0)\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_type))\n\n x_block_ptr += BLOCK\n o_block_ptr += BLOCK\n\n@triton.jit\ndef _tpe_recurrence_bwd(\n X,\n B,\n LOG_LAMBDA,\n DO,\n DX,\n DB,\n DLOG_LAMBDA,\n b: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n):\n off_b = tl.program_id(0)\n off_d = tl.program_id(1)\n # compute offset\n offset_x = off_b * n * d + off_d * BLOCK\n offset_b = off_d * BLOCK * e\n\n x_block_ptr = X + offset_x + tl.arange(0, BLOCK)\n b_block_ptr = B + offset_b + tl.arange(0, e)\n log_lambda_block_ptr = LOG_LAMBDA + tl.arange(0, e)\n o_block_ptr = O + offset_x + tl.arange(0, BLOCK)\n\n h = tl.zeros([BLOCK, e], dtype=tl.float32)\n b = tl.load(b_block_ptr).to(tl.float32)[None, :] # (1, e)\n lambda_ = tl.exp(tl.load(log_lambda_block_ptr).to(tl.float32))[None, :] # (1, e)\n\n for i in range(n):\n x = tl.load(x_block_ptr).to(tl.float32)[:, None] # (d, 1)\n h = lambda_ * h + b * x\n o = tl.sum(h, axis=0)\n\n tl.store(o_block_ptr, o.to(o_block_ptr.dtype.element_type))\n\n x_block_ptr += BLOCK\n o_block_ptr += BLOCK\n\nclass TpeRecurrence(torch.autograd.Function):\n @staticmethod\n @contiguous\n def forward(ctx, x, b, log_lambda):\n b, n, d = x.shape\n e = log_lambda.shape[-1]\n o = torch.empty_like(x)\n\n def grid(meta):\n return (b, meta[\"BLOCK\"])\n\n _tpe_recurrence_fwd[grid](x, b, log_lambda, o, b, n, d, e)\n\n ctx.save_for_backward(x, b, log_lambda)\n\n return o\n\n @staticmethod\n @contiguous\n def backward(ctx, do):\n x, b, log_lambda = ctx.saved_tensors\n b, h, n, d = x.shape\n\n dx = torch.empty_like(x)\n db = torch.empty_like(b)\n dlog_lambda = torch.empty_like(log_lambda)\n\n def grid(meta):\n return (b, meta[\"BLOCK\"])\n\n _tpe_recurrence_bwd[grid](x, b, log_lambda, do, dx, db, dlog_lambda, b, h, n, d)\n\n return dx, db, dlog_lambda\n", - "description_1": "Use Triton language to implement a forward and backward recurrence for a sequence, where each element in the sequence is updated based on weighted sums controlled by parameter tensors and exponential of log_lambda. The forward pass computes the recurrence, while the backward pass computes gradients with respect to the input, parameter, and lambda tensors.", - "description_2": "Use Triton language to compute forward and backward recurrence over sequences with gradient calculations for backpropagation, using block-based processing for efficiency in parallelism.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row,\n stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement forward and backward kernels for Swish-Gated Linear Unit (SwigLU). The forward kernel (_swiglu_fwd_kernel) takes 7 parameters: X (input), Y (input), OUT (output), stride_x_row, stride_y_row, stride_out_row (stride values for rows), and ncols (number of columns) to compute the element-wise SwigLU operation. The backward kernel (_swiglu_bwd_kernel) takes 14 parameters: X, Y, DOUT, OUT, DX, DY (input/output), stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row (stride values), ncols, and RECOMPUTE_OUTPUT (for output recomputation) to compute gradients.", - "description_2": "Use triton language to perform SwigLU activation and its gradient computation for GPU-accelerated neural networks.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n@triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n X1,\n W1,\n B1,\n Y1,\n RESIDUAL_OUT, # pointer to the residual\n ROWSCALE,\n SEEDS, # Dropout seeds for each row\n DROPOUT_MASK,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n stride_x1_row,\n stride_y1_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n dropout_p, # Dropout probability\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_DROPOUT: tl.constexpr,\n STORE_DROPOUT_MASK: tl.constexpr,\n HAS_ROWSCALE: tl.constexpr,\n HAS_X1: tl.constexpr,\n HAS_W1: tl.constexpr,\n HAS_B1: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n if HAS_X1:\n X1 += row * stride_x1_row\n if HAS_W1:\n Y1 += row * stride_y1_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n x *= rowscale\n if HAS_DROPOUT:\n # Compute dropout mask\n # 7 rounds is good enough, and reduces register pressure\n keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n if HAS_X1:\n x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n x1 *= rowscale\n if HAS_DROPOUT:\n # Compute dropout mask\n # 7 rounds is good enough, and reduces register pressure\n keep_mask = (\n tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n )\n x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)\n x += x1\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n if HAS_W1:\n w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n if HAS_B1:\n b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n tl.store(Y1 + cols, y1, mask=mask)\n\n\ndef _layer_norm_fwd(\n x,\n weight,\n bias,\n eps,\n residual=None,\n x1=None,\n weight1=None,\n bias1=None,\n dropout_p=0.0,\n rowscale=None,\n out_dtype=None,\n residual_dtype=None,\n is_rms_norm=False,\n return_dropout_mask=False,\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n if x1 is not None:\n assert x1.shape == x.shape\n assert rowscale is None\n assert x1.stride(-1) == 1\n if weight1 is not None:\n assert weight1.shape == (N,)\n assert weight1.stride(-1) == 1\n if bias1 is not None:\n assert bias1.shape == (N,)\n assert bias1.stride(-1) == 1\n if rowscale is not None:\n assert rowscale.is_contiguous()\n assert rowscale.shape == (M,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if weight1 is not None:\n y1 = torch.empty_like(y)\n assert y1.stride(-1) == 1\n else:\n y1 = None\n if (\n residual is not None\n or (residual_dtype is not None and residual_dtype != x.dtype)\n or dropout_p > 0.0\n or rowscale is not None\n or x1 is not None\n ):\n residual_out = torch.empty(\n M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype\n )\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n if dropout_p > 0.0:\n seeds = torch.randint(\n 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n )\n else:\n seeds = None\n if return_dropout_mask and dropout_p > 0.0:\n dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)\n else:\n dropout_mask = None\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n x1,\n weight1,\n bias1,\n y1,\n residual_out,\n rowscale,\n seeds,\n dropout_mask,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n x1.stride(0) if x1 is not None else 0,\n y1.stride(0) if y1 is not None else 0,\n M,\n N,\n eps,\n dropout_p,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n dropout_p > 0.0,\n dropout_mask is not None,\n rowscale is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0\n if dropout_mask is not None and x1 is not None:\n dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)\n else:\n dropout_mask1 = None\n return (\n y,\n y1,\n mean,\n rstd,\n residual_out if residual_out is not None else x,\n seeds,\n dropout_mask,\n dropout_mask1,\n )\n", - "description_1": "Use triton language to implement a fused layer normalization forward pass kernel. The kernel takes 31 parameters, including pointers to input, output, weights, biases, residuals, dropout settings, and configuration constants. It computes the mean and variance for normalization, applies dropout, and applies a linear transformation with optional residuals.", - "description_2": "Use triton language to create a fused layer normalization forward function. The function takes 14 parameters including input, weights, biases, residuals, and configuration flags. It prepares data for the kernel execution, allocates necessary outputs, and invokes the Triton kernel to compute the layer normalization with optional dropout and residual connections.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Kernel logic\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n", - "description_1": "Use triton language to implement a layer normalization operation, including a kernel function '_layer_norm_fwd_1pass_kernel' that normalizes input data with support for optional bias and additional input processing, and a calling function '_layer_norm_fwd' that prepares input data and invokes the kernel with specific configuration parameters.", - "description_2": "Use triton language to create a forward layer normalization kernel with optional features like bias and additional transformations, and implement its execution logic in Python.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n # Strides\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt) # scalar, not a matrix\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt # vector of size (dstate,)\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n x: (batch, dim) or (batch, nheads, dim)\n dt: (batch, dim) or (batch, nheads, dim)\n A: (dim, dstate) or (nheads, dim, dstate)\n B: (batch, dstate) or (batch, ngroups, dstate)\n C: (batch, dstate) or (batch, ngroups, dstate)\n D: (dim,) or (nheads, dim)\n z: (batch, dim) or (batch, nheads, dim)\n dt_bias: (dim,) or (nheads, dim)\n Return:\n out: (batch, dim) or (batch, nheads, dim)\n \"\"\"\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel and its associated function that selectively updates a state matrix. The kernel is called `_selective_scan_update_kernel` and it takes pointers to matrices, dimensions, and meta-parameters, and updates the state matrix based on various conditions using Triton's load and store operations. The function `selective_state_update` manages the input dimensions and calls the kernel with appropriately calculated parameters.", - "description_2": "Use triton language to perform conditional matrix updates with a kernel function that handles matrix pointers and dimensions, ensuring efficient GPU execution. Implement a wrapper function to prepare data and launch the kernel with optimized grid and block settings.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement and autotune forward and backward batch matrix multiplication kernels. The forward kernel (_bmm_chunk_fwd_kernel) has 22 parameters including pointers to the input matrices, output matrix, sequence indices, matrix dimensions, stride information, and several meta-parameters for customization. The backward kernel (_bmm_chunk_bwd_kernel) has 21 parameters, including pointers to input, output gradients, and residual, matrix dimensions, stride information, and meta-parameters. The kernels are called by the wrapper functions _bmm_chunk_fwd and _bmm_chunk_bwd, which handle tensor preparation, grid configuration, and kernel invocation using CUDA.", - "description_2": "Use triton language to create optimized kernels for batched matrix multiplication with configurable block sizes and support for causality and sequence index masking. Implement efficient tensor manipulation and grid management to enable high-performance computations on the GPU.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange, repeat\nfrom mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd\nfrom packaging import version\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel function implementation\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n # Allocates output.\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n\nclass ChunkScanFn(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n # Check constraints.\n batch, seqlen, nheads, headdim = x.shape\n _, _, ngroups, dstate = B.shape\n assert B.shape == (batch, seqlen, ngroups, dstate)\n _, _, nchunks, chunk_size = dt.shape\n assert seqlen == nchunks * chunk_size\n assert C.shape == B.shape\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)\n if B.stride(-1) != 1:\n B = B.contiguous()\n if C.stride(-1) != 1:\n C = C.contiguous()\n if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous\n x = x.contiguous()\n if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous\n z = z.contiguous()\n if D is not None and D.stride(-1) != 1:\n D = D.contiguous()\n CB = _bmm_chunk_fwd(C, B, chunk_size)\n out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)\n ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)\n return out\n", - "description_1": "Use Triton language to implement a forward scanning operation with custom configurations, optimizing for GPU performance through efficient data handling and parallel computation.", - "description_2": "Utilize Triton kernels to perform matrix operations with flexible configurations, suitable for integration with deep learning frameworks such as PyTorch.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}),\n triton.Config({'BLOCK_SIZE_H': 2}),\n triton.Config({'BLOCK_SIZE_H': 4}),\n triton.Config({'BLOCK_SIZE_H': 8}),\n triton.Config({'BLOCK_SIZE_H': 16}),\n triton.Config({'BLOCK_SIZE_H': 32}),\n triton.Config({'BLOCK_SIZE_H': 64}),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n", - "description_1": "Use triton language to implement a forward kernel for chunk-wise cumulative sum. The kernel takes 20 parameters: pointers to matrices (dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr), matrix dimensions (batch, seqlen, nheads, chunk_size), min and max values for dt (dt_min, dt_max), strides for accessing elements in matrices, and meta-parameters (DT_SOFTPLUS, HAS_DT_BIAS, BLOCK_SIZE_H, BLOCK_SIZE_CHUNK). The kernel computes a cumulative sum of the product of dt and A, with optional bias and softplus transformation, and stores the result in dA_cumsum_ptr.", - "description_2": "Use triton language to implement a function that calls the forward kernel for chunk-wise cumulative sum. The function takes 6 parameters: dt, A, chunk_size, optional dt_bias, dt_softplus flag, and dt_limit. It prepares output tensors, calculates grid dimensions, and launches the kernel with appropriate arguments.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom einops import rearrange\nfrom torch import Tensor\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n\ndef init_to_zero(names):\n return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n ],\n key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n if not HAS_SEQ_IDX:\n scale = tl.exp(dA_cs_last - dA_cs_m)\n else:\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n\n offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc = tl.dot(b, dstates) * scale[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc += tl.dot(b, dstates)\n b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n acc *= scale[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit\n K_MIN = pid_m * BLOCK_SIZE_M\n cb_ptrs += K_MIN * stride_cb_csize_k\n dout_ptrs += K_MIN * stride_dout_seqlen\n dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n k = tl.multiple_of(k, BLOCK_SIZE_K)\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(dout_ptr.dtype.element_ty)\n acc += tl.dot(cb, dout)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n dx = acc * dt_m[:, None]\n dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n if HAS_D:\n dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n dx += dout_res * D\n tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if HAS_D:\n dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n if D_HAS_HDIM:\n dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n dD = tl.sum(dout_res * x, axis=0)\n tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n else:\n dD = tl.sum(dout_res * x)\n tl.store(dD_ptr, dD)\n ddt = tl.sum(acc * x, axis=1)\n ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a kernel function '_chunk_scan_chunk_state_bwd_dx_kernel' which performs backward operation for a chunked scan using several matrix multiplications and data reductions. The kernel accepts 66 parameters including pointers to matrices, matrix dimensions, and strides for memory access, alongside several meta-parameters to control the kernel's behavior depending on the compilation environment and feature flags.", - "description_2": "Use triton language to create a kernel for the backward pass of a chunked scan operation, involving matrix operations controlled by a variety of parameters for pointers, sizes, and meta-configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n CONVERT_STATES: tl.constexpr,\n HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n else:\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\ndef _state_passing_bwd(states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, dstates_dtype=None, states_dtype=None, chunk_size=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = torch.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert dfinal_states.shape == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n if dfinal_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n if dinitstates is not None else (0, 0, 0)),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n", - "description_1": "Use triton language to implement forward and backward state-passing kernels. The forward kernel takes pointers to matrices, matrix dimensions, strides, meta-parameters, and computes the forward pass storing results in the output pointers. The backward kernel takes similar parameters and computes gradients required for backpropagation.", - "description_2": "Use triton language to perform state-passing in the forward pass and compute gradients in the backward pass for a neural network model, handling dimensions and configurations through meta-parameters.", - "difficulty": 4 - }, - { - "code": "import math\nimport triton\nimport triton.language as tl\n\nsqrt2pi = math.sqrt(2.0 / math.pi)\nsqrt2 = math.sqrt(2.0)\n\n@triton.jit\ndef tanh(x):\n \"\"\"Tanh activation function\"\"\"\n return tl.libdevice.tanh(x)\n\n@triton.jit\ndef relu(x):\n \"\"\"Relu activation function\"\"\"\n return tl.maximum(0, x)\n\n@triton.jit\ndef fast_gelu(x):\n \"\"\"Fast approximation of the gelu function. May slightly decrease accuracy.\"\"\"\n return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))\n\n@triton.jit\ndef gelu(x):\n \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2))\n", - "description_1": "Use triton language to implement four activation functions: tanh, relu, fast_gelu, and gelu. Each function takes a single parameter 'x', which is a tensor. The 'tanh' function computes the hyperbolic tangent of 'x'. The 'relu' function applies the rectified linear unit operation, returning the maximum of 0 and 'x'. The 'fast_gelu' function provides a fast approximation of the Gaussian Error Linear Unit using the tanh function. The 'gelu' function computes the Gaussian Error Linear Unit using the error function.", - "description_2": "Use triton language to create activation functions including tanh, relu, fast_gelu, and gelu, each operating on a tensor input 'x'.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_fwd\n\n@triton.jit\ndef _fwd_kernel(\n head_size,\n m_size,\n n_size,\n cache_key_m_size,\n cache_key_n_size,\n q_ptr,\n k_ptr,\n v_ptr,\n sm_scale,\n attention_mask_ptr,\n output_ptr,\n q_batch_stride,\n q_head_stride,\n q_m_stride,\n q_k_stride,\n k_batch_stride,\n k_head_stride,\n k_n_stride,\n k_k_stride,\n v_batch_stride,\n v_head_stride,\n v_k_stride,\n v_n_stride,\n output_batch_stride,\n output_head_stride,\n output_row_stride,\n output_col_stride,\n attention_mask_batch_stride,\n attention_mask_head_stride,\n attention_mask_m_stride,\n attention_mask_n_stride,\n min_clamp_value,\n attention_mask_batch_size,\n attention_mask_head_size,\n attention_mask_m_size,\n attention_mask_n_size,\n HAS_MASK: tl.constexpr,\n IS_MATRIX_MASK: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_DHEAD_SIZE: tl.constexpr,\n BLOCK_M_SIZE: tl.constexpr,\n BLOCK_N_SIZE: tl.constexpr,\n M_LOAD_MASK_NEEDED: tl.constexpr,\n N_LOAD_MASK_NEEDED: tl.constexpr,\n):\n block_m_idx = tl.program_id(0)\n head_idx = tl.program_id(1)\n current_batch_idx = head_idx // head_size\n current_head_idx = head_idx % head_size\n m_range_offs = tl.arange(0, BLOCK_M_SIZE)\n n_range_offs = tl.arange(0, BLOCK_N_SIZE)\n dhead_range_offs = tl.arange(0, BLOCK_DHEAD_SIZE)\n m_offs = block_m_idx * BLOCK_M_SIZE + m_range_offs\n\n q_offs = (\n current_batch_idx * q_batch_stride\n + current_head_idx * q_head_stride\n + (m_offs[:, None] * q_m_stride + dhead_range_offs[None, :] * q_k_stride)\n )\n k_offs = (\n current_batch_idx * k_batch_stride\n + current_head_idx * k_head_stride\n + (n_range_offs[:, None] * k_n_stride + dhead_range_offs[None, :] * k_k_stride)\n )\n v_offs = (\n current_batch_idx * v_batch_stride\n + current_head_idx * v_head_stride\n + (n_range_offs[:, None] * v_k_stride + dhead_range_offs[None, :] * v_n_stride)\n )\n output_offs = (\n current_batch_idx * output_batch_stride\n + current_head_idx * output_head_stride\n + (m_offs[:, None] * output_row_stride + dhead_range_offs[None, :] * output_col_stride)\n )\n q_ptrs = q_ptr + q_offs\n k_ptrs = k_ptr + k_offs\n v_ptrs = v_ptr + v_offs\n output_ptrs = output_ptr + output_offs\n l_i = tl.zeros((BLOCK_M_SIZE,), dtype=tl.float32) - float(\"inf\")\n d_i = tl.zeros((BLOCK_M_SIZE,), dtype=tl.float32)\n acc = tl.zeros((BLOCK_M_SIZE, BLOCK_DHEAD_SIZE), dtype=tl.float32)\n if M_LOAD_MASK_NEEDED | N_LOAD_MASK_NEEDED:\n q = tl.load(q_ptrs, mask=m_offs[:, None] < m_size, other=0.0)\n else:\n q = tl.load(q_ptrs)\n\n block_n_end = n_size\n if IS_CAUSAL:\n block_n_end = (block_m_idx + 1) * BLOCK_N_SIZE\n\n if HAS_MASK:\n attention_mask_batch_idx = (current_batch_idx,)\n if attention_mask_batch_size == 1:\n attention_mask_batch_idx = 0\n\n attention_mask_head_idx = current_head_idx\n if attention_mask_head_size == 1:\n attention_mask_head_idx = 0\n\n attention_mask_off = (\n attention_mask_batch_idx * attention_mask_batch_stride\n + attention_mask_head_idx * attention_mask_head_stride\n )\n\n for block_n_start_idx in range(0, block_n_end, BLOCK_N_SIZE):\n block_n_offs = block_n_start_idx + n_range_offs\n if N_LOAD_MASK_NEEDED:\n k_ptr_mask = block_n_offs[:, None] < n_size\n k = tl.load(k_ptrs + block_n_start_idx * k_n_stride, mask=k_ptr_mask, other=0.0)\n else:\n k = tl.load(k_ptrs + block_n_start_idx * k_n_stride)\n qk = tl.zeros((BLOCK_M_SIZE, BLOCK_N_SIZE), dtype=tl.float32)\n\n if N_LOAD_MASK_NEEDED:\n qk = tl.where(n_range_offs[None, :] < n_size, qk, float(\"-inf\"))\n qk += tl.dot(q, tl.trans(k))\n qk *= sm_scale\n if IS_CAUSAL:\n qk += tl.where(m_offs[:, None] >= block_n_offs[None, :], 0, float(\"-inf\"))\n\n if HAS_MASK:\n attention_mask_offs = attention_mask_off + block_n_offs * attention_mask_n_stride\n if IS_MATRIX_MASK:\n attention_mask_offs = attention_mask_offs[None, :] + m_offs[:, None] * attention_mask_m_stride\n\n if N_LOAD_MASK_NEEDED & (not IS_MATRIX_MASK):\n attention_mask_ptr_mask = block_n_offs < attention_mask_n_size\n if IS_MATRIX_MASK:\n if M_LOAD_MASK_NEEDED & (not N_LOAD_MASK_NEEDED):\n attention_mask_ptr_mask = m_offs[:, None] < attention_mask_m_size\n elif (not M_LOAD_MASK_NEEDED) & N_LOAD_MASK_NEEDED:\n attention_mask_ptr_mask = block_n_offs[None, :] < attention_mask_n_size\n elif M_LOAD_MASK_NEEDED & N_LOAD_MASK_NEEDED:\n attention_mask_ptr_mask = (block_n_offs[None, :] < attention_mask_n_size) & (\n m_offs[:, None] < attention_mask_m_size\n )\n\n if (M_LOAD_MASK_NEEDED & IS_MATRIX_MASK) | N_LOAD_MASK_NEEDED:\n attention_mask = tl.load(\n attention_mask_ptr + attention_mask_offs,\n eviction_policy=\"evict_first\",\n mask=attention_mask_ptr_mask,\n other=float(\"-inf\"),\n )\n else:\n attention_mask = tl.load(\n attention_mask_ptr + attention_mask_offs,\n eviction_policy=\"evict_first\",\n )\n attention_mask = tl.where(attention_mask == float(\"-inf\"), min_clamp_value, attention_mask)\n if IS_MATRIX_MASK:\n qk += attention_mask\n else:\n qk += attention_mask[None, :]\n\n l_j = tl.max(qk, 1)\n\n numerators = tl.exp(qk - l_j[:, None])\n d_j = tl.sum(numerators, 1)\n\n l_new = tl.maximum(l_i, l_j)\n alpha = tl.exp(l_i - l_new)\n beta = tl.exp(l_j - l_new)\n d_new = alpha * d_i + beta * d_j\n\n p_scale = beta / d_new\n\n qk_softmax = numerators * p_scale[:, None]\n\n acc_scale = d_i / d_new * alpha\n\n acc = acc * acc_scale[:, None]\n\n if N_LOAD_MASK_NEEDED:\n v_ptr_mask = block_n_offs[:, None] < n_size\n v = tl.load(v_ptrs + block_n_start_idx * v_k_stride, mask=v_ptr_mask, other=0.0)\n else:\n v = tl.load(v_ptrs + block_n_start_idx * v_k_stride)\n qk_softmax = qk_softmax.to(q_ptr.dtype.element_ty)\n acc += tl.dot(qk_softmax, v)\n\n d_i = d_new\n l_i = l_new\n\n if M_LOAD_MASK_NEEDED:\n output_ptr_mask = m_offs[:, None] < m_size\n tl.store(output_ptrs, acc, mask=output_ptr_mask)\n else:\n tl.store(output_ptrs, acc)\n\n\nclass Attention(torch.autograd.Function):\n @staticmethod\n @custom_fwd(cast_inputs=torch.float16)\n def forward(\n ctx,\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n output: torch.Tensor,\n sm_scale: float,\n is_causal: bool,\n attention_mask: Optional[torch.Tensor] = None,\n ):\n assert q.shape[-1] == k.shape[-1]\n assert (\n q.dtype == k.dtype == v.dtype == output.dtype\n ), f\"All tensors must have the same dtype: {q.dtype}, {k.dtype}, {v.dtype}, {output.dtype}\"\n assert q.dtype in [torch.float16, torch.bfloat16], f\"Only float16 and bfloat16 are supported, got {q.dtype}\"\n batch, head_size, m_size, dhead = q.size()\n n_size = k.size(2)\n\n grid = lambda args: (triton.cdiv(m_size, args[\"BLOCK_M_SIZE\"]), batch * head_size)\n\n HAS_MASK = False\n IS_MATRIX_MASK = False\n if attention_mask is not None:\n assert (\n attention_mask.size(0) == batch or attention_mask.size(0) == 1\n ), \"Incompatible broadcast batch dimension\"\n assert (\n attention_mask.size(1) == head_size or attention_mask.size(1) == 1\n ), \"Incompatible broadcast heads dimension\"\n assert (\n attention_mask.size(2) == m_size or attention_mask.size(2) == 1\n ), \"Incompatible broadcast m_size dimension\"\n assert attention_mask.size(3) == n_size, \"Last size of mask must broadcast on QK^t\"\n\n HAS_MASK = True\n IS_MATRIX_MASK = attention_mask.size(2) != 1\n\n _fwd_kernel[grid](\n head_size,\n m_size,\n n_size,\n m_size // 32,\n n_size // 32,\n q,\n k,\n v,\n sm_scale,\n attention_mask,\n output,\n *q.stride(),\n *k.stride(),\n *v.stride(),\n *output.stride(),\n *attention_mask.stride() if HAS_MASK else (0, 0, 0, 0),\n torch.finfo(attention_mask.dtype).min if HAS_MASK else 0,\n *attention_mask.size() if HAS_MASK else (0, 0, 0, 0),\n HAS_MASK,\n IS_MATRIX_MASK,\n is_causal,\n dhead,\n 128,\n 128,\n m_size % 128 != 0,\n n_size % 128 != 0,\n num_warps=4 if k.size(3) <= 64 else 8,\n num_stages=2,\n )\n return output\n\n\ndef attention_forward(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n output: torch.Tensor,\n sm_scale: float,\n is_causal: bool = False,\n attention_mask: Optional[torch.Tensor] = None,\n):\n return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask)\n", - "description_1": "Use triton language to define a kernel '_fwd_kernel' that computes attention using query, key, and value matrices. This kernel has 45 parameters including both tensors and constant expressions. The main task is to perform the Q•K^T operation followed by a scaling and optional masking and softmax normalization. Also, define a custom torch.autograd.Function 'Attention' with 7 input parameters to apply the kernel, including query, key, value tensors, output tensor, scaling factor, causal flag, and optional attention mask.", - "description_2": "Use triton language to implement a custom kernel for scaled dot-product attention, utilizing features like causal masking and softmax normalization, and integrate this kernel into a PyTorch autograd function for forward pass operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_fwd\n\n@triton.jit\ndef _fwd_part_1(\n head_size,\n m_size,\n n_size,\n q_ptr,\n k_ptr,\n v_ptr,\n sm_scale,\n attention_mask_ptr,\n output_ptr,\n maximums_ptr,\n sums_ptr,\n q_batch_stride,\n q_head_stride,\n q_m_stride,\n q_k_stride,\n k_batch_stride,\n k_head_stride,\n k_n_stride,\n k_k_stride,\n v_batch_stride,\n v_head_stride,\n v_k_stride,\n v_n_stride,\n sums_batch_stride,\n sums_head_stride,\n sums_step_stride,\n sums_m_stride,\n maximums_batch_stride,\n maximums_head_stride,\n maximums_step_stride,\n maximums_m_stride,\n output_batch_stride,\n output_head_stride,\n output_step_stride,\n output_m_stride,\n output_n_stride,\n attention_mask_batch_stride,\n attention_mask_head_stride,\n attention_mask_m_stride,\n attention_mask_k_stride,\n min_clamp_value,\n N_LOAD_MASK_NEEDED: tl.constexpr,\n M_LOAD_MASK_NEEDED: tl.constexpr,\n MASK_BATCH_SIZE: tl.constexpr,\n MASK_HEAD_SIZE: tl.constexpr,\n MASK_M_SIZE: tl.constexpr,\n MASK_K_SIZE: tl.constexpr,\n HAS_MASK: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M_SIZE: tl.constexpr,\n BLOCK_DHEAD_SIZE: tl.constexpr,\n BLOCK_N_SIZE: tl.constexpr,\n):\n block_n_idx = tl.program_id(0)\n block_m_idx = tl.program_id(1)\n head_idx = tl.program_id(2)\n\n m_range_offs = tl.arange(0, BLOCK_M_SIZE)\n n_range_offs = tl.arange(0, BLOCK_N_SIZE)\n d_range_offs = tl.arange(0, BLOCK_DHEAD_SIZE)\n\n m_offs = block_m_idx * BLOCK_M_SIZE + m_range_offs\n\n current_batch_idx = head_idx // head_size\n current_head_idx = head_idx % head_size\n\n q_offs = (\n current_batch_idx * q_batch_stride\n + current_head_idx * q_head_stride\n + (m_offs[:, None] * q_m_stride + d_range_offs[None, :] * q_k_stride)\n )\n\n k_offs = (\n current_batch_idx * k_batch_stride\n + current_head_idx * k_head_stride\n + (n_range_offs[:, None] * k_n_stride + d_range_offs[None, :] * k_k_stride)\n )\n\n v_offs = (\n current_batch_idx * v_batch_stride\n + current_head_idx * v_head_stride\n + (n_range_offs[:, None] * v_k_stride + d_range_offs[None, :] * v_n_stride)\n )\n\n q_ptrs = q_ptr + q_offs\n k_ptrs = k_ptr + k_offs\n v_ptrs = v_ptr + v_offs\n\n if M_LOAD_MASK_NEEDED:\n q = tl.load(q_ptrs, mask=m_offs[:, None] < m_size, eviction_policy=\"\", other=0.0)\n else:\n q = tl.load(q_ptrs, eviction_policy=\"\")\n\n if HAS_MASK:\n attention_mask_batch_idx = (current_batch_idx,)\n if MASK_BATCH_SIZE == 1:\n attention_mask_batch_idx = 0\n\n attention_mask_head_idx = current_head_idx\n if MASK_HEAD_SIZE == 1:\n attention_mask_head_idx = 0\n\n attention_mask_off = (\n attention_mask_batch_idx * attention_mask_batch_stride\n + attention_mask_head_idx * attention_mask_head_stride\n )\n\n block_n_start_idx = block_n_idx * BLOCK_N_SIZE\n block_n_offs = block_n_start_idx + n_range_offs\n\n if N_LOAD_MASK_NEEDED:\n k_ptr_mask = block_n_offs[:, None] < n_size\n k = tl.load(k_ptrs + block_n_start_idx * k_n_stride, mask=k_ptr_mask, eviction_policy=\"\", other=0.0)\n else:\n k = tl.load(k_ptrs + block_n_start_idx * k_n_stride, eviction_policy=\"\")\n\n qk = tl.zeros((BLOCK_M_SIZE, BLOCK_N_SIZE), dtype=tl.float32)\n\n if N_LOAD_MASK_NEEDED:\n qk = tl.where(n_range_offs[None, :] < n_size, qk, float(\"-inf\"))\n qk += tl.dot(q, tl.trans(k))\n qk *= sm_scale\n if IS_CAUSAL:\n qk += tl.where(m_offs[:, None] >= block_n_offs[None, :], 0, float(\"-inf\"))\n\n if HAS_MASK:\n attention_mask_offs = attention_mask_off + block_n_offs[None, :] * attention_mask_k_stride\n if MASK_M_SIZE != 1:\n attention_mask_offs += m_offs[:, None] * attention_mask_m_stride\n\n if N_LOAD_MASK_NEEDED & MASK_M_SIZE == 1:\n attention_mask_ptr_mask = block_n_offs[None, :] < n_size\n if MASK_M_SIZE != 1:\n if M_LOAD_MASK_NEEDED & (not N_LOAD_MASK_NEEDED):\n attention_mask_ptr_mask = m_offs[:, None] < m_size\n elif (not M_LOAD_MASK_NEEDED) & N_LOAD_MASK_NEEDED:\n attention_mask_ptr_mask = block_n_offs[None, :] < n_size\n elif M_LOAD_MASK_NEEDED & N_LOAD_MASK_NEEDED:\n attention_mask_ptr_mask = (block_n_offs[None, :] < n_size) & (m_offs[:, None] < m_size)\n\n if M_LOAD_MASK_NEEDED | N_LOAD_MASK_NEEDED:\n attention_mask = tl.load(\n attention_mask_ptr + attention_mask_offs,\n eviction_policy=\"\",\n mask=attention_mask_ptr_mask,\n other=float(\"-inf\"),\n )\n else:\n attention_mask = tl.load(\n attention_mask_ptr + attention_mask_offs,\n eviction_policy=\"\",\n )\n attention_mask = tl.where(attention_mask == float(\"-inf\"), min_clamp_value, attention_mask)\n qk += attention_mask\n\n l_j = tl.max(qk, 1)\n numerators = tl.exp(qk - l_j[:, None])\n d_j = tl.sum(numerators, 1)\n\n maximums_offs = (\n current_batch_idx * maximums_batch_stride\n + current_head_idx * maximums_head_stride\n + block_n_idx * maximums_step_stride\n + m_offs * maximums_m_stride\n )\n maximums_ptrs = maximums_ptr + maximums_offs\n tl.store(maximums_ptrs, l_j, mask=m_offs < m_size)\n\n sums_offs = (\n current_batch_idx * sums_batch_stride\n + current_head_idx * sums_head_stride\n + block_n_idx * sums_step_stride\n + m_offs * sums_m_stride\n )\n sums_ptrs = sums_ptr + sums_offs\n tl.store(sums_ptrs, d_j, mask=m_offs < m_size)\n\n if N_LOAD_MASK_NEEDED:\n v_ptr_mask = block_n_offs[:, None] < n_size\n v = tl.load(v_ptrs + block_n_start_idx * v_k_stride, mask=v_ptr_mask, other=0.0, eviction_policy=\"evict_first\")\n else:\n v = tl.load(v_ptrs + block_n_start_idx * v_k_stride, eviction_policy=\"evict_first\")\n\n result = tl.dot(numerators.to(q_ptr.dtype.element_ty), v)\n\n output_offs = (\n current_batch_idx * output_batch_stride\n + current_head_idx * output_head_stride\n + block_n_idx * output_step_stride\n + (m_offs[:, None] * output_m_stride + d_range_offs[None, :] * output_n_stride)\n )\n\n output_ptrs = output_ptr + output_offs\n\n if M_LOAD_MASK_NEEDED:\n output_ptr_mask = m_offs[:, None] < m_size\n tl.store(output_ptrs, result, mask=output_ptr_mask)\n else:\n tl.store(output_ptrs, result)\n\n@triton.jit\ndef _fwd_part_2(\n head_size,\n intermediates_size,\n m_size,\n input_ptr,\n input_batch_stride,\n input_head_stride,\n input_intermediate_stride,\n input_m_stride,\n input_n_stride,\n maximums_ptr,\n maximums_batch_stride,\n maximums_head_stride,\n maximums_intermediate_stride,\n maximums_m_stride,\n sums_ptr,\n sums_batch_stride,\n sums_head_stride,\n sums_intermediate_stride,\n sums_m_stride,\n output_ptr,\n output_batch_stride,\n output_head_stride,\n output_m_stride,\n output_n_stride,\n BLOCK_M_SIZE: tl.constexpr,\n BLOCK_DHEAD_SIZE: tl.constexpr,\n):\n block_m_idx = tl.program_id(0)\n head_idx = tl.program_id(1)\n current_batch_idx = head_idx // head_size\n current_head_idx = head_idx % head_size\n\n m_range_offs = tl.arange(0, BLOCK_M_SIZE)\n dhead_range_offs = tl.arange(0, BLOCK_DHEAD_SIZE)\n\n m_offs = block_m_idx * BLOCK_M_SIZE + m_range_offs\n\n acc = tl.zeros((BLOCK_M_SIZE, BLOCK_DHEAD_SIZE), dtype=tl.float32)\n l_i = tl.zeros((BLOCK_M_SIZE,), dtype=tl.float32) - float(\"inf\")\n d_i = tl.zeros((BLOCK_M_SIZE,), dtype=tl.float32)\n for n_intermediate_idx in range(0, intermediates_size):\n input_offs = (\n current_batch_idx * input_batch_stride\n + current_head_idx * input_head_stride\n + n_intermediate_idx * input_intermediate_stride\n + (m_offs[:, None] * input_m_stride + dhead_range_offs[None, :] * input_n_stride)\n )\n input_ptrs = input_ptr + input_offs\n numerators = tl.load(input_ptrs, mask=m_offs[:, None] < m_size, other=0.0)\n\n sums_offs = (\n current_batch_idx * sums_batch_stride\n + current_head_idx * sums_head_stride\n + n_intermediate_idx * sums_intermediate_stride\n + m_offs * sums_m_stride\n )\n sums_ptrs = sums_ptr + sums_offs\n d_j = tl.load(sums_ptrs, mask=m_offs < m_size, other=0.0)\n\n maximums_offs = (\n current_batch_idx * maximums_batch_stride\n + current_head_idx * maximums_head_stride\n + n_intermediate_idx * maximums_intermediate_stride\n + m_offs * maximums_m_stride\n )\n maximums_ptrs = maximums_ptr + maximums_offs\n l_j = tl.load(maximums_ptrs, mask=m_offs < m_size, other=0.0)\n\n l_new = tl.maximum(l_i, l_j)\n alpha = tl.exp(l_i - l_new)\n beta = tl.exp(l_j - l_new)\n d_new = alpha * d_i + beta * d_j\n\n p_scale = beta / d_new\n\n acc_scale = d_i / d_new * alpha\n acc *= acc_scale[:, None]\n\n acc += numerators * p_scale[:, None]\n\n d_i = d_new\n l_i = l_new\n\n output_offs = (\n current_batch_idx * output_batch_stride\n + current_head_idx * output_head_stride\n + (m_offs[:, None] * output_m_stride + dhead_range_offs[None, :] * output_n_stride)\n )\n output_ptrs = output_ptr + output_offs\n tl.store(output_ptrs, acc, mask=m_offs[:, None] < m_size)\n\nclass SkinnyAttention(torch.autograd.Function):\n @staticmethod\n @custom_fwd(cast_inputs=torch.float16)\n def forward(\n ctx: FunctionCtx,\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n output: torch.Tensor,\n sm_scale: float,\n is_causal: bool,\n attention_mask: Optional[torch.Tensor] = None,\n ):\n assert q.shape[-1] == k.shape[-1]\n assert q.dtype in [torch.float16, torch.bfloat16], f\"Only float16 and bfloat16 are supported, got {q.dtype}\"\n batch, heads, size_m, dhead = q.size()\n size_n = k.size(2)\n\n BLOCK_M = 16\n BLOCK_N = 128\n NEED_LOAD_MASK_SIZE_N = size_n % BLOCK_N != 0\n NEED_LOAD_MASK_SIZE_M = size_m % BLOCK_M != 0\n\n n_divisions = triton.cdiv(size_n, BLOCK_N)\n splitted_qkt = torch.empty(\n q.size(0), q.size(1), n_divisions, q.size(2), q.size(3), dtype=torch.float16, device=\"cuda\"\n )\n\n grid = (n_divisions, triton.cdiv(size_m, BLOCK_M), batch * heads)\n\n maximums = torch.zeros(\n (\n batch,\n heads,\n n_divisions,\n size_m,\n ),\n device=q.device,\n dtype=torch.float32,\n )\n sums = torch.zeros(\n (\n batch,\n heads,\n n_divisions,\n size_m,\n ),\n device=q.device,\n dtype=torch.float32,\n )\n\n HAS_MASK = False\n if attention_mask is not None:\n assert (\n attention_mask.size(0) == batch or attention_mask.size(0) == 1\n ), \"Incompatible broadcast batch dimension\"\n assert (\n attention_mask.size(1) == heads or attention_mask.size(1) == 1\n ), \"Incompatible broadcast heads dimension\"\n assert (\n attention_mask.size(2) == size_m or attention_mask.size(2) == 1\n ), \"Incompatible broadcast size_m dimension\"\n assert attention_mask.size(3) == size_n, \"Last size of mask must broadcast on QK^t\"\n\n HAS_MASK = True\n\n _fwd_part_1[grid](\n head_size=heads,\n m_size=size_m,\n n_size=size_n,\n q_ptr=q,\n k_ptr=k,\n v_ptr=v,\n sm_scale=sm_scale,\n attention_mask_ptr=attention_mask,\n output_ptr=splitted_qkt,\n maximums_ptr=maximums,\n sums_ptr=sums,\n q_batch_stride=q.stride(0),\n q_head_stride=q.stride(1),\n q_m_stride=q.stride(2),\n q_k_stride=q.stride(3),\n k_batch_stride=k.stride(0),\n k_head_stride=k.stride(1),\n k_n_stride=k.stride(2),\n k_k_stride=k.stride(3),\n v_batch_stride=v.stride(0),\n v_head_stride=v.stride(1),\n v_k_stride=v.stride(2),\n v_n_stride=v.stride(3),\n sums_batch_stride=sums.stride(0),\n sums_head_stride=sums.stride(1),\n sums_step_stride=sums.stride(2),\n sums_m_stride=sums.stride(3),\n maximums_batch_stride=maximums.stride(0),\n maximums_head_stride=maximums.stride(1),\n maximums_step_stride=maximums.stride(2),\n maximums_m_stride=maximums.stride(3),\n output_batch_stride=splitted_qkt.stride(0),\n output_head_stride=splitted_qkt.stride(1),\n output_step_stride=splitted_qkt.stride(2),\n output_m_stride=splitted_qkt.stride(3),\n output_n_stride=splitted_qkt.stride(4),\n attention_mask_batch_stride=attention_mask.stride(0) if HAS_MASK else 0,\n attention_mask_head_stride=attention_mask.stride(1) if HAS_MASK else 0,\n attention_mask_m_stride=attention_mask.stride(2) if HAS_MASK else 0,\n attention_mask_k_stride=attention_mask.stride(3) if HAS_MASK else 0,\n N_LOAD_MASK_NEEDED=NEED_LOAD_MASK_SIZE_N,\n M_LOAD_MASK_NEEDED=NEED_LOAD_MASK_SIZE_M,\n min_clamp_value=torch.finfo(attention_mask.dtype).min if HAS_MASK else 0,\n MASK_BATCH_SIZE=attention_mask.size(0) if HAS_MASK else 0,\n MASK_HEAD_SIZE=attention_mask.size(1) if HAS_MASK else 0,\n MASK_M_SIZE=attention_mask.size(2) if HAS_MASK else 0,\n MASK_K_SIZE=attention_mask.size(3) if HAS_MASK else 0,\n HAS_MASK=HAS_MASK,\n IS_CAUSAL=is_causal,\n BLOCK_M_SIZE=BLOCK_M,\n BLOCK_N_SIZE=BLOCK_N,\n BLOCK_DHEAD_SIZE=dhead,\n num_warps=1,\n num_stages=8,\n )\n\n batch, heads, steps, size_m, dhead = splitted_qkt.size()\n BLOCK_M = 16\n grid_part2 = (triton.cdiv(size_m, BLOCK_M), batch * heads)\n _fwd_part_2[grid_part2](\n heads,\n steps,\n size_m,\n splitted_qkt,\n *splitted_qkt.stride(),\n maximums,\n *maximums.stride(),\n sums,\n *sums.stride(),\n output,\n *output.stride(),\n BLOCK_M_SIZE=BLOCK_M,\n BLOCK_DHEAD_SIZE=dhead,\n num_warps=4,\n num_stages=1,\n )\n return output\n\ndef skinny_attention_forward(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n output: torch.Tensor,\n sm_scale: float,\n is_causal: bool = False,\n attention_mask: Optional[torch.Tensor] = None,\n):\n return SkinnyAttention.apply(q, k, v, output, sm_scale, is_causal, attention_mask)\n", - "description_1": "Use triton language to implement an attention mechanism with two parts, _fwd_part_1 and _fwd_part_2, performing operations with query, key, and value tensors with support for optional masks and causal configurations.", - "description_2": "Use triton language to implement a forward pass for a custom attention mechanism utilizing query, key, and value tensors, supporting optional attention masks and causal processing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef vec_mat(\n vec_col_size: tl.constexpr,\n matrix_row_size: tl.constexpr,\n matrix_col_size: tl.constexpr,\n output_col_size: tl.constexpr,\n vec_ptr,\n vec_batch_stride,\n vec_head_stride,\n vec_row_stride,\n vec_col_stride,\n matrix_ptr,\n matrix_batch_stride,\n matrix_head_stride,\n matrix_row_stride,\n matrix_col_stride,\n output_ptr,\n output_batch_stride,\n output_head_stride,\n output_row_stride,\n output_col_stride,\n SCALER: tl.constexpr,\n SHOULD_VEC_SOFTMAX: tl.constexpr,\n VEC_COL_ROUNDED_SIZE: tl.constexpr,\n N_SIZE: tl.constexpr,\n):\n block_n_idx = tl.program_id(0)\n head_idx = tl.program_id(1)\n batch_idx = tl.program_id(2)\n\n n_range_offs = tl.arange(0, N_SIZE)\n vec_col_rounded_range_offs = tl.arange(0, VEC_COL_ROUNDED_SIZE)\n\n vec_ptrs = vec_ptr + (\n batch_idx * vec_batch_stride + head_idx * vec_head_stride + vec_col_stride * vec_col_rounded_range_offs[:, None]\n )\n vec_ptr_mask = vec_col_rounded_range_offs[:, None] < vec_col_size\n vec = tl.load(pointer=vec_ptrs, mask=vec_ptr_mask, other=0.0).to(tl.float32)\n\n if SCALER != 1.0:\n vec = vec * SCALER\n\n if SHOULD_VEC_SOFTMAX:\n vec_max = tl.max(vec, axis=0)\n vec = vec - vec_max[:, None]\n vec = tl.exp(vec)\n vec = vec / tl.sum(vec, axis=0)[:, None]\n\n matrix_ptrs = matrix_ptr + (\n batch_idx * matrix_batch_stride\n + head_idx * matrix_head_stride\n + vec_col_rounded_range_offs[:, None] * matrix_row_stride # cols\n + (block_n_idx * N_SIZE + n_range_offs)[None, :] * matrix_col_stride # rows\n )\n matrix_ptr_mask = (vec_col_rounded_range_offs[:, None] < matrix_row_size) & (\n (block_n_idx * N_SIZE + n_range_offs)[None, :] < matrix_col_size\n )\n matrix = tl.load(pointer=matrix_ptrs, mask=matrix_ptr_mask, other=0.0).to(tl.float32)\n\n result = vec * matrix\n result = tl.sum(input=result, axis=0)\n\n output_ptrs = output_ptr + (\n batch_idx * output_batch_stride\n + head_idx * output_head_stride\n + (block_n_idx * N_SIZE + n_range_offs) * output_col_stride\n )\n output_ptr_mask = (block_n_idx * N_SIZE + n_range_offs) < output_col_size\n tl.store(pointer=output_ptrs, value=result, mask=output_ptr_mask)\n\n\ndef vec_mat_wrapper(\n vec: torch.Tensor,\n matrix: torch.Tensor,\n output: torch.Tensor,\n scaler: float,\n softmax_vec: bool,\n transpose_mat: bool,\n) -> torch.Tensor:\n vec_cols = vec.shape[-1]\n out_cols = output.shape[-1]\n\n batch, heads, mat_rows, mat_cols = matrix.shape\n matrix_stride = list(matrix.stride())\n if transpose_mat:\n matrix_stride[-1], matrix_stride[-2] = matrix_stride[-2], matrix_stride[-1]\n mat_rows, mat_cols = mat_cols, mat_rows\n\n assert vec.shape[-2] == output.shape[-2] == 1\n assert mat_cols == out_cols\n assert vec_cols == mat_rows\n\n def grid(args) -> Tuple[int, int, int]:\n return triton.cdiv(mat_cols, args[\"N_SIZE\"]), heads, batch\n\n vec_cols_pow_2 = triton.next_power_of_2(vec_cols)\n\n vec_mat[grid](\n vec_cols,\n mat_rows,\n mat_cols,\n out_cols,\n vec,\n *vec.stride(),\n matrix,\n *matrix_stride,\n output,\n *output.stride(),\n scaler,\n softmax_vec,\n vec_cols_pow_2,\n )\n return output\n", - "description_1": "Use triton language to create a vector-matrix multiplication kernel. The kernel has 26 parameters: vec_col_size (column size of vector), matrix_row_size (row size of matrix), matrix_col_size (column size of matrix), output_col_size (column size of output), vec_ptr (pointer to vector), vec_batch_stride (batch stride of vector), vec_head_stride (head stride of vector), vec_row_stride (row stride of vector), vec_col_stride (column stride of vector), matrix_ptr (pointer to matrix), matrix_batch_stride (batch stride of matrix), matrix_head_stride (head stride of matrix), matrix_row_stride (row stride of matrix), matrix_col_stride (column stride of matrix), output_ptr (pointer to output), output_batch_stride (batch stride of output), output_head_stride (head stride of output), output_row_stride (row stride of output), output_col_stride (column stride of output), SCALER (scaling factor), SHOULD_VEC_SOFTMAX (flag for softmax on vector), VEC_COL_ROUNDED_SIZE (rounded size for vector columns), N_SIZE (number of size elements for block), and performs the multiplication by loading data and computing the result in blocks using Triton operations.", - "description_2": "Use triton language to perform vector-matrix multiplication with optional softmax on vector, and scale the vector before multiplication in a block-wise manner.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n m_size,\n n_size,\n k_size,\n a_batch_stride,\n a_m_stride,\n a_k_stride,\n b_batch_stride,\n b_k_stride,\n b_n_stride,\n c_batch_stride,\n c_m_stride,\n c_n_stride,\n BLOCK_M_SIZE: tl.constexpr,\n BLOCK_N_SIZE: tl.constexpr,\n BLOCK_K_SIZE: tl.constexpr,\n GROUP_M_SIZE: tl.constexpr,\n):\n batch_idx = tl.program_id(axis=1)\n program_idx = tl.program_id(axis=0)\n\n program_m_count = tl.cdiv(m_size, BLOCK_M_SIZE)\n program_n_count = tl.cdiv(n_size, BLOCK_N_SIZE)\n\n program_in_group_count = GROUP_M_SIZE * program_n_count\n group_idx = program_idx // program_in_group_count\n first_program_m_idx = group_idx * GROUP_M_SIZE\n GROUP_M_SIZE = min(program_m_count - first_program_m_idx, GROUP_M_SIZE)\n program_m_idx = first_program_m_idx + (program_idx % GROUP_M_SIZE)\n program_n_idx = (program_idx % program_in_group_count) // GROUP_M_SIZE\n\n a_offs = program_m_idx * BLOCK_M_SIZE + tl.arange(0, BLOCK_M_SIZE)\n b_offs = program_n_idx * BLOCK_N_SIZE + tl.arange(0, BLOCK_N_SIZE)\n\n k_range_offs = tl.arange(0, BLOCK_K_SIZE)\n\n a_ptrs = a_ptr + a_batch_stride * batch_idx + (a_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride)\n b_ptrs = b_ptr + b_batch_stride * batch_idx + (k_range_offs[:, None] * b_k_stride + b_offs[None, :] * b_n_stride)\n\n accumulator = tl.zeros((BLOCK_M_SIZE, BLOCK_N_SIZE), dtype=tl.float32)\n for k in range(0, k_size, BLOCK_K_SIZE):\n a_ptr_mask = (a_offs[:, None] < m_size) & (k_range_offs[None, :] < k_size)\n a = tl.load(a_ptrs, mask=a_ptr_mask, other=0)\n\n b_ptr_mask = (k_range_offs[:, None] < k_size) & (b_offs[None, :] < n_size)\n b = tl.load(b_ptrs, mask=b_ptr_mask, other=0)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_K_SIZE * a_k_stride\n b_ptrs += BLOCK_K_SIZE * b_k_stride\n\n c = accumulator.to(tl.float16)\n\n c_m_offs = program_m_idx * BLOCK_M_SIZE + tl.arange(0, BLOCK_M_SIZE)\n c_n_offs = program_n_idx * BLOCK_N_SIZE + tl.arange(0, BLOCK_N_SIZE)\n c_ptrs = c_ptr + c_batch_stride * batch_idx + c_m_stride * c_m_offs[:, None] + c_n_stride * c_n_offs[None, :]\n c_ptr_mask = (c_m_offs[:, None] < m_size) & (c_n_offs[None, :] < n_size)\n tl.store(c_ptrs, c, mask=c_ptr_mask)\n\ndef batched_matmul(a, b):\n assert a.shape[2] == b.shape[1], \"incompatible dimensions\"\n assert a.is_contiguous(), \"matrix A must be contiguous\"\n assert b.is_contiguous(), \"matrix B must be contiguous\"\n batch_size, M, K = a.shape\n _, K, N = b.shape\n c = torch.empty((batch_size, M, N), device=a.device, dtype=a.dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_M_SIZE\"]) * triton.cdiv(N, META[\"BLOCK_N_SIZE\"]),\n batch_size,\n )\n matmul_kernel[grid](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n a.stride(2),\n b.stride(0),\n b.stride(1),\n b.stride(2),\n c.stride(0),\n c.stride(1),\n c.stride(2),\n )\n return c\n", - "description_1": "Use triton language to implement a batched matrix multiplication kernel. The kernel 'matmul_kernel' takes 19 parameters: pointers to matrices A, B, and C, dimensions m_size, n_size, k_size, strides for each dimension of A, B, and C, and meta-parameters BLOCK_M_SIZE, BLOCK_N_SIZE, BLOCK_K_SIZE, and GROUP_M_SIZE. The function 'batched_matmul' takes two parameters: matrices a and b, checks their dimensions and contiguity, allocates an output matrix c, and launches the kernel with a grid configuration.", - "description_2": "Use triton language to create a kernel for batched matrix multiplication with configurable block sizes and group sizes, and a function to prepare and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_fwd\nfrom triton import JITFunction\n\n@triton.jit\ndef layer_norm_xformers(\n output_ptr,\n a_ptr,\n weight_ptr,\n bias_ptr,\n mean_ptr,\n rstd_ptr,\n output_row_stride,\n output_col_stride,\n a_row_stride,\n a_col_stride,\n N_SIZE,\n eps,\n HAS_BIAS: tl.constexpr, \n IS_RMSNORM: tl.constexpr, \n BLOCK_N_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_N_SIZE)\n mask = cols < N_SIZE\n\n x_ptrs = a_ptr + row * a_row_stride + cols * a_col_stride\n\n x = tl.load(x_ptrs, mask=mask, other=0.0, eviction_policy=\"evict_first\").to(tl.float32)\n w = tl.load(weight_ptr + cols, mask=mask, other=1.0)\n b = tl.load(bias_ptr + cols, mask=mask, other=0.0)\n\n mean = tl.sum(x, axis=0) / N_SIZE\n x_zm = tl.where(mask, x - mean, 0.0)\n tl.store(mean_ptr + row, mean)\n\n x_var = tl.sum(x_zm * x_zm, axis=0) / N_SIZE\n rstd = 1.0 / tl.sqrt(x_var + eps)\n\n y = x_zm * rstd\n tl.store(rstd_ptr + row, rstd)\n\n y = y * w + b\n y_ptrs = output_ptr + row * output_row_stride + cols * output_col_stride\n tl.store(y_ptrs, y, mask=mask)\n\n@triton.jit\ndef _layer_norm_fwd_fused_single_pass(\n output_ptr,\n a_ptr,\n weight_ptr,\n bias_ptr,\n mean_ptr,\n rstd_ptr,\n output_row_stride,\n output_col_stride,\n a_row_stride,\n a_col_stride,\n N_SIZE,\n eps,\n HAS_BIAS: tl.constexpr,\n IS_RMSNORM: tl.constexpr,\n BLOCK_N_SIZE: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n\n a_row_off = row_idx * a_row_stride\n block_range_offs = tl.arange(0, BLOCK_N_SIZE)\n mean = 0.0\n var = 0.0\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n n_end_off = min((block_n_start_idx + BLOCK_N_SIZE), N_SIZE)\n block_cols_count = n_end_off - block_n_start_idx\n col_offs = block_n_start_idx + block_range_offs\n a_ptr_mask = col_offs < N_SIZE\n a = tl.load(\n a_ptr + a_row_off + col_offs * a_col_stride, mask=a_ptr_mask, other=0.0, eviction_policy=\"evict_last\"\n ).to(tl.float32)\n if IS_RMSNORM:\n var += tl.sum(a * a, axis=0)\n else:\n block_mean = tl.sum(a, axis=0) / block_cols_count\n delta_mean = block_mean - mean\n delta_mean_sqr = delta_mean * delta_mean\n\n block_delta = tl.sum((a - block_mean) * a, axis=0)\n mean += tl.sum((a - mean) * a_ptr_mask, axis=0) / n_end_off\n var += block_delta + delta_mean_sqr * (block_n_start_idx * block_cols_count) / n_end_off\n\n var /= N_SIZE\n rstd = 1 / tl.sqrt(var + eps)\n\n tl.store(mean_ptr + row_idx, mean)\n tl.store(rstd_ptr + row_idx, rstd)\n\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n col_offs = block_n_start_idx + block_range_offs\n a_ptr_mask = col_offs < N_SIZE\n weight = tl.load(weight_ptr + col_offs, mask=a_ptr_mask)\n a = tl.load(\n a_ptr + a_row_off + col_offs * a_col_stride, mask=a_ptr_mask, other=0.0, eviction_policy=\"evict_first\"\n ).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight\n if HAS_BIAS:\n bias = tl.load(bias_ptr + col_offs, mask=a_ptr_mask)\n out = out + bias\n tl.store(output_ptr + row_idx * output_row_stride + col_offs * output_col_stride, out, mask=a_ptr_mask)\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n @custom_fwd(cast_inputs=torch.float16)\n def forward(\n ctx,\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n eps: float,\n implementation: JITFunction,\n use_rms_norm: bool,\n ):\n assert x.dtype == weight.dtype, f\"input and weight bias must have the same dtype: {x.dtype}, {weight.dtype}\"\n if bias is not None:\n assert x.dtype == bias.dtype, f\"input and bias must have the same dtype: {x.dtype}, {bias.dtype}\"\n if x.dtype == torch.float16:\n eps = max(eps, 1.6e-5)\n out = torch.empty_like(x)\n a_arg = x.reshape(-1, x.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n std = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n if implementation == layer_norm_xformers:\n assert N <= 4096, \"LayerNorm: N is too large for xformers implementation\"\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n implementation[(M,)](\n output_ptr=out,\n a_ptr=a_arg,\n weight_ptr=weight,\n bias_ptr=bias if bias is not None else a_arg,\n mean_ptr=mean,\n rstd_ptr=std,\n output_row_stride=out.stride(-2),\n output_col_stride=out.stride(-1),\n a_row_stride=a_arg.stride(0),\n a_col_stride=a_arg.stride(1),\n N_SIZE=N,\n eps=eps,\n HAS_BIAS=bias is not None,\n IS_RMSNORM=use_rms_norm,\n BLOCK_N_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(x, mean, std, weight)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n return out\n\ndef layer_norm(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n eps: float,\n implementation: JITFunction = _layer_norm_fwd_fused_single_pass,\n use_rms_norm: bool = False,\n):\n return LayerNorm.apply(x, weight, bias, eps, implementation, use_rms_norm)\n", - "description_1": "Use triton language to implement multiple layer normalization kernel functions with different approaches. The kernels take input tensors, perform normalization using layer normalization algorithm with or without bias, and support fused single pass and multi-pass computation methods. The parameters include pointers to output and input tensors, weight and bias tensors, strides for tensors, size parameters, epsilon for numerical stability, and constants for optional bias and RMSNorm settings. A corresponding PyTorch autograd function wraps the kernel for easier use.", - "description_2": "Use triton language to implement layer normalization kernels and integrate with PyTorch for optimized batch processing with optional bias and different normalization strategies.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.autograd.function import FunctionCtx\nfrom torch.cuda.amp import custom_fwd\n\n@triton.jit\ndef kernel_fma(\n C, # Pointers to matrices\n ACT_INPUTS,\n A,\n B,\n bias,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n output_m_stride,\n output_n_stride,\n act_inputs_m_stride,\n act_inputs_n_stride,\n a_m_stride,\n a_k_stride,\n b_n_stride,\n b_k_stride,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n K_LOAD_MASK_NEEDED: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n SHOULD_SAVE_ACT_INPUTS: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Bias has shape (N,)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n\n This kernel will consolidate over K\n \"\"\"\n program_idx = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_idx = program_idx // width\n group_size = min(grid_m - group_idx * GROUP_M, GROUP_M)\n block_m_idx = group_idx * GROUP_M + (program_idx % group_size)\n block_n_idx = (program_idx % width) // group_size\n\n m_offs_untagged = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offs_untagged = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)\n\n m_offs = tl.max_contiguous(tl.multiple_of(m_offs_untagged % M, BLOCK_M), BLOCK_M)\n n_offs = tl.max_contiguous(tl.multiple_of(n_offs_untagged % N, BLOCK_N), BLOCK_N)\n\n k_range_offs = tl.arange(0, BLOCK_K)\n\n A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride)\n B = B + (k_range_offs[:, None] * b_k_stride + n_offs[None, :] * b_n_stride)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n if HAS_BIAS:\n bias = tl.load(bias + n_offs, mask=n_offs < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n for k in range(K, 0, -BLOCK_K):\n if K_LOAD_MASK_NEEDED:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=k_range_offs[None, :] < k, other=0.0)\n b = tl.load(B, mask=k_range_offs[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n A += BLOCK_K * a_k_stride\n B += BLOCK_K * b_k_stride\n\n if SHOULD_SAVE_ACT_INPUTS:\n act_in_ptrs = ACT_INPUTS + m_offs[:, None] * act_inputs_m_stride + n_offs[None, :] * act_inputs_n_stride\n tl.store(act_in_ptrs, acc)\n\n if ACTIVATION == \"tanh\":\n acc = activation_func.tanh(acc)\n if ACTIVATION == \"gelu\":\n acc = activation_func.gelu(acc)\n if ACTIVATION == \"fast_gelu\":\n acc = activation_func.fast_gelu(acc)\n if ACTIVATION == \"relu\":\n acc = activation_func.relu(acc)\n\n C = C + m_offs[:, None] * output_m_stride + n_offs[None, :] * output_n_stride\n c_ptr_mask = (m_offs < M)[:, None] & (n_offs < N)[None, :]\n tl.store(C, acc, mask=c_ptr_mask)\n\n\nclass LinearLayer(torch.autograd.Function):\n @staticmethod\n @custom_fwd(cast_inputs=torch.float16)\n def forward(\n ctx: FunctionCtx,\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n activation: str,\n act_inputs: Optional[torch.Tensor],\n ) -> torch.Tensor:\n \"\"\"\n Compute e = activation(x @ weight + bias).\n This wrapper kicks the `kernel_fma` Triton kernel\n :param ctx: context for autograd\n :param x: input tensor\n :param weight: weight matrix\n :param bias: an optional bias tensor\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_inputs: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n x_ = x if x.ndim == 2 else x.flatten(0, 1)\n\n assert x.dtype == weight.dtype, f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert x.dtype == bias.dtype, f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert x_.shape[1] == weight.shape[1], f\"Incompatible dimensions: {x_.shape} - {weight.shape}\"\n\n assert bias is None or bias.is_contiguous()\n assert bias is None or bias.shape[0] == weight.shape[0], \"Incompatible dimensions in between weight and bias\"\n assert weight.is_contiguous()\n\n M, K = x_.shape\n N, K = weight.shape\n\n outputs = torch.empty((M, N), device=x.device, dtype=x.dtype)\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_fma[grid](\n outputs,\n act_inputs,\n x_,\n weight, # data ptrs\n bias if bias is not None else x, # auto skip bias if not present\n M, # shapes\n N,\n K,\n M // 32, # key for triton cache (limit number of compilations)\n N // 32,\n K // 32,\n output_m_stride=outputs.stride(0), # strides\n output_n_stride=outputs.stride(1),\n act_inputs_m_stride=act_inputs.stride(0) if act_inputs is not None else 0,\n act_inputs_n_stride=act_inputs.stride(1) if act_inputs is not None else 0,\n a_m_stride=x_.stride(0),\n a_k_stride=x_.stride(1),\n b_n_stride=weight.stride(0),\n b_k_stride=weight.stride(1),\n HAS_BIAS=bias is not None, # optional fused bias\n SHOULD_SAVE_ACT_INPUTS=act_inputs is not None, # optional save activation inputs\n ACTIVATION=activation if not None else x, # optional fused activation\n GROUP_M=8, # speed optimization: group the programs\n )\n\n outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N)\n ctx.save_for_backward(weight, bias, x)\n return outputs\n\n\ndef linear_layer(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n activation=\"\",\n act_inputs: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n return LinearLayer.apply(x, weight, bias, activation, act_inputs)\n", - "description_1": "Use triton language to implement a kernel `kernel_fma` that performs matrix multiplication with activation and optional bias addition. The kernel takes pointers to input and weight matrices, optional bias and activation inputs, and meta-parameters to control block sizes and masking. It computes the matrix product A x B + bias and applies the specified activation function. The kernel is called from a PyTorch custom autograd function `LinearLayer` which provides an interface for forward computation using the Triton kernel.", - "description_2": "Use triton language to create a matrix multiplication kernel with activation function, wrapped in a PyTorch autograd function for easy integration.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_rms_w,\n stride_out_batch, stride_out_m, stride_out_k,\n N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):\n pid_batch = tl.program_id(0)\n pid_m = tl.program_id(1)\n\n offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m\n block_N = tl.arange(0, BLOCK_N_SIZE)\n var = tl.zeros((BLOCK_N_SIZE,), tl.float32)\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n offs_n = block_n_start_idx + block_N\n x_ptr_mask = offs_n < N_SIZE\n x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)\n var += tl.math.pow(x.to(tl.float32), 2)\n\n var = tl.sum(var, axis=0) / N_SIZE\n rstd = tl.math.rsqrt(var + eps)\n\n # multiply by weight and add bias\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n offs_n = block_n_start_idx + block_N\n x_ptr_mask = offs_n < N_SIZE\n rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)\n\n x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)\n x_hat = x * rstd\n out = x_hat * rms_w\n out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k\n tl.store(output_ptr + out_off, out, mask=x_ptr_mask)\n\ndef rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):\n batch, M, K = x.shape\n assert rms_w.shape[-1] == K\n out = torch.empty_like(x)\n rmsnorm_triton[(batch, M,)](x, rms_w, out,\n *x.stride(),\n *rms_w.stride(),\n *out.stride(),\n N_SIZE=K, eps=eps, BLOCK_N_SIZE=1024,\n )\n return out\n\n@triton.jit\ndef rbe_triton(x_ptr, out_ptr,\n M, K,\n stride_x_batch, stride_x_m, stride_x_n,\n stride_out_batch, stride_out_m, stride_out_n,\n start_token_position,\n THETA: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n pid_batch = tl.program_id(axis=0)\n pid = tl.program_id(axis=1)\n pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)\n pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers\n x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])\n x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)\n real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)\n x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)\n imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)\n tl.debug_barrier()\n start_block = start_token_position + pid_m * BLOCK_SIZE_M\n cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)\n\n out_real = real * cos - imag * sin\n out_imag = real * sin + imag * cos\n tl.debug_barrier()\n out_ptrs = out_ptr + (\n pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :])\n out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)\n tl.store(out_ptrs, out_real, mask=out_real_mask)\n out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)\n tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)\n\ndef rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:\n batch, M, K = x.shape\n out = torch.empty_like(x)\n grid = lambda META: (\n batch, triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"K\"], META[\"BLOCK_SIZE_K\"]),)\n\n rbe_triton[grid](x, out,\n M, K,\n *x.stride(),\n *out.stride(),\n start_token_position=pos, THETA=10000., BLOCK_SIZE_M=2, BLOCK_SIZE_K=1024)\n return out\n\n@triton.jit\ndef rms_matmul_rbe(\n x_ptr, w_ptr, rms_w_ptr, out_ptr,\n M, N, K,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_w_k, stride_w_n,\n stride_rms_w,\n stride_out_batch, stride_out_m, stride_out_n,\n start_token_position,\n USE_FP8: tl.constexpr,\n RBE_EPILOGUE: tl.constexpr,\n THETA: tl.constexpr,\n EPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_batch = tl.program_id(axis=0)\n pid = tl.program_id(axis=1)\n pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)\n pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)\n\n offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n x_ptrs = x_ptr + (pid_batch * stride_x_batch + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k)\n w_ptrs = w_ptr + (offs_k[:, None] * stride_w_k + offs_n[None, :] * stride_w_n)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n rms_w_ptrs = rms_w_ptr + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_rms_w\n x_sum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n x = tl.load(x_ptrs)\n x_sum += tl.math.pow(x.to(tl.float32), 2)\n rms_w = tl.load(rms_w_ptrs)\n if USE_FP8:\n rms_w = rms_w.to(tl.float8e5, bitcast=True)\n rms_w = rms_w.to(tl.float16)\n x = x * rms_w\n w = tl.load(w_ptrs)\n if USE_FP8:\n w = w.to(tl.float8e5, bitcast=True)\n w = w.to(tl.float32)\n w = w.to(tl.float16)\n accumulator += tl.dot(x, w)\n x_ptrs += BLOCK_SIZE_K * stride_x_k\n w_ptrs += BLOCK_SIZE_K * stride_w_k\n rms_w_ptrs += BLOCK_SIZE_K * stride_rms_w\n x_mean = tl.sum(x_sum, axis=1) / K + EPS\n x_norm = tl.math.rsqrt(x_mean)\n accumulator = accumulator * x_norm[:, None]\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n out_ptrs = out_ptr + (\n pid_batch * stride_out_batch + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n)\n out_mask = (offs_m[:, None] < M) & (offs_n[:, None] < N)\n\n if RBE_EPILOGUE:\n tl.store(out_ptrs, accumulator, mask=out_mask)\n tl.debug_barrier()\n rbe_triton(out_ptr, out_ptr, M, N, stride_out_batch, stride_out_m, stride_out_n, stride_out_batch, stride_out_m,\n stride_out_n, start_token_position, THETA,\n BLOCK_SIZE_M, BLOCK_SIZE_N)\n else:\n tl.store(out_ptrs, accumulator, mask=out_mask)\n\ndef rms_matmul_rbe_wrapper(x: torch.Tensor, weight: torch.Tensor, rms_w: torch.Tensor, use_rbe: bool, start_pos: int,\n n_heads: int, head_dim: int):\n assert weight.dtype == rms_w.dtype\n assert weight.dtype in [torch.float16, torch.int8]\n batch, M, K = x.shape\n weight_t = weight.t()\n K_W, N = weight_t.shape\n assert K == K_W\n out = torch.empty((batch, M, N), dtype=weight_t.dtype, device=weight_t.device)\n out_ptr = triton.reinterpret(out, tl.float8e5 if out.dtype == torch.int8 else tl.float16)\n\n grid = lambda META: (\n batch, triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"N\"], META[\"BLOCK_SIZE_N\"]))\n\n rms_matmul_rbe[grid](\n x_ptr=x,\n w_ptr=weight_t, rms_w_ptr=rms_w, out_ptr=out_ptr,\n M=M, N=N, K=K,\n stride_x_batch=x.stride(0), stride_x_m=x.stride(1), stride_x_k=x.stride(2),\n stride_w_k=weight_t.stride(0), stride_w_n=weight_t.stride(1),\n stride_rms_w=rms_w.stride(0),\n stride_out_batch=out.stride(0), stride_out_m=out.stride(1), stride_out_n=out.stride(2),\n start_token_position=start_pos,\n USE_FP8=weight_t.dtype == torch.int8,\n RBE_EPILOGUE=use_rbe,\n THETA=10000.,\n EPS=1e-6,\n BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64,\n num_stages=4, num_warps=4\n )\n out = out.view(batch, M, n_heads, head_dim)\n return out\n\n@triton.jit\ndef rms_matmul_rbe_qkv(x_ptr,\n q_weight_ptr, k_weight_ptr, v_weight_ptr,\n rms_w_ptr,\n q_ptr, k_ptr, v_ptr,\n M, N, K,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_q_w_k, stride_q_w_n,\n stride_k_w_k, stride_k_w_n,\n stride_v_w_k, stride_v_w_n,\n stride_rms_w,\n stride_q_batch, stride_q_m, stride_q_n,\n stride_k_batch, stride_k_m, stride_k_n,\n stride_v_batch, stride_v_m, stride_v_n,\n start_token_position,\n USE_FP8: tl.constexpr,\n THETA: tl.constexpr,\n EPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n # q\n rms_matmul_rbe(\n x_ptr=x_ptr,\n w_ptr=q_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=q_ptr,\n M=M, N=N, K=K,\n stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k,\n stride_w_k=stride_q_w_k, stride_w_n=stride_q_w_n,\n stride_rms_w=stride_rms_w,\n stride_out_batch=stride_q_batch, stride_out_m=stride_q_m, stride_out_n=stride_q_n,\n start_token_position=start_token_position,\n USE_FP8=USE_FP8,\n RBE_EPILOGUE=True,\n THETA=THETA,\n EPS=EPS,\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n # k\n rms_matmul_rbe(\n x_ptr=x_ptr,\n w_ptr=k_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=k_ptr,\n M=M, N=N, K=K,\n stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k,\n stride_w_k=stride_k_w_k, stride_w_n=stride_k_w_n,\n stride_rms_w=stride_rms_w,\n stride_out_batch=stride_k_batch, stride_out_m=stride_k_m, stride_out_n=stride_k_n,\n start_token_position=start_token_position,\n USE_FP8=USE_FP8,\n RBE_EPILOGUE=True,\n THETA=THETA,\n EPS=EPS,\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n # v\n rms_matmul_rbe(\n x_ptr=x_ptr,\n w_ptr=v_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=v_ptr,\n M=M, N=N, K=K,\n stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k,\n stride_w_k=stride_v_w_k, stride_w_n=stride_v_w_n,\n stride_rms_w=stride_rms_w,\n stride_out_batch=stride_v_batch, stride_out_m=stride_v_m, stride_out_n=stride_v_n,\n start_token_position=start_token_position,\n USE_FP8=USE_FP8,\n RBE_EPILOGUE=False,\n THETA=THETA,\n EPS=EPS,\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n\ndef rms_matmul_rbe_qkv_wrapper(x: torch.Tensor,\n start_pos: int,\n q_weight: torch.Tensor, k_weight: torch.Tensor, v_weight: torch.Tensor,\n rms_w: torch.Tensor,\n n_heads: int, head_dim: int,\n k: torch.Tensor,\n v: torch.Tensor,\n eps: float = 1e-6, theta=10000.):\n assert q_weight.shape == k_weight.shape == v_weight.shape\n assert q_weight.dtype == k_weight.dtype == v_weight.dtype == rms_w.dtype\n assert q_weight.dtype in [torch.float16, torch.int8]\n batch, M, K = x.shape\n\n assert K == rms_w.shape[0]\n\n q_weight_t = q_weight.t()\n k_weight_t = k_weight.t()\n v_weight_t = v_weight.t()\n K_W, N = q_weight_t.shape\n assert K == K_W\n q = torch.empty((batch, M, N), dtype=torch.float16, device=q_weight_t.device)\n\n k = k.view((batch, M, N))\n v = v.view((batch, M, N))\n assert k.dtype == k_weight.dtype\n assert v.dtype == v_weight.dtype\n\n q_ptr = triton.reinterpret(q, tl.float16)\n k_ptr = triton.reinterpret(k, tl.float8e5 if k.dtype == torch.int8 else tl.float16)\n v_ptr = triton.reinterpret(v, tl.float8e5 if v.dtype == torch.int8 else tl.float16)\n\n grid = lambda META: (\n batch, triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"N\"], META[\"BLOCK_SIZE_N\"]))\n\n rms_matmul_rbe_qkv[grid](\n x_ptr=x,\n q_weight_ptr=q_weight_t, k_weight_ptr=k_weight_t, v_weight_ptr=v_weight_t,\n rms_w_ptr=rms_w,\n q_ptr=q_ptr, k_ptr=k_ptr, v_ptr=v_ptr,\n M=M, N=N, K=K,\n stride_x_batch=x.stride(0), stride_x_m=x.stride(1), stride_x_k=x.stride(2),\n stride_q_w_k=q_weight_t.stride(0), stride_q_w_n=q_weight_t.stride(1),\n stride_k_w_k=k_weight_t.stride(0), stride_k_w_n=k_weight_t.stride(1),\n stride_v_w_k=v_weight_t.stride(0), stride_v_w_n=v_weight_t.stride(1),\n stride_rms_w=rms_w.stride(0),\n stride_q_batch=q.stride(0), stride_q_m=q.stride(1), stride_q_n=q.stride(2),\n stride_k_batch=k.stride(0), stride_k_m=k.stride(1), stride_k_n=k.stride(2),\n stride_v_batch=v.stride(0), stride_v_m=v.stride(1), stride_v_n=v.stride(2),\n start_token_position=start_pos,\n USE_FP8=q_weight.dtype == torch.int8,\n THETA=theta,\n EPS=eps,\n BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64,\n num_stages=4, num_warps=4\n )\n q = q.view(batch, M, n_heads, head_dim)\n k = k.view(batch, M, n_heads, head_dim)\n v = v.view(batch, M, n_heads, head_dim)\n return q, k, v\n", - "description_1": "Use triton language to implement four kernels: 1) rmsnorm_triton for Root Mean Square Layer Normalization with 11 tensor arguments and 3 meta-parameters. 2) rbe_triton for Rotary Positional Embedding computation with 9 tensor arguments and 3 meta-parameters. 3) rms_matmul_rbe for performing matrix multiplication with RMS normalization and optional rotary embedding epilogue, having 10 tensor arguments and 8 meta-parameters. 4) rms_matmul_rbe_qkv for applying RMS and Rotary embeddings on QKV matrices in sequence with 18 tensor arguments and 5 meta-parameters.", - "description_2": "Use triton language to implement RMS normalization with meta-parameters. Use triton language to apply rotary positional embeddings on tensor matrices.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n V,\n M,\n Out,\n vec_stride_x,\n matrix_stride_x,\n matrix_stride_y,\n out_stride_x,\n out_stride_y,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n IS_DOT: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n # transpose matrix\n matrix_ptr = M + d_head_arange[None, :] * matrix_stride_y + size_m_arange[:, None] * matrix_stride_x\n matrix = tl.load(matrix_ptr)\n out_ptr = Out + size_m_arange * out_stride_y\n\n if IS_DOT:\n vec_ptr = V + vec_stride_x * size_m_arange[:, None] + vec_stride_x * d_head_arange[None, :]\n vec = tl.load(vec_ptr, mask=size_m_arange[:, None] < 1, other=0.0)\n result = tl.dot(matrix, vec, trans_a=False, trans_b=True)\n else:\n vec_ptr = V + vec_stride_x * d_head_arange[None, :]\n vec = tl.load(vec_ptr)\n result = matrix.to(tl.float32) * vec.to(tl.float32)\n\n result = tl.sum(result, axis=1)\n tl.store(out_ptr, result)\n\nsize_m = 16\nd_head = 128\n\nvec = torch.randn((d_head,), dtype=torch.float16, device=\"cuda\")\nmatrix = torch.randn((size_m, d_head), dtype=torch.float16, device=\"cuda\")\nout = torch.zeros((1, size_m), dtype=torch.float16, device=\"cuda\")\n\nn_repeat = 10000\ngrid = (10000,)\n\nprint(\"CUDA times\")\nfor use_dot in [True, False]:\n start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]\n end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]\n # warmup\n for _ in range(n_repeat):\n kernel[grid](\n vec,\n matrix,\n out,\n *vec.stride(),\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n use_dot,\n )\n # run\n torch.cuda.synchronize()\n for i in range(n_repeat):\n start_event[i].record()\n kernel[grid](\n vec,\n matrix,\n out,\n *vec.stride(),\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n use_dot,\n )\n torch.cuda.synchronize()\n end_event[i].record()\n times_run = torch.median(torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]))\n # overhead\n\n for i in range(n_repeat):\n start_event[i].record()\n overhead_kernel[grid](\n vec,\n matrix,\n out,\n *vec.stride(),\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n use_dot,\n )\n torch.cuda.synchronize()\n end_event[i].record()\n times_overhead = torch.median(torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]))\n assert torch.allclose(out, vec @ matrix.t(), atol=1e-4)\n print(f\"{'tl.dot(a, b)' if use_dot else 'tl.sum(a * b, 1)':<20}{times_run.item() - times_overhead.item():.4f}\")\n", - "description_1": "Use triton language to define a kernel function that performs either a dot product or element-wise multiplication between a matrix and a vector, depending on a boolean flag. The kernel takes 10 parameters: V (vector), M (matrix), Out (output), vec_stride_x, matrix_stride_x, matrix_stride_y, out_stride_x, out_stride_y (stride values for accessing elements), SIZE_M (size of the matrix), D_HEAD (dimension of the head), and IS_DOT (flag to choose operation). The kernel loads the matrix and vector, performs the specified operation, and stores the result in the output.", - "description_2": "Use triton language to implement a kernel that computes either a dot product or element-wise multiplication between a matrix and a vector, controlled by a flag.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n # transpose\n matrix_ptr = M + d_head_arange[None, :] * matrix_stridey + size_m_arange[:, None] * matrix_stridex\n out_ptr = Out + d_head_arange[None, :] * out_stridex + size_m_arange[:, None] * out_stridey\n matrix = tl.load(matrix_ptr)\n tl.store(out_ptr, matrix)\n\nsize_m = 16\nd_head = 32\n\nmatrix = torch.randn((size_m, d_head), dtype=torch.float16, device=\"cuda\")\nout = torch.zeros((d_head, size_m), dtype=torch.float16, device=\"cuda\")\n\ngrid = (1,)\nkernel[grid](\n matrix,\n out,\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n)\n\nassert torch.allclose(matrix.t(), out)\n", - "description_1": "Use triton language to define a kernel that transposes a matrix. The kernel takes 8 parameters: M (input matrix), Out (output matrix), matrix_stridex (stride of input matrix in x direction), matrix_stridey (stride of input matrix in y direction), out_stridex (stride of output matrix in x direction), out_stridey (stride of output matrix in y direction), SIZE_M (number of rows in input matrix), and D_HEAD (number of columns in input matrix). The kernel computes the transpose of the input matrix and stores it in the output matrix.", - "description_2": "Use triton language to create a kernel for transposing a matrix. The kernel should handle input and output matrix pointers, strides, and dimensions to perform the transpose operation.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for fused attention\n@triton.jit\ndef fused_attention_kernel(\n Out, L, M, # outputs\n Q, K, V,\n sm_scale,\n seq_len,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n stride_h = BLOCK_DMODEL * seq_len\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_d[None, :]\n off_k = off_hz * stride_h + offs_n[None, :] * BLOCK_DMODEL + offs_d[:, None]\n off_v = off_hz * stride_h + offs_n[:, None] * BLOCK_DMODEL + offs_d[None, :]\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(Q.dtype.element_ty)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_ptrs += BLOCK_N * BLOCK_DMODEL\n v_ptrs += BLOCK_N * BLOCK_DMODEL\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * seq_len + offs_m\n m_ptrs = M + off_hz * seq_len + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_n[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n# Function to call the Triton kernel\ndef fused_attention(q, k, v, sm_scale, o_buf=None, l_buf=None, m_buf=None):\n BLOCK = 128 if q.dtype == torch.float16 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q) if o_buf is None else o_buf\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n shape = (q.shape[0] * q.shape[1], q.shape[2])\n L = torch.empty(shape, device=q.device, dtype=torch.float32) if l_buf is None else l_buf\n m = torch.empty(shape, device=q.device, dtype=torch.float32) if m_buf is None else m_buf\n num_warps = 4 if Lk <= 64 else 8\n\n fused_attention_kernel[grid](\n o, L, m,\n q, k, v,\n sm_scale, q.shape[2],\n # tl.constexpr\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps\n )\n\n return o\n", - "description_1": "Use triton language to implement a fused attention kernel that computes the attention output given query (Q), key (K), and value (V) matrices. The kernel takes 9 parameters: Out (output tensor), L (tensor for storing intermediate results), M (tensor for storing intermediate results), Q (query tensor), K (key tensor), V (value tensor), sm_scale (scale for softmax), seq_len (sequence length), and three block sizes (BLOCK_M, BLOCK_DMODEL, BLOCK_N) as compile-time constants. The kernel computes the attention scores using dot products and applies softmax scaling, updating the output tensor with the accumulated results. The fused_attention function calls this kernel with 7 parameters: q (query tensor), k (key tensor), v (value tensor), sm_scale (scale for softmax), and optional buffers o_buf, l_buf, m_buf for output and intermediate results.", - "description_2": "Use triton language to create a fused attention operator that efficiently computes attention scores and outputs using query, key, and value matrices with softmax scaling.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for fused attention\n@triton.jit\ndef fused_attention_kernel(\n Out, L, M, # outputs\n Q, K, V,\n sm_scale,\n seq_len,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n stride_h = BLOCK_DMODEL * seq_len\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_d[None, :]\n off_k = off_hz * stride_h + offs_n[None, :] * BLOCK_DMODEL + offs_d[:, None]\n off_v = off_hz * stride_h + offs_n[:, None] * BLOCK_DMODEL + offs_d[None, :]\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(Q.dtype.element_ty)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_ptrs += BLOCK_N * BLOCK_DMODEL\n v_ptrs += BLOCK_N * BLOCK_DMODEL\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * seq_len + offs_m\n m_ptrs = M + off_hz * seq_len + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_n[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n# Function to call the fused_attention_kernel\ndef fused_attention(q, k, v, sm_scale, o_buf=None, l_buf=None, m_buf=None):\n BLOCK = 128 if q.dtype == torch.float16 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q) if o_buf is None else o_buf\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n shape = (q.shape[0] * q.shape[1], q.shape[2])\n L = torch.empty(shape, device=q.device, dtype=torch.float32) if l_buf is None else l_buf\n m = torch.empty(shape, device=q.device, dtype=torch.float32) if m_buf is None else m_buf\n num_warps = 4 if Lk <= 64 else 8\n\n fused_attention_kernel[grid](\n o, L, m,\n q, k, v,\n sm_scale, q.shape[2],\n # tl.constexpr\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps\n )\n\n return o\n", - "description_1": "Use triton language to implement a fused attention mechanism. The kernel 'fused_attention_kernel' takes 12 parameters: Out, L, M as output buffers, Q, K, V as input matrices, sm_scale as scale for softmax, seq_len as sequence length, BLOCK_M, BLOCK_DMODEL, BLOCK_N as block sizes for matrix computation. The kernel computes scaled dot-product attention using these parameters and updates the output buffers accordingly. The function 'fused_attention' calls the kernel and takes 7 parameters: q, k, v matrices, sm_scale as scale for softmax, and optional output buffers o_buf, l_buf, m_buf. It configures execution based on input tensor dimensions and invokes the kernel with appropriate grid and block sizes.", - "description_2": "Use triton language to implement and invoke a fused attention kernel that performs scaled dot-product attention. The kernel handles input matrices Q, K, V and computes attention with respect to output buffers, using block-wise parallelization.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel to calculate 1D offset\n@triton.jit\ndef get_1d_offest(size, n_prev_chunks):\n return n_prev_chunks * size + tl.arange(0, size)\n\n# Kernel to calculate 2D offset\n@triton.jit\ndef get_2d_offset(offs_0, offs_1, stride_0, stride_1=1):\n return tl.expand_dims(offs_0, 1) * stride_0 + tl.expand_dims(offs_1, 0) * stride_1\n\n# Kernel to create a 1D mask\n@triton.jit\ndef get_1d_mask(offs, max):\n return offs < max\n\n# Kernel to create a 2D mask\n@triton.jit\ndef get_2d_mask(offs_0, offs_1, max_0, max_1):\n return (tl.expand_dims(offs_0, 1) < max_0) & (tl.expand_dims(offs_1, 0) < max_1)\n", - "description_1": "Use triton language to define four kernels: (1) get_1d_offest with 2 parameters: size (int) and n_prev_chunks (int), which calculates 1D offsets; (2) get_2d_offset with 4 parameters: offs_0 (tensor), offs_1 (tensor), stride_0 (int), and stride_1 (int, default=1), which calculates 2D offsets; (3) get_1d_mask with 2 parameters: offs (tensor) and max (int), which creates a 1D mask; (4) get_2d_mask with 4 parameters: offs_0 (tensor), offs_1 (tensor), max_0 (int), and max_1 (int), which creates a 2D mask.", - "description_2": "Use triton language to define kernels for calculating offsets and masks in 1D and 2D.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Calculate offset for block\n xoffset = tl.program_id(0) * XBLOCK\n # Compute indices for current block\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n # Create a mask for valid indices\n xmask = xindex < xnumel\n # Load input data into registers\n tmp0 = tl.load(in_ptr0 + xindex, xmask)\n tmp2 = tl.load(in_ptr1 + xindex, xmask)\n # Compute the element-wise squared sum\n tmp1 = tmp0 * tmp0\n tmp3 = tmp2 * tmp2\n tmp4 = tmp1 + tmp3\n # Store the result\n tl.store(out_ptr0 + xindex, tmp4, xmask)\n\ndef load_triton_kernel() -> None:\n # Prepare input and output tensors\n x = torch.randn(1000, device=\"cuda\")\n y = torch.randn(1000, device=\"cuda\")\n z = torch.empty_like(y)\n # Define block size\n BLOCK_SIZE = 256\n # Launch the Triton kernel\n triton_[(triton.cdiv(1000, 32),)](x, y, z, 1000, XBLOCK=BLOCK_SIZE)\n # Check correctness\n assert torch.allclose(z, x * x + y * y)\n", - "description_1": "Use triton language to define a kernel 'triton_' that computes the element-wise squared sum of two input arrays. The kernel takes 5 arguments: 'in_ptr0' and 'in_ptr1' (pointers to the input data), 'out_ptr0' (pointer to the output data), 'xnumel' (the total number of elements), and 'XBLOCK' (a compile-time constant defining the block size). The kernel is invoked by the 'load_triton_kernel' function, which prepares the input and output tensors, sets the block size, and launches the kernel.", - "description_2": "Use triton language to create a kernel that calculates the element-wise squared sum of two input arrays using GPU parallelization.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef sigmoid(inp):\n \"\"\"Applies sigmoid to the input\"\"\"\n return 1 / (1 + tl.exp(-inp))\n\n@triton.jit\ndef sigmoid_grad(inp):\n out = sigmoid(inp)\n return out * (1 - out)\n\n@triton.jit\ndef apply_act_func(inp, drop_p, seed, offset, act_func, dropout):\n if act_func != \"relu\":\n input_tensor = inp.to(tl.float32)\n if act_func == \"sigmoid\":\n output = sigmoid(input_tensor)\n if dropout:\n output = apply_dropout(input_tensor, drop_p, seed, offset)\n return output\n\n@triton.autotune(configs=element_wise_kernel_config(), key=[\"size\"])\n@triton.jit\ndef act_func_forward_kernel(\n input_pointer,\n output_pointer,\n size,\n drop_p,\n seed,\n act_func: tl.constexpr,\n dropout: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n input_tensor = tl.load(input_pointer + offset, mask=mask)\n tl.store(\n output_pointer + offset,\n apply_act_func(input_tensor, drop_p, seed, offset, act_func, dropout),\n mask=mask,\n )\n\n@triton.jit\ndef apply_act_grad_func(\n out_grad, inp, drop_p, seed, offset, act_func: tl.constexpr, dropout: tl.constexpr\n):\n if act_func != \"relu\":\n inp = inp.to(tl.float32)\n if act_func == \"sigmoid\":\n out = sigmoid_grad(inp)\n if dropout:\n out_grad = apply_dropout_grad(out_grad, drop_p, seed, offset)\n return out_grad * out\n\n@triton.autotune(configs=element_wise_kernel_config(), key=[\"size\"])\n@triton.jit\ndef act_func_backward_kernel(\n out_grad_ptr,\n inp_ptr,\n out_ptr,\n size,\n drop_p,\n seed,\n act_func: tl.constexpr,\n dropout: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n out_grad = tl.load(out_grad_ptr + offset, mask=mask)\n inp = tl.load(inp_ptr, mask=mask)\n tl.store(\n out_ptr + offset,\n apply_act_grad_func(out_grad, inp, drop_p, seed, offset, act_func, dropout),\n mask=mask,\n )\n", - "description_1": "Use triton language to define kernels for forward and backward activation functions with dropout support. The `sigmoid` kernel applies the sigmoid function to inputs, and `sigmoid_grad` calculates its gradient. The `apply_act_func` and `apply_act_grad_func` helper functions handle different activation functions and optional dropout. `act_func_forward_kernel` and `act_func_backward_kernel` are triton kernels for forward and backward passes respectively. They take input pointers, output pointers, sizes, dropout parameters, and other configurations to perform element-wise activation and gradient computation.", - "description_2": "Use triton language to create kernels for activation functions with optional dropout. Implement forward and backward operations using configurable parameters for efficient computation.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef _add(X, Y, Z, N: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n # Obtain the program ID for parallel execution\n pid = tl.program_id(0)\n # Calculate offsets for the current block\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Create a mask to handle boundary conditions\n mask = offsets < N\n # Load elements from X and Y with the mask\n X = tl.load(X + offsets, mask=mask)\n Y = tl.load(Y + offsets, mask=mask)\n # Store the result of X + Y into Z with the mask\n tl.store(Z + offsets, X + Y, mask=mask)\n\n# Function to call the Triton kernel\ndef add(X, Y):\n # Get the number of elements\n N = X.shape[0]\n # Ensure inputs are on CUDA\n assert X.is_cuda and Y.is_cuda\n # Create an output tensor\n Z = torch.empty_like(X)\n # Define block size for Triton kernel\n BLOCK_SIZE = 1024\n # Calculate grid size for kernel launch\n grid = (triton.cdiv(N, BLOCK_SIZE),)\n # Launch the Triton kernel\n _add[grid](X, Y, Z, N, BLOCK_SIZE=BLOCK_SIZE)\n return Z\n\n# Example usage of the Triton kernel\ndef main():\n # Create random input tensors on CUDA\n x = torch.randn(1000, device=\"cuda\")\n y = torch.randn(1000, device=\"cuda\")\n # Perform addition using Triton\n z = add(x, y)\n # Perform addition using PyTorch for verification\n z_torch = x + y\n # Print results\n print(f\"{x = }\")\n print(f\"{y = }\")\n print(f\"{z = }\")\n # Verify the results are close\n assert torch.allclose(z, z_torch)\n print(\"Success! Triton add works correctly\")\n\nif __name__ == \"__main__\":\n main()\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel '_add' takes 5 parameters: X (input tensor), Y (input tensor), Z (output tensor), N (number of elements, constexpr), and BLOCK_SIZE (block size, constexpr). The function 'add' calls this kernel with 2 parameters: X (input tensor) and Y (input tensor), and returns the result tensor Z.", - "description_2": "Use triton language to perform element-wise addition on two input tensors using a custom kernel, ensuring the result matches PyTorch's addition.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Heuristic function to calculate BLOCK_SIZE_SPATIAL based on batch and spatial dimensions\ndef BLOCK_SIZE_SPATIAL_heuristics(args: dict) -> int:\n BLOCK_SIZE_BATCH = triton.next_power_of_2(args[\"b_dim\"])\n BLOCK_SIZE_SPATIAL = triton.next_power_of_2(args[\"s_dim\"])\n return int(min(BLOCK_SIZE_SPATIAL, max(1, 2**14 / BLOCK_SIZE_BATCH)))\n\n# Kernel for RMS normalization\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=[\"b_dim\", \"s_dim\"],\n restore_value=[\"running_mean_pointer\", \"running_var_pointer\"]\n)\n@triton.heuristics({\n \"BLOCK_SIZE_BATCH\": lambda x: triton.next_power_of_2(x[\"b_dim\"]),\n \"BLOCK_SIZE_SPATIAL\": BLOCK_SIZE_SPATIAL_heuristics,\n})\n@triton.jit\ndef rms_norm_forward_kernel(\n input_pointer, weight_pointer, bias_pointer,\n output_pointer, b_dim, s_dim, running_mean_pointer, running_var_pointer,\n BLOCK_SIZE_BATCH: tl.constexpr,\n BLOCK_SIZE_SPATIAL: tl.constexpr\n):\n pass\n\n# Kernel for batch normalization forward pass\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=[\"b_dim\", \"s_dim\"],\n restore_value=[\"running_mean_ptr\", \"running_bias_ptr\"]\n)\n@triton.heuristics(\n values={\n \"BLOCK_SIZE_BATCH\": lambda args: triton.next_power_of_2(args[\"b_dim\"]),\n \"BLOCK_SIZE_SPATIAL\": BLOCK_SIZE_SPATIAL_heuristics\n }\n)\n@triton.jit\ndef batch_norm_forward_kernel(\n inp_ptr, weight_ptr, bias_ptr,\n mean_ptr, inv_std_ptr,\n inp_residual_ptr, pre_act_ptr, out_ptr,\n running_mean_ptr, running_var_ptr,\n b_dim, s_dim,\n inp_b_strd, inp_f_strd, inp_s_strd,\n inp_residual_b_strd, inp_residual_f_strd, inp_residual_s_strd,\n pre_act_b_strd, pre_act_f_strd, pre_act_s_strd,\n out_b_strd, out_f_strd, out_s_strd,\n momentum, eps,\n affine: tl.constexpr,\n is_train: tl.constexpr,\n save_stats: tl.constexpr,\n track_running_stats: tl.constexpr,\n add_residual: tl.constexpr,\n act_func: tl.constexpr,\n save_pre_act: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_SPATIAL: tl.constexpr\n):\n f_id = tl.program_id(axis=0)\n b_offs = tl.arange(0, BLOCK_SIZE_BATCH)\n b_mask = b_offs < b_dim\n\n m = 0\n mean = 0.0\n var = 0.0\n for s_ind in range(0, tl.cdiv(s_dim, BLOCK_SIZE_SPATIAL)):\n s_offs = s_ind * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n s_mask = s_offs < s_dim\n\n curr_inp_ptr = (inp_ptr +\n f_id * inp_f_strd +\n b_offs[:, None] * inp_b_strd + \n s_offs[None, :] * inp_s_strd)\n\n curr_inp = tl.load(curr_inp_ptr, mask=b_mask[:, None] & s_mask[None, :]).to(tl.float32)\n s_count = min(BLOCK_SIZE_SPATIAL, s_dim - s_ind * BLOCK_SIZE_SPATIAL)\n curr_m = s_count * b_dim\n m += curr_m\n prev_mean = mean\n mean += (tl.sum(curr_inp) - (prev_mean * curr_m)) / m\n deltas = tl.where(b_mask[:, None] & s_mask[None, :],\n (curr_inp * mean) - (curr_inp * prev_mean), 0.0)\n var += tl.sum(deltas)\n\n var /= m\n inv_std = 1.0 / tl.sqrt(var + eps)\n\n# Kernel for batch normalization backward pass\n@triton.autotune(\n configs=warps_kernel_configs(),\n key=[\"b_dim\", \"s_dim\"],\n)\n@triton.heuristics(\n values={\n \"BLOCK_SIZE_BATCH\": lambda args: triton.next_power_of_2(args[\"b_dim\"]),\n \"BLOCK_SIZE_SPATIAL\": BLOCK_SIZE_SPATIAL_heuristics\n },\n)\n@triton.jit\ndef batch_norm_backward_kernel(\n out_grad_ptr, inp_ptr,\n mean_ptr, inv_std_ptr,\n weight_ptr,\n inp_grad_ptr,\n weight_grad_ptr, bias_grad_ptr,\n b_dim, s_dim,\n out_grad_b_strd, out_grad_f_strd, out_grad_s_strd,\n inp_b_strd, inp_f_strd, inp_s_strd,\n inp_grad_b_strd, inp_grad_f_strd, inp_grad_s_strd,\n affine: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr,\n BLOCK_SIZE_SPATIAL: tl.constexpr\n):\n f_id = tl.program_id(axis=0)\n\n b_offs = tl.arange(0, BLOCK_SIZE_BATCH)\n b_mask = b_offs < b_dim\n\n mean = tl.load(f_id + mean_ptr)\n inv_std = tl.load(f_id + inv_std_ptr)\n\n inv_std_contrib = 0.0\n mean_contrib = 0.0\n \n for s_ind in range(0, tl.cdiv(s_dim, BLOCK_SIZE_SPATIAL)):\n s_offs = s_ind * BLOCK_SIZE_SPATIAL + tl.arange(0, BLOCK_SIZE_SPATIAL)\n s_mask = s_offs < s_dim\n\n curr_out_grad_ptr = (out_grad_ptr +\n f_id * out_grad_f_strd +\n b_offs[:, None] * out_grad_b_strd +\n s_offs[None, :] * out_grad_s_strd)\n \n\n curr_inp_ptr = (inp_ptr +\n f_id * inp_f_strd +\n b_offs[:, None] * inp_b_strd +\n s_offs[None, :] * inp_s_strd)\n \n\n curr_out_grad = tl.load(curr_out_grad_ptr, mask=b_mask[:, None] & s_mask[None, :])\n curr_inp = tl.load(curr_inp_ptr, mask=b_mask[:, None] & s_mask[None, :])\n\n curr_norm_inp = (curr_inp - mean) * inv_std\n inv_std_contrib += tl.sum(curr_out_grad * curr_norm_inp)\n mean_contrib += tl.sum(curr_out_grad)\n\n weight = tl.load(weight_ptr + f_id)\n m = s_dim * b_dim\n inv_std_contrib *= weight / m\n mean_contrib *= weight / m\n\n if affine:\n weight_grad = tl.load(weight_grad_ptr + f_id)\n bias_grad = tl.load(bias_grad_ptr + f_id)\n weight = tl.load(weight_ptr + f_id)\n else:\n weight = 1.0\n\n for s_ind in range(0, tl.cdiv(s_dim, BLOCK_SIZE_SPATIAL)):\n s_offs = s_ind * BLOCK_SIZE_SPATIAL + tl.arange(0, BLOCK_SIZE_SPATIAL)\n s_mask = s_offs < s_dim\n curr_out_grad_ptr = (out_grad_ptr +\n f_id * out_grad_f_strd +\n b_offs[:, None] * out_grad_b_strd +\n s_offs[None, :] * out_grad_s_strd)\n\n curr_inp_ptr = (inp_ptr +\n f_id * inp_f_strd +\n b_offs[:, None] * inp_b_strd +\n s_offs[None, :] * inp_s_strd)\n\n curr_inp_grad_ptr = (inp_grad_ptr +\n f_id * inp_grad_f_strd +\n b_offs[:, None] * inp_grad_b_strd +\n s_offs[None, :] * inp_grad_s_strd)\n \n curr_inp = tl.load(curr_inp_ptr, mask=b_mask[:, None] & s_mask[None, :])\n curr_norm_inp = (curr_inp - mean) * inv_std\n\n curr_out_grad = tl.load(curr_out_grad_ptr, mask=b_mask[:, None] & s_mask[None, :])\n curr_inp_grad = inv_std * (weight * curr_norm_inp - (mean_contrib - (inv_std_contrib * curr_norm_inp)))\n tl.store(curr_inp_grad_ptr, curr_inp_grad, mask=b_mask[:, None] & s_mask[None, :])\n\n if affine:\n weight_grad += tl.sum(curr_out_grad * curr_norm_inp)\n bias_grad += tl.sum(curr_out_grad)\n\n if affine:\n tl.store(weight_grad_ptr, weight_grad)\n tl.store(bias_grad_ptr, bias_grad)\n", - "description_1": "Use triton language to implement three kernels for RMS normalization and batch normalization. The rms_norm_forward_kernel takes 9 arguments for handling input/output and other parameters for RMS normalization. The batch_norm_forward_kernel takes 31 arguments to execute a forward pass of batch normalization with optional affine transformation, residual connections, activation functions, and more. The batch_norm_backward_kernel takes 21 arguments to perform the backward pass, compute gradients, and update statistics, optionally using an affine transformation.", - "description_2": "Use triton language to implement kernels for RMS and batch normalization forward and backward operations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef apply_dropout(inp, drop_p, seed, offset):\n # dropout in neural network turns out scale the rest\n # of the tensor that's not dropped out by 1 / (1 - drop_p)\n # so the total value across this network is stays the same\n random = tl.rand(seed, offset)\n return tl.where(random < drop_p, 0, inp / (1 - drop_p))\n\n@triton.jit\ndef apply_dropout_grad(out_grad, drop_p, seed, offset):\n # grad dropout is out_grad * (1 / (1 - drop_p))\n # basically the same as forward pass, but now we use\n # out_grad instead of inp\n random = tl.rand(seed, offset)\n return tl.where(random < drop_p, 0.0, out_grad * (1 / (1 - drop_p)))\n\n@triton.jit\ndef dropout_forward_kernel(\n inp_ptr, out_ptr, size, drop_p, seed, BLOCK_SIZE: tl.constexpr\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < size\n input_tensor = tl.load(inp_ptr + offset, mask=mask)\n output_tensor = apply_dropout(input_tensor, drop_p, seed, offset)\n tl.store(out_ptr + offset, output_tensor, mask=mask)\n\n@triton.jit\ndef dropout_backward_kernel(\n out_grad_ptr, \n inp_grad_ptr, \n size,\n drop_p,\n seed,\n BLOCK_SIZE: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offs < size\n\n out_grad = tl.load(out_grad_ptr + offs, mask=mask)\n inp_grad = apply_dropout_grad(out_grad, drop_p, seed, offs)\n\n tl.store(inp_grad_ptr + offs, inp_grad, mask=mask)\n", - "description_1": "Use triton language to implement dropout functionalities. The first kernel 'apply_dropout' takes four arguments: 'inp' (input tensor), 'drop_p' (dropout probability), 'seed' (random seed), and 'offset' (offset index), and returns the dropout result. The second kernel 'apply_dropout_grad' takes the same type of arguments but uses 'out_grad' for backward pass gradient calculation. The 'dropout_forward_kernel' wraps the forward pass of dropout with six arguments: pointers to input and output, 'size', 'drop_p', 'seed', and 'BLOCK_SIZE', utilizing 'apply_dropout' for its operation. The 'dropout_backward_kernel' encapsulates the backward pass logic, taking six similar arguments as the forward kernel but operates on gradient pointers, utilizing 'apply_dropout_grad'.", - "description_2": "Use triton language to implement both forward and backward dropout kernels with triton.jit decorator, involving the application of dropout during the forward pass and gradient adjustment during the backward pass.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, Lse, TMP, softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = (\n Q\n + off_b * stride_qb\n + off_h * stride_qh\n + (offs_m[:, None] * stride_qm + offs_d[None, :])\n )\n k_ptrs = (\n K\n + off_b * stride_kb\n + off_h * stride_kh\n + (offs_m[:, None] * stride_kn + offs_d[None, :])\n )\n v_ptrs = (\n V\n + off_b * stride_vb\n + off_h * stride_vh\n + (offs_m[:, None] * stride_vn + offs_d[None, :])\n )\n if BIAS_TYPE == \"vector\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == \"matrix\":\n b_ptrs = (\n Bias\n + off_b * stride_bb\n + off_h * stride_bh\n + (offs_m[:, None] * stride_bm + offs_n[None, :])\n )\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(\n q_ptrs,\n mask=offs_d[None, :] < headdim,\n other=0.0,\n )\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=(offs_m[None, :] < seqlen_q), other=0.0)\n else:\n q = tl.load(\n q_ptrs,\n mask=(offs_m[None, :] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0,\n )\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_M)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=(offs_d[None, :] < headdim),\n other=0.0,\n )\n else:\n if EVEN_HEADDIM:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=(offs_n[None, :] < seqlen_k),\n other=0.0,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=(offs_n[None, :] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0,\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n if not EVEN_N:\n qk += tl.where((start_n + offs_n[None, :]) < seqlen_k, 0.0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(\n (offs_m[:, None] >= (start_n + offs_n)[None, :]), 0.0, float(\"-inf\")\n )\n if BIAS_TYPE != \"none\":\n if BIAS_TYPE == \"vector\":\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(\n b_ptrs + start_n, mask=(offs_n < seqlen_k), other=0.0\n ).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == \"matrix\":\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(\n b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((offs_n + start_n)[None, :] < seqlen_k),\n other=0.0,\n ).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=(offs_d[None, :] < headdim),\n other=0.0,\n )\n else:\n if EVEN_HEADDIM:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=(start_n + offs_n)[:, None] < seqlen_k,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k)\n & (offs_d[None, :] < headdim),\n )\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n", - "description_1": "Use triton language to implement a forward kernel for a transformer model. The kernel takes 36 parameters: Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N. The kernel performs matrix multiplications and applies softmax with optional bias and causal masking.", - "description_2": "Use triton language to create a transformer forward kernel with 36 parameters for matrix operations and softmax.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .act_kernels import apply_act_func\nfrom .utils import allow_tf32, get_n_stages\n\ndef linear_forward_config(\n BLOCK_SIZE_BATCH: int,\n BLOCK_SIZE_IN_FEAT: int,\n BLOCK_SIZE_OUT_FEAT: int,\n GROUP_SIZE_BATCH: int = 8,\n n_warps: int = 4,\n n_stages: int = 2,\n) -> triton.Config:\n return triton.Config(\n {\n \"BLOCK_SIZE_BATCH\": BLOCK_SIZE_BATCH,\n \"BLOCK_SIZE_IN_FEAT\": BLOCK_SIZE_IN_FEAT,\n \"BLOCK_SIZE_OUT_FEAT\": BLOCK_SIZE_OUT_FEAT,\n \"GROUP_SIZE_BATCH\": GROUP_SIZE_BATCH,\n },\n num_warps=n_warps,\n num_stages=get_n_stages(n_stages),\n )\n\n\n@triton.autotune(\n configs=[\n linear_forward_config(32, 32, 32, n_warps=2, n_stages=2),\n linear_forward_config(64, 32, 32, n_warps=2, n_stages=5),\n linear_forward_config(64, 32, 128, n_warps=4, n_stages=4),\n linear_forward_config(64, 32, 256, n_warps=4, n_stages=4),\n linear_forward_config(128, 32, 32, n_warps=4, n_stages=4),\n linear_forward_config(128, 32, 64, n_warps=4, n_stages=4),\n linear_forward_config(128, 32, 128, n_warps=4, n_stages=4),\n linear_forward_config(128, 64, 256, n_warps=8, n_stages=3),\n ],\n key=[\"batch_dim\", \"in_feat_dim\", \"out_feat_dim\", \"fp16\"],\n)\n@triton.heuristics({\"tf32\": lambda _: allow_tf32()})\n@triton.jit\ndef linear_forward_kernel(\n input_pointer,\n weight_pointer,\n bias_pointer,\n pre_act_pointer,\n output_pointer,\n batch_dim,\n in_feat_dim,\n out_feat_dim,\n input_batch_stride,\n input_in_feat_stride,\n weight_in_feat_stride,\n weight_out_feat_stride,\n pre_act_batch_stride,\n pre_act_out_feat_stride,\n output_batch_stride,\n output_out_feat_stride,\n add_bias: tl.constexpr,\n act_func: tl.constexpr,\n save_pre_act: tl.constexpr,\n fp16: tl.constexpr,\n tf32: tl.constexpr,\n BLOCK_SIZE_BATCH: tl.constexpr,\n BLOCK_SIZE_IN_FEAT: tl.constexpr,\n BLOCK_SIZE_OUT_FEAT: tl.constexpr,\n GROUP_SIZE_BATCH: tl.constexpr,\n):\n # Programs are blocked together, GROUP_SIZE_BATCH at at time\n # to alleviate L2 Miss rates\n pid = tl.program_id(axis=0)\n\n n_batch_pids = tl.cdiv(batch_dim, BLOCK_SIZE_BATCH)\n n_out_feat_pids = tl.cdiv(out_feat_dim, BLOCK_SIZE_OUT_FEAT)\n\n # now create grouping\n pids_per_group = GROUP_SIZE_BATCH * n_out_feat_pids\n group_id = pid // pids_per_group\n first_batch_pids = group_id * GROUP_SIZE_BATCH\n GROUP_SIZE_BATCH = min(GROUP_SIZE_BATCH, n_out_feat_pids - first_batch_pids)\n batch_pid = first_batch_pids + (pid % GROUP_SIZE_BATCH)\n out_feat_pid = (pid % pids_per_group) // GROUP_SIZE_BATCH\n\n batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)\n out_feat_offset = out_feat_pid * BLOCK_SIZE_OUT_FEAT + tl.arange(\n 0, BLOCK_SIZE_OUT_FEAT\n )\n\n batch_mask = batch_offset < batch_dim\n out_feat_mask = out_feat_offset < out_feat_dim\n\n input_pointer += input_batch_stride * batch_offset[:, None]\n weight_pointer += weight_out_feat_stride * out_feat_offset[:, None]\n\n accum = tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE_OUT_FEAT), dtype=tl.float32)\n\n for block_ind in range(0, tl.cdiv(in_feat_dim, BLOCK_SIZE_IN_FEAT)):\n in_feat_offset = block_ind * BLOCK_SIZE_IN_FEAT + tl.arange(\n 0, BLOCK_SIZE_IN_FEAT\n )\n in_feat_mask = in_feat_offset < in_feat_dim\n\n curr_input_pointer = (\n input_pointer + input_in_feat_stride * in_feat_offset[:, None]\n )\n curr_weight_pointer = (\n weight_pointer + weight_in_feat_stride * in_feat_offset[:, None]\n )\n\n input_block = tl.load(\n curr_input_pointer, mask=batch_mask[:, None] & in_feat_mask[None, :]\n )\n weight_block = tl.load(\n curr_weight_pointer, mask=out_feat_mask[None, :] & in_feat_mask[None, :]\n )\n\n if fp16:\n input_block = input_block.to(tl.float16)\n weight_block = weight_block.to(tl.float16)\n\n accum += tl.dot(input_block, weight_block, allow_tf32=tf32)\n\n if add_bias:\n bias = tl.load(bias_pointer + out_feat_offset, mask=out_feat_mask)\n\n if fp16:\n bias = bias.to(tl.float16)\n\n accum += bias[None, :]\n\n if act_func is not None:\n if save_pre_act:\n pre_act_pointer += (\n pre_act_batch_stride * batch_offset[:, None]\n + pre_act_out_feat_stride * out_feat_offset[None, :]\n )\n tl.store(\n pre_act_pointer,\n accum,\n mask=batch_mask[:, None] & out_feat_mask[None, :],\n )\n\n accum = apply_act_func(accum, None, None, None, act_func, False)\n\n output_pointer += (\n output_batch_stride * batch_offset[:, None]\n + output_out_feat_stride * out_feat_offset[None, :]\n )\n tl.store(output_pointer, accum, mask=batch_mask[:, None] & out_feat_mask[None, :])\n", - "description_1": "Use triton language to implement a linear forward kernel function with 27 parameters that transforms input data using weights, optionally adds biases, and can apply an activation function. The function uses parameters like input and weight pointers, dimension sizes, strides, and compile-time constants to efficiently compute the result.", - "description_2": "Use triton language to create a linear transformation operator with optional bias addition and activation application for matrix operations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _mm_naive(\n A, B, C, stride_AX, stride_AY, stride_BX, stride_BY, stride_CX, stride_CY, N\n):\n row, col = tl.program_id(0), tl.program_id(1)\n\n sum_ = 0.0\n for k in range(N):\n a = tl.load(A + row * stride_AX + k)\n b = tl.load(B + k * stride_BX + col)\n sum_ += a * b\n c = tl.load(C + row * stride_CX + col)\n c += sum_\n tl.store(C + row * stride_CX + col, c)\n\n\ndef mm_naive_triton(A: torch.FloatTensor, B: torch.FloatTensor):\n assert (\n A.shape[0] == A.shape[1] == B.shape[0] == B.shape[1]\n ), \"Shape must be the same for all matrix\"\n assert A.is_cuda and B.is_cuda\n N = A.shape[1]\n C = torch.zeros_like(A)\n _mm_naive[(N, N)](A, B, C, *A.stride(), *B.stride(), *C.stride(), A.shape[0])\n return C\n", - "description_1": "Use triton language to implement a naive matrix multiplication kernel. The kernel '_mm_naive' takes 10 parameters: A, B, C (pointers to matrices), stride_AX, stride_AY, stride_BX, stride_BY, stride_CX, stride_CY (stride information for matrices), and N (the size of the matrices). It computes the matrix product of A and B and stores the result in C. The function 'mm_naive_triton' is a wrapper that prepares the input matrices and calls the kernel with appropriate grid size.", - "description_2": "Use triton language to perform matrix multiplication on CUDA tensors by implementing a kernel that computes the product of two square matrices and stores the result in a third matrix.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _reduce_sum_naive(A, B, stride_AX, N: tl.constexpr):\n row = tl.program_id(0)\n\n sum_ = 0.0\n for k in range(N):\n offsets = row * stride_AX + k\n mask = offsets < N\n a = tl.load(A + offsets, mask=mask)\n sum_ += a\n tl.store(B, sum_)\n\ndef reduce_sum_naive(A: torch.FloatTensor):\n assert A.is_cuda\n N = A.shape[0]\n B = torch.zeros(1, device=\"cuda\")\n _reduce_sum_naive[(N,)](A, B, *A.stride(), N)\n return B\n\n@triton.jit\ndef _reduce_sum(A, B, stride_AX, N: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n thread_idx = block_start + tl.arange(0, BLOCK_SIZE)\n\n i = thread_idx * 2\n\n stride = 1\n while stride < BLOCK_SIZE:\n thread_offset = i + stride\n a = tl.load(A + i, mask=(thread_offset < N) and (thread_idx % stride == 0))\n b = tl.load(\n A + thread_offset, mask=(thread_offset < N) and (thread_idx % stride == 0)\n )\n c = a + b\n tl.store(A + i, c, mask=thread_offset < N)\n\n tl.debug_barrier()\n\n stride *= 2\n\n tl.store(B + thread_idx, tl.load(A + thread_idx))\n\ndef reduce_sum(A: torch.FloatTensor):\n assert A.is_cuda\n N = A.shape[0]\n B = torch.zeros(1, device=\"cuda\")\n _reduce_sum[(1,)](A, B, *A.stride(), N, BLOCK_SIZE=32)\n return B\n\nprint(reduce_sum_naive(torch.arange(10, device=\"cuda\")))\nprint(reduce_sum(torch.arange(11, device=\"cuda\")))\n", - "description_1": "Use triton language to implement two kernels for reducing a tensor to a sum. The first kernel, _reduce_sum_naive, takes 4 parameters: A (input tensor), B (output tensor), stride_AX (stride of A), and N (number of elements). It computes the sum of elements in A and stores the result in B. The second kernel, _reduce_sum, takes 5 parameters: A (input tensor), B (output tensor), stride_AX (stride of A), N (number of elements), and BLOCK_SIZE (size of the block). It performs a parallel reduction on A and stores the result in B. Both kernels are called by their respective wrapper functions, reduce_sum_naive and reduce_sum, which prepare the input tensors and launch the kernels.", - "description_2": "Use triton language to implement two reduction kernels: one for naive summation and another for parallel reduction, each with their respective wrapper functions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _vector_addition(A, B, C, N: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n # Triton kernel for vector addition\n row = tl.program_id(axis=0)\n block_start = row * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < N\n a = tl.load(A + offsets, mask=mask)\n b = tl.load(B + offsets, mask=mask)\n c = a + b\n tl.store(C + offsets, c, mask=mask)\n\ndef vector_addition(A: torch.FloatTensor, B: torch.FloatTensor) -> torch.FloatTensor:\n # Function to call the Triton kernel for vector addition\n assert A.is_cuda and B.is_cuda\n N = A.shape[0]\n assert N == B.shape[0]\n C = torch.zeros_like(A)\n\n block_size = 128\n grid_size = triton.cdiv(N, block_size)\n grid = (grid_size,)\n\n _vector_addition[grid](\n A,\n B,\n C,\n N,\n block_size,\n )\n return C\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel '_vector_addition' takes 5 parameters: A (input tensor), B (input tensor), C (output tensor), N (size of the vectors, a constant expression), and BLOCK_SIZE (size of each block, a constant expression). The function 'vector_addition' calls this kernel with 2 parameters: A (input tensor) and B (input tensor), and returns the result of the addition in a new tensor C.", - "description_2": "Use triton language to perform element-wise addition of two vectors on the GPU using a custom kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef mean_kernel_batch_major(\n input_ptr, output_ptr, batch_size, spatial_size, BLOCK_SIZE: tl.constexpr\n):\n pid = tl.program_id(0)\n\n # Initialize accumulator\n acc = 0.0\n count = 0\n\n # Iterate over the batch dimension\n for i in range(0, batch_size, BLOCK_SIZE):\n batch_offset = i + tl.arange(0, BLOCK_SIZE)\n batch_mask = batch_offset < batch_size\n\n # Load and accumulate\n x = tl.load(\n input_ptr + pid * spatial_size * batch_size + batch_offset * spatial_size,\n mask=batch_mask,\n )\n acc += tl.sum(x * batch_mask, axis=0)\n count += tl.sum(batch_mask, axis=0)\n\n # Compute and store mean\n mean = acc / count\n tl.store(output_ptr + pid, mean)\n\ndef mean_triton(x, layout=\"batch_major\"):\n output = torch.empty(\n x.shape[1] if layout == \"batch_major\" else x.shape[0],\n device=x.device,\n dtype=x.dtype,\n )\n\n if layout == \"batch_major\":\n mean_kernel_batch_major[(x.shape[1],)](\n x, output, x.shape[0], x.shape[1], BLOCK_SIZE=32\n )\n else:\n mean_kernel_spatial_major[(x.shape[0],)](\n x, output, x.shape[0], x.shape[1], BLOCK_SIZE=32\n )\n\n return output\n\n@triton.jit\ndef mean_kernel_batch_major_2(\n inp_ptr,\n out_ptr,\n inp_b_strd,\n inp_s_strd,\n out_b_strd,\n s_dim,\n BLOCK_SIZE: tl.constexpr,\n):\n b_pid = tl.program_id(0)\n\n count = 0\n mean = 0.0\n\n for block_ind in range(0, tl.cdiv(s_dim, BLOCK_SIZE)):\n s_offs = block_ind * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n s_mask = s_offs < s_dim\n\n curr_inp_ptr = inp_ptr + b_pid * inp_b_strd + s_offs * inp_s_strd\n curr_inp = tl.load(curr_inp_ptr, mask=s_mask, other=0.0)\n\n s_count = min(BLOCK_SIZE, s_dim - BLOCK_SIZE * block_ind)\n count += s_count\n\n prev_mean = mean\n\n mean += (tl.sum(curr_inp) - (s_count * prev_mean)) / count\n\n tl.store(\n out_ptr + b_pid * out_b_strd,\n mean,\n )\n\ndef mean_kernel(x):\n out = torch.empty(x.shape[0], device=x.device, dtype=x.dtype)\n mean_kernel_batch_major_2[(x.shape[0],)](\n x, out, *x.stride(), *out.stride(), x.shape[1], BLOCK_SIZE=32\n )\n return out\n\n# Test the kernels\nbatch_size, spatial_size = 1024, 256\nx_batch_major = torch.randn(batch_size, spatial_size, device=\"cuda\")\nx_spatial_major = x_batch_major.t().contiguous()\n\ntriton_mean = mean_kernel(x_spatial_major)\nprint(f\"{triton_mean = }\")\n", - "description_1": "Use triton language to implement a kernel function 'mean_kernel_batch_major' that computes the mean of a batch-major input tensor. The function takes 5 parameters: input_ptr (pointer to input tensor), output_ptr (pointer to output tensor), batch_size (size of the batch dimension), spatial_size (size of the spatial dimension), and BLOCK_SIZE (block size for parallel processing). Another kernel function 'mean_kernel_batch_major_2' computes the mean for a batch-major input tensor with 7 parameters: inp_ptr (pointer to input tensor), out_ptr (pointer to output tensor), inp_b_strd (input batch stride), inp_s_strd (input spatial stride), out_b_strd (output batch stride), s_dim (spatial dimension size), and BLOCK_SIZE (block size for parallel processing). The 'mean_triton' function calls these kernels based on the layout of the input tensor.", - "description_2": "Use triton language to implement kernel functions for computing the mean of batch-major input tensors with parameters for input/output pointers, strides, dimensions, and block size.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport triton_util as tu\n\n# Triton kernel to load a 2D block of data, perform an operation, and store the result\n@triton.jit\ndef load_2d_kernel(\n x_ptr, y_ptr, x_size_0, x_size_1, x_stride_0, x_stride_1, BLOCK_SIZE: tl.constexpr\n):\n # Get the program IDs for the 2D grid\n pid_0 = tl.program_id(0)\n pid_1 = tl.program_id(1)\n\n # Load a 2D block of data from x_ptr\n x = tu.load_2d(\n x_ptr, BLOCK_SIZE, BLOCK_SIZE, pid_0, pid_1, x_size_0, x_size_1, x_stride_0\n )\n\n # Perform an operation on the loaded data\n x += pid_0 * pid_1\n\n # Calculate the offsets for storing the result\n y_offsets = tu.get_2d_offset(\n tu.get_1d_offset(BLOCK_SIZE, pid_0),\n tu.get_1d_offset(BLOCK_SIZE, pid_1),\n x_stride_0,\n x_stride_1,\n )\n\n # Store the result in y_ptr\n tl.store(y_ptr + y_offsets, x)\n\n# Function to initialize data and launch the Triton kernel\ndef load_2d():\n # Create a 16x16 tensor on the CUDA device\n a = torch.zeros(16, 16, device=\"cuda\", dtype=torch.float16)\n x_size_0, x_size_1 = a.size()\n x_stride_0, x_stride_1 = a.stride()\n b = torch.empty_like(a)\n\n # Launch the Triton kernel with a grid size determined by the input tensor dimensions\n load_2d_kernel[(tu.cdiv(x_size_0, 4), tu.cdiv(x_size_1, 4))](\n a,\n b,\n x_size_0,\n x_size_1,\n x_stride_0,\n x_stride_1,\n 4, # type: ignore\n )\n\n # Print the result\n print(b)\n", - "description_1": "Use triton language to define a kernel 'load_2d_kernel' that loads a 2D block of data from a pointer 'x_ptr', performs an operation by adding the product of program IDs, and stores the result in 'y_ptr'. The kernel takes 7 parameters: two pointers (x_ptr, y_ptr), two sizes (x_size_0, x_size_1), two strides (x_stride_0, x_stride_1), and a block size (BLOCK_SIZE). The function 'load_2d' initializes a 16x16 tensor, calculates its size and stride, and launches the kernel with a grid size based on the tensor dimensions.", - "description_2": "Use triton language to create a kernel that processes a 2D block of data by loading, modifying, and storing it, and a function to initialize data and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef triton_cross_scan(\n x, # (B, C, H, W)\n y, # (B, 4, C, H, W)\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n _tmp0 = i_c * BC * DH * DW\n _tmp1 = DC * DH * DW\n _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]\n p_x = x + i_b * _tmp1 + _tmp2\n p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same\n p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans\n p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip\n p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip\n\n for idxc in range(_for_C):\n _idx = idxc * DH * DW\n _x = tl.load(p_x + _idx, mask=_mask_hw)\n tl.store(p_y1 + _idx, _x, mask=_mask_hw)\n tl.store(p_y2 + _idx, _x, mask=_mask_hw)\n tl.store(p_y3 + _idx, _x, mask=_mask_hw)\n tl.store(p_y4 + _idx, _x, mask=_mask_hw)\n\n\n@triton.jit\ndef triton_cross_merge(\n x, # (B, C, H, W)\n y, # (B, 4, C, H, W)\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n _tmp0 = i_c * BC * DH * DW\n _tmp1 = DC * DH * DW\n _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]\n p_x = x + i_b * _tmp1 + _tmp2\n p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same\n p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans\n p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip\n p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip\n\n for idxc in range(_for_C):\n _idx = idxc * DH * DW\n _y1 = tl.load(p_y1 + _idx, mask=_mask_hw)\n _y2 = tl.load(p_y2 + _idx, mask=_mask_hw)\n _y3 = tl.load(p_y3 + _idx, mask=_mask_hw)\n _y4 = tl.load(p_y4 + _idx, mask=_mask_hw)\n tl.store(p_x + _idx, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)\n\n\n@triton.jit\ndef triton_cross_scan_1b1(\n x, # (B, C, H, W)\n y, # (B, 4, C, H, W)\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n _tmp0 = i_c * BC * DH * DW\n _tmp1 = DC * DH * DW\n _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]\n p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same\n p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans\n p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip\n p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip\n \n p_x1 = x + i_b * 4 * _tmp1 + _tmp2\n p_x2 = p_x1 + _tmp1\n p_x3 = p_x2 + _tmp1\n p_x4 = p_x3 + _tmp1\n for idxc in range(_for_C):\n _idx = idxc * DH * DW\n tl.store(p_y1 + _idx, tl.load(p_x1 + _idx), mask=_mask_hw)\n tl.store(p_y2 + _idx, tl.load(p_x2 + _idx), mask=_mask_hw)\n tl.store(p_y3 + _idx, tl.load(p_x3 + _idx), mask=_mask_hw)\n tl.store(p_y4 + _idx, tl.load(p_x4 + _idx), mask=_mask_hw)\n\n\n@triton.jit\ndef triton_cross_merge_1b1(\n x, # (B, C, H, W)\n y, # (B, 4, C, H, W)\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n _tmp0 = i_c * BC * DH * DW\n _tmp1 = DC * DH * DW\n _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]\n p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same\n p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans\n p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip\n p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip\n\n p_x1 = x + i_b * 4 * _tmp1 + _tmp2\n p_x2 = p_x1 + _tmp1\n p_x3 = p_x2 + _tmp1\n p_x4 = p_x3 + _tmp1\n for idxc in range(_for_C):\n _idx = idxc * DH * DW\n tl.store(p_x1 + _idx, tl.load(p_y1 + _idx), mask=_mask_hw)\n tl.store(p_x2 + _idx, tl.load(p_y2 + _idx), mask=_mask_hw)\n tl.store(p_x3 + _idx, tl.load(p_y3 + _idx), mask=_mask_hw)\n tl.store(p_x4 + _idx, tl.load(p_y4 + _idx), mask=_mask_hw)\n\n\nclass CrossScanTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x: torch.Tensor):\n B, C, H, W = x.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = min(triton.next_power_of_2(C), 2), min(triton.next_power_of_2(H), 32), min(triton.next_power_of_2(W), 32)\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n x = x.contiguous()\n y = x.new_empty((B, 4, C, H, W))\n triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return y.view(B, 4, C, -1)\n \n @staticmethod\n def backward(ctx, y: torch.Tensor):\n # out: (b, k, d, l)\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n y = y.contiguous().view(B, 4, C, H, W)\n x = y.new_empty((B, C, H, W))\n triton_cross_merge[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return x\n\n\nclass CrossMergeTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, y: torch.Tensor):\n B, K, C, H, W = y.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = min(triton.next_power_of_2(C), 2), min(triton.next_power_of_2(H), 32), min(triton.next_power_of_2(W), 32)\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n y = y.contiguous().view(B, 4, C, H, W)\n x = y.new_empty((B, C, H, W))\n triton_cross_merge[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return x.view(B, C, -1)\n \n @staticmethod\n def backward(ctx, x: torch.Tensor):\n # out: (b, d, l)\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n x = x.contiguous()\n y = x.new_empty((B, 4, C, H, W))\n triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return y\n\n\nclass CrossScanTriton1b1(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x: torch.Tensor):\n B, K, C, H, W = x.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = min(triton.next_power_of_2(C), 2), min(triton.next_power_of_2(H), 32), min(triton.next_power_of_2(W), 32)\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n x = x.contiguous()\n y = x.new_empty((B, 4, C, H, W))\n triton_cross_scan_1b1[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return y.view(B, 4, C, -1)\n \n @staticmethod\n def backward(ctx, y: torch.Tensor):\n # out: (b, k, d, l)\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n y = y.contiguous().view(B, 4, C, H, W)\n x = y.new_empty((B, 4, C, H, W))\n triton_cross_merge_1b1[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return x\n", - "description_1": "Use triton language to define kernels for cross-scan and cross-merge operations. The kernels operate on 4D tensors (B, C, H, W) and a tensor y with shape (B, 4, C, H, W). The function parameters include the number of blocks for each dimension (BC, BH, BW), and the dimensions of the data (DC, DH, DW). Constants NH and NW are the number of horizontal and vertical blocks, respectively.", - "description_2": "Use triton language to efficiently perform cross-scan and cross-merge on 4D tensors using block-level parallelism.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n@triton.jit\ndef _rescale_kernel(\n peer_m,\n m,\n peer_l,\n l,\n peer_o,\n o,\n L,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n LAST_STEP: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n o_offset = off_hz * stride_oh\n peer_o_block_ptr = tl.make_block_ptr(\n base=peer_o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n o_block_ptr = tl.make_block_ptr(\n base=o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m\n m_ptrs = m + off_hz * N_CTX + offs_m\n peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m\n l_ptrs = l + off_hz * N_CTX + offs_m\n \n peer_m_i = tl.load(peer_m_ptrs) \n peer_m_i = peer_m_i.to(tl.float32)\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n peer_l_i = tl.load(peer_l_ptrs) \n peer_l_i = peer_l_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n\n peer_acc = tl.load(peer_o_block_ptr)\n peer_acc = peer_acc.to(tl.float32)\n acc = tl.load(o_block_ptr) \n acc = acc.to(tl.float32)\n lo = 0\n hi = N_CTX\n m_i_sync = tl.maximum(m_i, peer_m_i)\n alpha = tl.math.exp2(m_i - m_i_sync)\n peer_alpha = tl.math.exp2(peer_m_i - m_i_sync)\n # -- scale and update acc --\n acc_scale = l_i * 0 + alpha # workaround some compiler bug\n peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug\n \n acc *= acc_scale[:, None]\n peer_acc *= peer_acc_scale[:, None]\n acc += peer_acc\n l_i = l_i * acc_scale + peer_l_i * peer_acc_scale\n # write back O, l, m\n tl.store(m_ptrs, m_i_sync)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i))\n tl.store(o_block_ptr, acc.to(tl.bfloat16))\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n m,\n l,\n O,\n L,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n LAST_STEP: tl.constexpr\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n O_block_ptr = tl.make_block_ptr(\n base=O + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l -> load from provided pointer\n m_ptrs = m + off_hz * N_CTX + offs_m\n l_ptrs = l + off_hz * N_CTX + offs_m\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n acc = tl.load(O_block_ptr) \n acc = acc.to(tl.float32)\n # scale sm_scale by log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.bfloat16)\n # loop over k, v and update accumulator\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc_scale = l_i * 0 + alpha # workaround some compiler bug\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.bfloat16), v)\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n # write back original l and m\n tl.store(m_ptrs, m_i)\n tl.store(l_ptrs, l_i)\n # write back O, L\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i))\n tl.store(O_block_ptr, acc.to(tl.bfloat16))\n\ndef _lightseq_forward(q, k, v, causal, sm_scale, comm_mode):\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n # Why do I have to change it from 128 64 to 32 32?\n BLOCK_M = 32\n BLOCK_N = 32\n \n bsz, nh, seq_len, hdim = q.shape\n\n m = torch.full((bsz * nh, seq_len), fill_value=-float(\"inf\"), device=q.device, dtype=torch.float32)\n l = torch.zeros_like(m)\n L = torch.zeros_like(m)\n o = torch.zeros_like(q)\n \n grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1)\n num_warps = 4 if Lk <= 64 else 8\n \n seq_rank = get_sequence_parallel_rank()\n seq_world_size = get_sequence_parallel_size()\n\n # Initialize all buffers\n peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o)\n \n fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid](\n q, k, v, sm_scale,\n m,\n l,\n o,\n L,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n IS_CAUSAL=IS_CAUSAL,\n LAST_STEP=LAST_STEP,\n num_warps=num_warps,\n num_stages=4)\n \n for time_step in range(seq_world_size // 2 + 1):\n # This is important for cuda scheduler to execute nccl calls first.\n torch.cuda.synchronize()\n # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step.\n buffer_idx_1 = time_step % 2\n buffer_idx_2 = (time_step - 1) % 2\n\n reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], \n [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode)\n if comm_mode == \"sync\":\n # if seq_rank == 0:\n # print(\"Immediate wait for abalation\")\n wait_async_handles(reqs)\n if is_compute_for_local_query(time_step):\n # print(f\"t={time_step}: (Comp) R={seq_rank} local compute\")\n if time_step == 0:\n fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step))\n else:\n # if needs to sync from others, do not normalize here\n fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step))\n elif is_idle(time_step):\n # print(f\"t={time_step}: (Comp) R={seq_rank} idle\")\n pass\n else:\n # print(f\"t={time_step}: (Comp) R={seq_rank} helps other\")\n peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float(\"inf\"))\n peer_l[buffer_idx_2] = torch.zeros_like(l)\n peer_o[buffer_idx_2] = torch.zeros_like(o)\n\n #print(f\"rank 3 q is: {peer_q[buffer_idx_2]}\")\n fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False)\n\n if comm_mode == \"lightseq\":\n # Make sure tensors for next steps are ready\n wait_async_handles(reqs)\n # sync between statistics get from other ranks and the local ones\n if is_sync_from_remote(time_step):\n _rescale_kernel[grid](\n peer_m[buffer_idx_1],\n m,\n peer_l[buffer_idx_1],\n l,\n peer_o[buffer_idx_1],\n o,\n L,\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n o.shape[0], o.shape[1], o.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n LAST_STEP=is_last_time(time_step),\n num_warps=num_warps,\n num_stages=4)\n return q, k, v, o, L\n", - "description_1": "Use triton language to create three kernels: 'max_fn', '_rescale_kernel', and '_fwd_kernel'. 'max_fn' computes the maximum of two inputs. '_rescale_kernel' rescales input tensors based on peer values and updates accumulation, with parameters for peer tensors, strides, grid dimensions, and constants for block dimensions. '_fwd_kernel' computes the forward pass for attention, taking tensors Q, K, V, scaling factors, masks, and dimensions, looping over key-value pairs to update accumulators.", - "description_2": "Use triton language to define kernels for computing maximum values, rescaling tensor blocks, and performing attention operations in neural networks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\nfrom einops import rearrange\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n@triton.jit\ndef _rescale_kernel(\n peer_m,\n m,\n peer_l,\n l,\n peer_o,\n o,\n L,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n seqlen_q_rounded, seqlen_peer_q_rounded,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n LAST_STEP: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n o_offset = off_hz * stride_oh\n peer_o_block_ptr = tl.make_block_ptr(\n base=peer_o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n o_block_ptr = tl.make_block_ptr(\n base=o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m\n m_ptrs = m + off_hz * seqlen_q_rounded + offs_m\n peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m\n l_ptrs = l + off_hz * seqlen_q_rounded + offs_m\n \n peer_m_i = tl.load(peer_m_ptrs) \n peer_m_i = peer_m_i.to(tl.float32)\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n peer_l_i = tl.load(peer_l_ptrs) \n peer_l_i = peer_l_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n\n peer_acc = tl.load(peer_o_block_ptr)\n peer_acc = peer_acc.to(tl.float32)\n acc = tl.load(o_block_ptr) \n acc = acc.to(tl.float32)\n lo = 0\n hi = N_CTX\n m_i_sync = tl.maximum(m_i, peer_m_i)\n alpha = tl.math.exp2(m_i - m_i_sync)\n peer_alpha = tl.math.exp2(peer_m_i - m_i_sync)\n \n acc_scale = l_i * 0 + alpha \n peer_acc_scale = peer_l_i * 0 + peer_alpha \n \n acc *= acc_scale[:, None]\n peer_acc *= peer_acc_scale[:, None]\n acc += peer_acc\n l_i = l_i * acc_scale + peer_l_i * peer_acc_scale\n \n tl.store(m_ptrs, m_i_sync)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i))\n tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1))\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n m,\n l,\n O,\n L,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n seqlen_q_rounded,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n LAST_STEP: tl.constexpr\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n O_block_ptr = tl.make_block_ptr(\n base=O + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n m_ptrs = m + off_hz * seqlen_q_rounded + offs_m\n l_ptrs = l + off_hz * seqlen_q_rounded + offs_m\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n acc = tl.load(O_block_ptr) \n acc = acc.to(tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option='zero')\n q = (q * qk_scale).to(tl.bfloat16)\n\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr, boundary_check=(1,), padding_option='zero')\n v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha \n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.bfloat16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n \n tl.store(m_ptrs, m_i)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * seqlen_q_rounded + offs_m\n tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i))\n tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1))\n\ndef _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n BLOCK_M = 128\n BLOCK_N = 64\n\n bsz, nh, unpadded_seq_len, hdim = q.shape\n cu_seq_lens = torch.arange(0, (bsz+1) * unpadded_seq_len, unpadded_seq_len, dtype=torch.int32, device=q.device)\n max_seqlen = unpadded_seq_len\n seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M\n\n m = torch.full((bsz * nh, seqlen_q_rounded), fill_value=-float(\"inf\"), device=q.device, dtype=torch.float32)\n l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.zeros_like(q)\n \n grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1)\n num_warps = 4 if Lk <= 64 else 8\n \n seq_rank = get_sequence_parallel_rank()\n seq_world_size = get_sequence_parallel_size()\n\n peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o)\n \n fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid](\n q, k, v, sm_scale,\n m,\n l,\n o,\n L,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n seqlen_q_rounded,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n IS_CAUSAL=IS_CAUSAL,\n LAST_STEP=LAST_STEP,\n num_warps=num_warps,\n num_stages=4)\n \n for time_step in range(seq_world_size // 2 + 1):\n torch.cuda.synchronize()\n buffer_idx_1 = time_step % 2\n buffer_idx_2 = (time_step - 1) % 2\n\n reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], \n [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode)\n if comm_mode == \"sync\":\n wait_async_handles(reqs)\n if is_compute_for_local_query(time_step):\n if time_step == 0:\n fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step))\n else:\n fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step))\n elif is_idle(time_step):\n pass\n else:\n peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float(\"inf\"))\n peer_l[buffer_idx_2] = torch.zeros_like(l)\n peer_o[buffer_idx_2] = torch.zeros_like(o)\n\n fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False)\n\n if comm_mode == \"lightseq\":\n wait_async_handles(reqs)\n if is_sync_from_remote(time_step):\n seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1]\n _rescale_kernel[grid](\n peer_m[buffer_idx_1],\n m,\n peer_l[buffer_idx_1],\n l,\n peer_o[buffer_idx_1],\n o,\n L,\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n o.shape[0], o.shape[1], o.shape[2],\n seqlen_q_rounded, seqlen_peer_q_rounded,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n LAST_STEP=is_last_time(time_step),\n num_warps=num_warps,\n num_stages=4)\n return q, k, v, o, L, cu_seq_lens, max_seqlen\n\nclass _attention_varlen(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n try:\n global args\n comm_mode = args.comm_mode\n backward_engine = args.backward_engine\n except:\n comm_mode = 'lightseq'\n backward_engine = 'flash'\n \n q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode)\n\n ctx.save_for_backward(q, k, v, o, L, cu_seq_lens)\n ctx.max_seqlen = max_seqlen\n ctx.sm_scale = sm_scale\n ctx.comm_mode = comm_mode\n ctx.backward_engine = backward_engine\n return o\n\ndist_attn_varlen = _attention_varlen.apply\n", - "description_1": "Use triton language to implement kernels for scaling and computing matrix operations, especially for tasks such as multi-head attention. It involves defining kernels for maximum computation, rescaling, and forward pass with specific attention to data types, pointer arithmetic, and maintaining numerical stability during these operations.", - "description_2": "Use triton language to create kernels for maximum value computation and rescaling in matrix operations, focusing on multi-head attention tasks, utilizing pointer arithmetic for efficiency.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef create_flashinfer_kv_indices_triton(\n req_to_token_ptr, # [max_batch, max_context_len]\n req_pool_indices_ptr,\n page_kernel_lens_ptr,\n kv_indptr,\n kv_start_idx,\n kv_indices_ptr,\n max_context_len: tl.constexpr,\n):\n BLOCK_SIZE: tl.constexpr = 512\n pid = tl.program_id(axis=0)\n req_pool_index = tl.load(req_pool_indices_ptr + pid)\n kv_indices_offset = tl.load(kv_indptr + pid)\n\n kv_start = 0\n kv_end = 0\n if kv_start_idx:\n kv_start = tl.load(kv_start_idx + pid).to(tl.int32)\n kv_end = kv_start\n kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)\n\n req_to_token_ptr += req_pool_index * max_context_len\n kv_indices_ptr += kv_indices_offset\n\n ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)\n st_offset = tl.arange(0, BLOCK_SIZE)\n num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)\n for _ in range(num_loop):\n mask = ld_offset < kv_end\n data = tl.load(req_to_token_ptr + ld_offset, mask=mask)\n tl.store(kv_indices_ptr + st_offset, data, mask=mask)\n ld_offset += BLOCK_SIZE\n st_offset += BLOCK_SIZE\n\n\nclass FlashinferUpdater:\n def __init__(\n self,\n forward_mode,\n model_runner,\n req_pool_indices,\n seq_lens,\n prefix_lens,\n decode_wrapper=None,\n use_ragged=False,\n ):\n self.forward_mode = forward_mode\n self.model_runner = model_runner\n self.req_pool_indices = req_pool_indices\n self.seq_lens = seq_lens\n self.prefix_lens = prefix_lens\n self.use_ragged = use_ragged\n\n self.num_qo_heads = (\n model_runner.model_config.num_attention_heads // model_runner.tp_size\n )\n self.num_kv_heads = model_runner.model_config.get_num_kv_heads(\n model_runner.tp_size\n )\n self.head_dim = model_runner.model_config.head_dim\n self.batch_size = len(req_pool_indices)\n\n self.decode_wrapper = (\n decode_wrapper or self.model_runner.attn_backend.decode_wrapper\n )\n self.prefill_wrapper_ragged = (\n self.model_runner.attn_backend.prefill_wrapper_ragged\n )\n self.prefill_wrapper_paged = (\n self.model_runner.attn_backend.prefill_wrapper_paged\n )\n\n self.kv_last_page_len = torch.ones(\n (self.batch_size,), dtype=torch.int32, device=\"cuda\"\n )\n\n def _init_indices_no_sliding_window(self):\n if self.use_ragged:\n paged_kernel_lens = self.prefix_lens\n else:\n paged_kernel_lens = self.seq_lens\n\n self.kv_indptr = torch.zeros(\n (self.batch_size + 1,), dtype=torch.int32, device=\"cuda\"\n )\n self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)\n self.kv_indices = torch.empty(\n self.kv_indptr[-1], dtype=torch.int32, device=\"cuda\"\n )\n\n create_flashinfer_kv_indices_triton[(self.batch_size,)](\n self.model_runner.req_to_token_pool.req_to_token,\n self.req_pool_indices,\n paged_kernel_lens,\n self.kv_indptr,\n None,\n self.kv_indices,\n self.model_runner.req_to_token_pool.req_to_token.size(1),\n )\n\n def _init_indices_sliding_window(self, wrapper_id):\n if wrapper_id == 0:\n # window attention use paged only\n if self.forward_mode.is_decode():\n paged_kernel_lens = torch.minimum(\n self.seq_lens,\n torch.tensor(self.model_runner.sliding_window_size + 1),\n )\n else:\n paged_kernel_lens = torch.minimum(\n self.seq_lens,\n torch.tensor(self.model_runner.sliding_window_size)\n + self.seq_lens\n - self.prefix_lens,\n )\n else:\n # full attention\n paged_kernel_lens = self.seq_lens\n\n kv_start_idx = self.seq_lens - paged_kernel_lens\n self.kv_indptr = torch.zeros(\n (self.batch_size + 1,), dtype=torch.int32, device=\"cuda\"\n )\n self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)\n self.kv_indices = torch.empty(\n self.kv_indptr[-1], dtype=torch.int32, device=\"cuda\"\n )\n create_flashinfer_kv_indices_triton[(self.batch_size,)](\n self.model_runner.req_to_token_pool.req_to_token,\n self.req_pool_indices,\n paged_kernel_lens,\n self.kv_indptr,\n kv_start_idx,\n self.kv_indices,\n self.model_runner.req_to_token_pool.req_to_token.size(1),\n )\n\n def update_indices_no_sliding_window(self):\n self._init_indices_no_sliding_window()\n\n def update_indices_sliding_window(self):\n assert self.use_ragged is False\n\n for wrapper_id in range(2):\n self._init_indices_sliding_window(wrapper_id)\n\n\ndef update_flashinfer_indices(\n forward_mode,\n model_runner,\n req_pool_indices,\n seq_lens,\n prefix_lens,\n decode_wrapper=None,\n use_ragged=False,\n):\n updater = FlashinferUpdater(\n forward_mode,\n model_runner,\n req_pool_indices,\n seq_lens,\n prefix_lens,\n decode_wrapper,\n use_ragged,\n )\n\n if model_runner.sliding_window_size is None:\n updater.update_indices_no_sliding_window()\n else:\n updater.update_indices_sliding_window()\n", - "description_1": "Use triton language to implement a kernel function 'create_flashinfer_kv_indices_triton' that processes token indices for a batch of requests. The kernel takes 7 parameters: req_to_token_ptr (pointer to token data), req_pool_indices_ptr (pointer to request pool indices), page_kernel_lens_ptr (pointer to page kernel lengths), kv_indptr (pointer to key-value index pointers), kv_start_idx (pointer to start indices for key-value), kv_indices_ptr (pointer to key-value indices), and max_context_len (maximum context length as a constant). The kernel calculates offsets and loops over blocks to load and store data with masking. The 'FlashinferUpdater' class initializes and updates these indices based on the mode (sliding window or not) and calls the kernel with appropriate parameters.", - "description_2": "Use triton language to create a kernel for processing token indices with parameters for token pointers, request indices, page lengths, and context length, and implement a class to manage and update these indices based on different modes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Any, Dict, Optional, Tuple\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr,\n sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk,\n stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr, compute_type: tl.constexpr, use_fp8: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak\n )\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = (\n b_ptr\n + off_experts * stride_be\n + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n )\n\n if use_fp8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(\n a_ptrs,\n mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0,\n )\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n if use_fp8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n if use_fp8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(\n A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool,\n top_k: int, config: Dict[str, Any], compute_type: tl.dtype,\n use_fp8: bool,\n) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if not use_fp8:\n assert A_scale is None\n assert B_scale is None\n else:\n A, A_scale = ops.scaled_fp8_quant(A, A_scale)\n assert B_scale is not None\n\n grid = lambda META: (\n triton.cdiv(sorted_token_ids.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(B.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n\n fused_moe_kernel[grid](\n A, B, C, A_scale, B_scale, topk_weights, sorted_token_ids,\n expert_ids, num_tokens_post_padded, B.shape[1], B.shape[2],\n sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0),\n A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1),\n C.stride(2), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k,\n compute_type=compute_type, use_fp8=use_fp8, **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel, 'fused_moe_kernel', takes 28 parameters including pointers to input matrices, matrix dimensions, stride variables, and meta-parameters. It performs block matrix multiplication using token and expert matrices, with optional scaling and routing weights. The function 'invoke_fused_moe_kernel' calls this kernel with 16 parameters, setting up the grid and handling optional scaling.", - "description_2": "Use triton language to create a kernel for block matrix multiplication in a Mixture of Experts model, and a function to invoke this kernel with appropriate parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef _fwd_kernel_stage1(\n Q,\n K_Buffer,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n Att_Out,\n stride_req_to_tokens_b,\n stride_qbs,\n stride_qh,\n stride_buf_kbs,\n stride_buf_kh,\n att_stride_h,\n kv_group_num: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n logit_cap: tl.constexpr,\n Lk: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n reduce_dtype = Att_Out.dtype.element_ty\n cur_kv_head = cur_head // kv_group_num\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n cur_batch_start_index = 0\n cur_batch_end_index = cur_batch_seq_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark).to(reduce_dtype)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n offs_buf_k = (\n k_loc[:, None] * stride_buf_kbs\n + cur_kv_head * stride_buf_kh\n + offs_d[None, :]\n )\n k = tl.load(\n K_Buffer + offs_buf_k,\n mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),\n other=0.0,\n ).to(reduce_dtype)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n\n if logit_cap > 0:\n att_value = logit_cap * tanh(att_value / logit_cap)\n\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n\n@triton.jit\ndef _fwd_kernel_stage2(\n logits,\n V_Buffer,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n stride_logic_h,\n stride_buf_vbs,\n stride_buf_vh,\n stride_obs,\n stride_oh,\n stride_req_to_token_b,\n kv_group_num: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n Lv: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]\n v_ptrs = V_Buffer + offs_buf_v\n\n e_max = float(\"-inf\")\n e_sum = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n v_index = tl.load(\n Req_to_tokens\n + cur_batch_req_idx * stride_req_to_token_b\n + (start_n + offs_n),\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0,\n )\n\n qk = tl.load(\n logits\n + cur_head * stride_logic_h\n + (cur_batch_start_loc + start_n + offs_n),\n mask=start_n + offs_n < cur_batch_seq_len,\n other=float(\"-inf\"),\n )\n\n n_e_max = tl.maximum(tl.max(qk, 0), e_max)\n old_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max)\n e_sum = e_sum * old_scale + tl.sum(p, 0)\n v = tl.load(\n v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)\n )\n acc = acc * old_scale + tl.sum(p[:, None] * v, 0)\n e_max = n_e_max\n\n acc = acc / e_sum\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=(offs_d < Lv))\n\ndef _decode_att_m_fwd(\n q,\n k_buffer,\n att_out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n max_len_in_batch,\n sm_scale,\n logit_cap,\n):\n BLOCK = 32\n Lk = k_buffer.shape[-1]\n\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))\n kv_group_num = q.shape[1] // k_buffer.shape[1]\n\n if kv_group_num == 1:\n num_warps = 4\n else:\n num_warps = 2\n\n BLOCK_DMODEL = triton.next_power_of_2(Lk)\n\n _fwd_kernel_stage1[grid](\n q,\n k_buffer,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n att_out,\n Req_to_tokens.stride(0),\n q.stride(0),\n q.stride(1),\n k_buffer.stride(0),\n k_buffer.stride(1),\n att_out.stride(0),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK,\n logit_cap=logit_cap,\n num_warps=num_warps,\n num_stages=1,\n Lk=Lk,\n )\n\ndef _decode_softmax_reducev_fwd(\n logits,\n v_buffer,\n o,\n req_to_tokens,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n):\n BLOCK = 64\n batch, head = b_seq_len.shape[0], logits.shape[0]\n grid = (batch, head, 1)\n kv_group_num = logits.shape[0] // v_buffer.shape[1]\n\n num_warps = 1\n\n Lv = v_buffer.shape[-1]\n BLOCK_DMODEL = triton.next_power_of_2(Lv)\n\n _fwd_kernel_stage2[grid](\n logits,\n v_buffer,\n o,\n req_to_tokens,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n logits.stride(0),\n v_buffer.stride(0),\n v_buffer.stride(1),\n o.stride(0),\n o.stride(1),\n req_to_tokens.stride(0),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=3,\n Lv=Lv,\n )\n\ndef decode_attention_fwd(\n q,\n k_buffer,\n v_buffer,\n o,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n attn_logits,\n max_len_in_batch,\n sm_scale,\n logit_cap=0.0,\n):\n kv_group_num = q.shape[1] // v_buffer.shape[1]\n\n if kv_group_num == 1:\n # MHA\n _decode_att_m_fwd(\n q,\n k_buffer,\n attn_logits,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n max_len_in_batch,\n sm_scale,\n logit_cap,\n )\n _decode_softmax_reducev_fwd(\n attn_logits,\n v_buffer,\n o,\n req_to_token,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n )\n", - "description_1": "Use triton language to implement a memory-efficient attention mechanism for decoding. The implementation includes two main stages: the first stage computes the attention logits using the query and key buffers, and the second stage applies softmax and reduces the values using the logits and value buffer. The kernels are parameterized by constants such as block sizes and strides, and they handle different configurations like multi-head attention (MHA) and grouped query attention (GQA).", - "description_2": "Use triton language to implement a two-stage attention mechanism for decoding, with kernels for computing attention logits and applying softmax and reduction.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef _fwd_kernel(\n Q_Extend,\n K_Extend,\n V_Extend,\n O_Extend,\n K_Buffer,\n V_Buffer,\n Req_to_tokens,\n B_req_idx,\n B_Seq_Len,\n B_Start_Loc_Extend,\n B_Seq_Len_Extend,\n sm_scale,\n kv_group_num,\n stride_qbs,\n stride_qh,\n stride_kbs,\n stride_kh,\n stride_vbs,\n stride_vh,\n stride_obs,\n stride_oh,\n stride_buf_kbs,\n stride_buf_kh,\n stride_buf_vbs,\n stride_buf_vh,\n stride_req_to_tokens_b,\n logit_cap: tl.constexpr,\n Lq: tl.constexpr,\n Lv: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DPE: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_seq = tl.program_id(0)\n cur_head = tl.program_id(1)\n cur_block_m = tl.program_id(2)\n cur_kv_head = cur_head // kv_group_num\n\n cur_seq_len = tl.load(B_Seq_Len + cur_seq)\n cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)\n cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend\n\n cur_seq_prefix_start_in_loc = 0\n cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)\n cur_batch_req_idx = tl.load(B_req_idx + cur_seq)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_dv = tl.arange(0, BLOCK_DV)\n offs_m = tl.arange(0, BLOCK_M)\n mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend\n\n mask_d = offs_d < Lq\n mask_dv = offs_dv < Lv\n\n offs_q = (\n (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :]\n )\n q = tl.load(\n Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0\n )\n\n if BLOCK_DPE > 0:\n offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\n offs_qpe = (\n (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n * stride_qbs\n + cur_head * stride_qh\n + offs_dpe[None, :]\n )\n qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)\n\n # stage 1: compute scores with prefix\n offs_n = tl.arange(0, BLOCK_N)\n\n acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)\n deno = tl.zeros([BLOCK_M], dtype=tl.float32)\n e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n\n for start_n in range(0, cur_seq_len_prefix, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n mask_n = (start_n + offs_n) < cur_seq_len_prefix\n offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (\n cur_seq_prefix_start_in_loc + start_n + offs_n\n )\n offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)\n\n # load k in transposed way\n offs_buf_k = (\n offs_kv_loc[None, :] * stride_buf_kbs\n + cur_kv_head * stride_buf_kh\n + offs_d[:, None]\n )\n k = tl.load(\n K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0\n )\n\n qk = tl.dot(q.to(k.dtype), k)\n if BLOCK_DPE > 0:\n offs_kpe = (\n offs_kv_loc[None, :] * stride_buf_kbs\n + cur_kv_head * stride_buf_kh\n + offs_dpe[:, None]\n )\n kpe = tl.load(\n K_Buffer + offs_kpe,\n mask=mask_n[None, :],\n other=0.0,\n )\n qk += tl.dot(qpe.to(kpe.dtype), kpe)\n qk *= sm_scale\n\n if logit_cap > 0:\n qk = logit_cap * tanh(qk / logit_cap)\n\n qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float(\"-inf\"))\n\n n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n re_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max[:, None])\n deno = deno * re_scale + tl.sum(p, 1)\n\n offs_buf_v = (\n offs_kv_loc[:, None] * stride_buf_vbs\n + cur_kv_head * stride_buf_vh\n + offs_dv[None, :]\n )\n v = tl.load(\n V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0\n )\n p = p.to(v.dtype)\n acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n e_max = n_e_max\n\n # stage 2: compute the triangle part\n\n cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)\n for start_n in range(0, cur_block_m_end, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n mask_n = (start_n + offs_n) < cur_block_m_end\n\n # load k in transposed way\n offs_k = (\n (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs\n + cur_kv_head * stride_kh\n + offs_d[:, None]\n )\n k = tl.load(\n K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0\n )\n\n qk = tl.dot(q, k, out_dtype=tl.float32)\n if BLOCK_DPE > 0:\n offs_kpe = (\n (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])\n * stride_kbs\n + cur_kv_head * stride_kh\n + offs_dpe[:, None]\n )\n kpe = tl.load(\n K_Extend + offs_kpe,\n mask=mask_n[None, :],\n other=0.0,\n )\n qk += tl.dot(qpe, kpe)\n\n qk *= sm_scale\n\n if logit_cap > 0:\n qk = logit_cap * tanh(qk / logit_cap)\n\n mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (\n start_n + offs_n[None, :]\n )\n mask_causual &= mask_m[:, None] & mask_n[None, :]\n qk = tl.where(mask_causual, qk, float(\"-inf\"))\n\n n_e_max = tl.maximum(tl.max(qk, 1), e_max)\n re_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max[:, None])\n deno = deno * re_scale + tl.sum(p, 1)\n\n offs_v = (\n (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs\n + cur_kv_head * stride_vh\n + offs_dv[None, :]\n )\n v = tl.load(\n V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0\n )\n p = p.to(v.dtype)\n acc = acc * re_scale[:, None] + tl.dot(p, v)\n\n e_max = n_e_max\n\n offs_o = (\n (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])\n * stride_obs\n + cur_head * stride_oh\n + offs_dv[None, :]\n )\n tl.store(\n O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]\n )\n\ndef extend_attention_fwd(\n q_extend,\n k_extend,\n v_extend,\n o_extend,\n k_buffer,\n v_buffer,\n req_to_tokens,\n b_req_idx,\n b_seq_len,\n b_seq_len_extend,\n b_start_loc_extend,\n max_len_extend,\n sm_scale=None,\n logit_cap=0.0,\n):\n \"\"\"\n q_extend, k_extend, v_extend, o_extend: contiguous tensors\n\n k_buffer, v_buffer: (prefix + extend) tensors in mem_manager\n \"\"\"\n Lq, Lk, Lv = (\n q_extend.shape[-1],\n k_extend.shape[-1],\n v_extend.shape[-1],\n )\n\n if Lq == 576:\n BLOCK_DMODEL = 512\n BLOCK_DPE = 64\n elif Lq == 288:\n BLOCK_DMODEL = 256\n BLOCK_DPE = 32\n else:\n BLOCK_DMODEL = triton.next_power_of_2(Lq)\n BLOCK_DPE = 0\n BLOCK_DV = triton.next_power_of_2(Lv)\n\n if CUDA_CAPABILITY[0] >= 9:\n if Lq <= 256:\n BLOCK_M, BLOCK_N = (128, 64)\n else:\n BLOCK_M, BLOCK_N = (32, 64)\n elif CUDA_CAPABILITY[0] >= 8:\n if Lq <= 128:\n BLOCK_M, BLOCK_N = (128, 128)\n elif Lq <= 256:\n BLOCK_M, BLOCK_N = (64, 64)\n else:\n BLOCK_M, BLOCK_N = (32, 64)\n else:\n BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)\n\n sm_scale = sm_scale or 1.0 / (Lq**0.5)\n batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]\n kv_group_num = q_extend.shape[1] // k_extend.shape[1]\n\n grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))\n num_warps = 4 if Lk <= 64 else 8\n num_stages = 1\n\n _fwd_kernel[grid](\n q_extend,\n k_extend,\n v_extend,\n o_extend,\n k_buffer,\n v_buffer,\n req_to_tokens,\n b_req_idx,\n b_seq_len,\n b_start_loc_extend,\n b_seq_len_extend,\n sm_scale,\n kv_group_num,\n q_extend.stride(0),\n q_extend.stride(1),\n k_extend.stride(0),\n k_extend.stride(1),\n v_extend.stride(0),\n v_extend.stride(1),\n o_extend.stride(0),\n o_extend.stride(1),\n k_buffer.stride(0),\n k_buffer.stride(1),\n v_buffer.stride(0),\n v_buffer.stride(1),\n req_to_tokens.stride(0),\n logit_cap=logit_cap,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_DPE=BLOCK_DPE,\n BLOCK_DV=BLOCK_DV,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n Lq=Lq,\n Lv=Lv,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n", - "description_1": "Use triton language to implement a forward kernel for attention mechanism. The kernel (_fwd_kernel) has 37 parameters: Q_Extend, K_Extend, V_Extend, O_Extend, K_Buffer, V_Buffer (tensors for query, key, value and output, both extended and buffered), Req_to_tokens, B_req_idx, B_Seq_Len, B_Start_Loc_Extend, B_Seq_Len_Extend (tensors for sequence processing), sm_scale, kv_group_num (scaling and group information), stride_qbs, stride_qh, stride_kbs, stride_kh, stride_vbs, stride_vh, stride_obs, stride_oh, stride_buf_kbs, stride_buf_kh, stride_buf_vbs, stride_buf_vh, stride_req_to_tokens_b (stride lengths for different dimensions), and constants: logit_cap, Lq, Lv, BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N for tensor configurations and execution parameters. The function implements the scaled dot-product attention with pre-fetched key and value buffers and writes the output in a blocked fashion for efficiency.", - "description_2": "Use triton language to compute an attention mechanism on input tensors with variable length support and specified kernel launch configuration.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n Out,\n stride_qbs,\n stride_qh,\n stride_kbs,\n stride_kh,\n stride_vbs,\n stride_vh,\n stride_obs,\n stride_oh,\n kv_group_num: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n Lk: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :]\n )\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]\n\n mask_d = offs_d < Lk\n\n q = tl.load(\n Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0\n )\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(\n k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),\n other=0.0,\n )\n # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(\n v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),\n other=0.0,\n )\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :]\n )\n out_ptrs = Out + off_o\n tl.store(\n out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])\n )\n\n\ndef context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n if CUDA_CAPABILITY[0] >= 8:\n BLOCK = 128\n else:\n BLOCK = 64\n\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n q.stride(0),\n q.stride(1),\n k.stride(0),\n k.stride(1),\n v.stride(0),\n v.stride(1),\n o.stride(0),\n o.stride(1),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=triton.next_power_of_2(Lk),\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n Lk=Lk,\n )\n", - "description_1": "Use triton language to implement a forward kernel for memory-efficient attention. The kernel takes 15 parameters: Q, K, V (query, key, value tensors), sm_scale (scale for softmax), B_Start_Loc, B_Seqlen (batch start location and sequence length), Out (output tensor), stride_qbs, stride_qh, stride_kbs, stride_kh, stride_vbs, stride_vh, stride_obs, stride_oh (stride values for accessing tensor elements), kv_group_num (number of key-value groups), BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes for matrix operations), and Lk (length of key). The kernel computes the attention scores and updates the output tensor. The context_attention_fwd function sets up the grid and block sizes and calls the kernel with appropriate parameters.", - "description_2": "Use triton language to create a kernel for efficient attention computation, handling query, key, and value tensors with specific block sizes and strides, and a function to configure and launch this kernel.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Kernel function for matrix multiplication\n@triton.jit\ndef matmul_kernel(A, B, C, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n # Triton kernel code for matrix multiplication\n pass\n\n# Function to call the kernel\ndef call_matmul_kernel(A, B, C, M, N, K):\n # Call the Triton kernel\n matmul_kernel[(M, N)](A, B, C, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=32)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with parameters A, B, C (matrices), M, N, K (dimensions), and BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K (block sizes). The kernel performs matrix multiplication and is called with specific block sizes.", - "description_2": "Use triton language to implement and call a matrix multiplication kernel with specified block sizes and dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport math\n\n@triton.jit\ndef rotate_half_kernel(\n qk_seq_ptr,\n position_ids_ptr,\n qk_seq_stride,\n position_ids_batch_stride,\n seq_len,\n HEAD_DIM: tl.constexpr,\n BLOCK_HEIGHT: tl.constexpr,\n BLOCK_WIDTH: tl.constexpr,\n INV_BASE: tl.constexpr\n):\n # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.\n # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.\n\n HALF_HEAD: tl.constexpr = HEAD_DIM // 2\n STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH\n\n batch_seq = tl.program_id(axis=0)\n row_blk_x_col_blk = tl.program_id(axis=1)\n\n row_blk = row_blk_x_col_blk // STEPS_PER_ROW\n row = row_blk * BLOCK_HEIGHT\n if BLOCK_WIDTH < HALF_HEAD:\n col_blk = row_blk_x_col_blk % STEPS_PER_ROW\n col = col_blk * BLOCK_WIDTH\n else:\n col: tl.constexpr = 0\n\n # A block will never cross a sequence boundary, which simplifies things a lot.\n batch = batch_seq // seq_len\n seq = batch_seq % seq_len\n position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)\n # As sometimes happens, just calculating this on the fly is faster than loading it from memory.\n # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.\n freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id\n cos = tl.cos(freq).to(tl.float32)\n sin = tl.sin(freq).to(tl.float32)\n\n col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)\n embed_offsets = (row * HEAD_DIM + col) + col_offsets\n x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets\n\n for k in range(0, BLOCK_HEIGHT):\n x = tl.load(x_ptrs).to(tl.float32)\n y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)\n out_x = x * cos - y * sin\n tl.store(x_ptrs, out_x)\n out_y = x * sin + y * cos\n tl.store(x_ptrs + HALF_HEAD, out_y)\n x_ptrs += HEAD_DIM\n\n\ndef triton_rotate_half_(qk, position_ids, config=None):\n with torch.cuda.device(qk.device):\n batch_size, seq_len, qandk, num_heads, head_dim = qk.shape\n\n # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.\n config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1}\n config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads)\n\n assert qk.stride(3) == head_dim\n assert qk.stride(4) == 1\n assert position_ids.shape == (batch_size, seq_len)\n assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension'\n assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config[\"BLOCK_HEIGHT\"]}'\n assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config[\"BLOCK_WIDTH\"]}'\n\n qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)\n grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH']))\n\n # Must be the same as the theta of the frequencies used to train the model.\n BASE = 10000.0\n\n rotate_half_kernel[grid](\n qk_by_seq,\n position_ids,\n qk_by_seq.stride(0),\n position_ids.stride(0),\n seq_len,\n HEAD_DIM=head_dim,\n BLOCK_HEIGHT=config['BLOCK_HEIGHT'],\n BLOCK_WIDTH=config['BLOCK_WIDTH'],\n INV_BASE=-2.0 * math.log(BASE) / head_dim,\n num_warps=config['num_warps']\n )\n", - "description_1": "Use triton language to implement a kernel function 'rotate_half_kernel' that performs in-place rotation of half of the head dimensions of a query-key sequence tensor. The kernel takes 9 parameters: qk_seq_ptr (pointer to the query-key sequence), position_ids_ptr (pointer to position ids), qk_seq_stride (stride of the query-key sequence), position_ids_batch_stride (stride of position ids), seq_len (sequence length), HEAD_DIM (head dimension), BLOCK_HEIGHT (block height), BLOCK_WIDTH (block width), and INV_BASE (inverse base for frequency calculation). The kernel computes cosine and sine of frequencies and applies them to rotate the input tensor. The function 'triton_rotate_half_' is a wrapper that configures and launches the kernel with appropriate grid and block sizes.", - "description_2": "Use triton language to implement a kernel that rotates half of the head dimensions of a tensor in-place, and a wrapper function to configure and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,\n stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = (zeros1 + 1)\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = (zeros2 + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\nclass QuantLlamaMLP(nn.Module):\n def __init__(self, gate_proj, down_proj, up_proj):\n super().__init__()\n self.register_buffer('gate_proj_qweight', gate_proj.qweight)\n self.register_buffer('gate_proj_scales', gate_proj.scales)\n self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)\n self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)\n self.register_buffer('up_proj_qweight', up_proj.qweight)\n self.register_buffer('up_proj_scales', up_proj.scales)\n self.register_buffer('up_proj_qzeros', up_proj.qzeros)\n self.register_buffer('up_proj_g_idx', up_proj.g_idx)\n\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.down_proj = down_proj\n\n def forward(self, x):\n return self.down_proj(self.triton_llama_mlp(x))\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size, )\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,\n self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),\n self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a fused matrix multiplication kernel that computes C = silu(A * B1) * (A * B2) with inputs A, B1, B2, scales, and zeros, and outputs C. The kernel takes 28 parameters: pointers to input and output matrices, dimensions M, N, K, bit width, max quantization value, strides for accessing elements, and block sizes for tiling.", - "description_2": "Use triton language to implement a SiLU activation function as a separate kernel, which takes a single parameter: the input tensor x.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32 \n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1) & maxq\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,\n stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32 \n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1) & maxq\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )\n matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )\n transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))\n return output\n", - "description_1": "Use triton language to create a matrix multiplication kernel `matmul_248_kernel` that computes C = A x B for matrices A and B with specified shapes. The kernel reads input data through pointers and uses loop unrolling and bit manipulations for efficiency. There is a supporting transpose kernel `transpose_matmul_248_kernel` for transposed matrix multiplication scenarios. Both kernels handle quantization and scaling factors during operations. The functions `matmul248` and `transpose_matmul248` manage grid configuration and call the respective kernels.", - "description_2": "Use triton language to implement matrix multiplication with quantization and scaling adjustments. Include a transpose variant for versatility.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rms_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\nclass TritonLlamaRMSNorm(nn.Module):\n def __init__(self, weight, eps=1e-6):\n super().__init__()\n self.weight = weight\n self.variance_epsilon = eps\n\n def forward(self, x):\n with torch.cuda.device(x.device):\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n rms_norm_fwd_fused[(M,)](x_arg, y, self.weight, \n x_arg.stride(0), N, self.variance_epsilon,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n return y\n", - "description_1": "Use triton language to implement RMS normalization as a kernel with parameters: input pointer X, output pointer Y, weights pointer W, stride for row traversal, number of columns N, epsilon for numerical stability, and BLOCK_SIZE for loading data. A forward function in TritonLlamaRMSNorm uses this kernel with input x and applies layer normalization.", - "description_2": "Use triton language to create a kernel and a forward pass for RMS normalization with necessary input, output, weights, and parameters for execution in a parallelized manner.", - "difficulty": 2 - }, - { - "code": "import triton\nimport math\n\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n\ndef autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):\n def decorator(fn):\n return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)\n return decorator\n\n@autotune(configs=[\n triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),\n triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),\n], key=['x_size'])\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n", - "description_1": "Use triton language to define a kernel function with two parameters: x_ptr (pointer to data) and x_size (size of data). The kernel uses a meta-parameter BLOCK_SIZE. The kernel is autotuned with two configurations, each specifying a different BLOCK_SIZE and number of warps.", - "description_2": "Use triton language to define and autotune a kernel with parameters for data pointer and size, using meta-parameters for block size.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,\n stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n a_mask = (offs_am[:, None] < M)\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = (zeros1 + 1)\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = (zeros2 + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.)\n b1 = tl.load(b1_ptrs)\n b2 = tl.load(b2_ptrs)\n\n b1 = (b1 >> shifter[:, None]) & maxq\n b1 = (b1 - zeros1) * scales1\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\nclass QuantLlamaMLP(nn.Module):\n def __init__(self, gate_proj, down_proj, up_proj):\n super().__init__()\n self.register_buffer('gate_proj_qweight', gate_proj.qweight)\n self.register_buffer('gate_proj_scales', gate_proj.scales)\n self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)\n self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)\n self.register_buffer('up_proj_qweight', up_proj.qweight)\n self.register_buffer('up_proj_scales', up_proj.scales)\n self.register_buffer('up_proj_qzeros', up_proj.qzeros)\n self.register_buffer('up_proj_g_idx', up_proj.g_idx)\n\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.down_proj = down_proj\n\n def forward(self, x):\n return self.down_proj(self.triton_llama_mlp(x))\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size, )\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device='cuda', dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,\n self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),\n self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a kernel that performs a fused matrix multiplication operation followed by a SiLU activation and element-wise multiplication with another matrix multiplication result. The kernel, `fusedmatmul_248_kernel`, takes 28 parameters: pointers to input tensors (a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr), dimensions (M, N, K), quantization parameters (bits, maxq), stride information (stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros), and compile-time constants (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M). Additionally, a utility function `silu` is used for the SiLU activation function. The `QuantLlamaMLP` class calls this kernel in its `triton_llama_mlp` method, which reshapes input tensors and prepares them for Triton kernel execution.", - "description_2": "Use triton language to create a matrix multiplication kernel with quantization support and fused SiLU activation for deep learning models, suitable for integration into a PyTorch module.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for matrix multiplication C = A x B with packed weights\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,\n NO_GROUP: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32 \n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n if NO_GROUP:\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n \n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n if not NO_GROUP: \n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n# Triton kernel for transposed matrix multiplication C = A x B\n@triton.jit\ndef transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,\n stride_zeros, NO_GROUP: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32 \n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n if NO_GROUP:\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n \n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n if not NO_GROUP:\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n# Call the first Triton kernel\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq, no_group):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)\n grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )\n matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), no_group)\n return output\n\n# Call the second Triton kernel\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq, no_group):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device='cuda', dtype=torch.float16)\n grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )\n transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), no_group)\n return output\n", - "description_1": "Use triton language to implement matrix multiplication kernels optimized for 2/4/8-bit quantized weights. The kernels compute A x B where A is in float16, B is int32 packed and quantized. The first kernel handles the forward pass and the second one transposed multiplication for gradients, managing scaling and zero-points for quantization.", - "description_2": "Use triton language to implement matrix multiplication of float16 and packed quantized int32 matrices, with forward and gradient operations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel for forward sequential scan\n@triton.jit\ndef fwd_sequential_scan(\n v, f1, hidden, B, L, C, BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n if offset_b >= B:\n return\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M\n h1 = tl.zeros([BLOCK_M,], dtype=tl.float32)\n for _ in range(L):\n x0 = tl.load(v + ptr).to(tl.float32)\n decay1 = tl.load(f1 + ptr).to(tl.float32)\n h1 = (h1 - x0) * decay1 + x0\n tl.store(hidden + ptr, h1.to(hidden.dtype.element_ty))\n ptr += C\n\n# Triton kernel for backward sequential scan\n@triton.jit\ndef bwd_sequential_scan(\n grad_output, v, f, h, B, L, C, BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n if offset_b >= B:\n return\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L - 1) * C + offset_n * BLOCK_M\n grad_h = tl.zeros([BLOCK_M,], dtype=tl.float32)\n for time_step in range(L - 1, -1, -1):\n grad = tl.load(grad_output + ptr).to(tl.float32)\n grad_h += grad\n decay = tl.load(f + ptr).to(tl.float32)\n input = tl.load(v + ptr).to(tl.float32)\n grad_v = (1 - decay) * grad_h\n tl.store(v + ptr, grad_v.to(v.dtype.element_ty))\n hidden_state = tl.load(h + ptr - C, mask=ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32)\n grad_f = grad_h * (hidden_state - input)\n tl.store(f + ptr, grad_f.to(f.dtype.element_ty))\n grad_h *= decay\n ptr -= C\n\n# Triton kernel for fused forward sequential scan\n@triton.jit\ndef fwd_sequential_scan_fused(\n v, f1, hidden, B, L, C, BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n if offset_b >= B:\n return\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M\n h1 = tl.zeros([BLOCK_M,], dtype=tl.float32)\n for _ in range(L):\n x0 = tl.load(v + ptr).to(tl.float32)\n decay1 = tl.load(f1 + ptr).to(tl.float32)\n decay1 = tl.sigmoid(decay1)\n h1 = (h1 - x0) * decay1 + x0\n tl.store(hidden + ptr, h1.to(hidden.dtype.element_ty))\n ptr += C\n\n# Triton kernel for fused backward sequential scan\n@triton.jit\ndef bwd_sequential_scan_fused(\n grad_output, v, f, h, B, L, C, BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n if offset_b >= B:\n return\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L - 1) * C + offset_n * BLOCK_M\n grad_h = tl.zeros([BLOCK_M,], dtype=tl.float32)\n for time_step in range(L - 1, -1, -1):\n grad = tl.load(grad_output + ptr).to(tl.float32)\n grad_h += grad\n decay = tl.load(f + ptr).to(tl.float32)\n decay = tl.sigmoid(decay)\n input = tl.load(v + ptr).to(tl.float32)\n grad_v = (1 - decay) * grad_h\n tl.store(v + ptr, grad_v.to(v.dtype.element_ty))\n hidden_state = tl.load(h + ptr - C, mask=ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32)\n grad_f = grad_h * (hidden_state - input) * decay * (1 - decay)\n tl.store(f + ptr, grad_f.to(f.dtype.element_ty))\n grad_h *= decay\n ptr -= C\n\nclass TritonSequentialScan(torch.autograd.Function):\n @staticmethod\n @torch.cuda.amp.custom_fwd\n def forward(ctx, v, f1):\n B, L, C = v.shape\n num_warps = 8\n assert C % 256 == 0\n v = v.contiguous()\n f1 = f1.contiguous()\n hidden = torch.zeros_like(v).contiguous()\n fwd_sequential_scan[(B, int(C / 256))](\n v, f1, hidden, B, L, C, BLOCK_M=256, num_warps=num_warps\n )\n ctx.save_for_backward(v, f1, hidden)\n return hidden\n\n @staticmethod\n @torch.cuda.amp.custom_bwd\n def backward(ctx, grad_output):\n v, f1, hidden = ctx.saved_tensors\n B, L, C = v.shape\n num_warps = 8\n bwd_sequential_scan[(B, int(C / 256))](\n grad_output, v, f1, hidden, B, L, C, BLOCK_M=256, num_warps=num_warps\n )\n return v, f1\n\nclass TritonSequentialScanFused(torch.autograd.Function):\n @staticmethod\n @torch.cuda.amp.custom_fwd\n def forward(ctx, v, f1):\n B, L, C = v.shape\n num_warps = 8\n assert C % 256 == 0\n v = v.contiguous()\n f1 = f1.contiguous()\n hidden = torch.zeros_like(v).contiguous()\n fwd_sequential_scan_fused[(B, int(C / 256))](\n v, f1, hidden, B, L, C, BLOCK_M=256, num_warps=num_warps\n )\n ctx.save_for_backward(v, f1, hidden)\n return hidden\n\n @staticmethod\n @torch.cuda.amp.custom_bwd\n def backward(ctx, grad_output):\n v, f1, hidden = ctx.saved_tensors\n B, L, C = v.shape\n num_warps = 8\n bwd_sequential_scan_fused[(B, int(C / 256))](\n grad_output, v, f1, hidden, B, L, C, BLOCK_M=256, num_warps=num_warps\n )\n return v, f1\n\nreal_scan_tie_input_gate = TritonSequentialScan.apply\nreal_scan_tie_input_gate_fused = TritonSequentialScanFused.apply\n", - "description_1": "Use triton language to implement both forward and backward sequential scan operations, with options for using a sigmoid function on the decay factor. The forward function operates with parameters: a 1D tensor `v`, a decay factor tensor `f1`, an output tensor `hidden`, and integers for batch size `B`, sequence length `L`, embedding dimension `C`, and a block size constant `BLOCK_M`. The backward function also takes gradient output tensor `grad_output` and tensor `h`. The parameters `B`, `L`, `C`, and `BLOCK_M` control the data dimensions and block sizes used in processing.", - "description_2": "Use triton language to create sequential scan operations that include triton.jit kernels for both forward and backward passes, supporting operations with and without a sigmoid applied to decay, parameterized by tensor shapes and constants like `BLOCK_M`.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Forward sequential scan kernel (without fusion)\n@triton.jit\ndef fwd_sequential_scan(\n v,\n f1,\n hidden,\n B,\n L,\n C, \n BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n \n if offset_b >= B:\n return\n\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M \n h1 = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for _ in range(L): \n x0 = tl.load(v + ptr).to(tl.float32) \n decay1 = tl.load(f1 + ptr).to(tl.float32)\n h1 = (h1 - x0) * decay1 + x0 # (h1 * decay1 + (1 - decay1) * x0)\n tl.store(hidden + ptr, h1.to(hidden.dtype.element_ty) )\n ptr += C\n\n\n# Forward sequential scan kernel (with fusion)\n@triton.jit\ndef fwd_sequential_scan_fused(\n v,\n f1,\n hidden,\n B,\n L,\n C, \n BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n \n if offset_b >= B:\n return\n\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M \n h1 = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for _ in range(L): \n x0 = tl.load(v + ptr).to(tl.float32) \n decay1 = tl.load(f1 + ptr).to(tl.float32)\n decay1 = tl.sigmoid(decay1)\n h1 = (h1 - x0) * decay1 + x0\n tl.store(hidden + ptr, h1.to(hidden.dtype.element_ty) )\n ptr += C\n\n\n# Backward sequential scan kernel (without fusion)\n@triton.jit\ndef bwd_sequential_scan(\n grad_output,\n v,\n f,\n h,\n B,\n L,\n C, \n BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n \n if offset_b >= B:\n return\n\n offset_n = tl.program_id(1) \n\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M\n\n grad_h = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for time_step in range(L-1, -1, -1): \n grad = tl.load(grad_output + ptr).to(tl.float32) \n grad_h += grad\n\n decay = tl.load(f + ptr).to(tl.float32)\n input = tl.load(v + ptr).to(tl.float32)\n\n grad_v = (1 - decay) * grad_h\n tl.store(v + ptr, grad_v.to(v.dtype.element_ty))\n\n # TODO: set the last one to h0\n hidden_state = tl.load(h + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32)\n\n grad_f = grad_h * (hidden_state - input) \n tl.store(f + ptr, grad_f.to(f.dtype.element_ty))\n\n grad_h *= decay \n ptr -= C \n\n\n# Backward sequential scan kernel (with fusion)\n@triton.jit\ndef bwd_sequential_scan_fused(\n grad_output,\n v,\n f,\n h,\n B,\n L,\n C, \n BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n \n if offset_b >= B:\n return\n\n offset_n = tl.program_id(1) \n\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M\n\n grad_h = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for time_step in range(L-1, -1, -1): \n grad = tl.load(grad_output + ptr).to(tl.float32) \n grad_h += grad\n\n decay = tl.load(f + ptr).to(tl.float32)\n decay = tl.sigmoid(decay)\n input = tl.load(v + ptr).to(tl.float32)\n\n grad_v = (1 - decay) * grad_h\n tl.store(v + ptr, grad_v.to(v.dtype.element_ty))\n\n # TODO: set the last one to h0\n hidden_state = tl.load(h + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32)\n\n grad_f = grad_h * (hidden_state - input) * decay * (1 - decay)\n tl.store(f + ptr, grad_f.to(f.dtype.element_ty))\n\n grad_h *= decay \n ptr -= C \n\n\n# Example function calls for forward and backward passes\n\nclass TritonSequentialScan(Function):\n @staticmethod\n @torch.cuda.amp.custom_fwd\n def forward(ctx, v, f1):\n B, L, C = v.shape\n num_warps = 8\n assert C % 256 == 0\n v = v.contiguous()\n f1 = f1.contiguous()\n hidden = torch.zeros_like(v).contiguous()\n \n fwd_sequential_scan[(B, int(C/256))](\n v,\n f1,\n hidden,\n B,\n L,\n C, \n BLOCK_M=256,\n num_warps=num_warps\n )\n\n ctx.save_for_backward(v, f1, hidden)\n return hidden\n \n @staticmethod\n @torch.cuda.amp.custom_bwd\n def backward(ctx, grad_output):\n v, f1, hidden = ctx.saved_tensors \n B, L, C = v.shape\n num_warps = 8\n\n bwd_sequential_scan[(B, int(C/256))](\n grad_output, \n v,\n f1,\n hidden,\n B,\n L,\n C, \n BLOCK_M=256,\n num_warps=num_warps\n )\n return v, f1\n\n\nclass TritonSequentialScanFused(Function):\n @staticmethod\n @torch.cuda.amp.custom_fwd\n def forward(ctx, v, f1):\n B, L, C = v.shape\n num_warps = 8\n assert C % 256 == 0\n v = v.contiguous()\n f1 = f1.contiguous()\n hidden = torch.zeros_like(v).contiguous()\n \n fwd_sequential_scan_fused[(B, int(C/256))](\n v,\n f1,\n hidden,\n B,\n L,\n C, \n BLOCK_M=256,\n num_warps=num_warps\n )\n\n ctx.save_for_backward(v, f1, hidden)\n return hidden\n \n @staticmethod\n @torch.cuda.amp.custom_bwd\n def backward(ctx, grad_output):\n v, f1, hidden = ctx.saved_tensors \n B, L, C = v.shape\n num_warps = 8\n\n bwd_sequential_scan_fused[(B, int(C/256))](\n grad_output, \n v,\n f1,\n hidden,\n B,\n L,\n C, \n BLOCK_M=256,\n num_warps=num_warps\n )\n return v, f1\n", - "description_1": "Use Triton language to implement forward and backward sequential scan kernels for a sequence of input vectors (v) and their corresponding decay factors (f1). The kernels are used for processing sequences in parallel and calculating hidden states with respect to each timestep, while storing intermediate results in hidden and gradients in backward pass.", - "description_2": "Use Triton language to compute forward and backward sequential scan of vectors with optional fusion of the decay step in the forward pass for efficiency. The forward and backward steps involve loading vectors, performing calculations, and storing results with parallelism across blocks.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel for complex operator with element-wise operations\n@triton.jit\ndef _complex_operator_element_(_x_real_a, _a_imag_a,\n _x_imag_a, _a_real_a, start, num,\n interval, offset_b, offset_n, L, C, last_interval,\n BLOCK_M: tl.constexpr):\n offset_t = tl.program_id(0)\n # Compute the thread index\n range_batch = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M\n range_time = (tl.arange(0, num) * interval + start) * C\n range_2dim = range_batch[:, None] + range_time[None, :]\n # range_2dim = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M + (offset_t * interval + start) * C\n ptr = range_2dim\n ptr_last = range_2dim - last_interval * C\n x_real_a = tl.load(_x_real_a + ptr).to(tl.float32)\n x_real_a_last = tl.load(_x_real_a + ptr_last).to(tl.float32)\n a_imag_a = tl.load(_a_imag_a + ptr).to(tl.float32)\n a_imag_a_last = tl.load(_a_imag_a + ptr_last).to(tl.float32)\n x_imag_a = tl.load(_x_imag_a + ptr).to(tl.float32)\n x_imag_a_last = tl.load(_x_imag_a + ptr_last).to(tl.float32)\n a_real_a = tl.load(_a_real_a + ptr).to(tl.float32)\n a_real_a_last = tl.load(_a_real_a + ptr_last).to(tl.float32)\n x_real_a = x_real_a + a_real_a * x_real_a_last - a_imag_a * x_imag_a_last\n x_imag_a = x_imag_a + a_real_a * x_imag_a_last + a_imag_a * x_real_a_last\n tl.store(_x_real_a + ptr, x_real_a.to(_x_real_a.dtype.element_ty))\n tl.store(_x_imag_a + ptr, x_imag_a.to(_x_imag_a.dtype.element_ty))\n\n a_real_a_next = a_real_a * a_real_a_last - a_imag_a * a_imag_a_last\n a_imag_a_next = a_imag_a * a_real_a_last - a_real_a * a_imag_a_last\n\n tl.store(_a_real_a + ptr, a_real_a_next.to(_a_real_a.dtype.element_ty))\n tl.store(_a_imag_a + ptr, a_imag_a_next.to(_a_imag_a.dtype.element_ty))\n\n\n# Forward pass kernel for complex sequential scan\n@triton.jit\ndef fwd_sequential_scan_complex(\n v_real,\n v_imag,\n decay_real,\n decay_imag,\n hidden_real,\n hidden_imag,\n hidden_real_input,\n hidden_imag_input,\n B,\n L,\n C, \n BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n \n if offset_b >= B:\n return\n\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M \n ptr_input_hidden = tl.arange(0, BLOCK_M) + offset_b * C + offset_n * BLOCK_M\n\n h_real = tl.load(hidden_real_input + ptr_input_hidden).to(tl.float32)\n h_imag = tl.load(hidden_imag_input + ptr_input_hidden).to(tl.float32)\n\n for _ in range(L): \n x_real = tl.load(v_real + ptr).to(tl.float32) \n x_imag = tl.load(v_imag + ptr).to(tl.float32)\n \n f_real = tl.load(decay_real + ptr).to(tl.float32) \n f_imag = tl.load(decay_imag + ptr).to(tl.float32) \n \n h_real_new = h_real * f_real - h_imag * f_imag + x_real\n h_imag_new = h_real * f_imag + h_imag * f_real + x_imag\n\n tl.store(hidden_real + ptr, h_real_new.to(hidden_real.dtype.element_ty))\n tl.store(hidden_imag + ptr, h_imag_new.to(hidden_imag.dtype.element_ty))\n h_real = h_real_new\n h_imag = h_imag_new\n ptr += C\n\n# Backward pass kernel for complex sequential scan\n@triton.jit\ndef bwd_sequential_scan_complex(\n grad_output_real,\n grad_output_imag,\n v_real,\n v_imag,\n f_real,\n f_imag,\n hidden_real,\n hidden_imag,\n grad_detach,\n B,\n L,\n C, \n BLOCK_M: tl.constexpr,\n):\n offset_b = tl.program_id(0)\n \n if offset_b >= B:\n return\n\n offset_n = tl.program_id(1) \n\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M\n grad_detach_ptr = grad_detach + offset_b * L + (L - 1)\n grad_h_real = tl.zeros([BLOCK_M,], dtype=tl.float32)\n grad_h_imag = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for time_step in range(L-1, -1, -1): # L-1, L-2, ..., 0\n grad_real = tl.load(grad_output_real + ptr).to(tl.float32)\n grad_imag = tl.load(grad_output_imag + ptr).to(tl.float32)\n\n grad_detach_item = tl.load(grad_detach_ptr).to(tl.float32)\n\n grad_h_real = grad_h_real * (1 - grad_detach_item)\n grad_h_imag = grad_h_imag * (1 - grad_detach_item)\n\n grad_h_real += grad_real\n grad_h_imag += grad_imag\n \n decay_real = tl.load(f_real + ptr).to(tl.float32) \n decay_imag = tl.load(f_imag + ptr).to(tl.float32) \n h_real = tl.load(hidden_real + ptr).to(tl.float32)\n h_imag = tl.load(hidden_imag + ptr).to(tl.float32)\n\n grad_f_real = (grad_h_real * h_real + grad_h_imag * h_imag) \n grad_f_imag = (grad_h_imag * h_real - grad_h_real * h_imag) \n\n tl.store(f_real + ptr, grad_f_real.to(f_real.dtype.element_ty)) \n tl.store(f_imag + ptr, grad_f_imag.to(f_real.dtype.element_ty)) \n\n tl.store(v_real + ptr, grad_h_real.to(v_real.dtype.element_ty))\n tl.store(v_imag + ptr, grad_h_imag.to(v_real.dtype.element_ty))\n\n grad_h_real_new = grad_h_real * decay_real + grad_h_imag * decay_imag \n grad_h_imag_new = grad_h_imag * decay_real - grad_h_real * decay_imag\n \n grad_h_real = grad_h_real_new\n grad_h_imag = grad_h_imag_new\n \n ptr -= C\n grad_detach_ptr -= 1\n\n# Wrapper class to call the forward and backward kernels\nclass TritonSequentialScan_Complex(Function):\n @staticmethod\n @torch.cuda.amp.custom_fwd\n def forward(ctx, v_real, v_imag, f_real, f_imag, hidden_real_input, hidden_imag_input, grad_detach):\n B,L,C = v_real.shape\n num_warps = 8\n assert C % 256 == 0, 'Hidden dimension must be multiple of 256'\n v_real = v_real.contiguous()\n v_imag = v_imag.contiguous()\n f_real = f_real.contiguous()\n f_imag = f_imag.contiguous()\n\n hidden_real_input = hidden_real_input.contiguous()\n hidden_imag_input = hidden_imag_input.contiguous()\n\n hidden_real = torch.zeros_like(v_real).contiguous()\n hidden_imag = torch.zeros_like(v_imag).contiguous()\n fwd_sequential_scan_complex[(B, int(C/256))](\n v_real,\n v_imag,\n f_real,\n f_imag,\n hidden_real,\n hidden_imag,\n hidden_real_input,\n hidden_imag_input,\n B,\n L,\n C, \n BLOCK_M=256,\n num_warps=num_warps\n )\n\n ctx.save_for_backward(v_real, v_imag, f_real, f_imag, hidden_real, hidden_imag, hidden_real_input, hidden_imag_input, grad_detach)\n return hidden_real, hidden_imag\n \n @staticmethod\n @torch.cuda.amp.custom_bwd\n def backward(ctx, grad_output_real, grad_output_imag):\n \n v_real, v_imag, f_real, f_imag, hidden_real, hidden_imag, hidden_real_input, hidden_imag_input, grad_detach = ctx.saved_tensors\n B, L, C = v_real.shape\n \n num_warps = 8\n hidden_real = torch.cat((hidden_real_input[..., :1, :], hidden_real[..., :-1, :]), dim=-2)\n hidden_imag = torch.cat((hidden_imag_input[..., :1, :], hidden_imag[..., :-1, :]), dim=-2)\n\n bwd_sequential_scan_complex[(B, int(C/256))](\n grad_output_real, \n grad_output_imag,\n v_real, \n v_imag,\n f_real,\n f_imag,\n hidden_real, \n hidden_imag,\n grad_detach,\n B,\n L,\n C, \n BLOCK_M=256,\n num_warps=num_warps\n )\n return v_real, v_imag, f_real, f_imag, None, None, None\n\n# Function to apply TritonSequentialScan_Complex\ncomplex_scan = TritonSequentialScan_Complex.apply\n", - "description_1": "Use triton language to define three kernels. The first one performs complex element-wise operations on multiple input arrays, adjusting based on provided parameters. The second kernel performs a forward sequential scan of complex numbers over a batch, length, and channel dimensions, updating hidden states. The third kernel computes the backward pass for this sequential scan, adjusting gradients and storing results. These operations leverage Triton's parallel execution for efficiency. A PyTorch function class wraps these kernels for forward and backward GPU computations, applied as complex_scan.", - "description_2": "Use triton language to implement kernels for complex element operations, forward sequential scan, and backward gradient computation, utilizing PyTorch's autograd Function for efficient GPU execution.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _complex_operator_fix_batch_tl_(_x_real_a, _x_real_a_last, _a_imag_a, _a_imag_a_last,\n _x_imag_a, _x_imag_a_last, _a_real_a, _a_real_a_last, mask, B,\n BLOCK_M: tl.constexpr):\n # Compute the thread index\n idx = tl.program_id(0)\n if idx < B:\n ptr = tl.arange(0, BLOCK_M) + idx * BLOCK_M\n x_real_a = tl.load(_x_real_a + ptr).to(tl.float32)\n x_real_a_last = tl.load(_x_real_a_last + ptr).to(tl.float32)\n a_imag_a = tl.load(_a_imag_a + ptr).to(tl.float32)\n a_imag_a_last = tl.load(_a_imag_a_last + ptr).to(tl.float32)\n x_imag_a = tl.load(_x_imag_a + ptr).to(tl.float32)\n x_imag_a_last = tl.load(_x_imag_a_last + ptr).to(tl.float32)\n a_real_a = tl.load(_a_real_a + ptr).to(tl.float32)\n a_real_a_last = tl.load(_a_real_a_last + ptr).to(tl.float32)\n mask_a = tl.load(mask + ptr).to(tl.float32)\n\n x_real_a = (x_real_a + a_real_a * x_real_a_last - a_imag_a * x_imag_a_last) * mask_a\n x_imag_a = (x_imag_a + a_real_a * x_imag_a_last + a_imag_a * x_real_a_last) * mask_a\n\n tl.store(_x_real_a + ptr, x_real_a.to(_x_real_a.dtype.element_ty))\n tl.store(_x_imag_a + ptr, x_imag_a.to(_x_imag_a.dtype.element_ty))\n\n a_real_a_next = (a_real_a * a_real_a_last - a_imag_a * a_imag_a_last) * mask_a + (1 - mask_a) * a_real_a\n a_imag_a_next = (a_imag_a * a_real_a_last - a_real_a * a_imag_a_last) * mask_a + (1 - mask_a) * a_imag_a\n\n tl.store(_a_real_a + ptr, a_real_a_next.to(_a_real_a.dtype.element_ty))\n tl.store(_a_imag_a + ptr, a_imag_a_next.to(_a_imag_a.dtype.element_ty))\n\ndef _complex_operator_fix_batch_tl(_x_real_a: torch.Tensor, _x_real_a_last: torch.Tensor, _a_imag_a: torch.Tensor, _a_imag_a_last: torch.Tensor,\n _x_imag_a: torch.Tensor, _x_imag_a_last: torch.Tensor, _a_real_a: torch.Tensor, _a_real_a_last: torch.Tensor, mask: torch.Tensor):\n B = _x_real_a.shape[0]\n _complex_operator_fix_batch_tl_[(B,)](\n _x_real_a, _x_real_a_last, _a_imag_a, _a_imag_a_last,\n _x_imag_a, _x_imag_a_last, _a_real_a, _a_real_a_last, mask, B, 256, num_warps=8\n )\n pass\n", - "description_1": "Use triton language to implement a kernel function '_complex_operator_fix_batch_tl_' that performs complex arithmetic operations on batches of tensors. The kernel takes 9 parameters: 8 tensors representing real and imaginary parts of complex numbers and a mask tensor, and a constant BLOCK_M. The function computes new values for the real and imaginary parts of the complex numbers based on the mask and stores the results back into the input tensors.", - "description_2": "Use triton language to perform batch-wise complex arithmetic operations with masking on input tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fwd_recurrence(\n A,\n B,\n C,\n Dt,\n X,\n Y,\n H,\n start,\n initial_state,\n T: tl.constexpr,\n D: tl.constexpr,\n K: tl.constexpr,\n BV: tl.constexpr,\n):\n i_bh = tl.program_id(0)\n i_v = tl.program_id(1)\n\n dt_ptr = Dt + i_bh * T * D + i_v * BV + tl.arange(0, BV)\n u_ptr = X + i_bh * T * D + i_v * BV + tl.arange(0, BV)\n o_ptr = Y + i_bh * T * D + i_v * BV + tl.arange(0, BV)\n start_ptr = start + i_bh * T\n h = tl.zeros([BV, K], dtype=tl.float32)\n\n b_ptr = B + i_bh * T * K + tl.arange(0, K)\n\n A = A + ((i_v * BV) + tl.arange(0, BV)\n [:, None]) * K + tl.arange(0, K)[None, :]\n _A = tl.load(A)\n\n H_ptr = H + i_bh * T * D * K + \\\n (i_v * BV + tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :]\n\n h += tl.load(initial_state + i_bh * D * K + (i_v * BV +\n tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :])\n\n for i in range(T):\n b = tl.load(b_ptr).to(tl.float32)\n dt = tl.load(dt_ptr)\n start_flag = tl.load(start_ptr).to(tl.float32)\n u = tl.load(u_ptr)\n x_dt = u * dt\n x_dt_b = x_dt[:, None] * b[None, :]\n dt_a = tl.exp(dt[:, None] * _A) * (1 - start_flag)\n h = h * dt_a + x_dt_b\n tl.store(H_ptr, h)\n\n b_ptr += K\n dt_ptr += D\n start_ptr += 1\n u_ptr += D\n o_ptr += D\n H_ptr += D * K\n\n@triton.jit\ndef bwd_recurrence(\n A,\n B,\n C,\n U,\n Dt,\n DO,\n H,\n start,\n DA,\n DB,\n DC,\n dDt,\n dU,\n batch,\n initial_state,\n grad_detach,\n T: tl.constexpr,\n D: tl.constexpr,\n K: tl.constexpr,\n BV: tl.constexpr,\n):\n i_bh = tl.program_id(0)\n i_v = tl.program_id(1)\n NV = tl.cdiv(D, BV)\n\n dt_ptr = Dt + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D\n ddt_ptr = dDt + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D\n u_ptr = U + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D\n du_ptr = dU + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D\n do_ptr = DO + i_bh * T * D + i_v * BV + tl.arange(0, BV) + (T - 1) * D\n\n start_ptr = start + i_bh * T + (T-1)\n grad_detach_ptr = grad_detach + i_bh * T + (T-1)\n\n dh = tl.zeros([BV, K], dtype=tl.float32)\n dA = tl.zeros([BV, K], dtype=tl.float32)\n\n b_ptr = B + i_bh * T * K + tl.arange(0, K) + (T - 1) * K\n c_ptr = C + i_bh * T * K + tl.arange(0, K) + (T - 1) * K\n dc_ptr = DC + (i_bh + batch * i_v) * T * K + tl.arange(0, K) + (T - 1) * K\n db_ptr = DB + (i_bh + batch * i_v) * T * K + tl.arange(0, K) + (T - 1) * K\n\n A = A + ((i_v * BV) + tl.arange(0, BV)\n [:, None]) * K + tl.arange(0, K)[None, :]\n _A = tl.load(A)\n H_ptr = H + i_bh * T * D * K + \\\n (i_v * BV + tl.arange(0, BV)[:, None]) * K + \\\n tl.arange(0, K)[None, :] + (T - 1) * D * K\n\n for i in range(T):\n h = tl.load(H_ptr)\n if i < T - 1:\n next_h = tl.load(H_ptr - D * K)\n else:\n next_h = tl.load(\n initial_state + i_bh * D * K + (i_v * BV + tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :])\n b = tl.load(b_ptr).to(tl.float32)\n c = tl.load(c_ptr).to(tl.float32)\n do = tl.load(do_ptr).to(tl.float32)\n u = tl.load(u_ptr).to(tl.float32)\n dt = tl.load(dt_ptr).to(tl.float32)\n start_flag = tl.load(start_ptr).to(tl.float32)\n grad_detach_flag = tl.load(grad_detach_ptr).to(tl.float32)\n # detach grad here\n dh = dh * (1 - grad_detach_flag)\n # dA = dA * (1 - grad_detach_flag)\n # gradient wrt output proj\n dc = tl.sum(h * do[:, None], axis=0)\n tl.store(dc_ptr, dc)\n\n # graident wrt input\n dh += do[:, None] * c[None, :]\n dt_u = dt * u\n db = tl.sum(dh * dt_u[:, None], axis=0)\n tl.store(db_ptr, db)\n ddt_u = tl.sum(dh * b[None, :], axis=1)\n ddt = ddt_u * u\n du = ddt_u * dt\n tl.store(du_ptr, du)\n\n # gradient wrt decay\n dt_a = tl.exp(dt[:, None] * _A) * (1 - start_flag)\n dh *= dt_a\n\n d_decay = dh * next_h\n dA += d_decay * dt[:, None]\n ddt += tl.sum(d_decay * _A, axis=1)\n tl.store(ddt_ptr, ddt)\n\n # update ptr\n b_ptr -= K\n c_ptr -= K\n dc_ptr -= K\n db_ptr -= K\n dt_ptr -= D\n ddt_ptr -= D\n u_ptr -= D\n du_ptr -= D\n do_ptr -= D\n H_ptr -= D * K\n start_ptr -= 1\n grad_detach_ptr -= 1\n\n DA_ptr = DA + i_bh * D * K + \\\n (i_v * BV + tl.arange(0, BV)[:, None]) * K + tl.arange(0, K)[None, :]\n tl.store(DA_ptr, dA)\n\n\nclass SelectiveScan(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, u, delta, A, B, C, start, grad_detach, initial_state=None):\n \"\"\"\n u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)\n delta: shape (b, l, d_in)\n A: shape (d_in, n)\n B: shape (b, l, n)\n C: shape (b, l, n)\n D: shape (d_in,)\n start: (b, l, 1)\n \"\"\"\n b_size, T, d = u.shape\n K = B.shape[-1]\n\n ctx.b_size = b_size\n ctx.T = T\n ctx.d = d\n ctx.K = K\n BV = 64\n num_warps = 4\n\n if b_size <= 16:\n BV = 32\n num_warps = 2\n\n NV = triton.cdiv(d, BV)\n\n o = torch.empty_like(u)\n H = torch.empty(b_size, T, d, K, device=u.device, dtype=torch.float32)\n\n if initial_state is None:\n initial_state = torch.zeros(\n b_size, d, K, device=u.device, dtype=torch.float32)\n A = A.contiguous()\n B = B.contiguous()\n C = C.contiguous()\n delta = delta.contiguous()\n u = u.contiguous()\n o = o.contiguous()\n H = H.contiguous()\n start = start.contiguous()\n initial_state = initial_state.contiguous()\n grad_detach = grad_detach.contiguous()\n fwd_recurrence[(b_size, NV)](A, B, C, delta, u, o, H, start,\n initial_state, T, d, K, BV, num_warps=num_warps, num_stages=1)\n o = reduce(H, C)\n ctx.save_for_backward(A, B, C, delta, H, u, start, grad_detach)\n ctx.initial_state = initial_state\n return o, H[:, -1]\n\n @staticmethod\n def backward(ctx, grad_output, d_final_state):\n do = grad_output\n A, B, C, delta, H, u, start, grad_detach = ctx.saved_tensors\n b_size = ctx.b_size\n T = ctx.T\n d = ctx.d\n K = ctx.K\n\n BV = 64\n num_warps = 4\n\n if b_size <= 16:\n BV = 32\n num_warps = 2\n\n NV = triton.cdiv(d, BV)\n dA = A.new_empty(b_size, d, K)\n du = torch.empty_like(u)\n d_delta = torch.empty_like(delta)\n db = B.new_empty(NV, b_size, T, K)\n dc = C.new_empty(NV, b_size, T, K)\n\n bwd_recurrence[(b_size, NV)](A, B, C, u, delta, do, H, start,\n dA, db, dc,\n d_delta, du, b_size, ctx.initial_state, grad_detach, T, d, K, BV, num_warps=num_warps)\n # dA = dA / valid_num\n db = db.sum(0)\n dc = dc.sum(0)\n\n return du, d_delta, dA.sum(0), db, dc, None, None, None\n\n\ndef triton_selective_scan_sequential(u, delta, A, B, C, D, start, grad_detach, initial_state=None):\n \"\"\"\n u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)\n delta: shape (b, l, d_in)\n A: shape (d_in, n)\n B: shape (b, l, n)\n C: shape (b, l, n)\n D: shape (d_in,)\n start: (b, l, 1)\n \"\"\"\n original_dtype = u.dtype\n D = D.float()\n A = A.float()\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = SelectiveScan.apply(u, delta, A, B, C, start, grad_detach, initial_state)\n o = o + D * u\n return o.to(original_dtype), final_state\n", - "description_1": "Use triton language to implement two kernels: fwd_recurrence and bwd_recurrence, and a function SelectiveScan for a selective scan operation. fwd_recurrence has 13 parameters for forward data manipulation including tensors and constants. bwd_recurrence has 17 parameters for backward propagation including gradients. The SelectiveScan class performs forward and backward operations using these kernels, with 8 forward parameters including input tensors and flags, and backward parameters retrieved from saved tensors and gradients.", - "description_2": "Use triton language to implement forward and backward recurrence operations for a selective scan with 13 parameters for fwd_recurrence and 17 parameters for bwd_recurrence, incorporating forward data manipulation and backward gradient propagation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, N,\n eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, residual_out,\n mean, rstd, x.stride(0), y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N, eps, is_rms_norm, BLOCK_N, residual is not None, \n residual_out is not None, bias is not None\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a forward pass for layer normalization. The function _layer_norm_fwd_1pass_kernel has 18 parameters: X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd are pointers for input, output, weights, biases, residuals, residual output, mean, and reciprocal of the standard deviation; stride_x_row, stride_y_row, stride_res_row, stride_res_out_row are the stride increments for each row in corresponding arrays; N is the number of columns in X; eps is a small epsilon value for numerical stability; IS_RMS_NORM, BLOCK_N, HAS_RESIDUAL, STORE_RESIDUAL_OUT, HAS_BIAS are compile-time constants. The function _layer_norm_fwd is a Python function that uses _layer_norm_fwd_1pass_kernel to perform layer normalization on input tensor x.", - "description_2": "Use triton language to create a kernel for layer normalization, handling input data with optional residuals and biases, to output normalized data with optional mean and variance storage.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 34 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to manage data preparation and kernel invocation.", - "description_2": "Use triton language to create a kernel for selective state update with optional bias and scaling, and a wrapper function to handle data and invoke the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef gather_transposed_gemv_flag_atomicadd_kernel(\n Y, # Pointers to matrices\n A,\n X,\n IDX,\n # Matrix dimensions\n M,\n N,\n CACHE_KEY_M,\n CACHE_KEY_N,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n stride_am,\n # Meta-parameters\n BATCHSIZE: tl.constexpr,\n SPARSITY_BIN: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n \"\"\"\n Kernel for computing Y = A[IDX, :]^T @ X + BIAS, where A is a dense matrix\n with Z rows and N columns. We also batch across the batch dimension of the input X.\n We will not check that the indices are valid, for performance reason.\n - Input X has shape (BATCHSIZE, M)\n - Weight has shape (Z, N)\n - IDX has shape (M), where M is the number of non-zero rows in A\n - Bias has shape (N)\n - Output has shape (BATCHSIZE, N)\n \"\"\"\n start_m = tl.program_id(0)\n start_n = tl.program_id(1)\n # now compute the block that each program will go through\n # rm (resp. rn) denotes a range of indices for rows (resp. col) of A\n rm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n IDX = IDX + rm\n idx = tl.load(IDX, mask=rm < M, other=0) > 0\n A = A + (rm[:, None] * stride_am + rn[None, :])\n X = X + rm\n Y = Y + rn\n \n if BATCHSIZE == 1:\n a = tl.load(A, mask=idx[:, None], other=0.0)\n x0 = tl.load(X)#, mask=idx, other=0.0) # if flag_gemv is correct, this will be unnecessary.\n acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], 0)\n\n # rematerialize rm and rn to save registers\n rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n tl.atomic_add(Y, acc0, mask=rn < N)\n \ndef gather_transposed_gemv_flag_3d(\n x: torch.Tensor,\n weight: torch.Tensor,\n idx: torch.Tensor,\n sparsity_bin: int\n) -> torch.Tensor:\n \"\"\"\n Compute y = weight[idx, :]^T @ x.\n :param x: input tensor\n :param weight: weight matrix\n :param idx: indices\n :return: result tensor\n \"\"\"\n Z, N = weight.shape\n beam_width, seq_len, _ = x.shape\n assert x.shape[2] == Z\n x = x.contiguous()\n if weight.stride(1) > 1:\n weight = weight.contiguous()\n\n output = torch.empty(\n beam_width,\n seq_len,\n N,\n device=x.device,\n dtype=torch.float32,\n )\n\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (\n triton.cdiv(Z, META[\"BLOCK_M\"]),\n triton.cdiv(N, META[\"BLOCK_N\"]),\n ) # noqa\n\n kernel = gather_transposed_gemv_flag_atomicadd_kernel\n kernel[grid](\n output, # data ptrs\n weight,\n x,\n idx,\n Z, # shapes\n N,\n Z // 128, # key for triton cache (limit number of compilations)\n N // 32,\n weight.stride(0), # strides\n beam_width, # can't use kwargs because auto-tuner requires args\n sparsity_bin,\n )\n return output# .to(dtype=weight.dtype)\n", - "description_1": "Use triton language to implement a kernel that computes Y = A[IDX, :]^T @ X + BIAS for a dense matrix A with Z rows and N columns. The kernel takes pointers to matrices Y, A, X, and IDX, matrix dimensions M and N, cache keys CACHE_KEY_M and CACHE_KEY_N, stride_am for pointer increment, and meta-parameters BATCHSIZE, SPARSITY_BIN, BLOCK_M, and BLOCK_N. The kernel is called by a function that computes y = weight[idx, :]^T @ x, where x is the input tensor, weight is the weight matrix, idx is the indices, and sparsity_bin is an integer representing the sparsity bin.", - "description_2": "Use triton language to create a matrix multiplication kernel with sparsity support, and a function to call this kernel for computing the product of a transposed indexed matrix and an input tensor.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for sparse GEMV\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128}, num_warps=2),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64}, num_warps=4),\n triton.Config({\"BLOCK_M\": 8, \"BLOCK_N\": 128}, num_warps=2),\n triton.Config({\"BLOCK_M\": 16, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 16}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 512}, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 512}, num_warps=4),\n triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 512}, num_warps=4),\n triton.Config({\"BLOCK_M\": 16, \"BLOCK_N\": 512}, num_warps=4),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"BATCHSIZE\", \"SPARSITY_BIN\"],\n)\n@triton.jit\ndef splitk_sparse_gemv_kernel(\n Y, # Pointers to matrices\n A, X, threshold,\n N, M,\n CACHE_KEY_N, CACHE_KEY_M,\n BATCHSIZE: tl.constexpr, SPARSITY_BIN: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr,\n):\n start_n = tl.program_id(0)\n start_m = tl.program_id(1)\n rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n rm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n A_ptr = A + (rm[:, None] * N + rn[None, :])\n X_ptr = X + rm\n Y_ptr = Y + rn\n\n if BATCHSIZE == 1:\n x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last')\n idx = tl.abs(x0) > threshold\n a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first')\n acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], 0)\n rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n tl.atomic_add(Y_ptr, acc0, mask=rn < N)\n\ndef splitk_sparse_gemv(\n x: torch.Tensor,\n weight: torch.Tensor,\n threshold: float,\n sparsity_bin: int\n) -> torch.Tensor:\n N, Z = weight.shape\n beam_width, seq_len, _ = x.shape\n assert x.shape[2] == Z\n x = x.contiguous()\n assert weight.stride(1) > 1, \"weight should be column major\"\n grid = lambda META: (\n triton.cdiv(N, META[\"BLOCK_N\"]),\n triton.cdiv(Z, META[\"BLOCK_M\"]),\n )\n output = torch.empty(\n beam_width,\n seq_len,\n N,\n device=x.device,\n dtype=torch.float16,\n )\n splitk_sparse_gemv_kernel[grid](\n output, weight, x, threshold,\n N, Z,\n N // 16, Z // 16,\n beam_width, sparsity_bin,\n )\n if x.dtype is not output.dtype:\n print(f\"Warning: incurring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. \")\n return output.to(dtype=x.dtype)\n return output\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128}, num_warps=2),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64}, num_warps=4),\n triton.Config({\"BLOCK_M\": 8, \"BLOCK_N\": 128}, num_warps=2),\n triton.Config({\"BLOCK_M\": 16, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 16}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 512}, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 512}, num_warps=4),\n triton.Config({\"BLOCK_M\": 32, \"BLOCK_N\": 512}, num_warps=4),\n triton.Config({\"BLOCK_M\": 16, \"BLOCK_N\": 512}, num_warps=4),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"BATCHSIZE\", \"SPARSITY_BIN\"],\n)\n@triton.jit\ndef qkv_kernel(\n Y, # Pointers to output matrices\n A, \n X, \n threshold_q, threshold_k, threshold_v,\n N, N_q, N_kv, M,\n CACHE_KEY_N, CACHE_KEY_M,\n BATCHSIZE: tl.constexpr, SPARSITY_BIN: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr,\n):\n start_n = tl.program_id(0)\n start_m = tl.program_id(1)\n is_q = start_n * BLOCK_N < N_q\n is_v = N_q + N_kv <= start_n * BLOCK_N\n rm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = start_n*BLOCK_N + tl.arange(0, BLOCK_N)\n A_ptr = A + rm[:, None] * N + rn[None, :]\n X_ptr = X + rm\n Y_ptr = Y + rn\n threshold = tl.where(is_q, threshold_q, tl.where(is_v, threshold_v, threshold_k))\n if BATCHSIZE == 1:\n x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last')\n idx = tl.abs(x0) > threshold\n a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first')\n acc = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], 0)\n rn = start_n*BLOCK_N + tl.arange(0, BLOCK_N)\n mask_n = rn < N\n tl.atomic_add(Y_ptr, acc, mask=mask_n)\n\ndef qkv_gemv(\n x: torch.Tensor,\n weight: torch.Tensor,\n threshold_q: float,\n threshold_k: float,\n threshold_v: float,\n sparsity_bin: int,\n kv_size: int\n):\n N, Z = weight.shape\n beam_width, seq_len, _ = x.shape\n assert x.shape[2] == Z\n x = x.contiguous()\n assert weight.stride(1) > 1, \"weights should be column major\"\n N_q = N - 2*kv_size\n N_k = kv_size\n grid = lambda META: (\n triton.cdiv(N, META[\"BLOCK_N\"]),\n triton.cdiv(Z, META[\"BLOCK_M\"]),\n )\n output = torch.empty(beam_width, seq_len, N, device=x.device, dtype=torch.float16)\n qkv_kernel[grid](\n output, weight, x,\n threshold_q, threshold_k, threshold_v,\n N, N_q, N_k, Z,\n N // 16, Z // 16,\n beam_width, sparsity_bin,\n )\n if x.dtype is not output.dtype:\n print(f\"Warning: incurring dtype conversion overhead. Input dtype: {x.dtype}\")\n return output.to(dtype=x.dtype)\n return output\n", - "description_1": "Use triton language to implement a sparse generalized matrix-vector multiplication (GEMV) operator and a fused QKV operator. The sparse GEMV kernel takes pointers to matrices and vectors, a threshold for sparsity, matrix dimensions, and meta-parameters for block sizes and configurations. It computes a sparse matrix-vector product based on the threshold, using atomic additions to accumulate results. The sparse GEMV function wraps this kernel, managing input and output tensors. The QKV kernel performs a fused operation on Q, K, V matrices with different sparsity thresholds and similar parameters as the sparse GEMV. The QKV function manages tensor shapes and invokes the kernel accordingly.", - "description_2": "Use triton language to create sparse matrix-vector multiplication and fused QKV operators with configurable block sizes and thresholds, leveraging Triton's parallel execution model to perform efficient computations on GPU.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _relative_information_injection_kernel_forward(t_q, t_emb, t_info, t_output,\n m_q, n_q: tl.constexpr,\n m_emb, n_emb: tl.constexpr,\n m_info, n_info: tl.constexpr,\n b_out, m_out, n_out: tl.constexpr,\n idxs_batch_sparsity, idxs_row_sparsity,\n stride_q_b, stride_q_m, stride_q_n,\n stride_emb_b, stride_emb_m, stride_emb_n,\n stride_info_b, stride_info_m, stride_info_n,\n stride_output_b, stride_output_m, stride_output_n,\n block_size_sparsity: tl.constexpr,\n BLOCK_SIZE_TRITON: tl.constexpr):\n pid_batch = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spar_batch = tl.load(idxs_batch_sparsity + pid_batch)\n spar_row = tl.load(idxs_row_sparsity + pid_batch)\n\n for i_row in range(0, BLOCK_SIZE_TRITON):\n for i_col in range(0, BLOCK_SIZE_TRITON):\n info_index = ((pid_batch * stride_info_b) +\n (pid_row * BLOCK_SIZE_TRITON * stride_info_m) +\n (pid_col * BLOCK_SIZE_TRITON * stride_info_n) +\n (i_row * stride_info_m) +\n (i_col * stride_info_n))\n info_mask = (pid_row * BLOCK_SIZE_TRITON + i_row < m_info) & (pid_col * BLOCK_SIZE_TRITON + i_col < n_info)\n info_value = tl.load(t_info + info_index, mask=info_mask).to(tl.int32)\n\n q_index = ((spar_batch * stride_q_b) +\n (spar_row * block_size_sparsity * stride_q_m) +\n (pid_row * BLOCK_SIZE_TRITON * stride_q_m) +\n (i_row * stride_q_m))\n q_offsets = (tl.arange(0, n_q) * stride_q_n)\n q_mask = (spar_row * block_size_sparsity + pid_row * BLOCK_SIZE_TRITON + i_row < m_q)\n q_values = tl.load(t_q + q_index + q_offsets, mask=q_mask)\n\n emb_index = ((spar_batch * stride_emb_b) +\n (info_value * stride_emb_m))\n emb_offsets = (tl.arange(0, n_emb) * stride_emb_n)\n emb_mask = (info_value < m_emb)\n emb_values = tl.load(t_emb + emb_index + emb_offsets, mask=emb_mask)\n\n output_index = ((pid_batch * stride_output_b) +\n (pid_row * BLOCK_SIZE_TRITON * stride_output_m) +\n (pid_col * BLOCK_SIZE_TRITON * stride_output_n) +\n (i_row * stride_output_m) +\n (i_col * stride_output_n))\n output_mask = (pid_row * BLOCK_SIZE_TRITON + i_row < m_out) & (pid_col * BLOCK_SIZE_TRITON + i_col < n_out)\n final_mask = (output_index <\n (b_out * stride_output_b + m_out * stride_output_m + n_out * stride_output_n))\n tl.store(t_output + output_index, tl.sum(q_values * emb_values), mask=output_mask & final_mask)\n\n\n@triton.jit\ndef _relative_information_injection_kernel_backward_q(t_grad, t_emb, t_info, t_output,\n m_grad, n_grad: tl.constexpr,\n m_emb, n_emb: tl.constexpr,\n m_info, n_info: tl.constexpr,\n b_out, m_out, n_out: tl.constexpr,\n idxs_batch_sparsity, idxs_row_sparsity,\n stride_grad_b, stride_grad_m, stride_grad_n: tl.constexpr,\n stride_emb_b, stride_emb_m, stride_emb_n: tl.constexpr,\n stride_info_b, stride_info_m, stride_info_n: tl.constexpr,\n stride_output_b, stride_output_m, stride_output_n: tl.constexpr,\n block_size_sparsity: tl.constexpr,\n BLOCK_SIZE_TRITON: tl.constexpr):\n pid_batch = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spar_batch = tl.load(idxs_batch_sparsity + pid_batch)\n spar_row = tl.load(idxs_row_sparsity + pid_batch)\n\n for i_row in range(0, BLOCK_SIZE_TRITON):\n info_index = ((pid_batch * stride_info_b) +\n (pid_row * BLOCK_SIZE_TRITON * stride_info_m) +\n (pid_col * BLOCK_SIZE_TRITON * stride_info_n) +\n (i_row * stride_info_m))\n info_offsets = (tl.arange(0, BLOCK_SIZE_TRITON) * stride_info_n)\n info_mask = (pid_row * BLOCK_SIZE_TRITON + i_row < m_info) & (info_offsets < n_info * stride_info_n)\n info_values = tl.load(t_info + info_index + info_offsets, mask=info_mask).to(tl.int32)\n\n grad_index = ((pid_batch * stride_grad_b) +\n (pid_row * BLOCK_SIZE_TRITON * stride_grad_m) +\n (pid_col * BLOCK_SIZE_TRITON * stride_grad_n) +\n (i_row * stride_grad_m))\n grad_offsets = (tl.arange(0, BLOCK_SIZE_TRITON) * stride_grad_n)\n grad_mask = (pid_row * BLOCK_SIZE_TRITON + i_row < m_grad) & (\n grad_offsets + pid_col * BLOCK_SIZE_TRITON < n_grad * stride_grad_n)\n grad_values = tl.load(t_grad + grad_index + grad_offsets, mask=grad_mask)\n\n for i_dim in range(0, n_emb):\n emb_index = ((spar_batch * stride_emb_b) +\n (i_dim * stride_emb_n))\n emb_offsets = (info_values * stride_emb_m)\n emb_mask = (i_dim < n_emb) & (emb_offsets < m_emb * stride_emb_m)\n emb_values = tl.load(t_emb + emb_index + emb_offsets, mask=emb_mask)\n\n output_index = ((spar_batch * stride_output_b) +\n (spar_row * block_size_sparsity * stride_output_m) +\n (pid_row * BLOCK_SIZE_TRITON * stride_output_m) +\n (i_row * stride_output_m) +\n (i_dim * stride_output_n))\n output_mask = (spar_row * block_size_sparsity + pid_row * BLOCK_SIZE_TRITON + i_row < m_out) & (\n i_dim < n_out)\n final_mask = (output_index <\n (b_out * stride_output_b + m_out * stride_output_m + n_out * stride_output_n))\n tl.atomic_add(t_output + output_index, tl.sum(grad_values * emb_values), mask=output_mask & final_mask)\n\n\n@triton.jit\ndef _relative_information_injection_kernel_backward_emb(t_grad, t_q, t_info, t_output,\n m_grad, n_grad: tl.constexpr,\n m_q, n_q: tl.constexpr,\n m_info, n_info: tl.constexpr,\n b_out, m_out, n_out: tl.constexpr,\n idxs_batch_sparsity, idxs_row_sparsity,\n stride_grad_b, stride_grad_m, stride_grad_n,\n stride_q_b, stride_q_m, stride_q_n,\n stride_info_b, stride_info_m, stride_info_n,\n stride_output_b, stride_output_m, stride_output_n,\n block_size_sparsity: tl.constexpr,\n BLOCK_SIZE_TRITON: tl.constexpr):\n pid_batch = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spar_batch = tl.load(idxs_batch_sparsity + pid_batch)\n spar_row = tl.load(idxs_row_sparsity + pid_batch)\n\n for i_col in range(0, BLOCK_SIZE_TRITON):\n info_index = ((pid_batch * stride_info_b) +\n (pid_row * BLOCK_SIZE_TRITON * stride_info_m) +\n (pid_col * BLOCK_SIZE_TRITON * stride_info_n) +\n (i_col * stride_info_n))\n info_offsets = (tl.arange(0, BLOCK_SIZE_TRITON) * stride_info_m)\n info_mask = ((pid_row * BLOCK_SIZE_TRITON * stride_info_m + info_offsets < m_info * stride_info_m) &\n (i_col < n_info))\n info_values = tl.load(t_info + info_index + info_offsets, mask=info_mask).to(tl.int32)\n\n grad_index = ((pid_batch * stride_grad_b) +\n (pid_row * BLOCK_SIZE_TRITON * stride_grad_m) +\n (pid_col * BLOCK_SIZE_TRITON * stride_grad_n) +\n (i_col * stride_grad_n))\n grad_offsets = (tl.arange(0, BLOCK_SIZE_TRITON) * stride_grad_m)\n grad_mask = ((pid_row * BLOCK_SIZE_TRITON * stride_grad_m + grad_offsets < m_grad * stride_grad_m) &\n (i_col < n_grad))\n grad_values = tl.load(t_grad + grad_index + grad_offsets, mask=grad_mask)\n\n for i_dim in range(0, n_q):\n q_index = ((spar_batch * stride_q_b) +\n (spar_row * block_size_sparsity * stride_q_m) +\n (pid_row * BLOCK_SIZE_TRITON * stride_q_m) +\n (i_dim * stride_q_n))\n q_offsets = (tl.arange(0, BLOCK_SIZE_TRITON) * stride_q_m)\n q_mask = (spar_row * block_size_sparsity * stride_q_m +\n pid_row * BLOCK_SIZE_TRITON * stride_q_m +\n q_offsets < m_q * stride_q_m) & (i_dim < n_q)\n q_values = tl.load(t_q + q_index + q_offsets, mask=q_mask)\n\n output_index = ((spar_batch * stride_output_b) +\n (i_dim * stride_output_n))\n output_offsets = (info_values * stride_output_m)\n output_mask = (info_values < m_out) & (i_dim < n_out)\n final_mask = (output_index + output_offsets <\n b_out * stride_output_b +\n m_out * stride_output_m +\n n_out * stride_output_n)\n tl.atomic_add(t_output + output_index + output_offsets, grad_values * q_values,\n mask=output_mask & final_mask)\n\n\nclass _RelativeInformationInjection(torch.autograd.Function):\n bst = 32\n\n @staticmethod\n def forward(ctx, q, emb, info, sparsity_layout, block_size_sparsity):\n t_q = compact(q)\n t_emb = compact(emb)\n t_info = compact(info)\n t_sparsity_layout = compact(sparsity_layout)\n output = torch.zeros_like(t_info, dtype=torch.float)\n\n idxs_batch_sparsity, idxs_row_sparsity, idxs_col_sparsity = t_sparsity_layout.nonzero(as_tuple=True)\n\n b_info, m_info, n_info = t_info.shape\n b_q, m_q, n_q = t_q.shape\n b_emb, m_emb, n_emb = t_emb.shape\n b_out, m_out, n_out = output.shape\n\n triton_grid = lambda meta: [b_info,\n triton.cdiv(m_info, meta[\"BLOCK_SIZE_TRITON\"]),\n triton.cdiv(n_info, meta[\"BLOCK_SIZE_TRITON\"])]\n\n ctx.save_for_backward(t_q, t_emb, t_info, t_sparsity_layout)\n ctx.size_q = q.size()\n ctx.size_emb = emb.size()\n ctx.block_size_sparsity = block_size_sparsity\n ctx.triton_grid = triton_grid\n\n _relative_information_injection_kernel_forward[triton_grid](t_q, t_emb, t_info, output,\n m_q, n_q,\n m_emb, n_emb,\n m_info, n_info,\n b_out, m_out, n_out,\n idxs_batch_sparsity, idxs_row_sparsity,\n t_q.stride(0), t_q.stride(1), t_q.stride(2),\n t_emb.stride(0), t_emb.stride(1), t_emb.stride(2),\n t_info.stride(0), t_info.stride(1),\n t_info.stride(2),\n output.stride(0), output.stride(1),\n output.stride(2),\n block_size_sparsity,\n BLOCK_SIZE_TRITON=_RelativeInformationInjection.bst)\n\n output = decompact(output, info.size())\n\n return output\n\n @staticmethod\n def backward(ctx, grad_output):\n prt_gradient = grad_output.contiguous()\n t_q, t_emb, t_info, t_sparsity_layout = ctx.saved_tensors\n size_q = ctx.size_q\n size_emb = ctx.size_emb\n block_size_sparsity = ctx.block_size_sparsity\n triton_grid = ctx.triton_grid\n\n b_grad, m_grad, n_grad = prt_gradient.shape\n b_q, m_q, n_q = t_q.shape\n b_emb, m_emb, n_emb = t_emb.shape\n b_info, m_info, n_info = t_info.shape\n\n idxs_batch_sparsity, idxs_row_sparsity, idxs_col_sparsity = t_sparsity_layout.nonzero(as_tuple=True)\n\n grad_q = torch.zeros_like(t_q, dtype=torch.float)\n b_out, m_out, n_out = grad_q.shape\n _relative_information_injection_kernel_backward_q[triton_grid](prt_gradient, t_emb, t_info, grad_q,\n m_grad, n_grad,\n m_emb, n_emb,\n m_info, n_info,\n b_out, m_out, n_out,\n idxs_batch_sparsity, idxs_row_sparsity,\n prt_gradient.stride(0), prt_gradient.stride(1),\n prt_gradient.stride(2),\n t_emb.stride(0), t_emb.stride(1),\n t_emb.stride(2),\n t_info.stride(0), t_info.stride(1),\n t_info.stride(2),\n grad_q.stride(0), grad_q.stride(1),\n grad_q.stride(2),\n block_size_sparsity,\n BLOCK_SIZE_TRITON=_RelativeInformationInjection.bst)\n grad_q = decompact(grad_q, size_q)\n\n grad_emb = torch.zeros_like(t_emb, dtype=torch.float)\n b_out, m_out, n_out = grad_emb.shape\n _relative_information_injection_kernel_backward_emb[triton_grid](prt_gradient, t_q, t_info, grad_emb,\n m_grad, n_grad,\n m_q, n_q,\n m_info, n_info,\n b_out, m_out, n_out,\n idxs_batch_sparsity, idxs_row_sparsity,\n prt_gradient.stride(0), prt_gradient.stride(1),\n prt_gradient.stride(2),\n t_q.stride(0), t_q.stride(1),\n t_q.stride(2),\n t_info.stride(0), t_info.stride(1),\n t_info.stride(2),\n grad_emb.stride(0), grad_emb.stride(1),\n grad_emb.stride(2),\n block_size_sparsity,\n BLOCK_SIZE_TRITON=_RelativeInformationInjection.bst)\n grad_emb = decompact(grad_emb, size_emb)\n\n return grad_q, grad_emb, None, None, None, None\n", - "description_1": "Use triton language to implement three kernels for a neural network operation. The forward kernel computes interactions between a query tensor and an embedding tensor using sparsity information. It requires 24 parameters, including tensors for query, embedding, information, and output; dimensions for these tensors; strides for each tensor; and constants for block sizes. The first backward kernel computes gradients of the query tensor. It takes 23 parameters similar to the forward kernel. The second backward kernel computes gradients of the embedding tensor, also requiring 23 parameters with a focus on embedding and query dimensions.", - "description_2": "Use triton language to create a custom operation with three parts: a forward kernel for tensor interaction with sparsity; a backward kernel for gradient computation of the query tensor; and another backward kernel for the embedding tensor, using respective tensor strides and block size constants.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\ndef build_distribution_layout(indices: torch.Tensor, sparsity_layout_indices: torch.Tensor,\n size_target: torch.Size, sparsity_block_size: int, triton_block_size: int = None) -> torch.Tensor:\n \"\"\"Builds the sparsity layout of either the source of a gather or the target of a scatter operation.\"\"\"\n sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()\n\n output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,\n dtype=torch.bool, device=indices.device)\n\n i_b, i_r, i_c = indices.size()\n i_b_s, i_r_s, i_c_s = indices.stride()\n s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()\n s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()\n s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()\n s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [i_b,\n triton.cdiv(i_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(i_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_distribution_layout[triton_grid]\n (indices,\n i_b, i_b_s, i_r_s, i_c_s,\n sparsity_layout_indices,\n s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,\n sparsity_lut_i,\n s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,\n output,\n o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,\n sparsity_block_size,\n triton_block_size))\n\n return output\n\n@triton.jit\ndef kernel_distribution_layout(i,\n i_b, i_b_s, i_r_s, i_c_s,\n s_l_i,\n s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,\n s_lut_i,\n s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,\n o,\n o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch, row, and column index\n spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)\n spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)\n spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)\n\n spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)\n spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)\n spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)\n\n blk_i_idx = (pid_blk * i_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])\n blk_i_msk = (blk_i_idx < i_b * i_b_s)\n blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)\n\n blk_i = blk_i // sparsity_block_size\n blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)\n\n blk_o_idx = ((spa_bat_i * o_b_s) +\n (spa_row_i * o_r_s) +\n (blk_i * o_c_s))\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)\n", - "description_1": "Use triton language to implement a kernel that builds the sparsity layout for a gather or scatter operation. The kernel takes 20 parameters: input tensor, its batch size, strides, sparsity layout indices, their sizes and strides, sparsity lookup table indices, their sizes and strides, output tensor, its sizes and strides, sparsity block size, and TRITON_BLOCK_SIZE. The kernel calculates the position of the current sparsity block and updates the output tensor accordingly.", - "description_2": "Use triton language to create a kernel that processes block-sparse indices and updates an output tensor based on sparsity layout and block size.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom torch import Tensor\nfrom triton import language as tl\n\ndef build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:\n validate_dimensions(x)\n validate_contiguous(x)\n validate_device(x)\n\n output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,\n dtype=torch.bool, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n validate_triton_block_size(triton_block_size, sparsity_block_size)\n\n triton_grid = lambda meta: [x_b,\n triton.cdiv(x_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(x_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_sparsity_layout[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_block_size,\n triton_block_size))\n\n return output\n\n@triton.jit\ndef kernel_sparsity_layout(x,\n x_b, x_b_s, x_r_s, x_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_bat = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n blk_x_idx = (pid_bat * x_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:\n blk_o_idx = (pid_bat * o_b_s +\n (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +\n ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, 1, mask=blk_o_msk)\n\ndef build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,\n sparsity_block_size_from: int, sparsity_block_size_to: int,\n triton_block_size: int = None) -> Tensor:\n validate_dimensions(x)\n validate_contiguous(x, sparsity_layout_from)\n validate_device(x)\n validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))\n validate_sparsity_block_size(sparsity_block_size_from, x)\n validate_sparsity_block_size(sparsity_block_size_to)\n min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)\n validate_triton_block_size(triton_block_size, min_sparsity_block_size)\n\n sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()\n\n validate_contiguous(sparsity_layout_from, sparsity_lut)\n\n o_b = sparsity_layout_from.size(0)\n o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)\n o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)\n\n output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_r, s_lut_c = sparsity_lut.size()\n s_lut_r_s, s_lut_c_s = sparsity_lut.stride()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size_from)\n\n triton_grid = lambda meta: [x_b,\n triton.cdiv(x_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(x_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_sparsity_layout_adaption[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_block_size_from,\n sparsity_block_size_to,\n triton_block_size))\n\n return output\n\n@triton.jit\ndef kernel_sparsity_layout_adaption(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_block_size_from,\n sparsity_block_size_to,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)\n spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)\n spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)\n\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:\n blk_o_idx = ((spa_bat * o_b_s) +\n (((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)\n // sparsity_block_size_to) * o_r_s) +\n (((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)\n // sparsity_block_size_to) * o_c_s))\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, 1, mask=blk_o_msk)\n", - "description_1": "Use triton language to define two kernels: `kernel_sparsity_layout` and `kernel_sparsity_layout_adaption`. The first kernel computes the sparsity layout of a dense tensor into a block-sparse format. It takes 11 parameters including input tensor `x`, strides, output tensor `o`, and block sizes. The second kernel adapts a block-sparse tensor to a new sparsity layout given a lookup table. It takes 15 parameters including input tensor `x`, strides, lookup table `s_lut`, output tensor `o`, and block sizes.", - "description_2": "Use triton language to create kernels for computing and adapting the sparsity layout of tensors using specified block sizes and a lookup table.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit\ndef kernel_broadcast_addition(x,\n x_b, x_b_s, x_c_s,\n y,\n y_b, y_b_s, y_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch, row, and column index\n spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)\n spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)\n\n spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)\n spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)\n\n spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)\n spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)\n\n # Load x block\n blk_x_idx = (spa_bat_o * x_b_s +\n ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n # Load y block\n blk_y_idx = (spa_bat_o * y_b_s +\n ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])\n blk_y_msk = (blk_y_idx < y_b * y_b_s)\n blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)\n\n # Compute sum\n blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)\n buf = blk_x + blk_y\n\n # Store result\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, buf, mask=blk_o_msk)\n\ndef broadcast_add(x: torch.Tensor, y: torch.Tensor, sparsity_layout_output: torch.Tensor,\n sparsity_block_size: int, triton_block_size: int = None) -> torch.Tensor:\n x = x.contiguous()\n y = y.contiguous()\n\n sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()\n\n n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()\n\n output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)\n\n x_b, x_c = x.size()\n x_b_s, x_c_s = x.stride()\n y_b, y_c = y.size()\n y_b_s, y_c_s = y.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()\n s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()\n\n if triton_block_size is None:\n triton_block_size = sparsity_block_size\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_broadcast_addition[triton_grid]\n (x,\n x_b, x_b_s, x_c_s,\n y,\n y_b, y_b_s, y_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,\n sparsity_block_size,\n triton_block_size))\n\n return output\n", - "description_1": "Use triton language to implement a kernel that performs block-wise addition of two tensors x and y, based on a sparsity layout. The kernel takes 17 parameters: x, x_b, x_b_s, x_c_s, y, y_b, y_b_s, y_c_s, o, o_b, o_b_s, o_r_s, o_c_s, s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s, sparsity_block_size, and TRITON_BLOCK_SIZE. The function broadcast_add is a wrapper that prepares the input tensors and calls the kernel with the appropriate grid configuration.", - "description_2": "Use triton language to create a kernel for block-wise tensor addition with sparsity, and a wrapper function to set up and invoke the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit\ndef kernel_repeat_interleave(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,\n r_lut_o,\n repeats,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get sparsity index of current output block consisting of its batch, row, and column index\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)\n spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)\n spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)\n\n # Load block\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n for repeat in range(repeats):\n # Get reverse sparsity index\n rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +\n spa_row * s_l_o_r_s +\n spa_col * s_l_o_c_s)\n rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)\n rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)\n\n # Store block\n blk_o_idx = ((rev_idx_spa * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)\n\ndef repeat_interleave(x: torch.Tensor, sparsity_layout: torch.Tensor, repeats: int,\n sparsity_block_size: int, triton_block_size: int = None) -> tuple[torch.Tensor, torch.Tensor]:\n x = x.contiguous()\n\n sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()\n\n sparsity_lut = torch.nonzero(sparsity_layout).contiguous()\n\n sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)\n sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *\n (sparsity_layout_output_flat == 1) -\n (1 * (sparsity_layout_output_flat == 0)))\n\n n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()\n\n output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_r, s_lut_c = sparsity_lut.size()\n s_lut_r_s, s_lut_c_s = sparsity_lut.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()\n s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()\n\n if triton_block_size is None:\n triton_block_size = sparsity_block_size # Assuming a function to get block size\n\n triton_grid = lambda meta: [x_b,\n triton.cdiv(x_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(x_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_repeat_interleave[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,\n sparsity_output_reverse_lut,\n repeats,\n triton_block_size))\n\n return output, sparsity_layout_output\n", - "description_1": "Use triton language to implement a kernel that repeats and interleaves a block-sparse tensor. The kernel takes 15 parameters: the input tensor, its dimensions and strides, a sparsity lookup table, the output tensor, its dimensions and strides, a reverse lookup table, the number of repeats, and a block size. The kernel computes the output by loading blocks from the input tensor and storing them in the output tensor according to the sparsity pattern.", - "description_2": "Use triton language to create a kernel for repeating and interleaving block-sparse tensors based on a sparsity pattern and repeat count.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom torch import Tensor\nfrom triton import language as tl\n\ndef row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,\n flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:\n x = x.contiguous()\n\n sparsity_lut = torch.nonzero(sparsity_layout).contiguous()\n\n sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)\n sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)\n sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *\n (sparsity_layout_output_flat == 1) -\n (1 * (sparsity_layout_output_flat == 0)))\n\n n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()\n\n output = torch.zeros(size=(n_sparse_blocks_output,\n sparsity_block_size,\n 1 if flag_slice_only else sparsity_block_size),\n dtype=x.dtype,\n device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_x_r, s_lut_x_c = sparsity_lut.size()\n s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()\n s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()\n\n if triton_block_size is None:\n triton_block_size = sparsity_block_size\n\n triton_grid = lambda meta: [x_b,\n triton.cdiv(x_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(x_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_blocksparse_row_wise_sum[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,\n output,\n o_b, o_b_s, o_r_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s,\n sparsity_reverse_lut_output,\n triton_block_size))\n\n return (output, sparsity_layout_output)\n\n\n@triton.jit\ndef kernel_blocksparse_row_wise_sum(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,\n o,\n o_b, o_b_s, o_r_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s,\n r_lut_o,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch and row index\n spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)\n spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)\n spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)\n spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)\n\n # Load reverse sparsity index for current block\n rev_idx_spa_idx = (spa_bat * s_l_o_b_s +\n spa_row * s_l_o_r_s)\n rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)\n rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)\n\n blk_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_msk = (blk_idx < x_b * x_b_s)\n blk = tl.load(x + blk_idx, mask=blk_msk)\n\n buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))\n\n o_idx = (rev_idx_spa * o_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n (tl.arange(0, 1))[None, :])\n o_msk = (o_idx < o_b * o_b_s)\n tl.atomic_add(o + o_idx, buf, o_msk)\n\ndef row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,\n flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:\n x = x.contiguous()\n\n sparsity_lut = torch.nonzero(sparsity_layout).contiguous()\n\n sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)\n sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)\n sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *\n (sparsity_layout_output_flat == 1) -\n (1 * (sparsity_layout_output_flat == 0)))\n\n n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()\n\n output = torch.full(size=(n_sparse_blocks_output,\n sparsity_block_size,\n 1 if flag_slice_only else sparsity_block_size),\n fill_value=float(\"-inf\"),\n device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_x_r, s_lut_x_c = sparsity_lut.size()\n s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()\n s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()\n\n if triton_block_size is None:\n triton_block_size = sparsity_block_size\n\n triton_grid = lambda meta: [x_b,\n triton.cdiv(x_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(x_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_blocksparse_row_wise_max[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,\n output,\n o_b, o_b_s, o_r_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s,\n sparsity_reverse_lut_output,\n triton_block_size))\n\n return output, sparsity_layout_output\n\n\n@triton.jit\ndef kernel_blocksparse_row_wise_max(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,\n o,\n o_b, o_b_s, o_r_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s,\n r_lut_o,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch and row index\n spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)\n spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)\n spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)\n spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)\n\n # Load reverse sparsity index for current block\n rev_idx_spa_idx = (spa_bat * s_l_o_b_s +\n spa_row * s_l_o_r_s)\n rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)\n rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)\n\n blk_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_msk = (blk_idx < x_b * x_b_s)\n blk = tl.load(x + blk_idx, mask=blk_msk)\n\n buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))\n\n o_idx = (rev_idx_spa * o_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n (tl.arange(0, 1))[None, :])\n o_msk = (o_idx < o_b * o_b_s)\n tl.atomic_max(o + o_idx, buf, o_msk)\n\ndef row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,\n sparsity_block_size: int, triton_block_size: int = None) -> Tensor:\n sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()\n\n sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)\n sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)\n sparsity_reverse_lut_rwm = ((torch.cumsum(sparsity_layout_rwm_flat, dim=-1) - 1) *\n (sparsity_layout_rwm_flat == 1) -\n (1 * (sparsity_layout_rwm_flat == 0)))\n\n output = torch.empty_like(x)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_r, s_lut_c = sparsity_lut.size()\n s_lut_r_s, s_lut_c_s = sparsity_lut.stride()\n y_b, y_r, y_c = y.size()\n y_b_s, y_r_s, y_c_s = y.stride()\n s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()\n s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = sparsity_block_size\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (kernel_blocksparse_row_wise_add[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n y, y_b, y_b_s, y_r_s, y_c_s,\n s_l_y_b, s_l_y_b_s, s_l_y_r_s,\n sparsity_reverse_lut_rwm,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n triton_block_size\n ))\n\n return output\n\n\n@triton.jit\ndef kernel_blocksparse_row_wise_add(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n y, y_b, y_b_s, y_r_s, y_c_s,\n s_l_y_b, s_l_y_b_s, s_l_y_r_s,\n r_lut_y,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch and row index\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n # Get reverse sparsity indices for s\n rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +\n spa_row * s_l_y_r_s)\n rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)\n rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)\n\n if rev_idx_spa_s == -1:\n assert False, \"Invalid sparsity block\"\n\n # Load x block\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n # Load sum block\n blk_s_idx = (rev_idx_spa_s * y_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +\n (tl.arange(0, 1) * y_c_s)[None, :])\n blk_s_msk = (blk_s_idx < y_b * y_b_s)\n blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)\n\n # Compute exp\n buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))\n\n # Store block\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, buf, mask=blk_o_msk)\n", - "description_1": "Use triton language to implement three different kernels for block-sparse operations: (1) row_wise_sum computes the sum of each row for a block-sparse tensor, handling the sparsity layout and output in compressed form, (2) row_wise_max computes the max of each row for a block-sparse tensor, handling sparsity similarly to the sum, and (3) row_wise_add adds a single-column sparse block tensor to another block-sparse tensor row-wise, managing sparsity indices appropriately. Each function accepts tensors, sparsity layouts, block sizes, and an optional triton block size, and outputs adjusted tensors.", - "description_2": "Use triton language to implement kernels that compute the row-wise sum, max, and element-wise addition for block-sparse tensors, each function taking specific tensor inputs and block sizes, handling the sparsity layout efficiently, and producing adjusted tensor outputs.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\nclass _BlocksparseToDense(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: torch.Tensor,\n sparsity_layout: torch.Tensor, sparsity_reverse_lut: torch.Tensor,\n sparsity_block_size: int, fill_value: float,\n triton_block_size: int) -> torch.Tensor:\n output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,\n sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.shape\n x_b_s, x_r_s, x_c_s = x.stride()\n s_l_b, s_l_r, s_l_c = sparsity_layout.size()\n s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,\n sparsity_reverse_lut,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_block_size,\n triton_block_size))\n\n ctx.save_for_backward(sparsity_layout)\n ctx.sparsity_block_size = sparsity_block_size\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_to_dense(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,\n sparsity_reverse_lut,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size\n spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size\n\n rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)\n rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)\n rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)\n\n if rev_idx_spa >= 0:\n blk_idx = (rev_idx_spa * x_b_s +\n (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_msk = (blk_idx < x_b * x_b_s)\n blk = tl.load(x + blk_idx, mask=blk_msk)\n\n o_idx = (pid_blk * o_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n o_msk = (o_idx < o_b * o_b_s)\n tl.store(o + o_idx, blk, o_msk)\n\nclass _BlocksparseToSparse(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: torch.Tensor,\n sparsity_layout: torch.Tensor, sparsity_lut: torch.Tensor,\n sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> torch.Tensor:\n output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_r, s_lut_c = sparsity_lut.size()\n s_lut_r_s, s_lut_c_s = sparsity_lut.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]\n (x, x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n output, o_b_s, o_r_s, o_c_s,\n sparsity_block_size,\n triton_block_size))\n\n ctx.save_for_backward(sparsity_layout)\n ctx.sparsity_block_size = sparsity_block_size\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_to_sparse(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n o,\n o_b_s, o_r_s, o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)\n spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)\n spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)\n\n blk_d_idx = (spa_bat * x_b_s +\n ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_d_msk = (blk_d_idx < x_b * x_b_s)\n blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)\n\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])\n blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)\n tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)\n\nclass _BlocksparseAdaptLayout(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: torch.Tensor,\n sparsity_layout_from: torch.Tensor, sparsity_reverse_lut_from: torch.Tensor, sparsity_block_size_from: int,\n sparsity_layout_to: torch.Tensor, sparsity_lut_to: torch.Tensor, sparsity_block_size_to: int,\n n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> torch.Tensor:\n output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()\n s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()\n s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(min_sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,\n sparsity_reverse_lut_from,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,\n sparsity_block_size_from,\n sparsity_block_size_to,\n triton_block_size))\n\n ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)\n ctx.sparsity_block_size_from = sparsity_block_size_from\n ctx.sparsity_block_size_to = sparsity_block_size_to\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n @triton.jit\n def kernel_adapt_layout(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,\n r_lut_x,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,\n sparsity_block_size_from,\n sparsity_block_size_to,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)\n spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)\n\n spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)\n spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)\n\n spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)\n spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)\n\n spa_bat_x = spa_bat_o\n spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from\n spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from\n\n rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +\n spa_row_x * s_l_x_r_s +\n spa_col_x * s_l_x_c_s)\n rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)\n rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)\n\n if rev_idx_spa_x >= 0:\n shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)\n % sparsity_block_size_from) // TRITON_BLOCK_SIZE\n shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)\n % sparsity_block_size_from) // TRITON_BLOCK_SIZE\n\n blk_x_idx = ((rev_idx_spa_x * x_b_s) +\n ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)\n", - "description_1": "Use triton language to implement a kernel that converts block-sparse tensors between compressed and regular forms, and adapts the sparsity layout. The functions include three kernels: 'kernel_blocksparse_to_dense' for converting compressed to dense format, 'kernel_blocksparse_to_sparse' for converting dense to compressed format, and 'kernel_adapt_layout' for adapting the sparsity layout. Each function uses parameters for tensor data, sparsity layout, block size, and output storage.", - "description_2": "Use triton language to implement kernels for converting block-sparse tensors and adapting sparsity layouts, including parameters for tensor manipulation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom torch import Tensor\nfrom triton import language as tl\nfrom blksprs.utils.tools import get_triton_block_size\n\nclass _BlocksparseGather(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,\n i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,\n sparsity_block_size: int, triton_block_size: int = None) -> Tensor:\n output = torch.empty_like(i, dtype=x.dtype)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()\n s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()\n i_b, i_r, i_c = i.size()\n i_b_s, i_r_s, i_c_s = i.stride()\n s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()\n s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,\n sparsity_reverse_lut_x,\n i,\n i_b, i_b_s, i_r_s, i_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,\n sparsity_block_size,\n triton_block_size))\n\n ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)\n ctx.sparsity_block_size = sparsity_block_size\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_gather(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,\n r_lut_x,\n i,\n i_b, i_b_s, i_r_s, i_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch, row, and column index\n spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)\n spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)\n\n spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)\n spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)\n\n # Load index values\n blk_i_idx = ((pid_blk * i_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])\n blk_i_msk = (blk_i_idx < i_b * i_b_s)\n blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)\n\n # Get positions of sparsity blocks\n pos_spa_blk_x = blk_i // sparsity_block_size\n pos_spa_col_x = blk_i % sparsity_block_size\n\n # Load reverse sparsity indices for x\n rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +\n (spa_row_o * s_l_x_r_s) +\n (pos_spa_blk_x * s_l_x_c_s))\n rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)\n rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)\n\n # Load x values\n blk_x_idx = ((rev_idx_spa_x * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n (pos_spa_col_x * x_c_s))\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n # Store output\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)\n\nclass _BlocksparseScatterReduce(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,\n i: Tensor,\n sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,\n sparsity_block_size: int, n_sparse_blocks: int,\n reduce_op: str, triton_block_size: int) -> Tensor:\n output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()\n s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()\n i_b, i_r, i_c = i.size()\n i_b_s, i_r_s, i_c_s = i.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()\n s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [x_b,\n triton.cdiv(x_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(x_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n reduce_op_ind = 0\n if reduce_op == \"sum\":\n reduce_op_ind = 1\n\n (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,\n i,\n i_b, i_b_s, i_r_s, i_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,\n sparsity_reverse_lut_o,\n reduce_op_ind,\n sparsity_block_size,\n triton_block_size))\n\n ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)\n ctx.sparsity_block_size = sparsity_block_size\n ctx.reduce_op = reduce_op\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n def backward(ctx, grad_output):\n sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors\n sparsity_block_size = ctx.sparsity_block_size\n reduce_op = ctx.reduce_op\n triton_block_size = ctx.triton_block_size\n\n if reduce_op == \"sum\":\n return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,\n triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None\n else:\n raise ValueError(f\"Reduction operation '{reduce_op}' does not support backward pass\")\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_scatter(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,\n i,\n i_b, i_b_s, i_r_s, i_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,\n r_lut_o,\n reduce_op_ind,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch, row, and column index\n spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)\n spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)\n spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)\n\n spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)\n spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)\n spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)\n\n # Load x values\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n # Load index values\n blk_i_idx = ((pid_blk * i_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])\n blk_i_msk = (blk_i_idx < i_b * i_b_s)\n blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)\n\n # Get positions of sparsity blocks\n pos_spa_blk_o = blk_i // sparsity_block_size\n pos_spa_col_o = blk_i % sparsity_block_size\n\n # Load reverse sparsity indices for o\n rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +\n (spa_row_x * s_l_o_r_s) +\n (pos_spa_blk_o * s_l_o_c_s))\n rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)\n rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)\n\n # Store output\n blk_o_idx = ((rev_idx_spa_o * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n (pos_spa_col_o * o_c_s))\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n\n if reduce_op_ind == 0:\n tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)\n elif reduce_op_ind == 1:\n tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)\n", - "description_1": "Use triton language to create two kernels: a blocksparse gather and a blocksparse scatter reduce. The gather kernel requires parameters for input tensor x, indices tensor i, output tensor o, and related sparsity information. The scatter reduce kernel needs similar parameters with a focus on reducing operations with either 'none' or 'sum'.", - "description_2": "Use triton language to implement a blocksparse gather operation and a blocksparse scatter reduce operation, including necessary sparsity layouts and tensor manipulations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom torch import Tensor\nfrom triton import language as tl\nfrom blksprs.utils.tools import get_triton_block_size\n\ndef exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:\n x = x.contiguous()\n return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)\n\nclass _BlocksparseExp(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:\n output = torch.empty_like(x)\n x_b, x_r, x_c = x.shape\n x_b_s, x_r_s, x_c_s = x.stride()\n o_b, o_r, o_c = output.shape\n o_b_s, o_r_s, o_c_s = output.stride()\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n (_BlocksparseExp.kernel_blocksparse_exp[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n triton_block_size))\n ctx.save_for_backward(output)\n return output\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_exp(x,\n x_b, x_b_s, x_r_s, x_c_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n buf = tl.exp(blk_x)\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, buf, mask=blk_o_msk)\n", - "description_1": "Use triton language to implement a block-sparse exponential function for tensors. The function requires three main components: 1) A Python function 'exp' that sets up input tensor parameters and calls a custom torch autograd function. 2) An autograd function '_BlocksparseExp' with a 'forward' method that prepares tensor metadata and invokes the triton kernel. 3) A triton kernel 'kernel_blocksparse_exp' that performs element-wise exponential operations on block indices, using masks to handle tensor boundaries. The exp function handles tensors in compressed sparse row format, facilitating computation on large, sparse datasets.", - "description_2": "Use triton language to define a block-sparse matrix operation that computes element-wise exponentials efficiently on large tensors, leveraging triton kernels for GPU acceleration in PyTorch.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom torch import Tensor\nfrom blksprs.utils.tools import get_triton_block_size\n\nclass _BlocksparseMatmulSSS(torch.autograd.Function):\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_matmul_sss(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,\n r_lut_x,\n y,\n y_b, y_b_s, y_r_s, y_c_s,\n s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,\n r_lut_y,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n s_lut_o,\n s_lut_o_r, s_lut_o_r_s,\n s_lut_o_c_s,\n sparsity_block_size,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get position of current sparsity block consisting of its batch, row, and column index\n spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)\n spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)\n\n spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)\n spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)\n\n spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)\n spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)\n spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)\n\n # Setup buffer\n buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)\n\n # Slide over triton block sized segments of input tensors\n for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):\n # Convert to segment index of sparsity layout\n i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size\n # Calculate the triton segment index within a block\n i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)\n\n # Get reverse sparsity indices for input tensors x and y\n # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor\n\n # Get reverse sparsity indices for x\n rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +\n spa_row_o * s_l_x_r_s +\n i_seg_spa * s_l_x_c_s)\n rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)\n rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)\n\n # Get reverse sparsity indices for y\n rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)\n rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)\n rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)\n\n # If both blocks are present commence calculation\n if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:\n blk_x_idx = ((rev_idx_spa_x * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((i_seg_tri_mod * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n blk_y_idx = ((rev_idx_spa_y * y_b_s) +\n ((i_seg_tri_mod * TRITON_BLOCK_SIZE +\n tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])\n blk_y_msk = (blk_y_idx < y_b * y_b_s)\n blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)\n\n # Perform matrix multiplication\n buf += tl.dot(blk_x, blk_y, input_precision=\"tf32\")\n\n # Store output\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, buf, mask=blk_o_msk)\n\n @staticmethod\n def forward(ctx, x: Tensor, y: Tensor,\n sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,\n sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,\n sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,\n sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:\n output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()\n s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()\n y_b, y_r, y_c = y.size()\n y_b_s, y_r_s, y_c_s = y.stride()\n s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()\n s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()\n s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,\n sparsity_reverse_lut_x,\n y,\n y_b, y_b_s, y_r_s, y_c_s,\n s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,\n sparsity_reverse_lut_y,\n output,\n o_b, o_b_s, o_r_s, o_c_s,\n sparsity_lut_o,\n s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,\n sparsity_block_size,\n triton_block_size))\n\n ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)\n ctx.sparsity_block_size = sparsity_block_size\n ctx.triton_block_size = triton_block_size\n\n return output\n", - "description_1": "Use triton language to implement a kernel for performing block-sparse matrix multiplication. The kernel, decorated with @triton.jit, requires parameters for input tensors x and y, their corresponding sizes and strides, sparsity layouts and reverse lookup tables, an output tensor, its size and stride, lookup table for sparsity, sparsity block size, and a constexpr for TRITON_BLOCK_SIZE.", - "description_2": "Use triton language to execute a kernel that computes block-sparse matrix multiplication using parameters such as input tensors, sparsity layouts, and block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\ndef softmax(x: torch.Tensor, sparsity_layout: torch.Tensor, sparsity_block_size: int, triton_block_size: int = None) -> torch.Tensor:\n x = x.contiguous()\n sparsity_lut = torch.nonzero(sparsity_layout).contiguous()\n\n sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)\n sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)\n sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *\n (sparsity_layout_rws_flat == 1) -\n (1 * (sparsity_layout_rws_flat == 0)))\n\n return _BlocksparseSoftmax.apply(x, sparsity_layout,\n sparsity_lut,\n sparsity_reverse_lut_rws,\n sparsity_block_size, triton_block_size)\n\n\nclass _BlocksparseSoftmax(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: torch.Tensor, sparsity_layout: torch.Tensor,\n sparsity_lut: torch.Tensor,\n sparsity_reverse_lut_rws: torch.Tensor,\n sparsity_block_size: int, triton_block_size: int) -> torch.Tensor:\n output = torch.empty_like(x)\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_lut_r, s_lut_c = sparsity_lut.size()\n s_lut_r_s, s_lut_c_s = sparsity_lut.stride()\n o_b, o_r, o_c = output.size()\n\n x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,\n flag_slice_only=True,\n triton_block_size=triton_block_size)\n x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)\n x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)\n x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,\n flag_slice_only=True,\n triton_block_size=triton_block_size)\n\n s_b, s_r, s_c = x_exp_row_wise_sum.shape\n s_b_s, s_r_s, s_c_s = x_exp_row_wise_sum.stride()\n s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape\n s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_rws.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]\n (x_exp,\n x_b, x_b_s, x_r_s, x_c_s,\n sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,\n s_l_s_b, s_l_s_b_s, s_l_s_r_s,\n sparsity_reverse_lut_rws,\n output,\n triton_block_size))\n\n ctx.save_for_backward(output, sparsity_layout, sparsity_lut)\n ctx.sparsity_block_size = sparsity_block_size\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_softmax(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n s, s_b, s_b_s, s_r_s, s_c_s,\n s_l_s_b, s_l_s_b_s, s_l_s_r_s,\n r_lut_s,\n o,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +\n spa_row * s_l_s_r_s)\n rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)\n rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)\n\n if rev_idx_spa_s == -1:\n assert False, \"Invalid sparsity block\"\n\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n blk_s_idx = (rev_idx_spa_s * s_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +\n (tl.arange(0, 1) * s_c_s)[None, :])\n blk_s_msk = (blk_s_idx < s_b * s_b_s)\n blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)\n\n buf = tl.div_rn(blk_x, blk_s)\n tl.store(o + blk_x_idx, buf, mask=blk_x_msk)\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_softmax_grad_x(g,\n g_b, g_b_s, g_r_s, g_c_s,\n x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n s,\n s_b, s_b_s, s_r_s, s_c_s,\n s_l_s_b, s_l_s_b_s, s_l_s_r_s,\n r_lut_s,\n o,\n o_b, o_b_s, o_r_s, o_c_s,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +\n spa_row * s_l_s_r_s)\n rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)\n rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)\n\n blk_s_idx = (rev_idx_spa_s * s_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +\n (tl.arange(0, 1) * s_c_s)[None, :])\n blk_s_msk = (blk_s_idx < s_b * s_b_s)\n blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)\n\n blk_g_idx = ((pid_blk * g_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])\n blk_g_msk = (blk_g_idx < g_b * g_b_s)\n blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)\n\n blk_x_idx = ((pid_blk * x_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n buf = blk_x * (blk_g - blk_s)\n\n blk_o_idx = ((pid_blk * o_b_s) +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, buf, mask=blk_o_msk)\n", - "description_1": "Use triton language to implement a block-sparse softmax operator. The main operator and its backward pass kernels take tensors for the input, sparsity layout, lookup tables, and block sizes. The forward pass divides input blocks by sum blocks, while the backward computes gradient adjustments.", - "description_2": "Use triton language to create kernels for block-sparse softmax and its gradient calculation, handling tensors with specific sparsity patterns and block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom torch import Tensor\nfrom triton import language as tl\n\nclass _BlocksparseTranspose(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x: Tensor,\n sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,\n n_sparse_blocks: int, triton_block_size: int) -> Tensor:\n output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),\n dtype=x.dtype, device=x.device)\n\n x_b, x_r, x_c = x.size()\n x_b_s, x_r_s, x_c_s = x.stride()\n s_l_b, s_l_r, s_l_c = sparsity_layout.size()\n s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()\n s_lut_r, s_lut_c = sparsity_lut.shape\n s_lut_r_s, s_lut_c_s = sparsity_lut.stride()\n o_b, o_r, o_c = output.size()\n o_b_s, o_r_s, o_c_s = output.stride()\n\n if triton_block_size is None:\n triton_block_size = get_triton_block_size(sparsity_block_size)\n\n triton_grid = lambda meta: [o_b,\n triton.cdiv(o_r, meta[\"TRITON_BLOCK_SIZE\"]),\n triton.cdiv(o_c, meta[\"TRITON_BLOCK_SIZE\"])]\n\n (_BlocksparseTranspose.kernel_blocksparse_transpose[triton_grid]\n (x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,\n sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n sparsity_reverse_lut,\n output,\n o_b, o_b_s,\n triton_block_size))\n\n # Save for backward pass\n ctx.save_for_backward(sparsity_layout)\n ctx.sparsity_layout = sparsity_layout\n ctx.sparsity_block_size = sparsity_block_size\n ctx.triton_block_size = triton_block_size\n\n return output\n\n @staticmethod\n @triton.jit\n def kernel_blocksparse_transpose(x,\n x_b, x_b_s, x_r_s, x_c_s,\n s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,\n s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,\n r_lut,\n o,\n o_b, o_b_s,\n TRITON_BLOCK_SIZE: tl.constexpr) -> None:\n # Get triton block indices\n pid_blk = tl.program_id(axis=0)\n pid_row = tl.program_id(axis=1)\n pid_col = tl.program_id(axis=2)\n\n # Get sparsity index of current output block consisting of its batch, row, and column index\n spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)\n spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)\n spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)\n\n spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)\n spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)\n spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)\n\n spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)\n spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)\n spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)\n\n # Get reverse sparsity index\n rev_idx_spa_idx = (spa_bat * s_l_b_s +\n spa_row * s_l_r_s +\n spa_col * s_l_c_s)\n rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)\n rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)\n\n if rev_idx_spa == -1:\n assert False, \"Invalid sparsity block\"\n\n blk_x_idx = (rev_idx_spa * x_b_s +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_x_msk = (blk_x_idx < x_b * x_b_s)\n blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)\n\n blk_x_t = tl.trans(blk_x)\n\n blk_o_idx = (pid_blk * o_b_s +\n ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +\n ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])\n blk_o_msk = (blk_o_idx < o_b * o_b_s)\n tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)\n", - "description_1": "Use triton language to implement a kernel function 'kernel_blocksparse_transpose' that transposes blocks of a sparse tensor. The kernel takes 16 parameters: the input tensor 'x', its batch size 'x_b', and strides 'x_b_s', 'x_r_s', 'x_c_s', the sparsity layout dimensions 's_l_b', 's_l_b_s', 's_l_r_s', 's_l_c_s', the sparsity lookup table 's_lut', its dimensions 's_lut_r', 's_lut_r_s', 's_lut_c_s', the reverse lookup table 'r_lut', the output tensor 'o', its batch size 'o_b', and stride 'o_b_s'. The kernel uses a constant 'TRITON_BLOCK_SIZE' to determine block sizes for processing.", - "description_2": "Use triton language to create a function 'forward' that prepares and calls the 'kernel_blocksparse_transpose' kernel. It takes 7 parameters: the input tensor 'x', the sparsity layout 'sparsity_layout', the sparsity lookup table 'sparsity_lut', the reverse lookup table 'sparsity_reverse_lut', the sparsity block size 'sparsity_block_size', the number of sparse blocks 'n_sparse_blocks', and the triton block size 'triton_block_size'. The function initializes an output tensor and calculates grid dimensions for the kernel launch.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch import empty_strided_cuda\nfrom torch._inductor.runtime.triton_heuristics import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.utils import maybe_profile\nfrom torch._inductor.codegen.memory_planning import _align as align\nfrom torch._inductor.hooks import run_intermediate_hooks\nfrom torch._inductor.codecache import AsyncCompile\nfrom torch._inductor.select_algorithm import extern_kernels\nimport time\n\n# Triton kernel for a fused clone operation\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK: tl.constexpr, XBLOCK: tl.constexpr):\n ynumel = 93161984\n xnumel = 4\n yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK\n yindex = yoffset + tl.arange(0, YBLOCK)[None, :]\n ymask = yindex < ynumel\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n x2 = xindex\n y0 = yindex % 128\n y1 = (yindex // 128)\n y3 = yindex\n tmp0 = tl.load(in_ptr0 + (y0 + (128 * x2) + (512 * y1)), xmask & ymask, eviction_policy='evict_last')\n tl.store(out_ptr0 + (x2 + (4 * y3)), tmp0, xmask & ymask)\n\n# Function to call the Triton kernel\ndef call(args):\n arg0_1, arg1_1, arg2_1 = args\n # args.clear()\n from torch._C._dynamo.guards import assert_size_stride\n assert_size_stride(arg0_1, (2, 4), (4, 1))\n assert_size_stride(arg1_1, (2,), (1,))\n assert_size_stride(arg2_1, (727828, 512), (512, 1))\n\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided_cuda((727828, 128, 4), (512, 4, 1), torch.float32)\n stream0 = get_raw_stream(0)\n triton_poi_fused_clone_0.run(arg2_1, buf0, 93161984, 4, grid=grid(93161984, 4), stream=stream0)\n\n return (buf0,)\n", - "description_1": "Use triton language to create a pointwise kernel for a fused clone operation that processes a large array. The kernel is designed to load elements from a source array into a destination array using 2D grid dimensions and processes elements in blocks defined by YBLOCK and XBLOCK constants. The function 'call' serves as a wrapper that initializes CUDA streams, handles device settings, and validates the input tensor's size and strides before invoking the kernel.", - "description_2": "Use triton language to implement a CUDA-accelerated fused clone operation that processes a large tensor using a block-based approach, with a wrapper function managing CUDA resources and input validation.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.runtime.triton_heuristics import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nimport torch\nfrom torch import empty_strided_cuda\nfrom torch._C._dynamo.guards import assert_size_stride\n\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK: tl.constexpr, XBLOCK: tl.constexpr):\n ynumel = 93161984\n xnumel = 4\n yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK\n yindex = yoffset + tl.arange(0, YBLOCK)[None, :]\n ymask = yindex < ynumel\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n if (tl.program_id(0) == 0 and tl.program_id(1) == 45488) and tl.program_id(2) == 0:\n tl.device_print(\"==>debug: yoffset: \", yoffset)\n tl.device_print(\"==>debug: yindex: \", yindex)\n tl.device_print(\"==>debug: xoffset: \", xoffset)\n tl.device_print(\"==>debug: xindex: \", xindex)\n x2 = xindex\n y0 = yindex % 128\n y1 = (yindex // 128)\n y3 = yindex\n tmp0 = tl.load(in_ptr0 + (y0 + (128*x2) + (512*y1)), xmask & ymask, eviction_policy='evict_last')\n tl.store(out_ptr0 + (x2 + (4*y3)), tmp0, xmask & ymask)\n\ndef call(args):\n arg0_1, = args\n args.clear()\n assert_size_stride(arg0_1, (727828, 512), (512, 1))\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided_cuda((727828, 128, 4), (512, 4, 1), torch.float32)\n stream0 = get_raw_stream(0)\n triton_poi_fused_clone_0.run(arg0_1, buf0, 93161984, 4, grid=grid(93161984, 4), stream=stream0)\n del arg0_1\n return (buf0, )\n", - "description_1": "Use triton language to implement a kernel function 'triton_' that takes 6 parameters: two pointers (in_ptr0, out_ptr0) for input and output data, two integers (ynumel, xnumel) representing the number of elements in y and x dimensions, and two constexpr integers (YBLOCK, XBLOCK) for block sizes. The kernel computes indices and masks for loading and storing data, and includes debug prints for specific program IDs. The 'call' function prepares input data, sets up the CUDA device and stream, and runs the kernel with specified grid dimensions.", - "description_2": "Use triton language to implement a kernel for data manipulation with index computation and conditional debug printing, and a Python function to execute this kernel on CUDA.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK: tl.constexpr, XBLOCK: tl.constexpr):\n ynumel = 67108864\n xnumel = 4\n yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK\n yindex = yoffset + tl.arange(0, YBLOCK)[None, :]\n ymask = yindex < ynumel\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n x2 = xindex\n y0 = yindex % 128\n y1 = (yindex // 128)\n y3 = yindex\n tmp0 = tl.load(in_ptr0 + (y0 + (128 * x2) + (512 * y1)), xmask, eviction_policy='evict_last')\n tl.store(out_ptr0 + (x2 + (4 * y3)), tmp0, xmask)\n\ndef run_triton(x, y):\n x = torch.randn((67108864, 4), device=\"cuda\")\n y = torch.empty((32768, 4), dtype=torch.float32, device='cuda')\n triton_(x, y, YBLOCK=67108864, XBLOCK=4)\n return y\n", - "description_1": "Use triton language to define a kernel `triton_` with 6 parameters: `in_ptr0` (input tensor pointer), `out_ptr0` (output tensor pointer), `ynumel` (number of y elements), `xnumel` (number of x elements), `YBLOCK` (block size in y dimension as a compile-time constant), and `XBLOCK` (block size in x dimension as a compile-time constant). The kernel computes indices for y and x dimensions, applies masks for valid index ranges, and uses these indices to load data from the input pointer and store results to the output pointer. The function `run_triton` initializes input and output tensors on GPU, sets up tensor shapes, and calls the `triton_` kernel for execution.", - "description_2": "Use triton language to define a kernel and execute it on GPU tensors, handling index computation and data movement between input and output tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel to add elements of two input tensors\n@triton.jit\ndef add_kernel(\n in_ptr0, # Pointer to the first input tensor\n in_ptr1, # Pointer to the second input tensor\n out_ptr, # Pointer to the output tensor\n n_elements, # Number of elements to process\n BLOCK_SIZE: \"tl.constexpr\", # Size of each block\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\nx = torch.randn(2, 2, device=\"cuda\")\nother = torch.randn(2, 2, device=\"cuda\")\n\n# Function to preprocess inputs and call the Triton kernel\ndef f(x, other):\n y = x.t().contiguous().t() # Transpose and make contiguous\n z = y.sin().t() # Apply sine and transpose\n grid = (z.numel(),)\n out = torch.empty_like(other)\n add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16)\n return out\n\nf_compile = torch.compile(f)\n\nout = f(x, other)\nout_compile = f_compile(x, other)\nprint(out)\nprint(out_compile)\nassert torch.allclose(out_compile, out)\n", - "description_1": "Use triton language to implement a kernel that performs element-wise addition of two input tensors. The kernel takes pointers to the input tensors, a pointer to the output tensor, the number of elements to process, and a block size as parameters. The kernel computes element-wise addition within blocks of the specified size, utilizing a mask to handle boundaries. Use torch to prepare and manipulate the inputs, and call the triton kernel with specified grid size.", - "description_2": "Use triton language to write a kernel for element-wise addition of tensors, handling block processing. Utilize torch for input manipulation and kernel invocation.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\nfrom torch._inductor import triton_helpers\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}}\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')\n _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp1 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp1) & (tmp1 < 32128), \"index out of bounds: 0 <= tmp1 < 32128\")\n tmp2 = tl.load(in_ptr1 + (r1 + (512 * tmp0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = _tmp5 + tmp4\n _tmp5 = tl.where(rmask, tmp6, _tmp5)\n tmp5 = tl.sum(_tmp5, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp8 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp8) & (tmp8 < 32128), \"index out of bounds: 0 <= tmp8 < 32128\")\n tmp9 = tl.load(in_ptr1 + (r1 + (512 * tmp0)), rmask, other=0).to(tl.float32)\n tmp10 = tmp9.to(tl.float32)\n tmp11 = 512.0\n tmp12 = tmp5 / tmp11\n tmp13 = 1e-06\n tmp14 = tmp12 + tmp13\n tmp15 = tl.math.rsqrt(tmp14)\n tmp16 = tmp10 * tmp15\n tmp17 = tmp16.to(tl.float32)\n tmp18 = tmp7 * tmp17\n tl.store(out_ptr1 + (r1 + (512 * x0)), tmp18, rmask)\n\n@pointwise(size_hints=[4194304], meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}})\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 4194304\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 64\n x1 = (xindex // 64) % 2048\n x2 = (xindex // 131072) % 8\n x3 = (xindex // 1048576)\n x4 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (64 * x2) + (512 * x1) + (1048576 * x3)), None).to(tl.float32)\n tl.store(out_ptr0 + (x4), tmp0, None)\n\n@pointwise(size_hints=[2048, 2048], tile_hint=TileHint.SQUARE, meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32', 3: 'i32'}})\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, xnumel, ynumel, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr):\n xnumel = 2048\n ynumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n yoffset = tl.program_id(1) * YBLOCK\n yindex = yoffset + tl.arange(0, YBLOCK)[None, :]\n ymask = yindex < ynumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n y2 = yindex\n x3 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (512 * y2) + (1048576 * x1)), None).to(tl.float32)\n tl.store(out_ptr0 + (y2 + (2048 * x3)), tmp0, None)\n\n@reduction(\n size_hints=[65536, 2048],\n reduction_hint=ReductionHint.INNER,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32', 3: 'i32'}}\n)\n@triton.jit\ndef triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 65536\n rnumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n _tmp2 = tl.full([XBLOCK, RBLOCK], float(\"-inf\"), tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(in_ptr0 + (r1 + (2048 * x0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp1 = tmp0.to(tl.float32)\n tmp3 = triton_helpers.maximum(_tmp2, tmp1)\n _tmp2 = tl.where(rmask, tmp3, _tmp2)\n tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]\n _tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp4 = tl.load(in_ptr0 + (r1 + (2048 * x0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp5 = tmp4.to(tl.float32)\n tmp6 = tmp5 - tmp2\n tmp7 = tl.exp(tmp6)\n tmp9 = _tmp8 + tmp7\n _tmp8 = tl.where(rmask, tmp9, _tmp8)\n tmp8 = tl.sum(_tmp8, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp10 = tl.load(in_ptr0 + (r1 + (2048 * x0)), rmask, other=0).to(tl.float32)\n tmp11 = tmp10.to(tl.float32)\n tmp12 = tmp11 - tmp2\n tmp13 = tl.exp(tmp12)\n tmp14 = tmp13 / tmp8\n tmp15 = tmp14.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (2048 * x0)), tmp15, rmask)\n\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.INNER,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}}\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n RBLOCK: tl.constexpr = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (512 * x0)), rmask, other=0).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (r1 + (512 * x0)), rmask, other=0).to(tl.float32)\n tmp8 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = tl.where(rmask, tmp4, 0)\n tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))\n tmp9 = 512.0\n tmp10 = tmp7 / tmp9\n tmp11 = 1e-06\n tmp12 = tmp10 + tmp11\n tmp13 = tl.math.rsqrt(tmp12)\n tmp14 = tmp3 * tmp13\n tmp15 = tmp14.to(tl.float32)\n tmp16 = tmp8 * tmp15\n tl.store(out_ptr1 + (r1 + (512 * x0)), tmp16, rmask)\n\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.INNER,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}}\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n RBLOCK: tl.constexpr = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (512 * x0)), rmask, other=0).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (r1 + (512 * x0)), rmask, other=0).to(tl.float32)\n tmp3 = tl.load(in_ptr2 + (r1 + (512 * x0)), rmask, other=0).to(tl.float32)\n tmp5 = tl.load(in_ptr3 + (r1 + (512 * x0)), rmask, other=0).to(tl.float32)\n tmp12 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp4 = tmp2 + tmp3\n tmp6 = tmp4 + tmp5\n tmp7 = tmp6.to(tl.float32)\n tmp8 = tmp7 * tmp7\n tmp10 = tl.where(rmask, tmp8, 0)\n tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp10, 0))\n tmp13 = 512.0\n tmp14 = tmp11 / tmp13\n tmp15 = 1e-06\n tmp16 = tmp14 + tmp15\n tmp17 = tl.math.rsqrt(tmp16)\n tmp18 = tmp7 * tmp17\n tmp19 = tmp18.to(tl.float32)\n tmp20 = tmp12 * tmp19\n tl.store(out_ptr1 + (r1 + (512 * x0)), tmp20, rmask)\n", - "description_1": "Use triton language to implement a fused kernel that performs reduction across specified dimensions and computes intermediate results with tensor loading, mathematical operations, and storing results back.", - "description_2": "Use triton language to implement a pointwise operation that clones tensor data from input to output with specific dimensions using `triton.jit`.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch import empty_strided, as_strided, device\nfrom torch._inductor.codecache import AsyncCompile\nfrom torch._inductor.select_algorithm import extern_kernels\n\nasync_compile = AsyncCompile()\n\n# Kernel 1\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')\n _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp1 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp1) & (tmp1 < 32128), \"index out of bounds: 0 <= tmp1 < 32128\")\n tmp2 = tl.load(in_ptr1 + (r1 + (512 * tmp0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = _tmp5 + tmp4\n _tmp5 = tl.where(rmask, tmp6, _tmp5)\n tmp5 = tl.sum(_tmp5, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp8 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp8) & (tmp8 < 32128), \"index out of bounds: 0 <= tmp8 < 32128\")\n tmp9 = tl.load(in_ptr1 + (r1 + (512 * tmp0)), rmask, other=0).to(tl.float32)\n tmp10 = tmp9.to(tl.float32)\n tmp11 = 512.0\n tmp12 = tmp5 / tmp11\n tmp13 = 1e-06\n tmp14 = tmp12 + tmp13\n tmp15 = tl.math.rsqrt(tmp14)\n tmp16 = tmp10 * tmp15\n tmp17 = tmp16.to(tl.float32)\n tmp18 = tmp7 * tmp17\n tl.store(out_ptr1 + (r1 + (512 * x0)), tmp18, rmask)\n\n# Call function\ndef call(args):\n arg0_1, arg13_1, arg32_1, buf1 = args\n args.clear()\n with torch.cuda._DeviceGuard(0):\n buf1 = empty_strided((4, 2048, 512), (1048576, 512, 1), device='cuda', dtype=torch.bfloat16)\n triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg133_1, arg32_1, arg13_1, buf1, 8192, 512, grid=grid(8192), stream=stream0)\n\nasync_compile.wait(globals())\ndel async_compile\n", - "description_1": "Use triton language to define a kernel that performs operations on input pointers with specified grid sizes, performs assertions, loads data, and writes output using rsqrt operations.", - "description_2": "Use triton language to execute defined kernels with specific grid sizes for tensor computations on CUDA devices.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint\nfrom torch._inductor.triton_heuristics import reduction, persistent_reduction, pointwise\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={\n 'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'},\n 'device': 0, 'constants': {}, 'mutated_arg_names': [],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}\n)\n@triton.jit\ndef triton_kernel_1(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel implementation with necessary logic\n # ...\n\n\n@pointwise(size_hints=[4194304], filename=__file__, \n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})\n@triton.jit\ndef triton_kernel_2(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Kernel implementation with necessary logic\n # ...\n\n\n# Similar structure for other Triton kernels follows here...\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={\n 'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*i64', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: 'i32', 9: 'i32'},\n 'device': 0, 'constants': {}, 'mutated_arg_names': [],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}\n)\n@triton.jit\ndef triton_kernel_3(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel implementation with necessary logic\n # ...\n\n\n# Functions calling the kernels\ndef call_triton_kernel_1(args):\n triton_kernel_1(args[0], args[1], args[2], args[3], 8192, 512, XBLOCK=128, RBLOCK=128)\n\ndef call_triton_kernel_2(args):\n triton_kernel_2(args[0], args[1], 4194304, XBLOCK=128)\n\ndef call_triton_kernel_3(args):\n triton_kernel_3(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], 8192, 512, XBLOCK=128, RBLOCK=128)\n", - "description_1": "Use triton language to implement kernels for reduction and pointwise operations. The kernels handle various data pointers, element numbers, and block constants for optimized computation.", - "description_2": "Implement Triton kernels to execute reduction and pointwise operations on provided input pointers with specified block and element configurations.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\n\n# Kernel 1: Fused operations including convert_element_type, add, embedding, mean, mul, pow, rsqrt\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={\n 'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'},\n 'device': 0,\n 'constants': {},\n 'mutated_arg_names': [],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_fused_op1(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel implementation with triton language\n ...\n\n# Kernel 2: Cloning operation\n@pointwise(size_hints=[4194304], filename=__file__, meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'},\n 'device': 0,\n 'constants': {},\n 'mutated_arg_names': [],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]\n})\n@triton.jit\ndef triton_clone_op1(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Kernel implementation with triton language\n ...\n\n# Kernel 3: Fused softmax, convert_element_type, add, mul, rsub operations\n@reduction(\n size_hints=[65536, 2048],\n reduction_hint=ReductionHint.INNER,\n filename=__file__,\n meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32', 4: 'i32'},\n 'device': 0,\n 'constants': {},\n 'mutated_arg_names': ['in_out_ptr0'],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_fused_op2(in_out_ptr0, in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel implementation with triton language\n ...\n\n# Kernel 4: Pointwise ReLU operation\n@pointwise(size_hints=[16777216], filename=__file__, meta={\n 'signature': {0: '*bf16', 1: 'i32'},\n 'device': 0,\n 'constants': {},\n 'mutated_arg_names': ['in_out_ptr0'],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]\n})\n@triton.jit\ndef triton_relu_op(in_out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Kernel implementation with triton language\n ...\n\n# Kernel 5: Persistent reduction operation\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.INNER,\n filename=__file__,\n meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'},\n 'device': 0,\n 'constants': {},\n 'mutated_arg_names': [],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_persistent_op(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n # Kernel implementation with triton language\n ...\n\n", - "description_1": "Use triton language to implement fused operations including data type conversion, addition, embedding, mean, multiplication, power, reciprocal square root, cloning, softmax, and ReLU. Implement persistent reductions using Triton for various element-wise and reduction operations with GPU-specific configurations and optimizations.", - "description_2": "Use triton language for complex operation fusion, combining multiple mathematical operations for efficiency. Employ persistent reduction techniques in Triton to perform high-performance parallel reductions on GPU data.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint, TileHint\nfrom torch._inductor.triton_heuristics import pointwise, reduction, persistent_reduction\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\nimport torch\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):\n xnumel = 8192\n rnumel = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')\n _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp1 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp1) & (tmp1 < 32128), \"index out of bounds: 0 <= tmp1 < 32128\")\n tmp2 = tl.load(in_ptr1 + (r1 + (512*tmp0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = _tmp5 + tmp4\n _tmp5 = tl.where(rmask, tmp6, _tmp5)\n tmp5 = tl.sum(_tmp5, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp8 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp8) & (tmp8 < 32128), \"index out of bounds: 0 <= tmp8 < 32128\")\n tmp9 = tl.load(in_ptr1 + (r1 + (512*tmp0)), rmask, other=0).to(tl.float32)\n tmp10 = tmp9.to(tl.float32)\n tmp11 = 512.0\n tmp12 = tmp5 / tmp11\n tmp13 = 1e-06\n tmp14 = tmp12 + tmp13\n tmp15 = tl.math.rsqrt(tmp14)\n tmp16 = tmp10 * tmp15\n tmp17 = tmp16.to(tl.float32)\n tmp18 = tmp7 * tmp17\n tl.store(out_ptr1 + (r1 + (512*x0)), tmp18, rmask)\n\n@pointwise(\n size_hints=[4194304],\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}\n)\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):\n xnumel = 4194304\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 64\n x1 = (xindex // 64) % 2048\n x2 = (xindex // 131072) % 8\n x3 = (xindex // 1048576)\n x4 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (64*x2) + (512*x1) + (1048576*x3)), None).to(tl.float32)\n tl.store(out_ptr0 + (x4), tmp0, None)\n", - "description_1": "Use triton language to implement a reduction kernel that processes input data, performs bounds checking, multiplication, and reciprocal square root operations for efficient large-scale computations. Another triton kernel performs pointwise operations, transforming and storing input data across specified dimensions.", - "description_2": "Use triton language to create efficient GPU kernels for reduction operations on input arrays with bounds checking and mathematical transformations, and separate pointwise kernels for data transformation and storage.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\n# Kernel 1\n@triton.jit\ndef triton_kernel1(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # ...\n # Implementation of the kernel\n # ...\n\n\n# Kernel 2\n@triton.jit\ndef triton_kernel2(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # ...\n # Implementation of the kernel\n # ...\n\n\n# Kernel 3\n@triton.jit\ndef triton_kernel3(in_ptr0, in_ptr1, out_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # ...\n # Implementation of the kernel\n # ...\n\n\n# Kernel 4\n@triton.jit\ndef triton_kernel4(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # ...\n # Implementation of the kernel\n # ...\n\n\n# Calling the kernels\ndef call_triton_kernels():\n # ...\n # Code to prepare inputs and call the Triton kernels\n # ...\n\n", - "description_1": "Use triton language to implement several kernels. The first kernel takes six input parameters and performs reduction operations, manipulating pointers and numerical values for computation. The second kernel involves three parameters and executes a pointwise operation with specific block size constraints. The third kernel takes five input parameters and carries out a reduction with precise handling of pointers and constraints. The fourth kernel manages ten parameters to conduct a complex reduction operation.", - "description_2": "Use triton language to create and execute GPU kernels designed for efficient memory and computation handling with specific constraints on block sizes and reduction operations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': 0}\n)\n@triton.jit\ndef triton_fused_kernel(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + x0, None, eviction_policy='evict_last')\n _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp1 = tl.load(in_ptr1 + (r1 + 512 * tmp0), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp2 = tl.load(in_ptr2 + r1, rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp3 = tmp1.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = _tmp5 + tmp4\n _tmp5 = tl.where(rmask, tmp6, _tmp5)\n tmp5 = tl.sum(_tmp5, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(in_ptr2 + r1, rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp8 = tmp5 / 512.0\n tmp9 = tmp8 + 1e-06\n tmp10 = tl.math.rsqrt(tmp9)\n tmp11 = tmp7 * tmp10\n tl.store(out_ptr1 + (r1 + 512 * x0), tmp11, rmask)\n\n@pointwise(size_hints=[4194304], meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0})\n@triton.jit\ndef triton_clone_kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 4194304\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 64\n x1 = (xindex // 64) % 2048\n x2 = (xindex // 131072) % 8\n x3 = (xindex // 1048576)\n x4 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + 64 * x2 + 512 * x1 + 1048576 * x3), None).to(tl.float32)\n tl.store(out_ptr0 + x4, tmp0, None)\n\n@persistent_reduction(\n size_hints=[65536, 2048],\n reduction_hint=ReductionHint.INNER,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32', 3: 'i32'}, 'device': 0}\n)\n@triton.jit\ndef triton_softmax_kernel(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 65536\n rnumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n _tmp2 = tl.full([XBLOCK, RBLOCK], float(\"-inf\"), tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(in_ptr0 + (r1 + 2048 * x0), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp1 = triton_helpers.maximum(_tmp2, tmp0.to(tl.float32))\n _tmp2 = tl.where(rmask, tmp1, _tmp2)\n tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]\n _tmp8 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp4 = tl.load(in_ptr0 + (r1 + 2048 * x0), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp5 = tmp4 - tmp2\n tmp6 = tl.exp(tmp5)\n tmp8 = _tmp8 + tmp6\n _tmp8 = tl.where(rmask, tmp8, _tmp8)\n tmp8 = tl.sum(_tmp8, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp10 = tl.load(in_ptr0 + (r1 + 2048 * x0), rmask, other=0).to(tl.float32)\n tmp11 = tmp10 - tmp2\n tmp12 = tl.exp(tmp11)\n tmp13 = tmp12 / tmp8\n tl.store(out_ptr2 + (r1 + 2048 * x0), tmp13.to(tl.float32), rmask)\n", - "description_1": "Use triton language to define kernels for fused reduction operations, cloning of elements from one buffer to another, and a softmax computation on specified dimensions.", - "description_2": "Use triton language to perform reduction and pointwise operations for deep learning tasks, including tensor cloning and applying the softmax function efficiently on large tensors.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint='default',\n filename=__file__,\n meta={\n 'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, \n 'device': 0, \n 'constants': {}, \n 'mutated_arg_names': [], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_kernel_1(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Parameters: 6 (in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel)\n # Description: Triton kernel function performing reduction operations\n ...\n\n@pointwise(\n size_hints=[4194304], \n filename=__file__, \n meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, \n 'device': 0, \n 'constants': {}, \n 'mutated_arg_names': [], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_kernel_2(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Parameters: 3 (in_ptr0, out_ptr0, xnumel)\n # Description: Triton kernel function for element-wise operations\n ...\n\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint='inner',\n filename=__file__,\n meta={\n 'signature': {0: '*bf16', 1: '*i64', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: 'i32', 8: 'i32'}, \n 'device': 0, \n 'constants': {}, \n 'mutated_arg_names': ['in_out_ptr0'], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_kernel_3(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n # Parameters: 9 (in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel)\n # Description: Triton kernel function performing persistent reduction\n ...\n\n@reduction(\n size_hints=[65536, 2048],\n reduction_hint='inner',\n filename=__file__,\n meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32', 4: 'i32'}, \n 'device': 0, \n 'constants': {}, \n 'mutated_arg_names': ['in_out_ptr0'], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_kernel_4(in_out_ptr0, in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Parameters: 5 (in_out_ptr0, in_ptr0, out_ptr2, xnumel, rnumel)\n # Description: Triton kernel function for reduction with softmax operation\n ...\n\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint='inner',\n filename=__file__,\n meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: 'i32', 6: 'i32'}, \n 'device': 0, \n 'constants': {}, \n 'mutated_arg_names': [], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_kernel_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n # Parameters: 7 (in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel)\n # Description: Triton kernel function performing persistent reduction\n ...\n\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint='inner',\n filename=__file__,\n meta={\n 'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, \n 'device': 0, \n 'constants': {}, \n 'mutated_arg_names': [], \n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_kernel_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n # Parameters: 8 (in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel)\n # Description: Triton kernel function performing persistent reduction\n ...\n", - "description_1": "Use triton language to define and implement various kernel functions including reduction, pointwise, and persistent reduction operations. The kernels use parameters for input and output pointers, element numbers, and block sizes for computations on GPU.", - "description_2": "Use triton language to perform reduction, pointwise, and persistent reduction operations with specific parameters and block sizes for GPU-based computations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint, TileHint\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\n# Kernel triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}}\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + x0, None, eviction_policy='evict_last')\n _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp1 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp1) & (tmp1 < 32128), \"index out of bounds: 0 <= tmp1 < 32128\")\n tmp2 = tl.load(in_ptr1 + (r1 + (512*tmp0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = _tmp5 + tmp4\n _tmp5 = tl.where(rmask, tmp6, _tmp5)\n tmp5 = tl.sum(_tmp5, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(in_ptr2 + r1, rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp8 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp8) & (tmp8 < 32128), \"index out of bounds: 0 <= tmp8 < 32128\")\n tmp9 = tl.load(in_ptr1 + (r1 + (512*tmp0)), rmask, other=0).to(tl.float32)\n tmp10 = tmp9.to(tl.float32)\n tmp11 = 512.0\n tmp12 = tmp5 / tmp11\n tmp13 = 1e-06\n tmp14 = tmp12 + tmp13\n tmp15 = tl.math.rsqrt(tmp14)\n tmp16 = tmp10 * tmp15\n tmp17 = tmp16.to(tl.float32)\n tmp18 = tmp7 * tmp17\n tl.store(out_ptr1 + (r1 + (512*x0)), tmp18, rmask)\n\n\n# Kernel triton_red_fused__softmax__to_copy_add_mul_rsub_3\n@reduction(\n size_hints=[65536, 2048],\n reduction_hint=ReductionHint.INNER,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32', 4: 'i32'}}\n)\n@triton.jit\ndef triton_(in_out_ptr0, in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 65536\n rnumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x4 = xindex\n x0 = xindex % 2048\n x1 = (xindex // 2048) % 8\n _tmp35 = tl.full([XBLOCK, RBLOCK], float(\"-inf\"), tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r3 = rindex\n tmp0 = tl.load(in_out_ptr0 + (r3 + (2048*x4)), rmask, other=0).to(tl.float32)\n tmp1 = r3 + ((-1)*x0)\n tmp2 = 0\n tmp3 = triton_helpers.minimum(tmp1, tmp2)\n tmp4 = -tmp3\n tmp5 = 16\n tmp6 = tmp4 < tmp5\n tmp7 = tmp4.to(tl.float32)\n tmp8 = 16.0\n tmp9 = tmp7 / tmp8\n tmp10 = tl.log(tmp9)\n tmp11 = 2.0794415416798357\n tmp12 = tmp10 / tmp11\n tmp13 = tmp12 * tmp8\n tmp14 = tmp13.to(tl.int64)\n tmp15 = tmp14 + tmp5\n tmp16 = 31\n tmp17 = triton_helpers.minimum(tmp15, tmp16)\n tmp18 = tl.where(tmp6, tmp4, tmp17)\n tmp19 = tmp18 + tmp2\n tmp20 = triton_helpers.promote_to_tensor(tmp19)\n tl.device_assert((0 <= tmp20) & (tmp20 < 32), \"index out of bounds: 0 <= tmp20 < 32\")\n tmp21 = tl.load(in_ptr0 + (x1 + (8*tmp19)), None).to(tl.float32)\n tmp22 = r3\n tmp23 = x0\n tmp24 = tmp22 <= tmp23\n tmp25 = tmp24.to(tl.float32)\n tmp26 = 1.0\n tmp27 = tmp25 * tmp26\n tmp28 = tmp27.to(tl.float32)\n tmp29 = tmp26 - tmp28\n tmp30 = -3.3895313892515355e+38\n tmp31 = tmp29 * tmp30\n tmp32 = tmp21 + tmp31\n tmp33 = tmp0 + tmp32\n tmp34 = tmp33.to(tl.float32)\n tmp36 = triton_helpers.maximum(_tmp35, tmp34)\n _tmp35 = tl.where(rmask, tmp36, _tmp35)\n tl.store(in_out_ptr0 + (r3 + (2048*x4)), tmp33, rmask)\n tmp35 = triton_helpers.max2(_tmp35, 1)[:, None]\n _tmp41 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r3 = rindex\n tmp37 = tl.load(in_out_ptr0 + (r3 + (2048*x4)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp38 = tmp37.to(tl.float32)\n tmp39 = tmp38 - tmp35\n tmp40 = tl.exp(tmp39)\n tmp42 = _tmp41 + tmp40\n _tmp41 = tl.where(rmask, tmp42, _tmp41)\n tmp41 = tl.sum(_tmp41, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r3 = rindex\n tmp43 = tl.load(in_out_ptr0 + (r3 + (2048*x4)), rmask, other=0).to(tl.float32)\n tmp44 = tmp43.to(tl.float32)\n tmp45 = tmp44 - tmp35\n tmp46 = tl.exp(tmp45)\n tmp47 = tmp46 / tmp41\n tmp48 = tmp47.to(tl.float32)\n tl.store(out_ptr2 + (r3 + (2048*x4)), tmp48, rmask)\n\n# Similar reduction and pointwise decorated triton kernels can be added here...\n\n", - "description_1": "Use triton language to create a reduction kernel that handles memory safely with device assertions and optimized eviction policies for tensor operations across multiple elements.", - "description_2": "Use triton language to implement efficient softmax-like operations by managing exponential calculations and reduction tasks on tensors, ensuring safe memory operations with device assertions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint\nfrom torch._inductor.triton_heuristics import reduction, pointwise\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\nfrom torch import empty_strided, as_strided\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]}\n)\n@triton.jit\ndef triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel logic with six parameters for computation\n ...\n\n@pointwise(\n size_hints=[4194304],\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}\n)\n@triton.jit\ndef triton_poi_fused_clone_1(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Clone operation logic with three parameters\n ...\n\n@pointwise(\n size_hints=[2048, 2048],\n tile_hint=TileHint.SQUARE,\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}\n)\n@triton.jit\ndef triton_poi_fused_clone_2(in_ptr0, out_ptr0, xnumel, ynumel, XBLOCK: tl.constexpr, YBLOCK: tl.constexpr):\n # Clone operation logic with four parameters for computation\n ...\n\n@reduction(\n size_hints=[65536, 2048],\n reduction_hint=ReductionHint.INNER,\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32', 4: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]}\n)\n@triton.jit\ndef triton_red_fused__softmax__to_copy_add_mul_rsub_3(in_out_ptr0, in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel logic with five parameters for softmax and other operations\n ...\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*i64', 4: '*bf16', 5: '*bf16', 6: '*bf16', 7: '*bf16', 8: 'i32', 9: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}\n)\n@triton.jit\ndef triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, out_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n # Kernel logic with ten parameters for various tensor operations\n ...\n\n\ndef call(args):\n arg0_1, arg13_1, arg14_1, arg32_1, arg33_1, arg34_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg132_1, arg133_1 = args\n args.clear()\n # Tensor assertions\n assert_size_stride(arg0_1, (512, ), (1, ))\n assert_size_stride(arg13_1, (512, ), (1, ))\n assert_size_stride(arg14_1, (512, ), (1, ))\n assert_size_stride(arg32_1, (32128, 512), (512, 1))\n assert_size_stride(arg33_1, (512, 512), (512, 1))\n assert_size_stride(arg34_1, (512, 512), (512, 1))\n assert_size_stride(arg70_1, (512, 512), (512, 1))\n assert_size_stride(arg71_1, (512, 512), (512, 1))\n assert_size_stride(arg72_1, (512, 512), (512, 1))\n assert_size_stride(arg73_1, (32, 8), (8, 1))\n assert_size_stride(arg74_1, (512, 512), (512, 1))\n assert_size_stride(arg75_1, (512, 512), (512, 1))\n assert_size_stride(arg132_1, (4, 2048), (2048, 1))\n assert_size_stride(arg133_1, (4, 2048), (2048, 1))\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0) # no-op to ensure context\n buf1 = empty_strided((4, 2048, 512), (1048576, 512, 1), device='cuda', dtype=torch.bfloat16)\n triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg133_1, arg32_1, arg13_1, buf1, 8192, 512, grid=grid(8192), stream=stream0)\n del arg13_1\n buf2 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.mm(as_strided(buf1, (8192, 512), (512, 1)), as_strided(arg70_1, (512, 512), (1, 512)), out=buf2)\n del arg70_1\n buf3 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.mm(as_strided(buf1, (8192, 512), (512, 1)), as_strided(arg71_1, (512, 512), (1, 512)), out=buf3)\n del arg71_1\n buf4 = empty_strided((4, 8, 2048, 64), (1048576, 131072, 64, 1), device='cuda', dtype=torch.bfloat16)\n triton_poi_fused_clone_1.run(buf2, buf4, 4194304, grid=grid(4194304), stream=stream0)\n del buf2\n buf5 = empty_strided((4, 8, 64, 2048), (1048576, 131072, 2048, 1), device='cuda', dtype=torch.bfloat16)\n triton_poi_fused_clone_2.run(buf3, buf5, 2048, 2048, grid=grid(2048, 2048), stream=stream0)\n buf6 = empty_strided((32, 2048, 2048), (4194304, 2048, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.bmm(as_strided(buf4, (32, 2048, 64), (131072, 64, 1)), as_strided(buf5, (32, 64, 2048), (131072, 2048, 1)), out=buf6)\n del buf4\n del buf5\n buf7 = as_strided(buf6, (4, 8, 2048, 2048), (33554432, 4194304, 2048, 1)); del buf6 # reuse\n buf11 = empty_strided((4, 8, 2048, 2048), (33554432, 4194304, 2048, 1), device='cuda', dtype=torch.bfloat16)\n triton_red_fused__softmax__to_copy_add_mul_rsub_3.run(buf7, arg73_1, buf11, 65536, 2048, grid=grid(65536), stream=stream0)\n del buf7\n buf10 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.mm(as_strided(buf1, (8192, 512), (512, 1)), as_strided(arg72_1, (512, 512), (1, 512)), out=buf10)\n del arg72_1\n del buf1\n buf12 = empty_strided((4, 8, 2048, 64), (1048576, 131072, 64, 1), device='cuda', dtype=torch.bfloat16)\n triton_poi_fused_clone_1.run(buf10, buf12, 4194304, grid=grid(4194304), stream=stream0)\n buf13 = empty_strided((32, 2048, 64), (131072, 64, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.bmm(as_strided(buf11, (32, 2048, 2048), (4194304, 2048, 1)), as_strided(buf12, (32, 2048, 64), (131072, 64, 1)), out=buf13)\n del buf11\n del buf12\n buf14 = empty_strided((4, 2048, 8, 64), (1048576, 512, 64, 1), device='cuda', dtype=torch.bfloat16)\n triton_poi_fused_clone_4.run(buf13, buf14, 4194304, grid=grid(4194304), stream=stream0)\n del buf13\n buf15 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.mm(as_strided(buf14, (8192, 512), (512, 1)), as_strided(arg74_1, (512, 512), (1, 512)), out=buf15)\n del arg74_1\n del buf14\n buf17 = empty_strided((4, 2048, 512), (1048576, 512, 1), device='cuda', dtype=torch.bfloat16)\n buf20 = empty_strided((4, 2048, 512), (1048576, 512, 1), device='cuda', dtype=torch.bfloat16)\n event_buf16_buf19_buf17_buf20 = torch.cuda.Event()\n triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_5.run(arg133_1, arg32_1, buf15, arg132_1, arg14_1, arg0_1, buf17, buf20, 8192, 512, grid=grid(8192), stream=stream0)\n event_buf16_buf19_buf17_buf20.record(stream0_raw)\n del arg0_1\n del arg14_1\n torch.cuda.set_stream(stream5_raw)\n buf18 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n stream5_raw.wait_event(event_buf16_buf19_buf17_buf20)\n extern_kernels.mm(as_strided(buf17, (8192, 512), (512, 1)), as_strided(arg75_1, (512, 512), (1, 512)), out=buf18)\n torch.cuda.set_stream(stream0_raw)\n del arg75_1\n del buf17\n buf21 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.mm(as_strided(buf20, (8192, 512), (512, 1)), as_strided(arg33_1, (512, 512), (1, 512)), out=buf21)\n del arg33_1\n buf22 = empty_strided((8192, 512), (512, 1), device='cuda', dtype=torch.bfloat16)\n extern_kernels.mm(as_strided(buf20, (8192, 512), (512, 1)), as_strided(arg34_1, (512, 512), (1, 512)), out=buf22)\n del arg34_1\n return buf18, buf21, buf22\n", - "description_1": "Use triton language to define multiple kernels with various reduction and pointwise operations, and a function to call these kernels for computations, involving tensor loading, element-wise operations, and matrix multiplications.", - "description_2": "Use triton language to implement kernels for reduction and pointwise operations and manage them using a call function with CUDA streams.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\n\n# Kernel 1\n@reduction(\n size_hints=[8192, 512],\n reduction_hint='default',\n filename=__file__,\n meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [()]}\n)\n@triton.jit\ndef triton_1(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n # ... kernel logic ...\n\n# Kernel 2\n@pointwise(size_hints=[4194304], filename=__file__, meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [()]})\n@triton.jit\ndef triton_2(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 4194304\n # ... kernel logic ...\n\n# Kernel 3\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint='inner',\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [()]}\n)\n@triton.jit\ndef triton_3(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n # ... kernel logic ...\n\n# Kernel 4\n@reduction(\n size_hints=[65536, 2048],\n reduction_hint='inner',\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [()]}\n)\n@triton.jit\ndef triton_4(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 65536\n rnumel = 2048\n # ... kernel logic ...\n\n# Kernel 5\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint='inner',\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [()]}\n)\n@triton.jit\ndef triton_5(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n # ... kernel logic ...\n\n# Kernel 6\n@persistent_reduction(\n size_hints=[8192, 512],\n reduction_hint='inner',\n filename=__file__,\n meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*bf16', 5: '*bf16', 6: 'i32', 7: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [()]}\n)\n@triton.jit\ndef triton_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n # ... kernel logic ...\n", - "description_1": "Use triton language to define and execute a series of kernels for various operations on tensors, including reduction and pointwise operations. The kernels are designed to handle input pointers, output pointers, and execute operations using the Triton language's specific features like tile sizes and reduction hints.", - "description_2": "Use triton language to define kernels for tensor operations. Utilize reduction and pointwise operations with given parameters for input, output, and execution blocks.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import reduction, pointwise, persistent_reduction\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\nimport torch\n\n@reduction(\n size_hints=[8192, 512],\n reduction_hint=tl.ReductionHint.DEFAULT,\n filename=__file__,\n meta={\n 'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: 'i32', 5: 'i32'},\n 'device': 0,\n 'constants': {},\n 'mutated_arg_names': [],\n 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]\n }\n)\n@triton.jit\ndef triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 8192\n rnumel = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')\n _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp1 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp1) & (tmp1 < 32128), \"index out of bounds: 0 <= tmp1 < 32128\")\n tmp2 = tl.load(in_ptr1 + (r1 + (512*tmp0)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tmp3 * tmp3\n tmp6 = _tmp5 + tmp4\n _tmp5 = tl.where(rmask, tmp6, _tmp5)\n tmp5 = tl.sum(_tmp5, 1)[:, None]\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0).to(tl.float32)\n tmp8 = triton_helpers.promote_to_tensor(tmp0)\n tl.device_assert((0 <= tmp8) & (tmp8 < 32128), \"index out of bounds: 0 <= tmp8 < 32128\")\n tmp9 = tl.load(in_ptr1 + (r1 + (512*tmp0)), rmask, other=0).to(tl.float32)\n tmp10 = tmp9.to(tl.float32)\n tmp11 = 512.0\n tmp12 = tmp5 / tmp11\n tmp13 = 1e-06\n tmp14 = tmp12 + tmp13\n tmp15 = tl.math.rsqrt(tmp14)\n tmp16 = tmp10 * tmp15\n tmp17 = tmp16.to(tl.float32)\n tmp18 = tmp7 * tmp17\n tl.store(out_ptr1 + (r1 + (512*x0)), tmp18, rmask)\n\ndef call(args):\n arg0, arg1, arg2, arg3, arg4, arg5 = args\n args.clear()\n torch.cuda.set_device(0)\n buf1 = torch.empty((8192, 512), dtype=torch.bfloat16, device='cuda')\n triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg0, arg1, arg2, buf1, 8192, 512, grid=(8192,), stream=torch.cuda.current_stream())\n return buf1\n", - "description_1": "Use triton language to implement a kernel that performs an element-wise square operation on a subset of a 1D buffer, followed by a reduction operation (sum), which is then used to normalize another element-wise multiplication of the input buffers. The function takes six arguments: three input buffers, one output buffer, and two scalar dimensions indicating the sizes for block-wise operations.", - "description_2": "Use triton language to perform square and reduce operations on a 1D buffer, then use the result to normalize multiplication of input buffers with dimensions for block operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.math import erf, pow, tanh\n\n@triton.jit\ndef gelu_none_and_mul_kernel(x, y):\n # Convert input to float32 for better precision in operations\n x_fp32 = x.to(tl.float32)\n # Compute the GELU function using the error function approximation\n x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * 0.7071067811))\n # Multiply the result by y and return\n return x_gelu * y\n\n@triton.jit\ndef gelu_tanh_and_mul_kernel(x, y):\n # Convert input to float32 for better precision in operations\n x_fp32 = x.to(tl.float32)\n # Compute the GELU function using the tanh approximation\n x_gelu = (\n 0.5\n * x_fp32\n * (\n 1\n + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2)))\n )\n )\n # Multiply the result by y and return\n return x_gelu * y\n\nclass GeluAndMul(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A, B, approximate=\"none\"):\n # Log debug information\n logging.debug(\"GEMS GELU AND MUL FORWARD\")\n # Choose the kernel based on the approximation method\n if approximate == \"none\":\n return gelu_none_and_mul_kernel(A, B)\n elif approximate == \"tanh\":\n return gelu_tanh_and_mul_kernel(A, B)\n else:\n raise ValueError(f\"Invalid approximate value: {approximate}\")\n\ndef gelu_and_mul(A, B, approximate=\"none\"):\n # Wrapper function for using GeluAndMul class\n return GeluAndMul.apply(A, B, approximate)\n", - "description_1": "Use triton language to implement two kernels, gelu_none_and_mul_kernel and gelu_tanh_and_mul_kernel, each taking two parameters x and y. gelu_none_and_mul_kernel applies the Gaussian Error Linear Unit (GELU) function using the error function approximation on x and multiplies the result by y. gelu_tanh_and_mul_kernel does the same using the tanh approximation. Both kernels return the product. Additionally, a GeluAndMul class is provided which selects one of these kernels based on an approximation method specified as 'none' or 'tanh'.", - "description_2": "Use triton language to implement kernels applying the GELU function on x, using different approximations, and then multiply by y. Include a mechanism to select the approximation method.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.jit\ndef apply_rotary_pos_emb_kernel(\n oq_ptr, ok_ptr, q_ptr, k_ptr, cos_ptr, sin_ptr, pos_ptr,\n q_stride_s, q_stride_h, q_stride_d, k_stride_s, k_stride_h, k_stride_d,\n oq_stride_s, oq_stride_h, oq_stride_d, ok_stride_s, ok_stride_h, ok_stride_d,\n p_stride_s, cos_stride_s, sin_stride_s, seq_len,\n NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr,\n HEAD_DIM: tl.constexpr, PADDED_HEAD_DIM: tl.constexpr,\n ROTARY_INTERLEAVED: tl.constexpr, MAX_POSITION_EMBEDDINGS: tl.constexpr,\n):\n s_id = tl.program_id(0)\n if pos_ptr is None:\n pos_id = s_id % seq_len\n else:\n pos_ptr += s_id * p_stride_s\n pos_id = tl.load(pos_ptr)\n cos_ptr += pos_id * cos_stride_s\n sin_ptr += pos_id * sin_stride_s\n tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, \"position id out of bound\")\n\n ordered_block = tl.arange(0, PADDED_HEAD_DIM)\n mask = ordered_block < HEAD_DIM\n if ROTARY_INTERLEAVED:\n odd_mask = ordered_block % 2 == 0\n rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1)\n sin_cos_block = ordered_block // 2\n cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)\n sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)\n sin = tl.where(odd_mask, -sin, sin)\n else:\n rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM\n sin_cos_block = ordered_block % (HEAD_DIM // 2)\n cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)\n sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32)\n sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin)\n\n oq_ptr += s_id * oq_stride_s\n q_ptr += s_id * q_stride_s\n\n for off_h in range(0, NUM_Q_HEADS):\n ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d)\n rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d)\n output_offs = off_h * oq_stride_h + (ordered_block * oq_stride_d)\n\n q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0)\n rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0)\n y = q * cos + rotated_q * sin\n tl.store(oq_ptr + output_offs, y, mask=mask)\n\n ok_ptr += s_id * ok_stride_s\n k_ptr += s_id * k_stride_s\n\n for off_h in range(0, NUM_K_HEADS):\n ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d)\n rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d)\n output_offs = off_h * ok_stride_h + (ordered_block * ok_stride_d)\n\n k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0)\n rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0)\n y = k * cos + rotated_k * sin\n tl.store(ok_ptr + output_offs, y, mask=mask)\n\n\ndef apply_rotary_pos_emb(\n q, k, cos, sin, position_ids: Optional[torch.IntTensor] = None, rotary_interleaved: bool = False,\n):\n assert k.shape[-1] == q.shape[-1]\n assert cos.shape[-1] == sin.shape[-1]\n assert cos.shape[-1] * 2 == q.shape[-1]\n assert cos.stride(-1) == 1\n assert sin.stride(-1) == 1\n\n q_shape = q.shape\n k_shape = k.shape\n assert q.shape[:-2] == k.shape[:-2]\n if position_ids is None:\n assert len(q.shape) == 4\n seq_len = q.shape[-3]\n else:\n assert position_ids.shape == q.shape[:-2]\n position_ids = position_ids.view(-1)\n seq_len = None\n\n q = q.view(-1, q.shape[-2], q.shape[-1])\n k = k.view(-1, k.shape[-2], k.shape[-1])\n\n q_embed = torch.empty_like(q)\n k_embed = torch.empty_like(k)\n\n n_tokens, q_heads, head_dim = q.shape\n padded_head_dim = max(triton.next_power_of_2(head_dim), 16)\n\n grid = (n_tokens,)\n with torch.cuda.device(q_embed.device):\n apply_rotary_pos_emb_kernel[grid](\n q_embed, k_embed, q, k, cos, sin, position_ids,\n q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2),\n q_embed.stride(0), q_embed.stride(1), q_embed.stride(2),\n k_embed.stride(0), k_embed.stride(1), k_embed.stride(2),\n position_ids.stride(0) if position_ids is not None else 0,\n cos.stride(0), sin.stride(0), seq_len,\n q.shape[-2], k.shape[-2], head_dim, padded_head_dim,\n rotary_interleaved, MAX_POSITION_EMBEDDINGS=cos.shape[0],\n )\n q_embed = q_embed.view(q_shape)\n k_embed = k_embed.view(k_shape)\n return q_embed, k_embed\n", - "description_1": "Use triton language to create a kernel apply_rotary_pos_emb_kernel that takes 30 parameters including pointers to tensors, strides, sequence length, and several constant expressions. This kernel applies rotary positional embeddings to queries and keys in a transformer model. Also, create a wrapper function apply_rotary_pos_emb in Python that takes 6 parameters including queries, keys, cosine and sine embedding tensors, optional position IDs, and a boolean flag. This function calls the kernel using the appropriate grid size and reshapes the results.", - "description_2": "Use triton language to create a kernel for applying rotary positional embeddings to transformer model tensors with a Python wrapper function for execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef silu_and_mul_kernel(x, y):\n # Convert input to float32\n x_fp32 = x.to(tl.float32)\n # Compute the SiLU activation\n x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))\n # Multiply the SiLU result with y\n return x_silu * y\n\nclass SiluAndMul(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A, B):\n # Call the Triton kernel\n return silu_and_mul_kernel(A, B)\n\ndef silu_and_mul(A, B):\n # Wrapper function to apply the Triton kernel\n return SiluAndMul.apply(A, B)\n", - "description_1": "Use triton language to implement a kernel that computes the SiLU activation of input tensor x and multiplies it with tensor y. The kernel takes two parameters: x and y, both of which are tensors. The function silu_and_mul_kernel performs the computation, and the function silu_and_mul serves as a wrapper to apply this kernel.", - "description_2": "Use triton language to create a kernel for SiLU activation and multiplication with another tensor, and provide a wrapper function for easy application.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef skip_layer_norm_kernel(\n Y, # pointer to the output\n X, # pointer to the input\n R, # pointer to the residual\n W, # pointer to the weights\n B, # pointer to the biases\n y_stride_r,\n y_stride_c,\n x_stride_r, # how much to increase the pointer when moving by 1 row\n x_stride_c, # how much to increase the pointer when moving by 1 col\n r_stride_r, # how much to increase the pointer when moving by 1 row\n r_stride_c, # how much to increase the pointer when moving by 1 col\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n Y += pid * y_stride_r\n X += pid * x_stride_r\n R += pid * r_stride_r\n\n mask = tl.arange(0, BLOCK_SIZE) < N\n cols = tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)\n r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)\n\n x += r\n\n mean = tl.sum(x, axis=0) / N\n\n # Compute variance\n _var = tl.where(mask, x - mean, 0.0)\n _var = _var * _var\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)\n b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32)\n\n x_hat = (x - mean) * rstd\n y = w * x_hat + b\n y = y.to(Y.dtype.element_ty)\n tl.store(Y + cols * y_stride_c, y, mask=mask)\n\n\nclass SkipLayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):\n dim = x.ndim - len(normalized_shape)\n M = math.prod(x.shape[:dim])\n N = math.prod(normalized_shape)\n\n BLOCK_SIZE = triton.next_power_of_2(N)\n x = x.contiguous()\n residual = residual.contiguous()\n weight = weight.contiguous()\n bias = bias.contiguous()\n y = torch.empty_like(x)\n\n with torch.cuda.device(x.device):\n skip_layer_norm_kernel[M,](\n y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE\n )\n return y\n\n\ndef skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5):\n return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)\n", - "description_1": "Use triton language to implement a skip-layer normalization operation as a kernel function 'skip_layer_norm_kernel' that performs normalization on input tensors with residual connection. It accepts 14 parameters: output pointer (Y), input pointer (X), residual pointer (R), weights pointer (W), biases pointer (B), various strides for Y, X, R (y_stride_r, y_stride_c, x_stride_r, x_stride_c, r_stride_r, r_stride_c), number of columns in X (N), epsilon for numerical stability (eps), and block size (BLOCK_SIZE). A wrapper function 'SkipLayerNorm' manages data preparation and kernel launch. The final interface 'skip_layer_norm' is used to invoke this functionality from PyTorch, accepting tensors and an epsilon as parameters.", - "description_2": "Use triton language to create a kernel for skip-layer normalization, handling input tensors with a residual addition. Encapsulate the operation in a class-based autograd function for seamless PyTorch integration.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef skip_rms_norm_kernel(\n Y, # pointer to the output\n X, # pointer to the input\n R, # pointer to the residual\n W, # pointer to the weights\n y_stride_r,\n y_stride_c,\n x_stride_r, # how much to increase the pointer when moving by 1 row\n x_stride_c, # how much to increase the pointer when moving by 1 col\n r_stride_r, # how much to increase the pointer when moving by 1 row\n r_stride_c, # how much to increase the pointer when moving by 1 col\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n Y += pid * y_stride_r\n X += pid * x_stride_r\n R += pid * r_stride_r\n\n mask = tl.arange(0, BLOCK_SIZE) < N\n cols = tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)\n r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)\n\n x += r\n\n var = tl.sum(x * x / N, axis=0)\n rrms = 1 / tl.sqrt(var + eps)\n\n w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)\n y = (x * rrms).to(Y.dtype.element_ty) * w\n tl.store(Y + cols * y_stride_c, y, mask=mask)\n\n\nclass SkipRmsNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, residual, normalized_shape, weight, eps=1e-5):\n dim = x.ndim - len(normalized_shape)\n M = math.prod(x.shape[:dim])\n N = math.prod(normalized_shape)\n\n BLOCK_SIZE = triton.next_power_of_2(N)\n x = x.contiguous()\n residual = residual.contiguous()\n weight = weight.contiguous()\n y = torch.empty_like(x)\n\n with torch.cuda.device(x.device):\n skip_rms_norm_kernel[M,](\n y, x, residual, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE\n )\n return y\n\n\ndef skip_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):\n return SkipRmsNorm.apply(x, residual, normalized_shape, weight, eps)\n", - "description_1": "Use triton language to implement a kernel function 'skip_rms_norm_kernel' that performs skip residual RMS normalization. The kernel takes 13 parameters: pointers to output (Y), input (X), residual (R), weights (W), strides for Y, X, and R, number of columns (N), epsilon (eps) to avoid division by zero, and a block size (BLOCK_SIZE). The kernel computes the variance, root mean square, and applies weights to store the result in Y. The function 'skip_rms_norm' is a wrapper that prepares inputs and calls the kernel.", - "description_2": "Use triton language to create a kernel for skip residual RMS normalization with input, residual, and weight pointers, and a wrapper function to execute it.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef add_func(x, y, alpha):\n # Triton kernel to add two tensors with a scalar multiplier\n return x + y * alpha\n\n@triton.jit\ndef add_func_tensor_scalar(x, y, alpha):\n # Triton kernel to add a tensor and a scalar with a scalar multiplier\n return x + y * alpha\n\n@triton.jit\ndef add_func_scalar_tensor(x, y, alpha):\n # Triton kernel to add a scalar and a tensor with a scalar multiplier\n return x + y * alpha\n\ndef add(A, B, *, alpha=1):\n # Function to select appropriate Triton kernel based on input types\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return add_func(A, B, alpha)\n elif isinstance(A, torch.Tensor):\n return add_func_tensor_scalar(A, B, alpha)\n elif isinstance(B, torch.Tensor):\n return add_func_scalar_tensor(A, B, alpha)\n else:\n return torch.tensor(A + B * alpha)\n", - "description_1": "Use triton language to implement three kernels: (1) add_func with 3 parameters (x, y, alpha) which adds two tensors with a scalar multiplier. (2) add_func_tensor_scalar with 3 parameters (x, y, alpha) which adds a tensor and a scalar with a scalar multiplier. (3) add_func_scalar_tensor with 3 parameters (x, y, alpha) which adds a scalar and a tensor with a scalar multiplier. Implement an add function that selects the appropriate kernel based on the types of the inputs A and B.", - "description_2": "Use triton language to create kernels for tensor and scalar addition with scalar multiplication, and implement a dispatcher function to choose the correct kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"alpha\", \"beta\"])\ndef addmm_kernel(\n a_ptr,\n b_ptr,\n bias_ptr,\n c_ptr,\n alpha,\n beta,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n bias_ptrs = bias_ptr + offs_bn\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(\n a_ptrs,\n mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0,\n )\n b = tl.load(\n b_ptrs,\n mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),\n other=0.0,\n )\n accumulator += tl.dot(a, b, allow_tf32=False)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n bias = tl.load(bias_ptrs, mask=offs_bn < N, other=0.0)\n accumulator = accumulator * alpha + bias * beta\n c = accumulator.to(bias.dtype)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef addmm(bias, mat1, mat2, *, beta=1, alpha=1):\n assert mat1.shape[1] == mat2.shape[0], \"Incompatible dimensions\"\n M, K = mat1.shape\n _, N = mat2.shape\n\n mat1 = mat1.contiguous()\n mat2 = mat2.contiguous()\n out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"]),\n triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n with torch.cuda.device(mat1.device):\n addmm_kernel[grid](\n mat1,\n mat2,\n bias,\n out,\n alpha,\n beta,\n M,\n N,\n K,\n mat1.stride(0),\n mat1.stride(1),\n mat2.stride(0),\n mat2.stride(1),\n out.stride(0),\n out.stride(1),\n )\n return out\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with bias addition. The kernel 'addmm_kernel' takes 18 parameters: pointers to matrices A, B, bias, and output C, scalars alpha and beta, dimensions M, N, K, strides for A, B, and C, and block sizes for M, N, and K. The function 'addmm' prepares the input matrices, sets up the grid for execution, and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to perform matrix multiplication with bias addition using a custom kernel. The kernel computes the product of two matrices and adds a bias, controlled by alpha and beta scaling factors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Helper function used to combine boolean values\n@triton.jit\ndef reduce_all(a, b):\n return a and b\n\n# Triton kernel that computes if all elements are non-zero along specified dimensions\n@triton.jit\ndef all_kernel_dim(\n inp,\n out,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n inp = inp + rows * N\n out = out + rows\n row_mask = rows < M\n\n _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n\n a = tl.load(inp + cols, mask, other=1.0)\n _all = _all and (a != 0)\n all = tl.reduce(_all, axis=1, combine_fn=reduce_all)\n tl.store(out, all[:, None], row_mask)\n\n# Triton kernel that computes if all elements are non-zero in the entire input\n@triton.jit\ndef all_kernel_1(\n inp,\n mid,\n n_elements,\n mid_size,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < n_elements\n inp_val = tl.load(inp_ptrs, mask=mask, other=1.0)\n all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all)\n mid_ptr = mid + pid\n tl.store(mid_ptr, all_val)\n\n# Triton kernel that reduces the mid results to a single output\n@triton.jit\ndef all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < MID_SIZE\n mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1)\n all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all)\n tl.store(out, all_val)\n\n# Wrapper function for all_kernel_1 and all_kernel_2 to compute \"all\" operation on input\ndef all(inp):\n n_elements = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))\n mid_size = triton.cdiv(n_elements, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)\n out = torch.empty([], dtype=torch.bool, device=inp.device)\n\n with torch.cuda.device(inp.device):\n all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)\n all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)\n\n return out\n\n# Wrapper function for all_kernel_dim to compute \"all\" operation along specified dimensions\ndef all_dim(inp, dim=None, keepdim=False):\n shape = list(inp.shape)\n if dim is None:\n out = all(inp)\n if keepdim:\n out = torch.reshape(out, [1] * inp.ndim)\n else:\n dim = dim % inp.ndim\n N = shape[dim]\n shape[dim] = 1\n M = inp.numel() // N\n\n out = torch.empty(shape, dtype=torch.bool, device=inp.device)\n\n grid = lambda meta: (triton.cdiv(M, meta[\"BLOCK_M\"]),)\n with torch.cuda.device(inp.device):\n all_kernel_dim[grid](inp, out, M, N)\n if not keepdim:\n out = out.squeeze(dim=dim)\n return out\n", - "description_1": "Use triton language to implement three kernels: reduce_all, all_kernel_dim, all_kernel_1, and all_kernel_2. The reduce_all kernel takes two arguments and returns their logical AND. The all_kernel_dim takes 6 arguments: inp, out, M, N, BLOCK_M, BLOCK_N and computes whether all elements are non-zero along specified dimensions using BLOCK_M and BLOCK_N as block sizes. The all_kernel_1 takes 5 arguments: inp, mid, n_elements, mid_size, BLOCK_SIZE and computes whether all elements are non-zero in the input, storing intermediate results in mid. The all_kernel_2 takes 4 arguments: mid, out, MID_SIZE, BLOCK_MID and reduces the intermediate results from mid to a final boolean output.", - "description_2": "Use triton language to create a set of kernels that evaluate whether all elements are non-zero in a tensor, both for the entire tensor and along specified dimensions, using logical operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel 1: amax_kernel_1\n@triton.jit\ndef amax_kernel_1(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n INT64_INDEX: tl.constexpr = False,\n):\n pid = tl.program_id(0)\n if INT64_INDEX:\n pid = pid.to(tl.int64)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n amax_val = tl.max(inp_val)\n mid_ptr = mid + pid\n tl.store(mid_ptr, amax_val)\n\n\n# Kernel 2: amax_kernel_2\n@triton.jit\ndef amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=-float(\"inf\"))\n amax_val = tl.max(mid_val)\n tl.store(out, amax_val)\n\n\n# Kernel 3: amax_kernel\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit\ndef amax_kernel(\n inp,\n out,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n INT64_INDEX: tl.constexpr = False,\n):\n # Map the program id to the row of inp it should compute.\n pid = tl.program_id(0)\n if INT64_INDEX:\n pid = pid.to(tl.int64)\n rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n inp = inp + rows * N\n out = out + rows\n row_mask = rows < M\n\n _all = tl.full([BLOCK_M, BLOCK_N], value=-float(\"inf\"), dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n\n a = tl.load(inp + cols, mask, other=-float(\"inf\")).to(tl.float32)\n _all = tl.maximum(_all, a)\n all = tl.max(_all, axis=1)[:, None]\n tl.store(out, all, row_mask)\n\n\n# Function to call the kernels\ndef amax(inp, dim=None, keepdim=False):\n logging.debug(\"GEMS AMAX\")\n if dim is None or len(dim) == 0:\n M = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n dtype = inp.dtype\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n use_int64_index = not can_use_int32_index(inp)\n if not keepdim:\n out = torch.empty([], dtype=dtype, device=inp.device)\n else:\n shape = list(inp.shape)\n for i in range(0, inp.dim()):\n shape[i] = 1\n out = torch.empty(shape, dtype=dtype, device=inp.device)\n with torch.cuda.device(inp.device):\n amax_kernel_1[(mid_size, 1)](\n inp, mid, M, block_size, INT64_INDEX=use_int64_index\n )\n amax_kernel_2[(1, 1)](\n mid, out, mid_size, block_mid\n ) # max block size is 128k, so mid does not require int64 index\n return out\n else:\n if isinstance(dim, int):\n dim = [dim]\n assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), \"Invalid dim\"\n dtype = inp.dtype\n\n shape = list(inp.shape)\n dim = [d % inp.ndim for d in dim]\n inp = dim_compress(inp, dim)\n use_int64_index = not can_use_int32_index(inp)\n N = 1\n for i in dim:\n N *= shape[i]\n shape[i] = 1\n M = inp.numel() // N\n\n out = torch.empty(shape, dtype=dtype, device=inp.device)\n\n grid = lambda meta: (triton.cdiv(M, meta[\"BLOCK_M\"]),)\n with torch.cuda.device(inp.device):\n amax_kernel[grid](inp, out, M, N, INT64_INDEX=use_int64_index)\n if not keepdim:\n out = out.squeeze(dim=dim)\n return out\n\n\n# Helper function to generate configurations for autotuning\ndef cfggen():\n block_m = [1, 2, 4, 8]\n configs = [\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": 1024}, num_warps=4) for m in block_m\n ]\n return configs\n\n", - "description_1": "Use Triton language to implement three kernels for computing the maximum values of a tensor along a given axis or across the entire tensor, utilizing blocks for parallelism. The kernels employ block size tuning and the option to use 64-bit indexing for handling larger tensor sizes. The first kernel, amax_kernel_1, calculates intermediate maximum values over blocks; the second kernel, amax_kernel_2, computes the final maximum from these intermediate results; and the third kernel, amax_kernel, directly computes the result for a general case of maximum reduction.", - "description_2": "Use Triton language to implement parallelized maximum reduction kernels (amax_kernel_1, amax_kernel_2, amax_kernel) with block-level computation and automatic block size tuning, supporting both 32-bit and 64-bit index types for large tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Simple reduction function to check if any value is non-zero\n@triton.jit\ndef reduce_any(a, b):\n return a or b\n\n# Kernel that operates on a specified dimension of the input tensor\n@triton.jit\ndef any_kernel_dim(\n inp,\n out,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of inp it should compute.\n pid = tl.program_id(0)\n rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n inp = inp + rows * N\n out = out + rows\n row_mask = rows < M\n\n _any = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int1)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask & col_mask\n\n a = tl.load(inp + cols, mask, other=0.0)\n _any = _any | (a != 0)\n any = tl.reduce(_any, axis=1, combine_fn=reduce_any)\n tl.store(out, any[:, None], row_mask)\n\n# Kernel to check if any element is non-zero across blocks of a tensor\n@triton.jit\ndef any_kernel_1(\n inp,\n mid,\n n_elements,\n mid_size,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < n_elements\n inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)\n any_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_any)\n mid_ptr = mid + pid\n tl.store(mid_ptr, any_val)\n\n# Final kernel that reduces intermediate results to a final output\n@triton.jit\ndef any_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < MID_SIZE\n mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(tl.int1)\n any_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_any)\n tl.store(out, any_val)\n\n# Function to check if any element in the input tensor is non-zero\ndef any(inp):\n n_elements = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))\n mid_size = triton.cdiv(n_elements, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)\n out = torch.empty([], dtype=torch.bool, device=inp.device)\n\n with torch.cuda.device(inp.device):\n any_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)\n any_kernel_2[(1, 1)](mid, out, mid_size, block_mid)\n\n return out\n\n# Function to check if any element in a specified dimension of the input tensor is non-zero\ndef any_dim(inp, dim=None, keepdim=False):\n shape = list(inp.shape)\n if dim is None:\n out = any(inp)\n if keepdim:\n out = torch.reshape(out, [1] * inp.ndim)\n else:\n assert dim >= -inp.ndim and dim < inp.ndim, \"Invalid dim\"\n dim = dim % inp.ndim\n inp = dim_compress(inp, dim)\n N = shape[dim]\n shape[dim] = 1\n M = inp.numel() // N\n\n out = torch.empty(shape, dtype=torch.bool, device=inp.device)\n\n grid = lambda meta: (triton.cdiv(M, meta[\"BLOCK_M\"]),)\n with torch.cuda.device(inp.device):\n any_kernel_dim[grid](inp, out, M, N)\n if not keepdim:\n out = out.squeeze(dim=dim)\n return out\n", - "description_1": "Use triton language to implement three kernels: reduce_any to check if any input is non-zero, any_kernel_dim to perform reduction along a specific dimension of input tensor, and any_kernel_1 and any_kernel_2 for block-based reduction across entire input tensor. Corresponding Python functions any and any_dim call these kernels.", - "description_2": "Use triton language to implement kernels for checking if any value is non-zero in a tensor and perform reductions along specific dimensions or across entire input tensor.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n y_ptr += pid * BLOCK_SIZE\n step_offset = pid * BLOCK_SIZE * step\n\n cols = tl.arange(0, BLOCK_SIZE)\n arange_val = cols * step + step_offset + start\n mask = cols + pid * BLOCK_SIZE\n tl.store(y_ptr + cols, arange_val, mask=mask < size)\n\ndef arange_start(\n start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None\n):\n if dtype is torch.int64:\n sgn = (step > 0) - (step < 0)\n size = (end - start + step - sgn) // step\n else:\n size = math.ceil((end - start) / step)\n\n BLOCK_SIZE = 128\n grid = triton.cdiv(size, BLOCK_SIZE)\n\n if dtype is None:\n dtype = torch.int64\n\n if pin_memory is None:\n pin_memory = False\n\n if device is None:\n device = torch.device(\"cuda\")\n\n result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)\n arange_func[grid,](result, start, end, step, size, BLOCK_SIZE)\n return result\n\ndef arange(end, *, dtype=None, layout=None, device=None, pin_memory=None):\n return arange_start(\n 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory\n )\n", - "description_1": "Use triton language to implement a kernel 'arange_func' that generates a sequence of numbers on the GPU. This kernel takes six parameters: y_ptr (output pointer), start (starting value), end (end value, unused), step (difference between consecutive values), size (total number of elements), and BLOCK_SIZE (constant block size). It calculates each element's value based on its index and stores the result in 'y_ptr'. The 'arange_start' function orchestrates the kernel launch, calculating the total size, setting up grid dimensions, and preparing the output tensor. The 'arange' function is a wrapper that calls 'arange_start' with a default start of 0 and step of 1.", - "description_2": "Use triton language to create an arange function generating sequences on the GPU with customizable start, end, and step values.", - "difficulty": 2 - }, - { - "code": "import triton\n\n@triton.jit\ndef bitwise_and_func_scalar(x, y):\n return x & y\n\ndef bitwise_and_scalar(A, B):\n return bitwise_and_func_scalar(A, B)\n\ndef bitwise_and_scalar_tensor(A, B):\n return bitwise_and_func_scalar(B, A)\n", - "description_1": "Use triton language to define a kernel 'bitwise_and_func_scalar' that performs a bitwise AND operation on two inputs 'x' and 'y'. The kernel is called by two functions: 'bitwise_and_scalar' which takes two arguments 'A' and 'B' and calls the kernel with these arguments, and 'bitwise_and_scalar_tensor' which also takes two arguments 'A' and 'B' but calls the kernel with 'B' and 'A' in reverse order.", - "description_2": "Use triton language to define a kernel for bitwise AND operation and implement two functions to call this kernel with different argument orders.", - "difficulty": 1 - }, - { - "code": "import logging\nimport triton\n\n\n# Triton kernel function\n@triton.jit\ndef bitwise_or_func(x, y):\n return x | y\n\n\n# Wrapper function for calling Triton kernel\ndef bitwise_or_tensor(A, B):\n logging.debug(\"GEMS BITWISE OR\")\n return bitwise_or_func(A, B)\n\n\n# Triton kernel function\n@triton.jit\ndef bitwise_or_func_scalar(x, y):\n return x | y\n\n\n# Wrapper function for calling Triton kernel\ndef bitwise_or_scalar(A, B):\n logging.debug(\"GEMS BITWISE OR SCALAR\")\n return bitwise_or_func_scalar(A, B)\n\n\n# Wrapper function for calling Triton kernel\ndef bitwise_or_scalar_tensor(A, B):\n logging.debug(\"GEMS BITWISE OR SCALAR TENSOR\")\n return bitwise_or_func_scalar(B, A)\n", - "description_1": "Use triton language to define two kernel functions: bitwise_or_func and bitwise_or_func_scalar. Both perform a bitwise OR operation between two inputs. The bitwise_or_func kernel operates on tensors, while bitwise_or_func_scalar is designed to handle scalar and tensor inputs. Each function is wrapped by a corresponding wrapper function (bitwise_or_tensor and bitwise_or_scalar) for easier calling, which takes in two inputs and returns the result of the kernel function.", - "description_2": "Use triton language to perform bitwise OR operation on two inputs, supporting both tensor and scalar input types.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport logging\n\n# Kernel for batched matrix multiplication\n@triton.jit\ndef bmm_kernel(\n A, B, O, M, N, K,\n TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,\n GROUP_M: tl.constexpr, DIVISIBLE_M: tl.constexpr,\n DIVISIBLE_N: tl.constexpr, DIVISIBLE_K: tl.constexpr\n):\n # batch offsets\n pid_b = tl.program_id(2)\n A += pid_b * M * K\n B += pid_b * K * N\n O += pid_b * M * N\n\n pidx = tl.program_id(0)\n pidy = tl.program_id(1)\n\n if GROUP_M == 1:\n pid_m, pid_n = pidx, pidy\n else:\n # reorder CTAs\n gridx = tl.num_programs(0)\n gridy = tl.num_programs(1)\n pid = pidx + pidy * gridx\n\n num_CTA_per_group = gridy * GROUP_M\n\n group_id = pid // num_CTA_per_group\n inner_group_id = pid % num_CTA_per_group\n if (group_id * GROUP_M + GROUP_M) > gridx:\n GROUP_SIZE = gridx % GROUP_M\n else:\n GROUP_SIZE = GROUP_M\n pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE\n pid_n = inner_group_id // GROUP_SIZE\n\n offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)\n offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)\n offs_k = tl.arange(0, TILE_K)\n\n if not DIVISIBLE_M:\n mask_m = offs_m < M\n if not DIVISIBLE_N:\n mask_n = offs_n < N\n\n a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]\n b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]\n o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]\n\n num_iters = tl.cdiv(K, TILE_K)\n o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n for _ in range(num_iters):\n if DIVISIBLE_K:\n if DIVISIBLE_M:\n mask_a = None\n else:\n mask_a = mask_m[:, None]\n if DIVISIBLE_N:\n mask_b = None\n else:\n mask_b = mask_n[None, :]\n else:\n mask_k = offs_k < K\n if DIVISIBLE_M:\n mask_a = mask_k[None, :]\n else:\n mask_a = mask_m[:, None] & mask_k[None, :]\n if DIVISIBLE_N:\n mask_b = mask_k[:, None]\n else:\n mask_b = mask_k[:, None] & mask_n[None, :]\n\n a = tl.load(a_ptrs, mask_a)\n b = tl.load(b_ptrs, mask_b)\n\n offs_k += TILE_K\n a_ptrs += TILE_K\n b_ptrs += TILE_K * N\n\n o += tl.dot(a, b, allow_tf32=False)\n\n if DIVISIBLE_M and DIVISIBLE_N:\n mask_c = None\n elif DIVISIBLE_M and not DIVISIBLE_N:\n mask_c = mask_n[None, :]\n elif not DIVISIBLE_M and DIVISIBLE_N:\n mask_c = mask_m[:, None]\n else:\n mask_c = mask_m[:, None] & mask_n[None, :]\n tl.store(o_ptrs, o, mask_c)\n\n# Function to execute the batched matrix multiplication kernel\ndef bmm(A, B):\n logging.debug(\"GEMS BMM\")\n batch, M, K = A.shape\n _, _, N = B.shape\n A = A.contiguous()\n B = B.contiguous()\n out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)\n\n grid_fn = lambda meta: (\n triton.cdiv(meta[\"M\"], meta[\"TILE_M\"]),\n triton.cdiv(meta[\"N\"], meta[\"TILE_N\"]),\n batch,\n )\n with torch.cuda.device(A.device):\n bmm_kernel[grid_fn](A, B, out, M, N, K)\n return out\n", - "description_1": "Use triton language to implement a batched matrix multiplication kernel. The kernel bmm_kernel takes 15 parameters: A, B, O (matrices), M, N, K (dimensions), TILE_M, TILE_N, TILE_K (tile sizes), GROUP_M, DIVISIBLE_M, DIVISIBLE_N, DIVISIBLE_K (group and divisibility flags). The bmm function calls this kernel and manages input matrices, setting up output storage, and configuration.", - "description_2": "Use triton language to define and call a kernel for performing batched matrix multiplication of matrices A and B with specified tiling and grouping configurations.", - "difficulty": 4 - }, - { - "code": "import triton\n\n@triton.jit\ndef copy_func(x):\n # The kernel copies the input tensor to the output tensor.\n return x\n\n\ndef cat(\n A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0\n) -> torch.Tensor:\n # Concatenate a list of tensors along a specified dimension using a Triton kernel.\n if len(A) == 0:\n raise RuntimeError(\"torch.cat(): expected a non-empty list of Tensors\")\n if len(A) == 1:\n return A[0]\n inp_shapes = [list(_.shape) for _ in A]\n inp0_shape = inp_shapes[0]\n for s in inp_shapes[1:]:\n if len(s) != len(inp0_shape):\n raise RuntimeError(\n f\"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}\"\n )\n for tensor_idx, inp_shape in enumerate(inp_shapes):\n for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):\n if idx == dim:\n continue\n elif length != common_length:\n raise RuntimeError(\n f\"Sizes of tensors must match except in dimension {dim}. \"\n f\"Expected size {common_length} but got size {length} for tensor number \"\n f\"{tensor_idx} in the list\"\n )\n\n out_shape = list(inp0_shape)\n out_shape[dim] = sum(s[dim] for s in inp_shapes)\n out0 = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device)\n out0_strides = out0.stride()\n out0_offsets = list(\n itertools.accumulate(\n [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0\n )\n )\n\n for a, out0_offset in zip(A, out0_offsets):\n in_view = StridedBuffer(a, a.shape, a.stride())\n out_view = StridedBuffer(out0, a.shape, out0.stride(), offset=out0_offset)\n copy_func.instantiate(a.ndim)(in_view, out0=out_view)\n return out0\n", - "description_1": "Use triton language to define a copy kernel named 'copy_func' which copies input tensor x to output tensor. Implement a function 'cat' to concatenate a list of torch.Tensors along a given dimension, utilizing the 'copy_func' kernel. The function 'cat' takes two parameters: A (the list or tuple of torch.Tensors to concatenate) and dim (the dimension along which to concatenate the tensors).", - "description_2": "Use triton language to implement a tensor copy kernel and a function to concatenate tensors using this kernel.", - "difficulty": 4 - }, - { - "code": "import logging\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef clamp_func_tensor(x, mini, maxi):\n # Clamp each element of x between mini and maxi.\n return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))\n\n@triton.jit\ndef clamp_func_min_tensor(x, mini):\n # Clamp each element of x to be at least mini.\n return tl.maximum(mini, x.to(tl.float32))\n\n@triton.jit\ndef clamp_func_max_tensor(x, maxi):\n # Clamp each element of x to be at most maxi.\n return tl.minimum(maxi, x.to(tl.float32))\n\ndef clamp_tensor(A, mini=None, maxi=None):\n logging.debug(\"GEMS CLAMP TENSOR\")\n if mini is None and maxi is None:\n raise ValueError(\"At least one of mini or maxi must not be None\")\n elif mini is None:\n return clamp_func_max_tensor(A, maxi)\n elif maxi is None:\n return clamp_func_min_tensor(A, mini)\n else:\n return clamp_func_tensor(A, mini, maxi)\n\n@triton.jit\ndef clamp_func(x, mini, maxi):\n # Clamp each element of x between mini and maxi.\n return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))\n\n@triton.jit\ndef clamp_func_min(x, mini):\n # Clamp each element of x to be at least mini.\n return tl.maximum(mini, x.to(tl.float32))\n\n@triton.jit\ndef clamp_func_max(x, maxi):\n # Clamp each element of x to be at most maxi.\n return tl.minimum(maxi, x.to(tl.float32))\n\ndef clamp(A, mini=None, maxi=None):\n logging.debug(\"GEMS CLAMP\")\n if mini is None and maxi is None:\n raise ValueError(\"At least one of mini or maxi must not be None\")\n elif mini is None:\n return clamp_func_max(A, maxi)\n elif maxi is None:\n return clamp_func_min(A, mini)\n else:\n return clamp_func(A, mini, maxi)\n", - "description_1": "Use triton language to implement element-wise clamping operations. The kernel 'clamp_func_tensor' takes three arguments: a tensor x, a minimum scalar mini, and a maximum scalar maxi, and clamps each element of x between mini and maxi. The kernel 'clamp_func_min_tensor' takes two arguments: a tensor x and a minimum scalar mini, and clamps each element of x to be at least mini. The kernel 'clamp_func_max_tensor' takes two arguments: a tensor x and a maximum scalar maxi, and clamps each element of x to be at most maxi. The functions 'clamp_tensor' and 'clamp' handle the logic of selecting the appropriate kernel to use based on the provided arguments mini and maxi.", - "description_2": "Use triton language to perform tensor element-wise clamping within specified bounds or to specified minima or maxima. Implement logic to determine appropriate clamping operation based on input parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"ignore_index\"])\ndef celoss_indice_kernel(\n inp_ptr,\n tgt_ptr,\n w_ptr,\n out_ptr,\n w_tgt_ptr,\n ignore_index,\n N,\n C,\n D,\n BLOCK_C: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n pid_d = tl.program_id(1)\n offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n\n tgt_ptrs = tgt_ptr + pid_n * D + offset_d\n tgt_mask = offset_d < D\n tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)\n\n ignore_mask = not (tgt == ignore_index)\n\n w_ptrs = w_ptr + tgt\n w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)\n w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d\n tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask and ignore_mask)\n\n tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, inp_mask, other=-float(\"inf\")).to(tl.float32)\n cur_max = tl.maximum(tmp_max, inp)\n cur_exp = tl.exp(inp - cur_max)\n tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp\n tmp_max = cur_max\n final_max = tl.max(tmp_max, axis=0)\n tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :])\n final_sum = tl.log(tl.sum(tmp_sum, axis=0))\n\n inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d\n inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float(\"inf\")).to(tl.float32)\n\n out = (final_sum + final_max - inp_tgt) * w_tgt\n out_ptrs = out_ptr + pid_n * D + offset_d\n tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask)\n\n@triton.jit(do_not_specialize=[\"label_smoothing\"])\ndef celoss_probability_kernel(\n inp_ptr,\n tgt_ptr,\n w_ptr,\n out_ptr,\n label_smoothing,\n N,\n C,\n D,\n BLOCK_C: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n pid_d = tl.program_id(1)\n offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n\n tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, inp_mask, other=-float(\"inf\")).to(tl.float32)\n cur_max = tl.maximum(tmp_max, inp)\n cur_exp = tl.exp(inp - cur_max)\n tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp\n tmp_max = cur_max\n final_max = tl.max(tmp_max, axis=0)[None, :]\n tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)\n final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]\n\n _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n mask = offset_c[:, None] < C and offset_d[None, :] < D\n w_ptrs = w_ptr + offset_c\n w_mask = offset_c < C\n inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)\n tgt = tl.load(tgt_ptrs, mask, other=1).to(tl.float32)\n tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C\n w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)[:, None]\n log = final_sum + final_max - inp\n _sum += w * log * tgt\n\n out = tl.sum(_sum, axis=0)\n out_ptrs = out_ptr + pid_n * D + offset_d\n tl.store(out_ptrs, out, mask=offset_d < D)\n\n@triton.jit(do_not_specialize=[\"ignore_index\", \"label_smoothing\"])\ndef celoss_indice_smooth_kernel(\n inp_ptr,\n tgt_ptr,\n w_ptr,\n out_ptr,\n w_tgt_ptr,\n ignore_index,\n label_smoothing,\n N,\n C,\n D,\n BLOCK_C: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n pid_d = tl.program_id(1)\n offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n\n tgt_ptrs = tgt_ptr + pid_n * D + offset_d\n tgt_mask = offset_d < D\n tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)\n\n ignore_mask = not (tgt == ignore_index)\n\n w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0).to(tl.float32)\n w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d\n tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask and ignore_mask)\n\n tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, mask, other=-float(\"inf\")).to(tl.float32)\n cur_max = tl.maximum(tmp_max, inp)\n cur_exp = tl.exp(inp - cur_max)\n tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp\n tmp_max = cur_max\n final_max = tl.max(tmp_max, axis=0)[None, :]\n tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)\n final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]\n\n _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n offset = offset_c[:, None] * D + offset_d[None, :]\n inp_ptrs = inp_ptr + pid_n * C * D + offset\n mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)\n\n w_ptrs = w_ptr + offset_c\n w = tl.load(w_ptrs, offset_c < C, other=0).to(tl.float32)\n\n smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)\n smooth = tl.where(\n offset_c[:, None] == tgt[None, :],\n 1 - label_smoothing + label_smoothing / C,\n smooth,\n )\n\n log = final_sum + final_max - inp\n _sum += log * smooth * w[:, None]\n\n out = tl.sum(_sum, axis=0)\n out_ptrs = out_ptr + pid_n * D + offset_d\n tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask)\n\n@triton.jit(do_not_specialize=[\"ignore_index\", \"mean_num\"])\ndef celoss_indice_bwd(\n out_grad_ptr,\n inp_ptr,\n tgt_ptr,\n w_ptr,\n inp_grad_ptr,\n ignore_index,\n mean_num,\n N,\n C,\n D,\n BLOCK_C: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n pid_d = tl.program_id(1)\n offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n\n tgt_ptrs = tgt_ptr + pid_n * D + offset_d\n tgt_mask = offset_d < D\n tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)\n out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d\n out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]\n w_ptrs = w_ptr + tgt\n w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]\n\n ignore_mask = (tgt != ignore_index)[None, :]\n\n tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, inp_mask, other=-float(\"inf\")).to(tl.float32)\n cur_max = tl.maximum(tmp_max, inp)\n cur_exp = tl.exp(inp - cur_max)\n tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp\n tmp_max = cur_max\n final_max = tl.max(tmp_max, axis=0)[None, :]\n tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)\n final_sum = tl.sum(tmp_sum, axis=0)[None, :]\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, inp_mask, other=-float(\"inf\")).to(tl.float32)\n minus_one = offset_c[:, None] == tgt[None, :]\n inp_grad = (\n (tl.exp(inp - final_max) / final_sum - minus_one)\n * w_tgt\n * out_grad\n * mean_num\n )\n inp_grad_ptrs = (\n inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n )\n tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)\n\n@triton.jit(do_not_specialize=[\"label_smoothing\", \"mean_num\"])\ndef celoss_probability_bwd(\n out_grad_ptr,\n inp_ptr,\n tgt_ptr,\n w_ptr,\n inp_grad_ptr,\n label_smoothing,\n mean_num,\n N,\n C,\n D,\n BLOCK_C: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n pid_d = tl.program_id(1)\n offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n\n out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d\n out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[\n None, :\n ]\n\n tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp = tl.load(inp_ptrs, mask, other=-float(\"inf\")).to(tl.float32)\n\n tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)\n tgt = tgt * (1 - label_smoothing) + label_smoothing / C\n\n w_ptrs = w_ptr + offset_c\n w_mask = offset_c < C\n w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)[:, None]\n\n w_tgt_sum += tgt * w\n\n cur_max = tl.maximum(tmp_max, inp)\n cur_exp = tl.exp(inp - cur_max)\n tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp\n tmp_max = cur_max\n final_max = tl.max(tmp_max, axis=0)[None, :]\n tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)\n final_sum = tl.sum(tmp_sum, axis=0)[None, :]\n w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :]\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_ptrs = inp_ptr + offset\n mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)\n\n tgt_ptrs = tgt_ptr + offset\n tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)\n tgt = tgt * (1 - label_smoothing) + label_smoothing / C\n\n w_ptrs = w_ptr + offset_c\n w_mask = offset_c < C\n w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)[:, None]\n\n grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - w * tgt\n inp_grad = grad * out_grad * mean_num\n\n inp_grad_ptrs = inp_grad_ptr + offset\n tl.store(inp_grad_ptrs, inp_grad, mask)\n\n@triton.jit(do_not_specialize=[\"ignore_index\", \"label_smoothing\", \"mean_num\"])\ndef celoss_indice_smooth_bwd(\n out_grad_ptr,\n inp_ptr,\n tgt_ptr,\n w_ptr,\n inp_grad_ptr,\n ignore_index,\n label_smoothing,\n mean_num,\n N,\n C,\n D,\n BLOCK_C: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid_n = tl.program_id(0)\n pid_d = tl.program_id(1)\n offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n\n tgt_ptrs = tgt_ptr + pid_n * D + offset_d\n tgt_mask = offset_d < D\n tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)\n out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d\n out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]\n\n ignore_mask = (tgt != ignore_index)[None, :]\n\n tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, inp_mask, other=-float(\"inf\")).to(tl.float32)\n\n w_ptrs = w_ptr + offset_c\n w_mask = offset_c < C\n w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)\n\n smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)\n smooth = tl.where(\n offset_c[:, None] == tgt[None, :],\n 1 - label_smoothing + label_smoothing / C,\n smooth,\n )\n\n w_sum += smooth * w[:, None]\n\n cur_max = tl.maximum(tmp_max, inp)\n cur_exp = tl.exp(inp - cur_max)\n tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp\n tmp_max = cur_max\n final_max = tl.max(tmp_max, axis=0)[None, :]\n tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)\n final_sum = tl.sum(tmp_sum, axis=0)[None, :]\n w_sum = tl.sum(w_sum, axis=0)[None, :]\n\n for off in range(0, C, BLOCK_C):\n offset_c = off + tl.arange(0, BLOCK_C)\n inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n inp_mask = offset_c[:, None] < C and offset_d[None, :] < D\n inp = tl.load(inp_ptrs, inp_mask, other=-float(\"inf\")).to(tl.float32)\n\n w_ptrs = w_ptr + offset_c\n w_mask = offset_c < C\n w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)\n\n smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)\n smooth = tl.where(\n offset_c[:, None] == tgt[None, :],\n 1 - label_smoothing + label_smoothing / C,\n smooth,\n )\n\n grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None]\n inp_grad = grad * out_grad * mean_num\n inp_grad_ptrs = (\n inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]\n )\n tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)\n\nclass CrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):\n\n shape = list(inp.shape)\n dim = inp.ndim\n N = 1 if dim == 1 else shape[0]\n C = shape[0] if dim == 1 else shape[1]\n D = inp.numel() // N // C\n axis = 0 if dim == 1 else 1\n del shape[axis]\n\n if weight is None:\n weight = torch.ones(\n [\n C,\n ],\n dtype=inp.dtype,\n device=inp.device,\n )\n\n inp = inp.contiguous()\n tgt = target.contiguous()\n weight = weight.contiguous()\n out = torch.zeros(shape, dtype=torch.float32, device=inp.device)\n grid = lambda meta: (N, triton.cdiv(D, meta[\"BLOCK_D\"]))\n\n if tgt.ndim == dim:\n with torch.cuda.device(inp.device):\n celoss_probability_kernel[grid](\n inp, tgt, weight, out, label_smoothing, N, C, D\n )\n elif label_smoothing == 0:\n w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)\n with torch.cuda.device(inp.device):\n celoss_indice_kernel[grid](\n inp, tgt, weight, out, w_tgt, ignore_index, N, C, D\n )\n else:\n w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)\n with torch.cuda.device(inp.device):\n celoss_indice_smooth_kernel[grid](\n inp, tgt, weight, out, w_tgt, ignore_index, label_smoothing, N, C, D\n )\n ctx.save_for_backward(inp, tgt, weight)\n ctx.N = N\n ctx.C = C\n ctx.D = D\n ctx.ignore_index = ignore_index\n ctx.label_smoothing = label_smoothing\n ctx.mean_num = 1\n ctx.shape = shape\n\n if reduction == 0: # NONE\n return out.to(inp.dtype)\n elif reduction == 1: # MEAN\n if tgt.ndim == dim:\n ctx.mean_num = 1 / (N * D)\n else:\n ctx.mean_num = 1 / sum(w_tgt).item()\n return (sum(out) * ctx.mean_num).to(inp.dtype)\n else: # SUM\n return sum(out).to(inp.dtype)\n\n @staticmethod\n def backward(ctx, out_grad):\n\n inp, tgt, weight = ctx.saved_tensors\n N = ctx.N\n C = ctx.C\n D = ctx.D\n ignore_index = ctx.ignore_index\n label_smoothing = ctx.label_smoothing\n mean_num = ctx.mean_num\n shape = ctx.shape\n\n out_grad = out_grad.broadcast_to(shape).contiguous()\n\n inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)\n grid = lambda meta: (N, triton.cdiv(D, meta[\"BLOCK_D\"]))\n if tgt.ndim == inp.ndim:\n celoss_probability_bwd[grid](\n out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, N, C, D\n )\n elif label_smoothing == 0:\n celoss_indice_bwd[grid](\n out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, N, C, D\n )\n else:\n celoss_indice_smooth_bwd[grid](\n out_grad,\n inp,\n tgt,\n weight,\n inp_grad,\n ignore_index,\n label_smoothing,\n mean_num,\n N,\n C,\n D,\n )\n return inp_grad, None, None, None, None, None\n\ndef cross_entropy_loss(\n inp, target, weight=None, reduction=1, ignore_index=-100, label_smoothing=0.0\n):\n return CrossEntropyLoss.apply(\n inp, target, weight, reduction, ignore_index, label_smoothing\n )\n", - "description_1": "Use triton language to implement cross-entropy loss calculation and its backward pass for a tensor input, supporting both target indices and target probabilities with optional label smoothing and ignore index.", - "description_2": "Use triton language to define kernels that compute the cross-entropy loss and gradients for various modes (index or probability targets, with label smoothing), and integrate them into a PyTorch custom autograd function.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"n_elements\", \"part_num\"])\ndef scan_part_sum_kernel(\n inp, out, partial_sum, n_elements, part_num, BLOCK_SIZE: tl.constexpr\n):\n # Kernel for calculating partial sums within a block\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < n_elements\n inp_ptrs = inp + offset\n inp_vals = tl.load(inp_ptrs, mask=mask)\n if (\n tl.constexpr(inp_vals.dtype.is_int64())\n or tl.constexpr(inp_vals.dtype.is_uint64())\n ) or tl.constexpr(inp_vals.dtype.is_fp64()):\n inp_vals = inp_vals\n elif tl.constexpr(inp_vals.dtype.is_int()):\n inp_vals = inp_vals.to(tl.int32)\n else:\n inp_vals = inp_vals.to(tl.float32)\n result = tl.cumsum(inp_vals, axis=0)\n part_sum_via_sum = tl.sum(inp_vals)\n out_ptrs = out + offset\n tl.store(out_ptrs, result, mask=mask)\n partial_sum_ptrs = partial_sum + pid\n tl.store(partial_sum_ptrs, part_sum_via_sum)\n\n@triton.jit(do_not_specialize=[\"n_elements\", \"part_num\"])\ndef add_base_sum_kernel(\n out, partial_sum, n_elements, part_num, BLOCK_SIZE: tl.constexpr\n):\n # Kernel for adding base sum to each part's sum\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < n_elements\n out_ptrs = out + offset\n out_vals = tl.load(out_ptrs, mask=mask)\n if pid > 0:\n partial_sum_ptrs = partial_sum + pid - 1\n last_part_sum_via_sum = tl.load(partial_sum_ptrs)\n final_vals = out_vals + last_part_sum_via_sum\n tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)\n\ndef scan_then_fan_col(inp, out, n_ele, dtype):\n BLOCK_SIZE = 1024\n if n_ele <= 1024 * 4:\n BLOCK_SIZE = triton.next_power_of_2(n_ele)\n part_num = math.ceil(n_ele / BLOCK_SIZE)\n partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)\n grid = (part_num,)\n with torch.cuda.device(inp.device):\n scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)\n if part_num >= 2:\n scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)\n with torch.cuda.device(inp.device):\n add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)\n", - "description_1": "Use triton language to implement a cumulative sum operation on input tensors. The operation divides the input into blocks, computes partial sums within each block, and then adds base sums across blocks. It consists of two main kernels: `scan_part_sum_kernel` which computes the cumulative sums within each block and stores partial sums, and `add_base_sum_kernel` which adjusts each block's result by adding in the sum of the previous blocks. This is wrapped in a Python function `scan_then_fan_col` that handles block size determination and grid configuration.", - "description_2": "Use triton language to compute block-wise cumulative sums on input data, accumulating results across blocks.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _int_floordiv(x, y):\n r = x % y\n c1 = r != 0\n c2 = (x < 0) ^ (y < 0)\n return tl.where(c1 & c2, x // y - 1, x // y)\n\n@triton.jit\ndef _float_floordiv(x, y):\n remainder = fmod(x, y)\n imperfect = remainder != 0.0\n different_sign = (x < 0) ^ (y < 0)\n q = div_rn(x - remainder, y)\n q = tl.where(imperfect & different_sign, q - 1, q)\n floor_q = tl.math.floor(q)\n c = q - floor_q > 0.5\n floor_q = tl.where(c, floor_q + 1.0, floor_q)\n q_is_zeros = q == 0.0\n floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)\n is_div_by_zero = y == 0.0\n float_division = x / y\n out = tl.where(is_div_by_zero, float_division, floor_q)\n return out\n\n@triton.jit\ndef _remainder(x, y):\n r = x % y\n c1 = r != 0\n c2 = (x < 0) ^ (y < 0)\n return tl.where(c1 & c2, r + y, r)\n\ndef true_divide(A, B):\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return true_div_func(A, B)\n elif isinstance(A, torch.Tensor):\n return true_div_func_tensor_scalar(A, B)\n elif isinstance(B, torch.Tensor):\n return true_div_func_scalar_tensor(A, B)\n else:\n return torch.tensor(A / B)\n\ndef trunc_divide(A, B):\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return trunc_div_func(A, B)\n elif isinstance(A, torch.Tensor):\n return trunc_div_func_tensor_scalar(A, B)\n elif isinstance(B, torch.Tensor):\n return trunc_div_func_scalar_tensor(A, B)\n else:\n return torch.tensor(A / B)\n\ndef floor_divide(A, B):\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return floor_div_func(A, B)\n elif isinstance(A, torch.Tensor):\n return floor_div_func_tensor_scalar(A, B)\n elif isinstance(B, torch.Tensor):\n return floor_div_func_scalar_tensor(A, B)\n else:\n return torch.tensor(A // B)\n\ndef remainder(A, B):\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return rem_tt(A, B)\n elif isinstance(A, torch.Tensor):\n return rem_ts(A, B)\n elif isinstance(B, torch.Tensor):\n return rem_st(A, B)\n else:\n return torch.tensor(A % B)\n", - "description_1": "Use triton language to create kernels for performing integer floor division, float floor division, and remainder operations. The operations are defined for both tensor-tensor and mixed tensor-scalar inputs. The kernels leverage Triton's parallel programming capabilities to execute efficiently on GPU.", - "description_2": "Use triton language to implement kernels for tensor division and remainder operations, handling both integer and float inputs with different combinations of tensor and scalar values.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float\n\ndef heur_block(args):\n if args[\"N\"] <= 512:\n return 512\n else:\n return 1024\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 512:\n return 4\n elif args[\"N\"] <= 1024:\n return 8\n else:\n return 16\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"p\", \"philox_seed\", \"philox_offset\"])\ndef dropout_forward_kernel(\n X,\n Y,\n N,\n p,\n philox_seed,\n philox_offset,\n BLOCK: tl.constexpr,\n):\n UNROLL: tl.constexpr = 4 # philox generates 128 random bits at a time\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n r0 = uint_to_uniform_float(r0)\n r1 = uint_to_uniform_float(r1)\n r2 = uint_to_uniform_float(r2)\n r3 = uint_to_uniform_float(r3)\n\n mask0 = r0 > p\n mask1 = r1 > p\n mask2 = r2 > p\n mask3 = r3 > p\n p = 1.0 / (1.0 - p)\n\n off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n\n x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy=\"evict_first\")\n x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy=\"evict_first\")\n x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy=\"evict_first\")\n x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy=\"evict_first\")\n\n y0 = x0 * p * mask0\n y1 = x1 * p * mask1\n y2 = x2 * p * mask2\n y3 = x3 * p * mask3\n\n tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"p\", \"philox_seed\", \"philox_offset\"])\ndef dropout_backward_kernel(\n DY,\n DX,\n N,\n p,\n philox_seed,\n philox_offset,\n BLOCK: tl.constexpr,\n):\n UNROLL = 4\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n r0 = uint_to_uniform_float(r0)\n r1 = uint_to_uniform_float(r1)\n r2 = uint_to_uniform_float(r2)\n r3 = uint_to_uniform_float(r3)\n\n mask0 = r0 > p\n mask1 = r1 > p\n mask2 = r2 > p\n mask3 = r3 > p\n off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n\n dy_0 = tl.load(DY + off_0, mask=off_0 < N, other=0.0, eviction_policy=\"evict_first\")\n dy_1 = tl.load(DY + off_1, mask=off_1 < N, other=0.0, eviction_policy=\"evict_first\")\n dy_2 = tl.load(DY + off_2, mask=off_2 < N, other=0.0, eviction_policy=\"evict_first\")\n dy_3 = tl.load(DY + off_3, mask=off_3 < N, other=0.0, eviction_policy=\"evict_first\")\n\n p = 1.0 / (1.0 - p)\n dx_0 = p * dy_0 * mask0\n dx_1 = p * dy_1 * mask1\n dx_2 = p * dy_2 * mask2\n dx_3 = p * dy_3 * mask3\n\n tl.store(DX + off_0, dx_0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(DX + off_1, dx_1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(DX + off_2, dx_2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(DX + off_3, dx_3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\n\nUNROLL = 4\n\nclass NativeDropout(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, p, train):\n device = x.device\n x = x.contiguous()\n out = torch.empty_like(x)\n N = x.numel()\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n increment = triton.cdiv(N, UNROLL)\n with torch.cuda.device(device):\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset)\n ctx.p = p\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n return out, None\n\n @staticmethod\n def backward(ctx, grad_outputs, kwargs):\n device = grad_outputs.device\n grad_outputs = grad_outputs.contiguous()\n grad_inputs = torch.empty_like(grad_outputs)\n N = grad_outputs.numel()\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n with torch.cuda.device(device):\n dropout_backward_kernel[grid_fn](\n grad_outputs, grad_inputs, N, ctx.p, ctx.philox_seed, ctx.philox_offset\n )\n return grad_inputs, None, None\n\n\ndef native_dropout(x, p=0.5, train=True):\n return NativeDropout.apply(x, p, train)\n", - "description_1": "Use triton language to implement dropout forward and backward kernels. The forward kernel (dropout_forward_kernel) takes 6 parameters: X (input tensor), Y (output tensor), N (number of elements), p (dropout probability), philox_seed (random seed), and philox_offset (random offset). It performs dropout by generating random masks using Philox RNG, scaling the input X, and storing the result in Y. The backward kernel (dropout_backward_kernel) takes the gradient of Y (DY), the gradient of X (DX), and similar parameters as the forward kernel to backpropagate through the dropout operation.", - "description_2": "Use triton language to create dropout kernels for forward and backward passes, using Philox RNG for mask generation and applying dropout scaling.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef embedding_kernel(\n out_ptr, # pointer to the output\n in_ptr, # pointer to the input\n weight_ptr, # pointer to the weights\n N: tl.constexpr, # number of columns in X\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n out_ptr += pid * N\n in_ptr += pid\n\n mask = tl.arange(0, BLOCK_SIZE) < N\n cols = tl.arange(0, BLOCK_SIZE)\n\n row_idx = tl.load(in_ptr)\n weight_ptr += row_idx * N\n embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0)\n tl.store(out_ptr + cols, embedding_weight, mask)\n\n\n@triton.jit\ndef indice_freq_kernel(\n indices_freq,\n indices, # pointer to the input\n elem_cnt: tl.constexpr, # number of columns in X\n INDICE_BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n block_start = pid * INDICE_BLOCK_SIZE\n\n offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)\n mask = offsets < elem_cnt\n\n index_element = tl.load(indices + offsets, mask=mask)\n tl.atomic_add(indices_freq + index_element, 1, mask=mask)\n\n\n@triton.jit(do_not_specialize=[\"padding_idx\"])\ndef embedding_backward_kernel(\n grad_in, # pointer to the gradient input\n grad_out, # pointer to the gradient output\n indices, # pointer to the input\n padding_idx, # padding_idx\n HAS_PADDING_IDX: tl.constexpr,\n N: tl.constexpr, # number of columns in X\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n grad_out += pid * N\n indices += pid\n\n mask = tl.arange(0, BLOCK_SIZE) < N\n cols = tl.arange(0, BLOCK_SIZE)\n\n row_idx = tl.load(indices).to(tl.int32)\n if not HAS_PADDING_IDX:\n grad_in += row_idx * N\n embedding_grad = tl.load(grad_out + cols, mask, other=0.0)\n tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)\n else:\n if row_idx != padding_idx:\n grad_in += row_idx * N\n embedding_grad = tl.load(grad_out + cols, mask, other=0.0)\n tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)\n\n\n@triton.jit(do_not_specialize=[\"n_rows\"])\ndef embedding_grad_scale_kernel(\n grad_out,\n indice_freq,\n n_rows,\n N,\n BLOCK_SIZE: tl.constexpr,\n):\n row_start = tl.program_id(0)\n row_step = tl.num_programs(0)\n\n for row_idx in range(row_start, n_rows, row_step):\n embedding_scale = 1.0\n indice_freq_val = tl.load(indice_freq + row_idx)\n if indice_freq_val > 1:\n embedding_scale = 1.0 / indice_freq_val\n\n cols = tl.arange(0, BLOCK_SIZE)\n mask = tl.arange(0, BLOCK_SIZE) < N\n embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)\n scaled_embedding_grad = embedding_grad * embedding_scale\n tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)\n\n\nclass Embedding(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False\n ):\n M = math.prod(indices.shape)\n N = weight.shape[-1]\n\n BLOCK_SIZE = triton.next_power_of_2(N)\n indices = indices.contiguous()\n weight = weight.contiguous()\n output = torch.empty(\n (*indices.shape, N), device=indices.device, dtype=weight.dtype\n )\n\n with torch.cuda.device(weight.device):\n embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)\n\n ctx.M = M\n ctx.N = N\n ctx.num_weights = weight.shape[0]\n ctx.padding_idx = padding_idx\n ctx.scale_grad_by_freq = scale_grad_by_freq\n ctx.sparse = sparse\n ctx.indices = indices\n\n return output\n\n @staticmethod\n def backward(ctx, grad_outputs):\n grad_inputs = torch.zeros(\n (ctx.num_weights, grad_outputs.shape[-1]),\n device=grad_outputs.device,\n dtype=grad_outputs.dtype,\n )\n\n if ctx.scale_grad_by_freq:\n indice_freq = torch.zeros(\n (ctx.num_weights,),\n requires_grad=False,\n device=grad_outputs.device,\n dtype=torch.int32,\n )\n INDICE_BLOCK_SIZE = 256\n indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),)\n\n with torch.cuda.device(grad_outputs.device):\n indice_freq_kernel[indice_grid](\n indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE\n )\n else:\n indice_freq = None\n\n BLOCK_SIZE = triton.next_power_of_2(ctx.N)\n\n HAS_PADDING_IDX = ctx.padding_idx is not None\n\n with torch.cuda.device(grad_outputs.device):\n embedding_backward_kernel[ctx.M,](\n grad_inputs,\n grad_outputs,\n ctx.indices,\n ctx.padding_idx,\n HAS_PADDING_IDX,\n ctx.N,\n BLOCK_SIZE,\n )\n\n if ctx.scale_grad_by_freq:\n with torch.cuda.device(grad_outputs.device):\n embedding_grad_scale_kernel[ctx.M,](\n grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE\n )\n return grad_inputs, None, None, None, None\n\n\ndef embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):\n return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse)\n", - "description_1": "Use triton language to implement kernels for embedding operations in neural networks. The kernels are designed to handle forward and backward passes for embeddings with options for scaling gradients by frequency and handling padding indices. The embedding_kernel has parameters: output pointer, input pointer, weights pointer, number of columns in X, and block size. The indice_freq_kernel calculates the frequency of indices with parameters: indices frequency pointer, input pointer, element count, and block size. The embedding_backward_kernel computes gradients with parameters: gradient input pointer, gradient output pointer, input indices pointer, padding index, flag for padding index, number of columns in X, and block size. Lastly, embedding_grad_scale_kernel scales the gradient based on frequency with parameters: gradient output pointer, indices frequency pointer, number of rows, number of columns in X, and block size.", - "description_2": "Use triton language to implement multiple kernels for forward and backward embedding operations in neural networks, supporting gradient scaling by frequency and padding index handling.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport logging\n\n# Triton kernel for element-wise equality check\n@triton.jit\ndef eq_func(x, y):\n return x.to(tl.float32) == y.to(tl.float32)\n\n# Function to call the eq_func kernel\ndef eq(A, B):\n if A.device != B.device:\n if A.device.type == \"cuda\":\n B = B.to(A.device)\n else:\n A = A.to(B.device)\n logging.debug(\"GEMS EQ\")\n return eq_func(A, B)\n\n# Triton kernel for element-wise equality check with a scalar\n@triton.jit\ndef eq_func_scalar(x, y):\n return x.to(tl.float32) == y.to(tl.float32)\n\n# Function to call the eq_func_scalar kernel\ndef eq_scalar(A, B):\n logging.debug(\"GEMS EQ SCALAR\")\n return eq_func_scalar(A, B)\n", - "description_1": "Use triton language to implement two kernels: eq_func and eq_func_scalar. The eq_func kernel takes two tensor arguments, x and y, and returns a tensor indicating element-wise equality after converting both to float32. The eq function calls eq_func, ensuring both tensors are on the same device. The eq_func_scalar kernel also takes two arguments, x (tensor) and y (scalar), and performs a similar equality check. The eq_scalar function calls eq_func_scalar.", - "description_2": "Use triton language to implement element-wise equality check kernels for tensors and tensor-scalar pairs, ensuring device compatibility and type conversion to float32.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef heur_block(args):\n if args[\"N\"] <= 512:\n return 512\n else:\n return 1024\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 512:\n return 4\n elif args[\"N\"] <= 1024:\n return 8\n else:\n return 16\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"philox_seed\", \"philox_offset\", \"N\"])\ndef fused_exponential_kernel(\n out_ptr,\n N,\n is_double,\n lambd,\n eps,\n philox_seed,\n philox_offset,\n BLOCK: tl.constexpr,\n):\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n if is_double:\n d0 = uint_to_uniform_float(paste_u64(r0, r2))\n d1 = uint_to_uniform_float(paste_u64(r1, r3))\n y0 = transform_exponential(d0, lambd, eps)\n y1 = transform_exponential(d1, lambd, eps)\n UNROLL = 2\n start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL\n off_0 = start + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy=\"evict_first\")\n else:\n f0 = uint_to_uniform_float(r0)\n f1 = uint_to_uniform_float(r1)\n f2 = uint_to_uniform_float(r2)\n f3 = uint_to_uniform_float(r3)\n y0 = transform_exponential(f0, lambd, eps)\n y1 = transform_exponential(f1, lambd, eps)\n y2 = transform_exponential(f2, lambd, eps)\n y3 = transform_exponential(f3, lambd, eps)\n UNROLL = 4\n start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL\n off_0 = start + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\n@triton.jit\ndef paste_u64(hi: tl.uint32, lo: tl.uint32):\n hi = hi.to(tl.uint64) << 32\n x = hi | lo.to(tl.uint64)\n return x\n\n@triton.jit\ndef transform_exponential(u, lambd, eps):\n eps1 = -0.5 * eps\n is_min = u >= 1.0 + eps1\n log = tl.where(is_min, eps1, tl.math.log(u))\n v = -1.0 / lambd * log\n return v\n\ndef exponential_(x, lambd: float = 1.0, *, gen=None):\n dtype = x.dtype\n device = x.device\n inplace = x.is_contiguous()\n assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)\n is_double = dtype in (torch.float64,)\n UNROLL = 2 if is_double else 4\n N = x.numel()\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n increment = triton.cdiv(N, UNROLL)\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n eps = torch.finfo(dtype).eps\n x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)\n with torch.cuda.device(device):\n fused_exponential_kernel[grid_fn](\n x_, N, is_double, lambd, eps, philox_seed, philox_offset\n )\n if not inplace:\n x.copy_(x_)\n return x\n", - "description_1": "Use triton language to implement an exponential random number generator kernel (fused_exponential_kernel) with parameters: out_ptr (output pointer), N (total number of elements), is_double (boolean indicating if data type is double), lambd (rate parameter for exponential distribution), eps (machine epsilon for numerical stability), philox_seed (seed for random number generation), philox_offset (offset for random number generation), and BLOCK (block size for execution). The kernel uses helper functions paste_u64 (to combine two 32-bit integers into a 64-bit integer) and transform_exponential (to apply exponential transformation) and is invoked in the exponential_ function which manages input tensors and grid configuration.", - "description_2": "Use triton language to create an exponential distribution kernel with random number generation and block-level parallelization, incorporating transformations and storage of the generated values.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"value_scalar\"])\ndef fill_scalar_kernel(\n out_ptr,\n N,\n value_scalar,\n BLOCK_SIZE: tl.constexpr,\n):\n # Kernel to fill a tensor with a scalar value.\n pid = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE)\n offset = pid * BLOCK_SIZE + cols\n tl.store(out_ptr + offset, value_scalar, mask=offset < N)\n\n@triton.jit\ndef fill_tensor_kernel(\n out_ptr,\n N,\n value_ptr,\n BLOCK_SIZE: tl.constexpr,\n):\n # Kernel to fill a tensor with values from another tensor.\n pid = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE)\n offset = pid * BLOCK_SIZE + cols\n value_scalar = tl.load(value_ptr) # load the value from the tensor.\n tl.store(out_ptr + offset, value_scalar, mask=offset < N)\n\ndef fill_tensor(input, value):\n out = torch.empty_like(input)\n N = out.numel()\n BLOCK_SIZE = 512\n grid = triton.cdiv(N, BLOCK_SIZE)\n\n with torch.cuda.device(input.device):\n fill_tensor_kernel[grid,](out, N, value, BLOCK_SIZE)\n return out\n\ndef fill_scalar(input, value):\n out = torch.empty_like(input)\n N = out.numel()\n BLOCK_SIZE = 512\n grid = triton.cdiv(N, BLOCK_SIZE)\n\n with torch.cuda.device(input.device):\n fill_scalar_kernel[grid,](out, N, value, BLOCK_SIZE)\n return out\n", - "description_1": "Use triton language to implement two kernels: one for filling a tensor with a scalar value and the other for filling a tensor with values from another tensor. The fill_scalar_kernel takes an output pointer, the number of elements N, a scalar value, and block size as inputs. It computes the offset and fills the tensor with the scalar value. The fill_tensor_kernel takes an output pointer, the number of elements N, a pointer to the value tensor, and block size as inputs. It loads a value from the input tensor and fills the output tensor using this value. Both kernels make use of Triton's parallel computation capabilities by dividing the workload into blocks using grid-based execution.", - "description_2": "Use triton language to create a kernel that fills a tensor with a scalar value. Use triton language to create a kernel that fills a tensor using values from another tensor.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n@triton.jit\ndef copy_func(x):\n return x\n\ndef flip(A: torch.Tensor, dims) -> torch.Tensor:\n strides = list(A.stride())\n flip_dims_b = [False for _ in A.stride()]\n for dim in dims:\n assert (\n dim >= -A.dim() and dim < A.dim()\n ), \"Dimension out of range (expected to be in range of [{}, {}], but got {})\".format(\n -A.dim(), A.dim() - 1, dim\n )\n assert not flip_dims_b[\n dim\n ], \"dim {} appears multiple times in the list of dims\".format(dim)\n flip_dims_b[dim] = True\n n = 0\n offset = 0\n for i in range(len(flip_dims_b)):\n if flip_dims_b[i] and A.size(i) > 1 and A.stride(i) != 0:\n offset += strides[i] * (A.shape[i] - 1)\n strides[i] = -strides[i]\n n += 1\n if n == 0 or A.numel() <= 1:\n return A.clone()\n out = torch.empty_like(A)\n flipped_A = StridedBuffer(A, strides=strides, offset=offset)\n overload = copy_func.instantiate(A.ndim)\n overload(flipped_A, out0=out)\n return out\n", - "description_1": "Use triton language to define a kernel 'copy_func' that takes one parameter 'x' and returns it. Define a function 'flip' that takes a PyTorch tensor 'A' and a list of dimensions 'dims'. It calculates the strides and offsets to create a flipped view of 'A' and uses the 'copy_func' kernel to copy the flipped view into an output tensor.", - "description_2": "Use triton language to create a kernel that copies input data. Implement a function to flip a tensor along specified dimensions using this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flag_gems.utils.shape_utils import volume\n\n# Triton kernel to fill a tensor with a specified value\n@triton.jit(do_not_specialize=[\"fill_value\"])\ndef full_kernel(\n output_ptr, # Pointer to the output tensor\n n_elements, # Total number of elements in the tensor\n fill_value, # Value to fill the tensor with\n BLOCK_SIZE: tl.constexpr, # Size of each block\n):\n pid = tl.program_id(axis=0) # Program ID for the current block\n block_start = pid * BLOCK_SIZE # Start index for the current block\n offsets = block_start + tl.arange(0, BLOCK_SIZE) # Offsets for the current block\n mask = offsets < n_elements # Mask to ensure we don't write out of bounds\n tl.store(output_ptr + offsets, fill_value, mask=mask) # Store the fill value\n\n# Function to create a tensor filled with a specified value using the Triton kernel\ndef full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):\n if dtype is None:\n dtype = torch.get_default_dtype() # Use default dtype if not specified\n if device is None:\n device = torch.device(\"cuda\") # Use CUDA device if not specified\n\n out = torch.empty(size, device=device, dtype=dtype) # Create an empty tensor\n N = volume(size) # Calculate the total number of elements\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK_SIZE\"]),) # Define grid size\n with torch.cuda.device(device):\n full_kernel[grid_fn](out, N, fill_value, BLOCK_SIZE=1024) # Launch the kernel\n return out\n", - "description_1": "Use triton language to implement a kernel 'full_kernel' that fills a tensor with a specified value. The kernel takes four parameters: a pointer to the output tensor, the total number of elements, the fill value, and the block size. The function 'full' wraps this kernel to create a tensor of a given size and fill it with the specified value, using optional parameters for data type and device.", - "description_2": "Use triton language to create a kernel that fills a tensor with a specified value and a function to wrap this kernel for tensor creation and filling.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _gather_jit_function(\n inp, out, index,\n inp_stride_0: int, inp_stride_1: int,\n index_stride_0: int, index_stride_1: int,\n index_shape_0: int, index_shape_1: int,\n dim, stride_dim, M, N,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr\n):\n pid_x = tl.program_id(0)\n pid_y = tl.program_id(1)\n rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]\n rows_mask = rows_offsets < M\n cols_mask = cols_offsets < N\n\n offsets = (rows_offsets * N + cols_offsets).to(tl.int64)\n mask = rows_mask & cols_mask\n\n # 1. Calculate inp_offsets and idx_offsets\n inp_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)\n idx_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)\n cur_idx = rows_offsets * N + cols_offsets\n\n # 2. snippets\n mod = cur_idx % index_shape_0\n inp_offsets += mod * inp_stride_0\n idx_offsets += mod * index_stride_0\n cur_idx //= index_shape_0\n\n mod = cur_idx % index_shape_1\n inp_offsets += mod * inp_stride_1\n idx_offsets += mod * index_stride_1\n\n # Use offsets to gather\n cur_index = tl.load(index + idx_offsets, mask=mask, other=0)\n inp_offsets += cur_index * stride_dim\n cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)\n tl.store(out + idx_offsets, cur_inp, mask=mask)\n\ndef gather(inp, dim, index, out=None, sparse_grad=False):\n inp = inp.contiguous()\n index = index.contiguous()\n if out is None:\n out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)\n out = out.contiguous()\n stride_dim = inp.stride(dim)\n\n inp_strided = restride_dim(inp, dim, index.shape)\n N = list(index.shape)[index.ndim - 1]\n M = index.numel() // N\n\n _gather_func(inp_strided, out, index, dim, stride_dim, M, N)\n return out\n", - "description_1": "Use triton language to implement a gather kernel that takes input tensor `inp`, output tensor `out`, and index tensor `index` with given strides and shapes. It uses BLOCK_M and BLOCK_N as constexpr values to determine block sizes for processing, iterates over tensor indices to compute offsets, and gathers values from `inp` using these offsets into `out` based on `index`.", - "description_2": "Use triton language to create a gather operation on tensor inputs utilizing block processing and offsets computation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Triton kernel that compares two tensors element-wise.\n@triton.jit\ndef ge_func(x, y):\n return x.to(tl.float32) >= y\n\n# Triton kernel that compares a tensor with a scalar element-wise.\n@triton.jit\ndef ge_func_scalar(x, y):\n return x.to(tl.float32) >= y\n\ndef ge(A, B):\n return ge_func(A, B)\n\ndef ge_scalar(A, B):\n return ge_func_scalar(A, B)\n", - "description_1": "Use triton language to define two kernels: 'ge_func' and 'ge_func_scalar'. 'ge_func' takes two arguments (x, y) which are tensors and returns a tensor where each element is the result of the element-wise comparison (>=) of x and y, both cast to float32. 'ge_func_scalar' also takes two arguments (x, y) where x is a tensor and y is a scalar, performing an element-wise (>=) comparison, returning a tensor.", - "description_2": "Use triton language to define kernels for element-wise tensor and scalar comparisons using '>='.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.libdevice import erf, exp, pow, tanh\n\n@triton.jit\ndef gelu_none(x):\n scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)\n output = 0.5 * x * (1 + erf(x * scale))\n return output\n\n@triton.jit\ndef gelu_tanh(x):\n output = (\n 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2))))\n )\n return output\n\n@triton.jit\ndef gelu_backward_none(x, dy):\n scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)\n scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)\n x_fp32 = x.to(tl.float32)\n dydx = (\n scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))\n + 0.5 * erf(scale1 * x_fp32)\n + 0.5\n )\n dx = dydx * dy\n return dx\n\n@triton.jit\ndef gelu_backward_tanh(x, dy):\n x_fp32 = x.to(tl.float32)\n # 0.79788456 = math.sqrt(2 / math.pi)\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2)))\n dydx = 0.5 * x * (\n (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))\n ) + 0.5 * (1 + tanh_out)\n dx = dydx * dy\n return dx\n\nclass Gelu(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A, approximate):\n if approximate == \"tanh\":\n out = gelu_tanh(A)\n else:\n out = gelu_none(A)\n ctx.save_for_backward(A)\n ctx.approximate = approximate\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n (inp,) = ctx.saved_tensors\n approximate = ctx.approximate\n if approximate == \"tanh\":\n in_grad = gelu_backward_tanh(inp, out_grad)\n else:\n in_grad = gelu_backward_none(inp, out_grad)\n return in_grad, None\n\ndef gelu(A, *, approximate=\"none\"):\n return Gelu.apply(A, approximate)\n", - "description_1": "Use triton language to implement GELU activation and its backward pass. The kernels include gelu_none(x) for standard GELU, gelu_tanh(x) for approximate GELU using tanh, gelu_backward_none(x, dy) for the backward pass of standard GELU, and gelu_backward_tanh(x, dy) for the backward pass of approximate GELU. The Gelu class wraps these kernels for use in PyTorch's autograd system.", - "description_2": "Use triton language to implement GELU activation and its backward pass with both standard and tanh approximations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.extra.cuda.libdevice import rsqrt\n\n@triton.jit\ndef group_norm_kernel(\n X,\n Y,\n W,\n B,\n Mean,\n Rstd,\n group_size,\n C,\n HW,\n num_groups,\n eps,\n BLOCK_GROUP_SIZE: tl.constexpr,\n BLOCK_HW_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n group = pid % num_groups\n num_elements = group_size * HW\n group_offset = tl.arange(0, BLOCK_GROUP_SIZE)\n hw_offset = tl.arange(0, BLOCK_HW_SIZE)\n\n wb_offset = group * group_size + group_offset\n wb_mask = wb_offset < C\n W_ptr = W + wb_offset\n B_ptr = B + wb_offset\n\n xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]\n xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW\n\n Mean_ptr = Mean + pid\n Rstd_ptr = Rstd + pid\n\n X_ptr = X + xy_offset\n Y_ptr = Y + xy_offset\n\n X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)\n mean = tl.sum(X_val) / num_elements\n x = tl.where(xy_mask, X_val - mean, 0.0)\n\n var = tl.sum(x * x) / num_elements\n rstd = rsqrt(var + eps)\n x_hat = x * rstd\n\n weight = tl.load(W_ptr, mask=wb_mask, other=0.0)[:, None]\n bias = tl.load(B_ptr, mask=wb_mask, other=0.0)[:, None]\n Y_val = x_hat * weight + bias\n\n tl.store(Y_ptr, Y_val, mask=xy_mask)\n tl.store(Mean_ptr, mean)\n tl.store(Rstd_ptr, rstd)\n\n@triton.jit\ndef group_norm_backward_kernel(\n grad_y,\n X,\n W,\n Mean,\n Rstd,\n num_groups,\n group_size,\n grad_x,\n C,\n HW,\n BLOCK_GROUP_SIZE: tl.constexpr,\n BLOCK_HW_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n group = pid % num_groups\n num_elements = group_size * HW\n\n group_offset = tl.arange(0, BLOCK_GROUP_SIZE)\n hw_offset = tl.arange(0, BLOCK_HW_SIZE)\n wb_offset = group * group_size + group_offset\n\n wb_mask = wb_offset < C\n W_ptr = W + wb_offset\n\n xy_offset = pid * num_elements + group_offset[:, None] * HW + hw_offset[None, :]\n xy_mask = wb_offset[:, None] < C and hw_offset[None, :] < HW\n\n Mean_ptr = Mean + pid\n Rstd_ptr = Rstd + pid\n X_ptr = X + xy_offset\n dY_ptr = grad_y + xy_offset\n dX_ptr = grad_x + xy_offset\n\n rstd = tl.load(Rstd_ptr).to(tl.float32)\n mean = tl.load(Mean_ptr).to(tl.float32)\n dY_val = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)\n X_val = tl.load(X_ptr, mask=xy_mask, other=0.0).to(tl.float32)\n weight = tl.load(W_ptr, mask=wb_mask, other=0.0).to(tl.float32)[:, None]\n\n dx_hat = weight * dY_val\n\n x = tl.where(xy_mask, X_val - mean, 0.0)\n\n grad_std = tl.sum(dx_hat * x)\n grad_var = grad_std * -(0.5 * rstd * rstd * rstd) / (HW * group_size)\n grad_distance = 2 * x * grad_var\n grad_centered_mean = dx_hat * rstd + grad_distance\n grad_mean = -tl.sum(grad_centered_mean) / num_elements\n grad_X = grad_centered_mean + grad_mean\n tl.store(dX_ptr, grad_X, mask=xy_mask)\n\n@triton.jit\ndef weight_bias_backward_kernel(\n dY,\n X,\n Mean,\n Rstd,\n dW,\n dB,\n num_groups,\n group_size,\n N,\n C,\n HW,\n BLOCK_N: tl.constexpr,\n BLOCK_HW: tl.constexpr,\n):\n pid = tl.program_id(0)\n group = pid // group_size\n n_offset = tl.arange(0, BLOCK_N)\n hw_offset = tl.arange(0, BLOCK_HW)\n xy_mask = n_offset[:, None] < N and hw_offset[None, :] < HW\n mr_mask = n_offset < N\n\n dW_ptr = dW + pid\n dB_ptr = dB + pid\n\n mean_ptr = Mean + group + n_offset * num_groups\n rstd_ptr = Rstd + group + n_offset * num_groups\n\n dY_ptr = dY + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]\n x_ptr = X + pid * HW + n_offset[:, None] * C * HW + hw_offset[None, :]\n\n grad_y = tl.load(dY_ptr, mask=xy_mask, other=0.0).to(tl.float32)\n x = tl.load(x_ptr, mask=xy_mask, other=0.0)\n x_f32 = x.to(tl.float32)\n mean = tl.load(mean_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]\n rstd = tl.load(rstd_ptr, mask=mr_mask, other=0.0).to(tl.float32)[:, None]\n\n dB = tl.sum(grad_y)\n dW = tl.sum((x_f32 - mean) * rstd * grad_y)\n tl.store(dW_ptr, dW.to(x.dtype))\n tl.store(dB_ptr, dB.to(x.dtype))\n\ndef group_norm(x, weight, bias, N, C, HW, num_groups, eps):\n return GroupNorm.apply(x, weight, bias, N, C, HW, num_groups, eps)\n\n", - "description_1": "Use triton language to implement three kernels: group_norm_kernel for forward group normalization, group_norm_backward_kernel for backward pass through the group normalization, and weight_bias_backward_kernel for computing gradients of weights and biases. These kernels manage memory loads and stores, perform arithmetic operations, and ensure correct threading and block sizes.", - "description_2": "Use triton language to implement group normalization and its gradient computations in three separate kernels, optimizing memory access and compute efficiency.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef gt_func_scalar(x, y):\n # Compare elements in x with the scalar y and return a boolean tensor.\n return x.to(tl.float32) > y\n\ndef gt_scalar(A, B):\n # A is a tensor, B is a scalar.\n # Logs \"GEMS GT SCALAR\" and calls the gt_func_scalar kernel to perform the comparison.\n return gt_func_scalar(A, B)\n", - "description_1": "Use triton language to define a kernel (gt_func_scalar) that compares a tensor's elements (x) with a scalar (y) and returns a tensor of booleans. The function gt_scalar logs a debug message and calls this kernel, taking A (a tensor) and B (a scalar) as inputs.", - "description_2": "Use triton language to create a kernel to compare tensor elements with a scalar. Log a message and execute the kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel to copy elements\n@triton.jit\ndef copy_func(x, out0):\n # x: input tensor, out0: output tensor\n return out0.store(x)\n\n# Function to horizontally stack a list of tensors\ndef hstack(tensors):\n # tensors: list of input tensors\n if len(tensors) == 0:\n raise RuntimeError(\"hstack expected a non-empty TensorList\")\n\n if tensors[0].ndim == 0:\n tensors[0] = tensors[0].view(1)\n inp0_shape = tensors[0].shape\n out_shape = list(inp0_shape)\n inp_shapes = [inp0_shape]\n\n if len(inp0_shape) == 1:\n dim = 0\n else:\n dim = 1\n\n for tensor in tensors[1:]:\n if tensor.ndim == 0:\n tensor = tensor.view(1)\n inp_shape = tensor.shape\n inp_shapes.append(inp_shape)\n\n out_shape[dim] = sum(s[dim] for s in inp_shapes)\n\n out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device)\n out0_strides = out0.stride()\n out0_offsets = list(\n itertools.accumulate(\n [s[dim] * out0_strides[dim] for s in inp_shapes[:-1]], initial=0\n )\n )\n\n for a, out0_offset in zip(tensors, out0_offsets):\n copy_func(a, out0[0])\n \n return out0\n", - "description_1": "Use triton language to create a kernel `copy_func` that takes an input tensor `x` and an output tensor `out0`, and copies elements from `x` to `out0`. Additionally, implement a function `hstack` that takes a list of tensors as input, determines their output shape when horizontally stacked, and uses the `copy_func` kernel to populate an output tensor with the elements of the input tensors.", - "description_2": "Use triton language to define a kernel that copies elements from an input tensor to an output tensor. Implement a function to horizontally stack input tensors using this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef index_select_kernel(\n inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr\n):\n # Get program ids for x and y axes\n pid_x = tl.program_id(axis=0)\n pid_y = tl.program_id(axis=1)\n \n # Calculate row and column offsets\n rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n rows_mask = rows_offsets < M\n cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)\n cols_mask = cols_offsets < N\n\n # Compute masks for blocks and output\n block_mask = rows_mask & cols_mask\n out_mask = rows_mask & (cols_offsets < index_len)\n\n # Load indices and compute offsets\n indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0)\n inp_off = rows_offsets * N + indices[None, :]\n out_off = rows_offsets * index_len + cols_offsets[None, :]\n\n # Load selected input and store in output\n selected = tl.load(inp + inp_off, mask=block_mask, other=0.0)\n tl.store(out + out_off, selected, mask=out_mask)\n\n\ndef index_select(inp, dim, index):\n assert dim >= -inp.ndim and dim < inp.ndim, \"Invalid dim\"\n assert index.ndim <= 1, \"Index should have dimension 1 or 0\"\n assert all((i >= 0 and i < inp.size(dim)) for i in index), \"Index out of range\"\n\n # Adjust dimension and index\n if index.ndim == 0:\n index = index.unsqueeze(0)\n dim = dim % inp.ndim\n inp_shape = list(inp.shape)\n index_len = index.numel()\n\n # Compress input along the dimension\n inp = dim_compress(inp, dim)\n N = inp_shape[dim]\n M = inp.numel() // N\n out_shape = list(inp.shape)\n out_shape[inp.ndim - 1] = index_len\n out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)\n\n # Define grid based on blocks\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n triton.cdiv(index_len, meta[\"BLOCK_N\"]),\n )\n \n # Call the kernel with calculated grid\n index_select_kernel[grid](inp, out, M, N, index, index_len)\n \n # Adjust output order if necessary\n if dim != out.ndim - 1:\n order = [i for i in range(out.ndim - 1)]\n order.insert(dim, out.ndim - 1)\n return out.permute(order)\n else:\n return out\n", - "description_1": "Use triton language to implement an index select kernel. The kernel (index_select_kernel) takes 8 parameters: input tensor (inp), output tensor (out), number of rows (M), number of columns (N), index tensor, index length, and block dimensions (BLOCK_M, BLOCK_N) as compile-time constants. It uses triton's parallel execution to load specified indices from the input and store them in the output tensor. The associated function (index_select) manages the input's dimension adjustment, calculates grid size based on input and index, and invokes the triton kernel.", - "description_2": "Use triton language to efficiently load indexed rows from a large tensor into an output tensor using parallel processing, employing kernel invocation based on computed grid dimensions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport logging\n\ntry:\n from triton.language.extra.cuda.libdevice import isfinited as _isfinited\nexcept ImportError:\n try:\n from triton.language.math import isfinited as _isfinited\n except ImportError:\n from triton.language.libdevice import isfinited as _isfinited\n\ntry:\n from triton.language.extra.cuda.libdevice import finitef as _finitef\nexcept ImportError:\n try:\n from triton.language.math import finitef as _finitef\n except ImportError:\n from triton.language.libdevice import finitef as _finitef\n\n\n@triton.jit\ndef isclose_func(\n x,\n y,\n rtol,\n atol,\n equal_nan: tl.constexpr,\n zero_tol: tl.constexpr,\n):\n cast_x = x if x.dtype.is_fp64() else x.to(tl.float32)\n cast_y = y if x.dtype.is_fp64() else y.to(tl.float32)\n if x.dtype.is_bf16():\n close = cast_x == cast_y\n else:\n close = x == y\n if equal_nan:\n close |= (cast_x != cast_x) & (cast_y != cast_y)\n if not zero_tol:\n allowed = atol + tl.abs(rtol * cast_y)\n actual = tl.abs(cast_x - cast_y)\n actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual)\n close |= actual_finite.to(tl.int1) & (actual <= allowed)\n return close\n\n\ndef isclose(\n A: torch.Tensor,\n B: torch.Tensor,\n rtol=1e-05,\n atol=1e-08,\n equal_nan: bool = False,\n) -> torch.Tensor:\n logging.debug(\"GEMS ISCLOSE\")\n if A.dtype == torch.bool:\n return A == B\n if A.dtype != B.dtype:\n raise RuntimeError(\"{} did not match {}\".format(A.dtype, B.dtype))\n if A.is_quantized or B.is_quantized:\n raise RuntimeError(\"isclose is not supported for quantized inputs.\")\n if rtol < 0:\n raise RuntimeError(\n \"rtol must be greater than or equal to zero, but got {}\".format(rtol)\n )\n if atol < 0:\n raise RuntimeError(\n \"atol must be greater than or equal to zero, but got {}\".format(atol)\n )\n zero_tol = (rtol == 0) and (atol == 0)\n return isclose_func(A, B, rtol, atol, equal_nan, zero_tol)\n\n\ndef allclose(\n A: torch.Tensor,\n B: torch.Tensor,\n rtol=1e-05,\n atol=1e-08,\n equal_nan: bool = False,\n) -> bool:\n logging.debug(\"GEMS ALLCLOSE\")\n return all(isclose(A, B, rtol, atol, equal_nan)).item()\n", - "description_1": "Use triton language to implement a kernel function 'isclose_func' that checks element-wise closeness of two tensors 'x' and 'y' with relative tolerance 'rtol', absolute tolerance 'atol', and options 'equal_nan' and 'zero_tol'. The function handles different data types and uses Triton's math functions for finite checks. The 'isclose' function wraps this kernel for PyTorch tensors, ensuring type compatibility and handling special cases. The 'allclose' function checks if all elements are close using 'isclose'.", - "description_2": "Use triton language to create a kernel for element-wise tensor closeness check with tolerance parameters and special case handling. Wrap this kernel for PyTorch tensors and provide a function to check if all elements are close.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef isfinite_func(x):\n # Check if elements are finite\n return _isfinited(x) if x.dtype.is_fp64() else _finitef(x.to(tl.float32))\n\ndef isfinite(\n A: torch.Tensor,\n) -> torch.Tensor:\n # Determine if elements of A are finite\n if A.is_floating_point():\n return isfinite_func(A)\n else:\n return torch.full(A.shape, True, dtype=torch.bool, device=A.device)\n", - "description_1": "Use triton language to define a kernel 'isfinite_func' that checks if elements of a tensor are finite. The kernel takes one parameter: 'x', a tensor. The function 'isfinite' calls this kernel and takes one parameter: 'A', a torch tensor, and returns a tensor indicating if each element is finite.", - "description_2": "Use triton language to create a kernel that checks for finite elements in a tensor and a function to apply this kernel to a torch tensor.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\ndef launch_arg(BLOCK_M, BLOCK_N, N, num_warps):\n return BLOCK_M, min(BLOCK_N, triton.next_power_of_2(N)), num_warps\n\n@triton.jit\ndef isin_by_comparation_impl(\n global_pid,\n in0_ravel_ptr: tl.tensor,\n in1_ravel_ptr: tl.tensor, # in\n out_ptr: tl.tensor, # out\n M: int, # num_tasks\n N: int, # num_tasks_1\n BLOCK_M: tl.constexpr, # tile_size\n BLOCK_N: tl.constexpr, # tile_size_1\n invert: tl.constexpr,\n):\n row_off = global_pid * BLOCK_M\n rows = row_off + tl.arange(0, BLOCK_M)[:, None]\n row_mask = rows < M\n out_ptr += rows\n in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)\n in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]\n\n block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)\n in0 = tl.load(in0_ravel_ptr, row_mask, other=0)\n for col_off in range(0, N, BLOCK_N):\n cols = col_off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)\n block = tl.where(\n mask,\n tl.where(invert, block and (in0 != in1), block or (in0 == in1)),\n invert,\n )\n out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))\n tl.store(out_ptr, out[:, None], row_mask)\n\n@triton.jit\ndef isin_by_search_impl(\n global_pid,\n in0_ravel_ptr: tl.tensor,\n in1_sorted_ptr: tl.tensor, # in\n out_ptr: tl.tensor, # out\n M: int, # num_tasks\n N: int, # num_tasks_1\n log_n: tl.constexpr,\n BLOCK_M: tl.constexpr, # tile_size\n invert: tl.constexpr,\n):\n r = tl.arange(0, BLOCK_M)\n i0 = global_pid * BLOCK_M + r\n mask = i0 < M\n\n # load in0_ravel\n in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)\n\n # binary search: lower_bound\n out = tl.zeros_like(r).to(tl.int1)\n start = tl.zeros_like(r)\n end = start + N\n while_mask = start < end\n for i in range(log_n):\n mid = tl.where(while_mask, start + (end - start) // 2, 0)\n mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)\n out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found\n start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)\n end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)\n while_mask = start < end\n\n # store out\n tl.store(out_ptr + i0, not out if invert else out, mask=mask)\n\ndef isin_by_comparation(\n in0: torch.tensor,\n in1: torch.tensor,\n invert: bool,\n):\n in0_ravel = in0.contiguous().ravel()\n in1_ravel = in1.contiguous().ravel()\n M = in0.numel()\n N = in1.numel()\n if M <= 1024:\n BLOCK_M, BLOCK_N, num_warps = launch_arg(1, 256, N, 4)\n elif M <= 3072:\n BLOCK_M, BLOCK_N, num_warps = launch_arg(2, 256, N, 4)\n elif M <= 6144:\n BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)\n elif M <= 9216:\n BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 256, N, 8)\n else:\n BLOCK_M, BLOCK_N, num_warps = launch_arg(4, 128, N, 4)\n ctas_num = min(65536, triton.cdiv(M, BLOCK_M))\n tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)\n grid = (ctas_num,)\n out = torch.empty_like(in0_ravel, dtype=torch.bool)\n with torch.cuda.device(in0_ravel.device.index):\n isin_by_comparation_kernel[grid](\n in0_ravel,\n in1_ravel, # in\n out, # out\n M,\n N,\n BLOCK_M,\n BLOCK_N,\n tiles_per_cta=tiles_per_cta,\n invert=invert,\n num_warps=num_warps,\n )\n return out.view_as(in0)\n\ndef isin_by_search(\n in0: torch.tensor,\n in1: torch.tensor,\n invert: bool,\n unique_in0: bool,\n unique_in1: bool,\n):\n # unique or sort or ravel\n if unique_in0:\n in0_ravel, unique_order, _ = _unique2(\n in0, sorted=True, return_inverse=True, return_counts=False\n )\n else:\n in0_ravel = in0.contiguous().ravel()\n if unique_in1:\n in1_ravel, _, _ = _unique2(\n in1, sorted=True, return_inverse=False, return_counts=False\n )\n else:\n in1_ravel, _ = torch.sort(in1.ravel())\n # launch kernel func\n M = in0_ravel.numel()\n N = in1_ravel.numel()\n if M <= 1048576: # 2 ** 20 = 1024 * 1024\n _, BLOCK_M, num_warps = launch_arg(None, 512, M, 8)\n elif M <= 4194304: # 2 ** 22 = 1024 * 4096\n _, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)\n elif M <= 8388608: # 2 ** 23 = 1024 * 8192\n _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)\n elif M <= 268435456: # 2 ** 28 = 1024 * 262144\n _, BLOCK_M, num_warps = launch_arg(None, 4096, M, 32)\n else:\n _, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)\n log_n = int(math.log2(N)) + 1\n ctas_num = min(65536, triton.cdiv(M, BLOCK_M))\n tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num)\n grid = (ctas_num,)\n out = torch.empty_like(in0_ravel, dtype=torch.bool)\n with torch.cuda.device(in0_ravel.device.index):\n isin_by_search_kernel[grid](\n in0_ravel,\n in1_ravel, # in\n out, # out\n M,\n N,\n log_n,\n BLOCK_M,\n tiles_per_cta=tiles_per_cta,\n invert=invert,\n num_warps=num_warps,\n )\n if unique_in0:\n out = torch.gather(out, 0, unique_order.ravel().to(torch.int64))\n return out.view_as(in0)\n", - "description_1": "Use triton language to implement two kernels: 'isin_by_comparation_impl' and 'isin_by_search_impl'. The first kernel checks if elements of a tensor are in another tensor using a comparison method, while the second uses a binary search method. Both kernels take pointers to input tensors, output tensor, task sizes, block sizes, and an invert flag. The 'isin_by_comparation_impl' kernel iterates over blocks of elements, comparing them and storing results, while 'isin_by_search_impl' performs a binary search to find elements. The kernels are called by 'isin_by_comparation' and 'isin_by_search' functions, which prepare the input data, set up grid and block sizes, and launch the kernels.", - "description_2": "Use triton language to implement kernels for checking tensor membership using comparison and binary search methods, with functions to prepare data and launch these kernels.", - "difficulty": 4 - }, - { - "code": "import logging\nimport triton\nimport triton.language as tl\n\ntry:\n from triton.language.extra.cuda.libdevice import isinf as _isinf\nexcept ImportError:\n try:\n from triton.language.math import isinf as _isinf\n except ImportError:\n from triton.language.libdevice import isinf as _isinf\n\n@triton.jit\ndef isinf_func(x):\n # Kernel to check if elements are infinite\n return _isinf(x.to(tl.float32))\n\ndef isinf(A):\n # Wrapper function to call the isinf_func kernel\n logging.debug(\"GEMS ISINF\")\n return isinf_func(A)\n", - "description_1": "Use triton language to define a kernel 'isinf_func' that checks if elements in a tensor are infinite. The kernel takes one parameter 'x', which is a tensor. The function 'isinf' is a wrapper that calls 'isinf_func' with one parameter 'A', which is the input tensor.", - "description_2": "Use triton language to create a kernel that checks for infinite values in a tensor and a wrapper function to call this kernel.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef isnan_func(x):\n return _isnan(x.to(tl.float32))\n\ndef isnan(A):\n return isnan_func(A)\n", - "description_1": "Use triton language to create a kernel 'isnan_func' which takes one argument, a tensor 'x', and checks if the values in 'x' are NaN by converting them to float32 using Triton's intrinsic isnan function. The function 'isnan' serves as a wrapper that takes one argument 'A' and calls 'isnan_func' with 'A'.", - "description_2": "Use triton language to create a function to check if tensor elements are NaN by converting them to float32 and calling Triton's isnan.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef prev_multiple_of(a, b):\n return tl.cdiv(a, b) * b - b\n\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef layer_norm_persistent_kernel(\n in_ptr, out_ptr, weight_ptr, bias_ptr, out_mean_ptr, out_rstd_ptr, M, N, eps, TILE_N: tl.constexpr\n):\n pid = tl.program_id(0)\n n_offsets = tl.arange(0, TILE_N)\n mask = n_offsets < N\n\n x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)\n m = tl.sum(x) / N\n d = x - m\n s = tl.where(mask, d * d, 0)\n sum_square = tl.sum(s)\n var = sum_square / N\n rstd = tl.math.rsqrt(var + eps)\n\n tl.store(out_mean_ptr + pid, m)\n tl.store(out_rstd_ptr + pid, rstd)\n\n w = tl.load(weight_ptr + n_offsets, mask=mask)\n b = tl.load(bias_ptr + n_offsets, mask=mask)\n out = (x - m) * rstd * w + b\n\n tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)\n\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef layer_norm_persistent_kernel_multiline(\n in_ptr, out_ptr, weight_ptr, bias_ptr, out_mean_ptr, out_rstd_ptr, M, N, eps, TILE_M: tl.constexpr, TILE_N: tl.constexpr\n):\n pid = tl.program_id(0)\n m_offsets = pid * TILE_M + tl.arange(0, TILE_M)\n m_mask = m_offsets < M\n\n n_offsets = tl.arange(0, TILE_N)[None, :]\n n_mask = n_offsets < N\n mask = m_mask[:, None] & n_mask\n\n x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(tl.float32)\n m = tl.sum(x, axis=1) / N\n d = x - m[:, None]\n s = tl.where(mask, d * d, 0)\n sum_square = tl.sum(s, axis=1)\n var = sum_square / N\n rstd = tl.math.rsqrt(var + eps)\n\n tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)\n tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)\n\n w = tl.load(weight_ptr + n_offsets, mask=n_mask)\n b = tl.load(bias_ptr + n_offsets, mask=n_mask)\n out = (x - m[:, None]) * rstd[:, None] * w + b\n\n tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)\n\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef layer_norm_loop_kernel(\n in_ptr, out_ptr, weight_ptr, bias_ptr, out_mean_ptr, out_rstd_ptr, M, N, eps, TILE_N: tl.constexpr\n):\n pid = tl.program_id(0)\n m = tl.zeros((TILE_N,), dtype=tl.float32)\n s = tl.zeros((TILE_N,), dtype=tl.float32)\n cnt = tl.zeros((TILE_N,), dtype=tl.int32)\n num_steps = tl.cdiv(N, TILE_N)\n for step in range(0, num_steps - 1, 1):\n start_n = step * TILE_N\n n_offsets = start_n + tl.arange(0, TILE_N)\n x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)\n new_m = m + (x - m) / (step + 1)\n new_s = s + (x - new_m) * (x - m)\n cnt += 1\n m = new_m\n s = new_s\n\n for step in range(num_steps - 1, num_steps, 1):\n start_n = step * TILE_N\n n_offsets = start_n + tl.arange(0, TILE_N)\n mask = n_offsets < N\n x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)\n new_m = tl.where(mask, m + (x - m) / (step + 1), m)\n new_s = tl.where(mask, s + (x - new_m) * (x - m), s)\n cnt += mask.to(tl.int32)\n m = new_m\n s = new_s\n\n final_m = tl.sum(m * cnt) / N\n var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N\n rstd = tl.math.rsqrt(var + eps)\n m = final_m\n\n tl.store(out_mean_ptr + pid, m)\n tl.store(out_rstd_ptr + pid, rstd)\n\n prev_multiple = prev_multiple_of(N, TILE_N)\n for start_n in range(0, TILE_N, TILE_N):\n n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)\n mask = n_offsets < N\n x = tl.load(in_ptr + pid * N + n_offsets, mask=mask, other=0.0, eviction_policy=\"evict_first\").to(tl.float32)\n w = tl.load(weight_ptr + n_offsets, mask=mask)\n b = tl.load(bias_ptr + n_offsets, mask=mask)\n out = w * (x - m) * rstd + b\n tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)\n\n for start_n in range(TILE_N, N, TILE_N):\n n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)\n x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy=\"evict_first\").to(tl.float32)\n w = tl.load(weight_ptr + n_offsets)\n b = tl.load(bias_ptr + n_offsets)\n out = w * (x - m) * rstd + b\n tl.store(out_ptr + pid * N + n_offsets, out)\n\n\n@triton.jit\ndef layer_norm_backward_kernel(\n dY, X, W, Mean, Rstd, dX, M, N, BLOCK_ROW_SIZE: tl.constexpr, BLOCK_COL_SIZE: tl.constexpr\n):\n pid = tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]\n row_mask = pid < M\n dY += pid * N\n X += pid * N\n dX += pid * N\n Mean += pid\n Rstd += pid\n\n mean = tl.load(Mean).to(tl.float32)\n rstd = tl.load(Rstd).to(tl.float32)\n\n dx_part2 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)\n dx_part3 = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)\n\n for off in range(0, N, BLOCK_COL_SIZE):\n cols = off + tl.arange(0, BLOCK_COL_SIZE)\n col_mask = cols[None, :] < N\n mask = row_mask and col_mask\n dy = tl.load(dY + cols[None, :], mask).to(tl.float32)\n x = tl.load(X + cols[None, :], mask).to(tl.float32)\n x = tl.where(mask, x - mean, 0.0)\n x_hat = x * rstd\n w = tl.load(W + cols, mask=cols < N).to(tl.float32)\n dx_hat = dy * w\n dx_part2 += dx_hat\n dx_part3 += dx_hat * x_hat\n\n dx_2 = tl.sum(dx_part2, axis=1)[:, None]\n dx_3 = tl.sum(dx_part3, axis=1)[:, None]\n\n for off in range(0, N, BLOCK_COL_SIZE):\n cols = off + tl.arange(0, BLOCK_COL_SIZE)\n col_mask = cols[None, :] < N\n mask = row_mask and col_mask\n dy = tl.load(dY + cols[None, :], mask).to(tl.float32)\n x = tl.load(X + cols[None, :], mask).to(tl.float32)\n w = tl.load(W + cols, mask=cols < N).to(tl.float32)\n x = tl.where(mask, x - mean, 0.0)\n x_hat = x * rstd\n dx_hat = dy * w\n dx = rstd * (dx_hat - (dx_2 + x_hat * dx_3) / N)\n tl.store(dX + cols, dx, mask=mask)\n\n\n@triton.jit\ndef weight_bias_backward_kernel(\n dY, X, Mean, Rstd, dW, dB, M, N, BLOCK_ROW_SIZE: tl.constexpr, BLOCK_COL_SIZE: tl.constexpr\n):\n pid = tl.program_id(0) * BLOCK_COL_SIZE + tl.arange(0, BLOCK_COL_SIZE)[None, :]\n col_mask = pid < N\n dY += pid\n X += pid\n dW += pid\n dB += pid\n accW = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)\n accB = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)\n for off in range(0, M, BLOCK_ROW_SIZE):\n rows = off + tl.arange(0, BLOCK_ROW_SIZE)\n row_mask = rows[:, None] < M\n mask = row_mask and col_mask\n dy = tl.load(dY + rows[:, None] * N, mask).to(tl.float32)\n x = tl.load(X + rows[:, None] * N, mask).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M)[:, None].to(tl.float32)\n rstd = tl.load(Rstd + rows, mask=rows < M)[:, None].to(tl.float32)\n x = tl.where(col_mask, x - mean, 0.0)\n x_hat = x * rstd\n accW += dy * x_hat\n accB += dy\n dw = tl.sum(accW, axis=0)\n db = tl.sum(accB, axis=0)\n tl.store(dW, dw[None, :], mask=col_mask)\n tl.store(dB, db[None, :], mask=col_mask)\n\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True):\n N = math.prod(normalized_shape)\n M = x.numel() // N\n\n x = x.contiguous()\n weight = weight.contiguous()\n bias = bias.contiguous()\n y = torch.empty_like(x)\n\n acc_type = get_accumulator_dtype(x.dtype)\n mean = torch.empty(M, dtype=acc_type, device=x.device)\n rstd = torch.empty(M, dtype=acc_type, device=x.device)\n\n with torch.cuda.device(x.device):\n if N <= 128:\n TILE_N = triton.next_power_of_2(N)\n TILE_M = triton.cdiv(1024, TILE_N)\n grid = (triton.cdiv(M, TILE_M), 1, 1)\n layer_norm_persistent_kernel_multiline[grid](\n x, y, weight, bias, mean, rstd, M, N, eps, TILE_M, TILE_N\n )\n elif N <= 4096:\n TILE_N = triton.next_power_of_2(N)\n grid = (M, 1, 1)\n layer_norm_persistent_kernel[grid](x, y, weight, bias, mean, rstd, M, N, eps, TILE_N)\n else:\n grid = (M, 1, 1)\n layer_norm_loop_kernel[grid](x, y, weight, bias, mean, rstd, M, N, eps)\n ctx.save_for_backward(x, weight, mean, rstd)\n ctx.M = M\n ctx.N = N\n return y, mean, rstd\n\n @staticmethod\n def backward(ctx, out_grad, mean_grad, rstd_grad):\n out_grad = out_grad.contiguous()\n (x, weight, mean, rstd) = ctx.saved_tensors\n M = ctx.M\n N = ctx.N\n\n with torch.cuda.device(x.device):\n in_grad = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(M, meta[\"BLOCK_ROW_SIZE\"]), 1, 1)\n layer_norm_backward_kernel[grid](out_grad, x, weight, mean, rstd, in_grad, M, N)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"BLOCK_COL_SIZE\"]), 1, 1)\n weight_grad = torch.empty_like(weight)\n bias_grad = torch.empty_like(weight)\n weight_bias_backward_kernel[grid](out_grad, x, mean, rstd, weight_grad, bias_grad, M, N)\n return in_grad, None, weight_grad, bias_grad, None, None\n\n\ndef layer_norm(x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True):\n return LayerNorm.apply(x, normalized_shape, weight, bias, eps, cudnn_enable)\n", - "description_1": "Use triton language to implement various layer normalization kernels and their backward passes, optimizing for different problem sizes. The kernels accept pointers to input and output data, weight, bias, mean and reciprocal of standard deviation, along with the dimensions M and N, a small constant epsilon, and tile sizes as constexpr. They perform computations necessary for layer normalization and store results appropriately, handling different cases where N is less than or equal to 128, 4096, or more, using different kernel strategies.", - "description_2": "Use triton language to implement layer normalization and its backward pass with various optimized kernels, depending on the size of the normalization dimension.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef le_func(x, y):\n return x.to(tl.float32) <= y\n\ndef le(A, B):\n return le_func(A, B)\n\n@triton.jit\ndef le_func_scalar(x, y):\n return x.to(tl.float32) <= y\n\ndef le_scalar(A, B):\n return le_func_scalar(A, B)\n", - "description_1": "Use triton language to define two kernel functions, `le_func` and `le_func_scalar`. Both functions accept two parameters `x` and `y`. The `le_func` checks if elements in `x`, converted to float32, are less than or equal to elements in `y`. The `le_func_scalar` performs the same operation assuming `y` is a scalar value. Both kernels are wrapped in functions `le` and `le_scalar`, respectively, which pass the input arguments `A` and `B` directly to these kernels.", - "description_2": "Use triton language to implement kernels for element-wise and scalar comparison using <= operator.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef heur_block_n(args):\n return triton.next_power_of_2(args[\"N\"])\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 1024:\n return 4\n elif args[\"N\"] <= 2048:\n return 8\n else:\n return 16\n\n@triton.jit\ndef log_softmax_kernel(\n output_ptr,\n input_ptr,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation for log softmax\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n input_ptrs = input_ptr + offset\n inp = tl.load(input_ptrs, mask=mask, other=-float(\"inf\")).to(tl.float32)\n row_minus_max = inp - tl.max(inp, axis=1)[:, None]\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=1)[:, None]\n softmax_output = tl.log(numerator / denominator)\n output_ptrs = output_ptr + offset\n tl.store(output_ptrs, softmax_output, mask=mask)\n\n@triton.jit\ndef log_softmax_backward_kernel(\n out_ptr,\n out_grad_ptr,\n in_grad_ptr,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation for log softmax backward\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n\n offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n out_ptrs = out_ptr + offsets\n out = tl.load(out_ptrs, mask=mask).to(tl.float32)\n out_grad_ptrs = out_grad_ptr + offsets\n out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)\n\n scale = tl.sum(out_grad, 1)\n in_grad = out_grad - tl.exp(out.to(tl.float32)) * scale[:, None]\n\n in_grad_ptrs = in_grad_ptr + offsets\n tl.store(in_grad_ptrs, in_grad, mask=mask)\n\nclass LogSoftmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, dim, dtype):\n # Forward method for log softmax\n assert dim >= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i]\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n out = torch.empty_like(inp, dtype=dtype)\n K = inp.numel() // M // N\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n log_softmax_kernel[grid](\n out,\n inp,\n M,\n N,\n K,\n )\n ctx.save_for_backward(out)\n ctx.dim = dim\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n # Backward method for log softmax\n dim = ctx.dim\n (out,) = ctx.saved_tensors\n\n assert dim >= -out.ndim and dim < out.ndim, \"Invalid dim\"\n dim = dim % out.ndim\n M = 1\n N = out.shape[dim]\n for i in range(dim):\n M *= out.shape[i]\n\n out_grad = out_grad.contiguous()\n in_grad = torch.empty_like(out)\n K = out.numel() // M // N\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(in_grad.device):\n log_softmax_backward_kernel[grid](\n out,\n out_grad,\n in_grad,\n M,\n N,\n K,\n )\n return in_grad, None, None\n\ndef log_softmax(x, dim=-1, dtype=None):\n return LogSoftmax.apply(x, dim, dtype)\n", - "description_1": "Use triton language to implement a log softmax and its backward pass for tensors. The kernel function 'log_softmax_kernel' takes 7 arguments: output_ptr, input_ptr, M, N, K, BLOCK_M, BLOCK_N. It calculates the log softmax of the input tensor. The backward kernel function 'log_softmax_backward_kernel' also takes 7 arguments: out_ptr, out_grad_ptr, in_grad_ptr, M, N, K, BLOCK_M, BLOCK_N. It computes the gradient of the log softmax operation.", - "description_2": "Use triton language to implement log softmax operations and their gradients for efficient parallel computation on tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef cfggen():\n block_m = [1, 2, 4]\n block_n = [1024, 2048, 4096]\n warps = [4, 8, 16]\n configs = [\n triton.Config({\"BLOCK_ROW_SIZE\": m, \"BLOCK_COL_SIZE\": n}, num_warps=w)\n for m in block_m\n for n in block_n\n for w in warps\n ]\n return configs\n\n@triton.jit\ndef masked_fill_kernel(\n inp, expand_mask, value, out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr\n):\n pid_x = tl.program_id(axis=0)\n pid_y = tl.program_id(axis=1)\n rows_offset = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n cols_offset = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]\n mask = rows_offset < M and cols_offset < N\n\n offsets = rows_offset * N + cols_offset\n fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)\n cur_inp = tl.load(inp + offsets, mask=(not fill_mask) and mask, other=0)\n tl.store(out + offsets, cur_inp, (not fill_mask) and mask)\n\n cur_val = tl.full((BLOCK_M, BLOCK_N), value, dtype=cur_inp.dtype)\n tl.store(out + offsets, cur_val, fill_mask and mask)\n\ndef masked_fill(inp, mask, value):\n assert (\n isinstance(value, float)\n or isinstance(value, int)\n or (torch.is_tensor(value) and value.ndim == 0)\n ), \"masked_fill_ only supports a Number or a 0-dimensional value tensor\"\n if torch.is_tensor(value):\n value = value.item()\n inp_shape = tuple(inp.shape)\n mask_shape = tuple(mask.shape)\n assert broadcastable_to(\n mask_shape, inp_shape\n ), \"The shape of mask must be broadcastable with the shape of the underlying tensor\"\n\n inp = inp.contiguous()\n mask = mask.contiguous()\n value = value.contiguous()\n expand_mask = mask.expand(inp.shape)\n out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device)\n\n N = inp.size(inp.ndim - 1)\n M = inp.numel() // N\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n triton.cdiv(N, meta[\"BLOCK_N\"]),\n )\n masked_fill_kernel[grid](inp, expand_mask.to(torch.int), value, out, M, N)\n return out\n", - "description_1": "Use triton language to implement a masked fill operation. The kernel 'masked_fill_kernel' takes 7 parameters: 'inp' (input tensor), 'expand_mask' (expanded mask tensor), 'value' (value to fill), 'out' (output tensor), 'M' (number of rows), 'N' (number of columns), and two constexpr parameters 'BLOCK_M' and 'BLOCK_N' for block sizes. The function 'masked_fill' is a wrapper that prepares the input, mask, and value, calculates grid dimensions, and launches the kernel.", - "description_2": "Use triton language to create a kernel for masked filling of a tensor, and a wrapper function to handle input preparation and kernel execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef cfggen():\n configs = [\n triton.Config({\"BLOCK_SIZE\": bs}, num_warps=w)\n for w in [4, 8, 16, 32]\n for bs in [256, 512, 1024, 2048, 4096]\n ]\n return configs\n\n@triton.jit\ndef masked_select_kernel(\n inp_ptr,\n select_mask_ptr,\n prefix_sum_ptr,\n out_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0)\n select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1)\n out_offset = tl.load(prefix_sum_ptr + offsets, mask=mask, other=0.0) - 1\n\n tl.store(out_ptr + out_offset, inp, mask=(select_mask and mask))\n\ndef masked_select(inp, mask):\n inp_shape = tuple(inp.shape)\n mask_shape = tuple(mask.shape)\n\n assert broadcastable(\n inp_shape, mask_shape\n ), \"The shapes of the `mask` and the `input` tensor must be broadcastable\"\n inp, mask = torch.broadcast_tensors(inp, mask)\n\n inp = inp.contiguous()\n mask = mask.contiguous()\n\n mask_flattened = mask.ravel()\n\n prefix_sum = mask_flattened.cumsum(axis=0)\n out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device)\n\n n_elements = inp.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n with torch.cuda.device(inp.device):\n masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements)\n return out\n", - "description_1": "Use triton language to implement a masked select operation. The kernel 'masked_select_kernel' takes 6 parameters: inp_ptr (input tensor pointer), select_mask_ptr (mask tensor pointer), prefix_sum_ptr (prefix sum of mask), out_ptr (output tensor pointer), n_elements (number of elements), and BLOCK_SIZE (block size for parallel execution). The function 'masked_select' prepares the input and mask tensors, computes the prefix sum of the mask, and calls the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel for masked selection of elements from an input tensor based on a mask, and implement a function to prepare data and invoke this kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef max_kernel_1(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n):\n # Compute the program's ID and offset\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Load input elements and apply the mask\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n # Compute the maximum value\n max_val = tl.max(inp_val)\n # Store the result\n mid_ptr = mid + pid\n tl.store(mid_ptr, max_val)\n\n@triton.jit\ndef max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):\n # Calculate offset\n offset = tl.arange(0, BLOCK_MID)\n # Load intermediate values and apply the mask\n mid_ptrs = mid + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=-float(\"inf\"))\n # Compute the maximum value\n max_val = tl.max(mid_val)\n # Store the result\n tl.store(out, max_val)\n\ndef max(inp):\n M = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n dtype = inp.dtype\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n out = torch.empty([], dtype=dtype, device=inp.device)\n\n with torch.cuda.device(inp.device):\n max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)\n max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)\n return out\n\n@triton.jit\ndef max_kernel(\n inp,\n out_value,\n out_index,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # set offset\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n offset_index = m_offset * K + pid_k\n # set mask\n mask1 = m_offset < M\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n inp_ptrs = inp + offset\n inp_vals = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n result_value, result_index = tl.max(inp_vals, axis=1, return_indices=True)\n\n out_value_ptrs = out_value + offset_index\n out_index_ptrs = out_index + offset_index\n\n tl.store(out_value_ptrs, result_value, mask=mask1)\n tl.store(out_index_ptrs, result_index, mask=mask1)\n\ndef max_dim(inp, dim=None, keepdim=False):\n shape = inp.shape\n dim = dim % inp.ndim\n N = shape[dim]\n M = math.prod(shape[:dim])\n K = inp.numel() // M // N\n\n inp = inp.contiguous()\n\n shape_list = list(shape)\n shape_list[dim] = 1\n out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)\n out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)\n\n if not keepdim:\n out_value = torch.squeeze(out_value, dim)\n out_index = torch.squeeze(out_index, dim)\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n max_kernel[grid](inp, out_value, out_index, M, N, K)\n Max_out = namedtuple(\"max\", [\"values\", \"indices\"])\n out = Max_out(values=out_value, indices=out_index)\n return out\n", - "description_1": "Use triton language to implement three kernels: `max_kernel_1` calculates the maximum value for a given block of an input array, with parameters for the input array, output array, array size, and block size; `max_kernel_2` calculates the maximum from an intermediate array, with parameters for the intermediate array, output variable, intermediate size, and block size; `max_kernel` finds maximum values and indices in a 2D matrix with parameters for input array, output values and indices, and dimensions M, N, K with block sizes.", - "description_2": "Use triton language to create kernels for computing max values over segments of an input array, and for finding maximums and indices in matrices, by implementing GPU parallel processing.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport logging\n\n@triton.jit\ndef maximum_kernel(X, Y):\n # Check if input tensors have bfloat16 data type, and cast them to float32 if needed\n if X.dtype == tl.bfloat16:\n X = X.to(tl.float32)\n Y = Y.to(tl.float32)\n\n # Element-wise maximum operation on tensors X and Y\n return tl.maximum(X, Y)\n\ndef maximum(X, Y):\n logging.debug(\"GEMS MAXIMUM\")\n # Assert that both input tensors are on CUDA device\n assert X.is_cuda and Y.is_cuda\n # Call the triton kernel for maximum operation\n return maximum_kernel(X, Y)\n", - "description_1": "Use triton language to implement an element-wise maximum operation between two tensors X and Y. If either tensor has data type bfloat16, convert them to float32 before performing the maximum operation.", - "description_2": "Use triton language to perform element-wise maximum between two tensors, with optional type conversion from bfloat16 to float32.", - "difficulty": 3 - }, - { - "code": "import logging\nimport math\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef mean_kernel_1(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)\n sum_val = tl.sum(inp_val, axis=0)\n mid_ptr = mid + pid\n tl.store(mid_ptr, sum_val)\n\n\n@triton.jit\ndef mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < MID_SIZE\n mid_val = tl.load(mid_ptrs, mask=mask, other=0.0)\n sum_val = tl.sum(mid_val, axis=0) / M\n tl.store(out, sum_val)\n\n\ndef mean(inp, *, dtype=None):\n logging.debug(\"GEMS MEAN\")\n M = inp.numel()\n if dtype is None:\n dtype = inp.dtype\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n out = torch.empty([], dtype=dtype, device=inp.device)\n\n with torch.cuda.device(inp.device):\n mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)\n mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid)\n return out\n\n\n@triton.jit\ndef mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Map the program id to the row of X it should compute.\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Mean = Mean + pid\n row_mask = pid < M\n\n # Compute mean\n _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n\n a = tl.load(X + cols, mask, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=1) / N\n mean = mean[:, None]\n tl.store(Mean, mean, row_mask)\n\n\ndef mean_dim(x, dim, keepdim=False, *, dtype=None):\n logging.debug(\"GEMS MEAN DIM\")\n\n if dtype is None:\n dtype = x.dtype\n if dim is None:\n out = mean(x, dtype=dtype)\n if not keepdim:\n out = out.reshape([1] * x.ndim)\n return out\n\n shape = list(x.shape)\n dim = [d % x.ndim for d in dim]\n x = dim_compress(x, dim)\n N = 1\n for i in dim:\n N *= shape[i]\n shape[i] = 1\n M = x.numel() // N\n out = torch.empty(shape, dtype=dtype, device=x.device)\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]),)\n\n with torch.cuda.device(x.device):\n mean_dim_kernel[grid](x, out, M, N)\n if not keepdim:\n out = out.squeeze(dim)\n return out\n", - "description_1": "Use triton language to define two kernels, mean_kernel_1 and mean_kernel_2. The first one (mean_kernel_1) takes four parameters: inp (input tensor), mid (intermediate storage tensor), M (number of elements in input tensor), and BLOCK_SIZE (block size). It calculates the sum of elements in blocks. The second kernel (mean_kernel_2) takes five parameters: mid (intermediate storage tensor), out (output tensor), M (number of elements in input tensor), MID_SIZE (size of intermediate storage), and BLOCK_MID (block size for intermediate storage). It computes the mean of the elements stored in mid and stores the result in out.", - "description_2": "Use triton language to define a kernel, mean_dim_kernel, which computes the mean across a specific dimension. The kernel takes six parameters: X (input tensor), Mean (output tensor to store the mean results), M (number of rows to process), N (number of columns to process), BLOCK_M (block size for rows), and BLOCK_N (block size for columns). It computes the mean value for each row in a block and stores the result in the Mean tensor.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\nfrom collections import namedtuple\n\n@triton.jit\ndef min_kernel_1(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=float(\"inf\"))\n min_val = tl.min(inp_val)\n mid_ptr = mid + pid\n tl.store(mid_ptr, min_val)\n\n\n@triton.jit\ndef min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=float(\"inf\"))\n min_val = tl.min(mid_val)\n tl.store(out, min_val)\n\n\n@triton.jit\ndef min_kernel(\n inp,\n out_value,\n out_index,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # set offset\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n offset_index = m_offset * K + pid_k\n # set mask\n mask1 = m_offset < M\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n inp_ptrs = inp + offset\n inp_vals = tl.load(inp_ptrs, mask=mask, other=float(\"inf\")).to(tl.float32)\n result_value, result_index = tl.min(inp_vals, axis=1, return_indices=True)\n\n out_value_ptrs = out_value + offset_index\n out_index_ptrs = out_index + offset_index\n\n tl.store(out_value_ptrs, result_value, mask=mask1)\n tl.store(out_index_ptrs, result_index, mask=mask1)\n\n\ndef min(inp):\n M = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n dtype = inp.dtype\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n out = torch.empty([], dtype=dtype, device=inp.device)\n\n with torch.cuda.device(inp.device):\n min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)\n min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)\n return out\n\n\ndef min_dim(inp, dim=None, keepdim=False):\n shape = inp.shape\n dim = dim % inp.ndim\n N = shape[dim]\n M = math.prod(shape[:dim])\n K = inp.numel() // M // N\n\n inp = inp.contiguous()\n\n shape_list = list(shape)\n shape_list[dim] = 1\n out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)\n out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)\n\n if not keepdim:\n out_value = torch.squeeze(out_value, dim)\n out_index = torch.squeeze(out_index, dim)\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n min_kernel[grid](inp, out_value, out_index, M, N, K)\n Min_out = namedtuple(\"min\", [\"values\", \"indices\"])\n out = Min_out(values=out_value, indices=out_index)\n return out\n", - "description_1": "Use triton language to implement three kernels: min_kernel_1 to compute the minimum of blocks in an input tensor and store them in a midpoint tensor; min_kernel_2 to compute the minimum value from the midpoint tensor and store it in an output tensor; and min_kernel to compute the minimum values and indices along a specified dimension of a 2D input tensor. Helper functions are included to configure and launch the kernels.", - "description_2": "Use triton language to implement minimum computation kernels for both 1D reduction and 2D dimension-wise reduction.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport logging\n\n@triton.jit\ndef minimum_kernel(X, Y):\n # Convert inputs to float32 if they are of bfloat16 type\n if X.dtype == tl.bfloat16:\n X = X.to(tl.float32)\n Y = Y.to(tl.float32)\n # Return the element-wise minimum of X and Y\n return tl.minimum(X, Y)\n\n\ndef minimum(X, Y):\n logging.debug(\"GEMS MINIMUM\")\n # Ensure inputs are CUDA tensors before invoking the Triton kernel\n assert X.is_cuda and Y.is_cuda\n return minimum_kernel(X, Y)\n", - "description_1": "Use triton language to create a kernel that computes the element-wise minimum of two input tensors X and Y. If the tensors are of type bfloat16, convert them to float32 before the computation. Ensure that the input tensors are on CUDA devices before invoking the kernel.", - "description_2": "Use triton language to implement a minimum operation for CUDA tensors, ensuring proper type conversion.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef heur_even_k(args):\n return args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1},\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1},\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.heuristics(\n {\n \"EVEN_K\": heur_even_k,\n }\n)\n@triton.jit\ndef mm_kernel(\n A,\n B,\n C,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n dot_out_dtype: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n if a.dtype != b.dtype:\n a = a.to(C.dtype.element_ty)\n b = b.to(C.dtype.element_ty)\n acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\n_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]\n\n\ndef get_higher_dtype(a, b):\n if a is b:\n return a\n\n assert a in _ordered_datatypes\n assert b in _ordered_datatypes\n\n for d in _ordered_datatypes:\n if a is d:\n return b\n if b is d:\n return a\n\n\ndef mm(a, b):\n device = a.device\n # handle non-contiguous inputs if necessary\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c_dtype = get_higher_dtype(a.dtype, b.dtype)\n c = torch.empty((M, N), device=device, dtype=c_dtype)\n dot_out_dtype = tl.float32\n # launch kernel\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n META[\"SPLIT_K\"],\n )\n with torch.cuda.device(a.device):\n mm_kernel[grid](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n dot_out_dtype=dot_out_dtype,\n GROUP_M=8,\n )\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (mm_kernel) with 15 parameters including tensors A, B, C, their dimensions M, N, K, and stride values, along with constexpr parameters for block sizes and group settings. The wrapper function mm handles input tensors, checks dimensionality, allocates output, and calls the kernel with the specified grid size.", - "description_2": "Use triton language to create a matrix multiplication function with configurable block and group sizes, and stride handling.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef mul_func(x, y):\n return x * y\n\n@triton.jit\ndef mul_func_scalar(x, y):\n return x * y\n\ndef mul(A, B):\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return mul_func(A, B)\n elif isinstance(A, torch.Tensor):\n return mul_func_scalar(A, B)\n elif isinstance(B, torch.Tensor):\n return mul_func_scalar(B, A)\n else:\n return torch.tensor(A * B)\n", - "description_1": "Use triton language to define two kernels: 'mul_func' and 'mul_func_scalar'. Both kernels take two arguments, 'x' and 'y', and return their product. The 'mul' function determines the type of inputs 'A' and 'B', and calls the appropriate kernel or returns a PyTorch tensor for scalar multiplication.", - "description_2": "Use triton language to create kernels for element-wise multiplication of tensors and scalars, and a function to select the appropriate kernel based on input types.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flag_gems.utils.random_utils import philox_cuda_seed_offset, uniform\nfrom flag_gems.ops import normed_cumsum\n\n@triton.jit(do_not_specialize=[\"K\", \"N\", \"philox_seed\", \"philox_offset\"])\ndef multinomial_with_replacement(\n cdf_ptr, out_ptr, K, N, philox_seed, philox_offset, NBLOCK: tl.constexpr\n):\n # The computation is arranged in a 2d grid of blocks, each producing\n # a batch of samples for a particular distribution.\n y_off = tl.program_id(1) * N\n n = tl.program_id(0) * NBLOCK + tl.arange(0, NBLOCK)\n rv, _, _, _ = uniform(philox_seed, philox_offset, y_off + n)\n\n # Do a binary search for each random number on the cumulative probabilities.\n rv += 0.0001\n rv = tl.where(rv > 0.9999, 0.9999, rv)\n\n cdf_ptr += tl.program_id(1) * K\n start = tl.zeros((NBLOCK,), dtype=tl.int32)\n end = tl.zeros((NBLOCK,), dtype=tl.int32) + K - 1\n steps = tl.math.log2(K.to(tl.float32)).to(tl.int32) + 1\n for _ in range(steps):\n mid = start + (end - start) // 2\n x = tl.load(cdf_ptr + mid, mask=n < N)\n start = tl.where(x < rv, mid + 1, start)\n end = tl.where(x < rv, end, mid)\n\n # Returns the last index in case of an overflow\n start = tl.where(start >= K, K - 1, start)\n\n tl.store(out_ptr + y_off + n, start, mask=n < N)\n\ndef multinomial(prob, n_samples, with_replacement=False, *, gen=None):\n assert prob.dtype in (torch.float16, torch.float32, torch.bfloat16, torch.float64)\n assert 0 < prob.dim() <= 2, \"prob_dist must be 1 or 2 dim\"\n n_categories = prob.size(-1)\n assert n_categories <= (1 << 24), \"number of categories cannot exceed 2^24\"\n assert (\n with_replacement or n_samples <= n_categories\n ), \"cannot sample n_samples > prob.size(-1) samples without replacement.\"\n\n # Sampling without replacement\n if (not with_replacement) or n_samples == 1:\n q = torch.empty_like(prob).exponential_(1.0)\n s = torch.div(prob, q, out=q)\n if n_samples == 1:\n return torch.argmax(s, dim=-1, keepdim=True).to(torch.int64)\n else:\n vals, indices = torch.topk(s, n_samples, dim=-1)\n return indices.to(torch.int64)\n\n cum_prob = normed_cumsum(prob, dim=-1)\n\n if cum_prob.dim() == 1:\n n_dist = 1\n out = torch.empty((n_samples,), device=prob.device, dtype=torch.int64)\n else:\n n_dist = cum_prob.size(0)\n out = torch.empty((n_dist, n_samples), device=prob.device, dtype=torch.int64)\n \n increment = n_dist * n_samples\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n grid = lambda META: (triton.cdiv(n_samples, META[\"NBLOCK\"]), n_dist)\n multinomial_with_replacement[grid](\n cum_prob, out, n_categories, n_samples, philox_seed, philox_offset\n )\n return out\n", - "description_1": "Use triton language to implement a multinomial sampling kernel `multinomial_with_replacement` and a calling function `multinomial`. The kernel takes 7 parameters: `cdf_ptr` and `out_ptr` as pointers to memory, `K` as the number of categories, `N` as the number of samples, `philox_seed` and `philox_offset` for random number generation, and `NBLOCK` as a constant for the block size. It performs binary search over cumulative distribution functions (CDF) to generate multinomial samples. The calling function `multinomial` handles input validation, cumulative probability normalization, and sets up the kernel execution grid, taking 4 parameters: `prob` as the input probabilities tensor, `n_samples` as the number of samples to draw, `with_replacement` as a boolean flag for sampling with replacement, and optional generator `gen`.", - "description_2": "Use triton language to implement multinomial sampling with replacement using a binary search approach on cumulative distribution functions, leveraging the `multinomial_with_replacement` kernel and managing input validation and execution through a `multinomial` function.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef mv_kernel(\n A,\n B,\n C,\n N,\n M,\n stride_an,\n stride_am,\n stride_bm,\n stride_cn,\n BLOCK_N: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]\n offset_m = tl.arange(0, BLOCK_M)[None, :]\n n_mask = offset_n < N\n A_ptrs = A + offset_n * stride_an + offset_m * stride_am\n B_ptrs = B + offset_m * stride_bm\n acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n for m in range(0, M, BLOCK_M):\n m_mask = m + offset_m < M\n a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)\n b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)\n acc += a * b\n A_ptrs += BLOCK_M * stride_am\n B_ptrs += BLOCK_M * stride_bm\n\n acc = tl.sum(acc, axis=1)\n C_ptrs = C + offset_n * stride_cn\n tl.store(C_ptrs, acc[:, None], mask=n_mask)\n\n\ndef mv(inp, vec):\n assert inp.shape[1] == vec.shape[0], \"incompatible dimensions\"\n N, M = inp.shape\n out = torch.empty((N,), device=inp.device, dtype=inp.dtype)\n grid = lambda META: (triton.cdiv(N, META[\"BLOCK_N\"]),)\n with torch.cuda.device(inp.device):\n mv_kernel[grid](\n inp,\n vec,\n out,\n N,\n M,\n inp.stride(0),\n inp.stride(1),\n vec.stride(0),\n out.stride(0),\n )\n return out\n", - "description_1": "Use triton language to implement a matrix-vector multiplication kernel (mv_kernel) and a wrapper function (mv). The mv_kernel takes 10 parameters: A (matrix), B (vector), C (output vector), N (number of rows in A), M (number of columns in A), stride_an (stride of A in the n dimension), stride_am (stride of A in the m dimension), stride_bm (stride of B in the m dimension), stride_cn (stride of C in the n dimension), and two constexpr parameters BLOCK_N and BLOCK_M which define the block size for the kernel. The kernel computes the matrix-vector product by iterating over blocks of the matrix and vector, performing element-wise multiplication and accumulation, and storing the result in the output vector C. The mv function is a wrapper that prepares the input data, sets up the grid size for the kernel launch, and calls the mv_kernel with the appropriate parameters.", - "description_2": "Use triton language to create a matrix-vector multiplication kernel and a wrapper function to execute it on GPU. The kernel processes data in blocks, performing element-wise multiplication and accumulation, and the wrapper sets up and launches the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport logging\n\n@triton.jit\ndef nonzero_kernel(\n inp,\n prefix_sum,\n out,\n n_elements,\n shape,\n ndim: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < n_elements\n\n inp_vals = tl.load(inp + offset, mask=mask)\n out_offset = tl.load(prefix_sum + offset, mask=mask) - 1\n\n nonzero_mask = mask and inp_vals == True # noqa\n\n idx_flat = offset\n for dim in range(ndim - 1, -1, -1):\n dim_size = tl.load(shape + dim)\n remainder = idx_flat % dim_size\n idx_flat //= dim_size\n tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask)\n\n\ndef nonzero(inp, *, as_tuple=False):\n logging.debug(\"GEMS NONZERO\")\n\n inp_ndim = inp.ndim\n\n inp = inp.contiguous()\n n_elements = inp.numel()\n inp_view = inp.view(n_elements)\n\n shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device)\n\n inp_bool = inp_view\n if inp_view.dtype != torch.bool:\n inp_bool = inp_view != 0\n\n prefix_sum = inp_bool.cumsum(axis=0)\n\n num_nonzeros = n_elements\n out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)\n\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n with torch.cuda.device(inp.device):\n nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim)\n\n num_nonzeros = prefix_sum[n_elements - 1].item()\n out = out[0:num_nonzeros]\n\n if as_tuple:\n return torch.unbind(out, dim=0)\n else:\n return out\n", - "description_1": "Use triton language to implement a kernel function 'nonzero_kernel' that identifies non-zero elements in a flattened input tensor. The kernel takes 7 parameters: 'inp' (input tensor), 'prefix_sum' (cumulative sum of boolean input), 'out' (output tensor for non-zero indices), 'n_elements' (number of elements in input), 'shape' (shape of the input tensor), 'ndim' (number of dimensions, a compile-time constant), and 'BLOCK_SIZE' (block size, a compile-time constant). The function 'nonzero' is a wrapper that prepares the input, calculates the prefix sum, and calls the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel that computes the indices of non-zero elements in a tensor, and a wrapper function to handle input preparation and kernel invocation.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\nUNROLL = 4\n\n@triton.jit\ndef transform_func_tensor_tensor(val, std, mean):\n return val * std + mean\n\n@triton.jit\ndef transform_func_tensor_float(val, std, mean):\n return val * std + mean\n\n@triton.jit\ndef transform_func_float_tensor(val, std, mean):\n return val * std + mean\n\n@triton.jit\ndef transform_func_float_float(val, std, mean):\n return val * std + mean\n\ndef normal_distribution(mean, std, *, generator=None):\n shape = broadcast_shapes([mean.shape, std.shape])\n out = torch.empty(shape, device=mean.device, dtype=torch.float32)\n N = volume(shape)\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n\n increment = triton.cdiv(N, UNROLL)\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n with torch.cuda.device(mean.device):\n randn_kernel[grid_fn](out, N, philox_seed, philox_offset)\n return out\n\ndef normal_tensor_tensor(mean, std, *, generator=None):\n out = normal_distribution(mean, std)\n return transform_func_tensor_tensor(out, std, mean)\n\ndef normal_tensor_float(mean, std, *, generator=None):\n out = normal_distribution(mean, std)\n return transform_func_tensor_float(out, std, mean)\n\ndef normal_float_tensor(mean, std, *, generator=None):\n out = normal_distribution(mean, std)\n return transform_func_float_tensor(out, std, mean)\n\ndef normal_float_float(mean, std, *, generator=None):\n out = normal_distribution(mean, std)\n return transform_func_float_float(out, std, mean)\n", - "description_1": "Use triton language to define four kernels, each transforming values by multiplying with 'std' and adding 'mean'. Each kernel handles different combinations of tensor and float inputs for 'val', 'std', and 'mean'. Define functions to execute a normal distribution calculation using these kernels.", - "description_2": "Use triton language to implement kernels that perform element-wise transformations of input values using standard deviation and mean, and integrate these in a normal distribution workflow.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flag_gems.utils.shape_utils import volume\n\n# Triton kernel that sets all elements in the output tensor to 1.0\n@triton.jit\ndef ones_kernel(\n output_ptr, # Pointer to the output tensor in GPU memory\n n_elements, # Total number of elements to process\n BLOCK_SIZE: tl.constexpr, # Size of each block of threads\n):\n pid = tl.program_id(axis=0) # Get the block index\n block_start = pid * BLOCK_SIZE # Calculate the start index for this block\n offsets = block_start + tl.arange(0, BLOCK_SIZE) # Calculate offsets for each thread\n mask = offsets < n_elements # Mask to ensure we don't write out of bounds\n tl.store(output_ptr + offsets, 1.0, mask=mask) # Store 1.0 in all valid positions\n\n# Function to initialize a tensor of given size with ones using the Triton kernel\ndef ones(size, *, dtype=None, layout=None, device=None, pin_memory=None):\n if dtype is None:\n dtype = torch.get_default_dtype() # Use default PyTorch dtype if none provided\n if device is None:\n device = torch.device(\"cuda\") # Default to CUDA device\n\n out = torch.empty(size, device=device, dtype=dtype) # Create an empty tensor\n N = volume(size) # Calculate the total number of elements\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK_SIZE\"]),) # Determine the grid size\n with torch.cuda.device(device):\n ones_kernel[grid_fn](out, N, BLOCK_SIZE=1024) # Launch the Triton kernel\n return out # Return the initialized tensor\n", - "description_1": "Use triton language to create a kernel called ones_kernel that initializes a tensor with 1.0 values on the GPU. The kernel takes three arguments: output_ptr (the GPU memory pointer to the output tensor), n_elements (the total number of elements to process), and BLOCK_SIZE (a compile-time constant specifying the size of each thread block). The ones function wraps this kernel to accept standard tensor creation parameters like size, dtype, and device, computes the necessary grid size, and then launches the kernel on the GPU.", - "description_2": "Use triton language to develop a GPU kernel for initializing a tensor with ones, and provide a Python function to configure and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\n\n@triton.jit(do_not_specialize=[\"value\"])\ndef _jit_function(\n in0_ptr: tl.tensor, \n out0_ptr: tl.tensor, \n x_shape0: int, x_shape1: int, x_shape2: int, \n in_strides0: int, in_strides1: int, in_strides2: int, \n out_strides0: int, out_strides1: int, out_strides2: int, \n valid_dim0_start: int, valid_dim1_start: int, valid_dim2_start: int, \n valid_dim0_end: int, valid_dim1_end: int, valid_dim2_end: int, \n in_elem_cnt: tl.constexpr, \n out_elem_cnt: tl.constexpr, \n value, \n IS_CONSTANT: tl.constexpr, \n IS_REFLECT: tl.constexpr, \n IS_REPLICATE: tl.constexpr, \n IS_CIRCULAR: tl.constexpr, \n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n block_offset = pid * BLOCK_SIZE\n offset = block_offset + tl.arange(0, BLOCK_SIZE)\n remaining = offset\n idx = remaining // out_strides0\n dst_index_0 = idx\n remaining = remaining - idx * out_strides0\n idx = remaining // out_strides1\n dst_index_1 = idx\n remaining = remaining - idx * out_strides1\n dst_index_2 = remaining // out_strides2\n \n if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)\n if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)\n \n cond = (dst_index_0 >= valid_dim0_start and dst_index_0 < valid_dim0_end) \n cond &= (dst_index_1 >= valid_dim1_start and dst_index_1 < valid_dim1_end)\n cond &= (dst_index_2 >= valid_dim2_start and dst_index_2 < valid_dim2_end)\n \n if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)\n \n src_index_0 = dst_index_0 - valid_dim0_start \n src_index_1 = dst_index_1 - valid_dim1_start \n src_index_2 = dst_index_2 - valid_dim2_start \n \n src_index_0 = tl.where(src_index_0 < 0, 0, src_index_0)\n src_index_1 = tl.where(src_index_1 < 0, 0, src_index_1)\n src_index_2 = tl.where(src_index_2 < 0, 0, src_index_2)\n\n if IS_REFLECT: \n src_index_0 = tl.where(dst_index_0 < valid_dim0_start,\n valid_dim0_start - dst_index_0, src_index_0)\n src_index_1 = tl.where(dst_index_1 < valid_dim1_start,\n valid_dim1_start - dst_index_1, src_index_1)\n src_index_2 = tl.where(dst_index_2 < valid_dim2_start,\n valid_dim2_start - dst_index_2, src_index_2)\n\n src_index_0 = tl.where(dst_index_0 >= valid_dim0_end,\n (x_shape0 + valid_dim0_start - 1) * 2 - dst_index_0 - valid_dim0_start, src_index_0)\n src_index_1 = tl.where(dst_index_1 >= valid_dim1_end,\n (x_shape1 + valid_dim1_start - 1) * 2 - dst_index_1 - valid_dim1_start, src_index_1)\n src_index_2 = tl.where(dst_index_2 >= valid_dim2_end,\n (x_shape2 + valid_dim2_start - 1) * 2 - dst_index_2 - valid_dim2_start, src_index_2)\n\n if IS_REPLICATE: \n src_index_0 = tl.where(dst_index_0 < valid_dim0_start, 0, src_index_0)\n src_index_1 = tl.where(dst_index_1 < valid_dim1_start, 0, src_index_1)\n src_index_2 = tl.where(dst_index_2 < valid_dim2_start, 0, src_index_2)\n\n src_index_0 = tl.where(dst_index_0 >= valid_dim0_end, x_shape0 - 1, src_index_0)\n src_index_1 = tl.where(dst_index_1 >= valid_dim1_end, x_shape1 - 1, src_index_1)\n src_index_2 = tl.where(dst_index_2 >= valid_dim2_end, x_shape2 - 1, src_index_2)\n\n if IS_CIRCULAR: \n src_index_0 = tl.where(dst_index_0 < valid_dim0_start,\n dst_index_0 + x_shape0 - valid_dim0_start, src_index_0)\n src_index_1 = tl.where(dst_index_1 < valid_dim1_start,\n dst_index_1 + x_shape1 - valid_dim1_start, src_index_1)\n src_index_2 = tl.where(dst_index_2 < valid_dim2_start,\n dst_index_2 + x_shape2 - valid_dim2_start, src_index_2)\n\n src_index_0 = tl.where(dst_index_0 >= valid_dim0_end,\n dst_index_0 - valid_dim0_end, src_index_0)\n src_index_1 = tl.where(dst_index_1 >= valid_dim1_end,\n dst_index_1 - valid_dim1_end, src_index_1)\n src_index_2 = tl.where(dst_index_2 >= valid_dim2_end,\n dst_index_2 - valid_dim2_end, src_index_2)\n\n src_offset = src_index_0 * in_strides0 + src_index_1 * in_strides1 + src_index_2 * in_strides2\n\n load_cond = src_index_0 < x_shape0\n load_cond &= src_index_1 < x_shape1\n load_cond &= src_index_2 < x_shape2\n\n if IS_CONSTANT: \n x_val = tl.load(in0_ptr + src_offset, mask=(not if_pad) and load_cond, other=value)\n else: \n x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)\n \n tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)\n\n\ndef pad(self, pad, mode=\"constant\", value=None):\n BLOCK_SIZE = 256\n grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)\n\n x_shape = in0.shape\n in_strides0 = in0.stride()\n out_strides = out0.stride()\n\n if rank > 0:\n for i in range(rank):\n valid_dim_start = pad_before[i]\n valid_dim_end = dst_shape[i] - pad_after[i]\n\n IS_CONSTANT = mode == 'constant'\n IS_REFLECT = mode == 'reflect'\n IS_REPLICATE = mode == 'replicate'\n IS_CIRCULAR = mode == 'circular'\n\n with torch.cuda.device(in0.device):\n _jit_function[grid](\n in0, out0,\n x_shape[0], x_shape[1], x_shape[2], # shape for x\n in_strides0[0], in_strides0[1], in_strides0[2], # stride for x\n out_strides[0], out_strides[1], out_strides[2], # stride for out\n valid_dim0_start, valid_dim1_start, valid_dim2_start, # valid dim start\n valid_dim0_end, valid_dim1_end, valid_dim2_end, # valid dim end\n in0.numel(),\n out0.numel(),\n value,\n IS_CONSTANT,\n IS_REFLECT,\n IS_REPLICATE,\n IS_CIRCULAR,\n BLOCK_SIZE,\n )\n\n return out0\n", - "description_1": "Use triton language to define a padding kernel that processes multi-dimensional tensor inputs for various padding modes including constant, reflect, replicate, and circular. The kernel accepts input tensor pointers, output tensor pointers, shape dimensions, strides, valid start and end indices for each dimension, padding value, boolean flags for each mode, and block size for parallel execution. A corresponding wrapper function sets up grid configuration and calls the kernel to perform padding operation on given input tensor with specified padding parameters.", - "description_2": "Use triton language to implement a kernel that applies various padding strategies to tensors using input tensor pointers, output tensor pointers, shape dimensions, and mode parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport logging\n\n@triton.jit\ndef pow_func_tensor_scalar(x, exponent):\n # Apply power function to tensor `x` with a scalar `exponent`\n return tl.libdevice.pow(x.to(tl.float32), exponent)\n\ndef pow_tensor_scalar(A, exponent):\n logging.debug(\"GEMS POW_TENSOR_SCALAR\")\n return pow_func_tensor_scalar(A, exponent)\n\n@triton.jit\ndef pow_func_scalar_tensor(x, exponent):\n # Apply power function to scalar `x` with a tensor `exponent`\n return tl.libdevice.pow(x.to(tl.float32), exponent)\n\ndef pow_scalar(A, exponent):\n logging.debug(\"GEMS POW_SCALAR\")\n return pow_func_scalar_tensor(A, exponent)\n", - "description_1": "Use triton language to implement kernels that apply a power function. The 'pow_func_tensor_scalar' kernel takes a tensor 'x' and a scalar 'exponent' and computes x^exponent element-wise. The 'pow_tensor_scalar' function logs a debug message and calls 'pow_func_tensor_scalar'. The 'pow_func_scalar_tensor' kernel takes a scalar 'x' and a tensor 'exponent' and computes x^exponent element-wise. The 'pow_scalar' function logs a debug message and calls 'pow_func_scalar_tensor'.", - "description_2": "Use triton language to implement kernels that compute the power of a tensor raised to a scalar and a scalar raised to a tensor, both element-wise. Include logging for debugging.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Reduce multiplication kernel\n@triton.jit\ndef reduce_mul(a, b):\n return a * b\n\n# Kernel to compute product for intermediate results\n@triton.jit\ndef prod_kernel_mid(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)\n mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul)\n mid_ptr = mid + pid\n tl.store(mid_ptr, mid_value.to(inp_val.dtype))\n\n# Kernel to compute final product result\n@triton.jit\ndef prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32)\n prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul)\n tl.store(out, prod_val)\n\n# Product function calling prod_kernel_mid and prod_kernel_result\ndef prod(inp, *, dtype=None):\n if dtype is None:\n dtype = inp.dtype\n\n M = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n out = torch.empty([], dtype=dtype, device=inp.device)\n\n with torch.cuda.device(inp.device):\n prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size)\n prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid)\n return out\n\n# Autotuned and heuristic-based product kernel\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 8}, num_warps=8),\n triton.Config({\"BLOCK_M\": 16}, num_warps=8),\n triton.Config({\"BLOCK_M\": 32}, num_warps=8),\n ],\n key=[\n \"M\",\n \"N\",\n ],\n)\n@triton.jit\ndef prod_kernel(\n inp,\n out,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n offset_index = m_offset * K + pid_k\n mask1 = m_offset < M\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n inp_ptrs = inp + offset\n inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32)\n result_index = tl.reduce(inp_vals, axis=1, combine_fn=reduce_mul)\n\n out_ptrs = out + offset_index\n tl.store(out_ptrs, result_index, mask=mask1)\n\n# Function calling the autotuned product kernel\ndef prod_dim(inp, dim=None, keepdim=False, *, dtype=None):\n shape = inp.shape\n dim = dim % inp.ndim\n N = shape[dim]\n M = math.prod(shape[:dim])\n K = inp.numel() // M // N\n\n inp = inp.contiguous()\n\n shape_list = list(shape)\n shape_list[dim] = 1\n\n if dtype is None:\n dtype = inp.dtype\n out = torch.empty(shape_list, dtype=dtype, device=inp.device)\n if not keepdim:\n out = torch.squeeze(out, dim)\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n prod_kernel[grid](inp, out, M, N, K)\n\n return out\n", - "description_1": "Use triton language to define and compute element-wise product using reduce_mul, compute intermediate products with prod_kernel_mid, and finalize with prod_kernel_result. Use prod_kernel for dimensional products with autotuning and heuristics to optimize BLOCK_M and BLOCK_N.", - "description_2": "Use triton language to compute element-wise and dimensional products optimized with autotuning and heuristics.", - "difficulty": 4 - }, - { - "code": "import logging\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float\nfrom flag_gems.utils.shape_utils import volume\n\n\ndef heur_block(args):\n if args[\"N\"] <= 512:\n return 512\n else:\n return 1024\n\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 512:\n return 4\n elif args[\"N\"] <= 1024:\n return 8\n else:\n return 16\n\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"philox_seed\", \"philox_offset\"])\ndef rand_kernel(\n out_ptr,\n N,\n philox_seed,\n philox_offset,\n BLOCK: tl.constexpr,\n):\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n r0 = uint_to_uniform_float(r0)\n r1 = uint_to_uniform_float(r1)\n r2 = uint_to_uniform_float(r2)\n r3 = uint_to_uniform_float(r3)\n off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\n\nUNROLL = 4\n\n\ndef rand(size, *, dtype=None, layout=None, device=None, pin_memory=None):\n logging.debug(\"GEMS RAND\")\n if dtype is None:\n dtype = torch.get_default_dtype()\n if device is None:\n device = torch.device(\"cuda\")\n\n out = torch.empty(size, device=device, dtype=dtype)\n N = volume(size)\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,\n # hence we cannot obtain the per thread offset as in Pytorch.\n increment = triton.cdiv(N, UNROLL)\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n with torch.cuda.device(device):\n rand_kernel[grid_fn](out, N, philox_seed, philox_offset)\n return out\n", - "description_1": "Use triton language to implement a random number generator kernel (`rand_kernel`) and a function (`rand`) to invoke this kernel. The kernel accepts five parameters: `out_ptr` (pointer to output memory), `N` (number of random numbers), `philox_seed` and `philox_offset` (used for generating random numbers), and `BLOCK` (block size for parallel computation). The function `rand` initializes output tensor and calls `rand_kernel` with appropriate grid and block configurations.", - "description_2": "Use triton language to create a random number generator kernel and a Python function to execute this kernel with specified parameters for generating random numbers.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float\nfrom flag_gems.utils.shape_utils import volume\n\n@triton.jit\ndef pair_uniform_to_normal(u1, u2):\n \"\"\"Box-Muller transform\"\"\"\n u1 = tl.maximum(1.0e-7, u1)\n th = 6.283185307179586 * u2\n r = tl.sqrt(-2.0 * tl.log(u1))\n return r * tl.cos(th), r * tl.sin(th)\n\n\ndef heur_block(args):\n if args[\"N\"] <= 512:\n return 512\n else:\n return 1024\n\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 512:\n return 4\n elif args[\"N\"] <= 1024:\n return 8\n else:\n return 16\n\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"philox_seed\", \"philox_offset\"])\ndef randn_kernel(\n out_ptr,\n N,\n philox_seed,\n philox_offset,\n BLOCK: tl.constexpr,\n):\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n r0 = uint_to_uniform_float(r0)\n r1 = uint_to_uniform_float(r1)\n r2 = uint_to_uniform_float(r2)\n r3 = uint_to_uniform_float(r3)\n n0, n1 = pair_uniform_to_normal(r0, r1)\n n2, n3 = pair_uniform_to_normal(r2, r3)\n off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\n\nUNROLL = 4\n\n\ndef randn(size, *, dtype=None, layout=None, device=None, pin_memory=None):\n if dtype is None:\n dtype = torch.get_default_dtype()\n if device is None:\n device = torch.device(\"cuda\")\n out = torch.empty(size, device=device, dtype=dtype)\n N = volume(size)\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n increment = triton.cdiv(N, UNROLL)\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n with torch.cuda.device(device):\n randn_kernel[grid_fn](out, N, philox_seed, philox_offset)\n return out\n", - "description_1": "Use triton language to define and invoke a kernel (randn_kernel) that generates random numbers on a GPU using the Philox algorithm for random number generation. It utilizes heuristics to determine block sizes and number of warps for efficient execution, and employs a custom Box-Muller transform (pair_uniform_to_normal) to convert uniform random numbers to normally distributed numbers. It has 5 parameters: output pointer (out_ptr), total number of elements (N), Philox seed (philox_seed), Philox offset (philox_offset), and block size (BLOCK) as a compile-time constant.", - "description_2": "Use triton language to generate normally distributed random numbers on a GPU using a customized Box-Muller transform and the Philox random number generator.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef relu_forward(x):\n # Element-wise ReLU operation\n return tl.where(x > 0, x, 0)\n\n@triton.jit\ndef relu_backward(x, dy):\n # Element-wise ReLU backward operation\n return tl.where(x > 0, dy, 0)\n\nclass Relu(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A):\n # Forward pass for ReLU\n out = relu_forward(A)\n ctx.save_for_backward(A)\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n # Backward pass for ReLU\n (inp,) = ctx.saved_tensors\n in_grad = relu_backward(inp, out_grad)\n return in_grad\n\ndef relu(A):\n # Apply the custom autograd function\n return Relu.apply(A)\n", - "description_1": "Use triton language to implement a ReLU activation function with two kernels: relu_forward and relu_backward. The relu_forward kernel takes one argument, x, which is a tensor, and applies the ReLU operation element-wise. The relu_backward kernel takes two arguments, x and dy, where x is the input tensor and dy is the gradient of the output, and computes the gradient of the input for the backward pass. The Relu class wraps these kernels for use in PyTorch's autograd system.", - "description_2": "Use triton language to create a ReLU activation function with forward and backward kernels, and integrate it with PyTorch's autograd.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom flag_gems.utils.shape_utils import volume\nfrom flag_gems.utils.libentry import libentry\n\ndef generate_destination_passing_repeat_wrapper(\n rank: int,\n wrapper_name: str,\n kernel_name: str,\n code: IndentedBuffer,\n) -> IndentedBuffer:\n parameters: str = parameter_for_wrapper_out()\n wrapper_signature: str = f\"def {wrapper_name}({parameters}):\"\n code.writeline(wrapper_signature)\n\n with code.indent():\n if rank > 0:\n code.writeline(\"shape = out0.shape\")\n code.writeline(\"num_tasks = volume(shape)\")\n\n if rank > 0:\n code.writeline(\"tile_size = min(512, triton.next_power_of_2(num_tasks))\")\n code.writeline(\"num_warps = 4\")\n code.writeline(\"num_ctas = min(65535, triton.cdiv(num_tasks, tile_size))\")\n code.writeline(\n \"tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)\"\n )\n else:\n code.writeline(\"num_warps = 1\")\n code.writeline(\"num_ctas = 1\")\n code.writeline(\"grid = (num_ctas, 1, 1)\")\n code.newline()\n\n if rank > 0:\n code.writeline(\"# strides of each tensor argument w.r.t the task space\")\n code.writeline(\"in0_strides = in0.stride()\")\n code.writeline(\"in0_shape = in0.shape\")\n code.writeline(\"out0_strides = out0.stride()\")\n code.newline()\n\n code.writeline(\"# kernel launch\")\n\n code.writeline(\"with torch.cuda.device(in0.device.index):\")\n with code.indent():\n kernel_launch: str = f\"{kernel_name}[grid](\"\n code.writeline(kernel_launch)\n\n with code.indent():\n code.writeline(\"in0, out0, \")\n\n if rank > 0:\n s = \", \".join(f\"in0_strides[{j}]\" for j in range(rank))\n code.writeline(f\"{s}, # stride for in0\")\n\n s = \", \".join(f\"out0_strides[{j}]\" for j in range(rank))\n code.writeline(f\"{s}, # stride for out0\")\n\n shape_args: str = \", \".join(f\"shape[{i}]\" for i in range(rank))\n code.writeline(f\"{shape_args}, # task indexing space\")\n in_shape_args: str = \", \".join(f\"in0_shape[{i}]\" for i in range(rank))\n code.writeline(\n f\"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape\"\n )\n code.writeline(\"num_tasks, # num tasks\")\n code.writeline(\"tiles_per_cta=tiles_per_cta, # tiles_per_cta\")\n code.writeline(\"tile_size=tile_size,\")\n code.writeline(\"one_tile_per_cta=tiles_per_cta==1,\")\n code.writeline(\"num_warps=num_warps,\")\n code.writeline(\")\")\n\n code.writeline(\"return out0\")\n code.newline()\n code.newline()\n return code\n\ndef generate_repeat_kernel(\n rank: int,\n kernel_name: str,\n code: IndentedBuffer,\n) -> IndentedBuffer:\n code.newline()\n\n code.writeline(\"@libentry()\")\n code.writeline(\"@triton.jit\")\n\n code.writeline(f\"def {kernel_name}(\")\n function_ns = NameSpace()\n with code.indent():\n code.writeline(\"in0_ptr: tl.tensor, # of tl.pointer_type\")\n function_ns.create_name(\"in0_ptr\")\n\n code.writeline(\"out0_ptr: tl.tensor, # of tl.pointer_type\")\n function_ns.create_name(\"out0_ptr\")\n\n if rank > 0:\n for j in range(rank):\n function_ns.create_name(f\"in0_stride{j}\")\n stride_args = \", \".join(f\"in0_stride{j}: int\" for j in range(rank))\n code.writeline(f\"{stride_args}, # strides for in0\")\n\n for j in range(rank):\n function_ns.create_name(f\"out0_stride{j}\")\n stride_args = \", \".join(f\"out0_stride{j}: int\" for j in range(rank))\n code.writeline(f\"{stride_args}, # strides for out0\")\n\n task_space_args = \", \".join(f\"s{i}: int\" for i in range(rank))\n for i in range(rank):\n function_ns.create_name(f\"s{i}\")\n code.writeline(f\"{task_space_args}, # task_space\")\n\n task_space_args2 = \", \".join(f\"in_s{i}: int\" for i in range(rank))\n for i in range(rank):\n function_ns.create_name(f\"in_s{i}\")\n code.writeline(\n f\"{task_space_args2}, # task_space2 used when input and output tensor has different shape\"\n )\n\n code.writeline(\"num_tasks: int,\")\n function_ns.create_name(\"num_tasks\")\n\n if rank > 0:\n code.writeline(\"tiles_per_cta,\")\n function_ns.create_name(\"tiles_per_cta\")\n\n code.writeline(\"tile_size: tl.constexpr,\")\n function_ns.create_name(\"tile_size\")\n\n code.writeline(\"one_tile_per_cta: tl.constexpr,\")\n function_ns.create_name(\"one_tile_per_cta\")\n code.writeline(\"):\")\n\n with code.indent():\n code.writeline(\"# task id & masking\")\n pid_stmt = \"pid = tl.program_id(0)\"\n code.writeline(pid_stmt)\n function_ns.create_name(\"pid\")\n\n code.writeline(\"num_ctas = tl.num_programs(0)\")\n function_ns.create_name(\"num_ctas\")\n\n tid_stmt = \"init_tid = pid * tile_size + tl.arange(0, tile_size)\"\n code.writeline(tid_stmt)\n function_ns.create_name(\"init_tid\")\n\n code.writeline(\"if one_tile_per_cta: # monolitic kernel style\")\n with code.indent():\n tid_stmt = \"tid = init_tid\"\n code.writeline(tid_stmt)\n function_ns.create_name(\"tid\")\n\n mask_stmt: str = \"mask = tid < num_tasks\"\n code.writeline(mask_stmt)\n function_ns.create_name(\"mask\")\n code.newline()\n\n code.writeline(\"# multi index recontruction\")\n for i in reversed(range(rank)):\n if i > 0:\n code.writeline(f\"i{i} = tid % s{i}\")\n code.writeline(f\"tid //= s{i}\")\n else:\n code.writeline(f\"i{i} = tid\")\n function_ns.create_name(f\"{i}\")\n code.newline()\n\n code.writeline(\"# loads\")\n ptrs_expr: str = \" + \".join(\n f\"(i{j} % in_s{j}) * in{i}_stride{j}\" for j in range(rank)\n )\n ptrs_expr: str = f\"in0_ptr + {ptrs_expr}\"\n load_stmt: str = f\"in0 = tl.load({ptrs_expr}, mask=mask)\"\n function_ns.create_name(\"in0\") \n code.writeline(load_stmt)\n code.newline()\n\n code.writeline(\"# compute\")\n code.writeline(\"out0 = in0\")\n code.newline()\n\n code.writeline(\"# stores\")\n ptrs_expr: str = \" + \".join(f\"i{j} * out0_stride{j}\" for j in range(rank))\n ptrs_expr: str = f\"out0_ptr + {ptrs_expr}\"\n store_stmt: str = f\"tl.store({ptrs_expr}, out0, mask=mask)\"\n code.writeline(store_stmt)\n\n code.writeline(\"else: # grid-stride-loop style kernel\")\n with code.indent():\n code.writeline(\"for j in range(0, tiles_per_cta):\")\n function_ns.create_name(\"j\")\n with code.indent():\n tid_stmt = \"tid = init_tid + j * tile_size * num_ctas\"\n code.writeline(tid_stmt)\n function_ns.create_name(\"tid\")\n\n mask_stmt: str = \"mask = tid < num_tasks\"\n code.writeline(mask_stmt)\n function_ns.create_name(\"mask\")\n code.newline()\n\n code.writeline(\"# multi index recontruction\")\n for i in reversed(range(rank)):\n if i > 0:\n code.writeline(f\"i{i} = tid % s{i}\")\n code.writeline(f\"tid //= s{i}\")\n else:\n code.writeline(f\"i{i} = tid\")\n function_ns.create_name(f\"{i}\")\n code.newline()\n\n code.writeline(\"# loads\")\n ptrs_expr: str = \" + \".join(\n f\"(i{j} % in_s{j}) * in{i}_stride{j}\" for j in range(rank)\n )\n ptrs_expr: str = f\"in0_ptr + {ptrs_expr}\"\n load_stmt: str = f\"in0 = tl.load({ptrs_expr}, mask=mask)\"\n function_ns.create_name(\"in0\") \n code.writeline(load_stmt)\n code.newline()\n\n code.writeline(\"# compute\")\n code.writeline(\"out0 = in0\")\n code.newline()\n\n code.writeline(\"# stores\")\n ptrs_expr: str = \" + \".join(\n f\"i{j} * out0_stride{j}\" for j in range(rank)\n )\n ptrs_expr: str = f\"out0_ptr + {ptrs_expr}\"\n store_stmt: str = f\"tl.store({ptrs_expr}, out0, mask=mask)\"\n code.writeline(store_stmt)\n code.newline()\n return code\n", - "description_1": "Use triton language to implement a repeat function with a kernel that supports multi-dimensional tensors. The kernel has parameters including input and output pointers, strides for inputs and outputs, task spaces, number of tasks, tile size, and execution style (monolithic or grid-stride-loop). The functional wrapper prepares the task grid, computes strides, and calls the kernel.", - "description_2": "Use triton language to implement a repeat kernel that loads data, computes repeat operation, and stores the result efficiently using grid-stride-loop technique and tensor task spaces.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit\ndef copy_func(x):\n return x\n\ndef repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):\n if dim is None:\n inp = inp.flatten()\n dim = 0\n else:\n if (dim < -inp.ndim) or (dim >= inp.ndim):\n raise IndexError(\n \"Dimension out of range (expected to be in range of [{}, {}], but got {})\".format(\n -inp.ndim, inp.ndim - 1, dim\n )\n )\n inp_shape = list(inp.shape)\n inp_stride = list(inp.stride())\n output_shape = list(inp.shape)\n\n if dim < 0:\n dim = dim + len(inp_shape)\n\n output_shape[dim] *= repeats\n\n if output_size is not None and output_size != output_shape[dim]:\n raise RuntimeError(\n \"repeat_interleave: Invalid output_size, expected {} but got {}\".format(\n output_shape[dim], output_size\n )\n )\n\n output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)\n\n if repeats == 0:\n return output\n\n in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]\n out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]\n out_view_stride = c_contiguous_stride(out_view_shape)\n\n in_view = StridedBuffer(inp, out_view_shape, in_view_stride)\n out_view = StridedBuffer(output, out_view_shape, out_view_stride)\n ndim = len(out_view_shape)\n copy_func.instantiate(ndim)(in_view, out0=out_view)\n return output\n\n@triton.jit\ndef repeat_interleave_tensor_kernel(\n repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr\n):\n pid = tl.program_id(0)\n mask = pid < size\n cumsum = tl.load(cumsum_ptr + pid, mask, other=0)\n repeats = tl.load(repeats_ptr + pid, mask, other=0)\n out_offset = cumsum - repeats\n\n tl.device_assert(repeats >= 0, \"repeats can not be negative\")\n\n out_ptr += out_offset\n for start_k in range(0, repeats, BLOCK_SIZE):\n offsets_k = start_k + tl.arange(0, BLOCK_SIZE)\n mask_k = offsets_k < repeats\n tl.store(out_ptr + offsets_k, pid, mask=mask_k)\n\ndef repeat_interleave_tensor(repeats, *, output_size=None):\n assert repeats.ndim == 1, \"repeat_interleave only accept 1D vector as repeat\"\n\n cumsum = repeats.cumsum(axis=0)\n result_size = cumsum[-1].item()\n\n assert result_size >= 0, \"repeats can not be negative\"\n\n out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)\n size = repeats.size(0)\n\n grid = (size,)\n BLOCK_SIZE = 32\n repeat_interleave_tensor_kernel[grid](\n repeats,\n cumsum,\n out,\n size,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=1,\n )\n return out\n", - "description_1": "Use triton language to implement two kernels: 'copy_func' and 'repeat_interleave_tensor_kernel'. 'copy_func' takes one argument 'x' and returns it. 'repeat_interleave_tensor_kernel' takes five arguments: 'repeats_ptr', 'cumsum_ptr', 'out_ptr', 'size', and 'BLOCK_SIZE'. It performs repeat interleave operation on a tensor using the given repeat counts and cumulative sum, storing the result in 'out_ptr'. The function 'repeat_interleave_self_int' calls 'copy_func' to repeat elements of a tensor along a specified dimension. The function 'repeat_interleave_tensor' calls 'repeat_interleave_tensor_kernel' to repeat elements of a 1D tensor based on repeat counts.", - "description_2": "Use triton language to create a kernel for repeating elements of a tensor along a specified dimension. Use triton language to create a kernel for repeating elements of a 1D tensor based on repeat counts.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _scatter_jit_function(\n src_strided,\n index,\n inp,\n out,\n inp_stride_0: int,\n inp_stride_1: int,\n index_stride_0: int,\n index_stride_1: int,\n index_shape_0: int,\n index_shape_1: int,\n dim,\n stride_dim,\n M,\n N,\n IS_ADD: tl.constexpr,\n IS_MUL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid_x = tl.program_id(0)\n pid_y = tl.program_id(1)\n rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]\n rows_mask = rows_offsets < M\n cols_mask = cols_offsets < N\n\n offsets = (rows_offsets * N + cols_offsets).to(tl.int64)\n mask = rows_mask & cols_mask\n\n inp_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)\n idx_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)\n cur_idx = rows_offsets * N + cols_offsets\n\n mod = cur_idx % index_shape_0\n inp_offsets += mod * inp_stride_0\n idx_offsets += mod * index_stride_0\n cur_idx = cur_idx // index_shape_0\n\n mod = cur_idx % index_shape_1\n inp_offsets += mod * inp_stride_1\n idx_offsets += mod * index_stride_1\n\n cur_src = tl.load(src_strided + idx_offsets, mask=mask, other=0)\n cur_index = tl.load(index + idx_offsets, mask=mask, other=0)\n inp_offsets += cur_index * stride_dim\n\n if IS_ADD: \n cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)\n res = cur_inp + cur_src\n tl.store(out + inp_offsets, res, mask=mask)\n elif IS_MUL:\n cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)\n res = cur_inp * cur_src\n tl.store(out + inp_offsets, res, mask=mask)\n else:\n tl.store(out + inp_offsets, cur_src, mask=mask)\n\ndef _scatter_wrapper(src_strided, index, inp, out, dim, M, N, reduce):\n inp_strides = list(inp.stride())\n index_strides = index.stride()\n index_shapes = list(index.shape)\n stride_dim = inp_strides[dim]\n inp_strides[dim] = 0\n\n IS_ADD = reduce == \"add\"\n IS_MUL = reduce == \"multiply\"\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n triton.cdiv(N, meta[\"BLOCK_N\"])\n )\n\n _scatter_jit_function[grid](\n src_strided, index, inp, out, \n inp_strides[0], inp_strides[1],\n index_strides[0], index_strides[1],\n index_shapes[0], index_shapes[1],\n dim, stride_dim, M, N,\n IS_ADD, IS_MUL\n )\n return out\n\ndef scatter(inp, dim, index, src, reduce=None):\n inp = inp.contiguous()\n index = index.contiguous()\n src = src.contiguous()\n out = inp.clone()\n\n src_strided = src.as_strided(index.shape, src.stride()).contiguous()\n N = list(index.shape)[index.ndim - 1]\n M = index.numel() // N\n\n return _scatter_wrapper(src_strided, index, inp, out, dim, M, N, reduce)\n", - "description_1": "Use triton language to implement a scatter operation with optional reduction (add or multiply) in a custom kernel. The kernel function '_scatter_jit_function' takes 18 parameters, including source and index tensors, strides and shapes of the input and index, dimensions and stride dimensions for calculation, grid size (M and N), and constants for reduction type and block size. The function calculates offsets, applies the specified reduction (if any), and stores the result. The wrapper function '_scatter_wrapper' prepares inputs and calls the kernel, while the 'scatter' function acts as the main API for users to input tensors and specify the operation.", - "description_2": "Use triton language to perform a scatter operation with optional reduction in a customized kernel. It requires inputs such as source tensor, index, and operation details, executed by '_scatter_jit_function'.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.libdevice import exp2\n\n@triton.jit\ndef sigmoid_forward(x):\n # log2e: tl.constexpr = math.log2(math.e)\n # triton 3.0.0 disallow calling non-jitted function inside jitted function, even if it is in\n # the rhs of an assignment to a constexpr, so we use numeric literal instead to work around this.\n log2e: tl.constexpr = 1.4426950408889634\n return 1 / (1 + exp2(-x.to(tl.float32) * log2e))\n\n@triton.jit\ndef sigmoid_backward(y, dy):\n y_f32 = y.to(tl.float32)\n dy_f32 = dy.to(tl.float32)\n return dy_f32 * (1.0 - y_f32) * y_f32\n\nclass Sigmoid(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A):\n if A.requires_grad is True:\n out = sigmoid_forward(A.to(torch.float32))\n ctx.save_for_backward(out)\n return out.to(A.dtype)\n else:\n out = sigmoid_forward(A)\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n (out,) = ctx.saved_tensors\n in_grad = sigmoid_backward(out, out_grad)\n return in_grad\n\ndef sigmoid(A):\n return Sigmoid.apply(A)\n", - "description_1": "Use triton language to implement a sigmoid function with two kernels: sigmoid_forward and sigmoid_backward. The sigmoid_forward kernel takes one argument, x, which is a tensor, and computes the sigmoid function using a constant log2e. The sigmoid_backward kernel takes two arguments, y and dy, which are tensors, and computes the gradient of the sigmoid function. The Sigmoid class wraps these kernels for use in PyTorch's autograd system, with forward and backward methods handling the computation and gradient propagation respectively.", - "description_2": "Use triton language to create a sigmoid function with forward and backward kernels for PyTorch autograd. The forward kernel computes the sigmoid, and the backward kernel computes its gradient.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.libdevice import div_rn\n\n@triton.jit\ndef silu_forward(x):\n # Convert input to float32\n x_fp32 = x.to(tl.float32)\n # Compute SiLU activation function\n y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))\n return y\n\n@triton.jit\ndef silu_backward(x, dy):\n # Convert inputs to float32\n dy_fp32 = dy.to(tl.float32)\n x_fp32 = x.to(tl.float32)\n # Compute the gradient of SiLU\n sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32))\n dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma))\n return dx\n\nclass Silu(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A):\n out = silu_forward(A)\n ctx.save_for_backward(A)\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n (inp,) = ctx.saved_tensors\n in_grad = silu_backward(inp, out_grad)\n return in_grad\n\ndef silu(A):\n return Silu.apply(A)\n", - "description_1": "Use triton language to implement the SiLU activation function and its gradient for autograd. The kernel 'silu_forward' computes the SiLU function, taking 1 argument: a tensor 'x'. It returns the result after applying the SiLU operation. The kernel 'silu_backward' computes the gradient, taking 2 arguments: a tensor 'x' and a tensor 'dy'. It returns the gradient 'dx'. The class 'Silu' implements the forward and backward functions for autograd using these kernels.", - "description_2": "Use triton language to create kernels for the SiLU activation function and its gradient, and integrate them with PyTorch's autograd.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef heur_tile_k(args):\n tile_k = 1\n MAX_TILE_K = 8192\n NUM_SMS = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count\n upper_bound = min(args[\"K\"], MAX_TILE_K)\n while tile_k <= upper_bound:\n num_blocks = args[\"M\"] * triton.cdiv(args[\"K\"], tile_k)\n num_waves = num_blocks / NUM_SMS\n if (num_waves > 1) and (tile_k * 2 <= upper_bound):\n tile_k *= 2\n else:\n break\n return tile_k\n\ndef heur_tile_n_non_inner(args):\n return triton.cdiv(8192, args[\"TILE_K\"])\n\ndef heur_one_tile_per_cta(args):\n return args[\"TILE_N\"] >= args[\"N\"]\n\ndef heur_num_warps_non_inner(args):\n tile_size = args[\"TILE_N\"] * args[\"TILE_K\"]\n if tile_size < 2048:\n return 4\n elif tile_size < 4096:\n return 8\n else:\n return 16\n\n@triton.heuristics(\n {\n \"TILE_K\": heur_tile_k,\n \"TILE_N\": heur_tile_n_non_inner,\n \"ONE_TILE_PER_CTA\": heur_one_tile_per_cta,\n \"num_warps\": heur_num_warps_non_inner,\n }\n)\n@triton.jit\ndef softmax_kernel_non_inner(\n output_ptr,\n input_ptr,\n M,\n N,\n K,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n ONE_TILE_PER_CTA: tl.constexpr,\n):\n # Kernel logic here...\n pass\n\ndef heur_tile_n_inner(args):\n if args[\"N\"] <= (32 * 1024):\n return triton.next_power_of_2(args[\"N\"])\n else:\n return 4096\n\ndef heur_num_warps_inner(args):\n tile_size = args[\"TILE_N\"]\n if tile_size < 2048:\n return 4\n elif tile_size < 4096:\n return 8\n else:\n return 16\n\n@triton.heuristics(\n {\n \"TILE_N\": heur_tile_n_inner,\n \"ONE_TILE_PER_CTA\": heur_one_tile_per_cta,\n \"num_warps\": heur_num_warps_inner,\n }\n)\n@triton.jit\ndef softmax_kernel_inner(\n output_ptr,\n input_ptr,\n M,\n N,\n TILE_N: tl.constexpr,\n ONE_TILE_PER_CTA: tl.constexpr,\n):\n # Kernel logic here...\n pass\n\nclass Softmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, dim, dtype):\n assert dim >= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i] # pre_dim\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n out = torch.empty_like(inp, dtype=dtype)\n K = inp.numel() // M // N # post_dim\n\n with torch.cuda.device(inp.device):\n if K > 1:\n grid = lambda meta: (M, triton.cdiv(K, meta[\"TILE_K\"]), 1)\n softmax_kernel_non_inner[grid](\n out,\n inp,\n M,\n N,\n K,\n )\n else:\n grid = (M, 1, 1)\n softmax_kernel_inner[grid](\n out,\n inp,\n M,\n N,\n )\n ctx.save_for_backward(out)\n ctx.dim = dim\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n dim = ctx.dim\n (out,) = ctx.saved_tensors\n assert dim >= -out.ndim and dim < out.ndim, \"Invalid dim\"\n dim = dim % out.ndim\n M = 1\n N = out.shape[dim]\n for i in range(dim):\n M *= out.shape[i]\n\n out_grad = out_grad.contiguous()\n in_grad = torch.empty_like(out)\n K = out.numel() // M // N\n\n with torch.cuda.device(in_grad.device):\n if K > 1:\n grid = lambda meta: (M, triton.cdiv(K, meta[\"TILE_K\"]), 1)\n softmax_backward_kernel_non_inner[grid](\n out,\n out_grad,\n in_grad,\n M,\n N,\n K,\n )\n else:\n grid = lambda meta: (triton.cdiv(M, meta[\"TILE_M\"]), 1, 1)\n softmax_backward_kernel_inner[grid](\n out,\n out_grad,\n in_grad,\n M,\n N,\n )\n return in_grad, None, None\n\ndef softmax(x, dim=-1, dtype=None):\n return Softmax.apply(x, dim, dtype)\n", - "description_1": "Use triton language to implement a softmax function for input tensors, utilizing different strategies for computing the softmax when the post-dimension (K) is larger than 1 or equals 1. The code includes forward and backward pass implementations with heuristics to adjust tile sizes for optimal performance.", - "description_2": "Use triton language to optimize the softmax function with kernel implementations for efficient GPU computation, including forward and backward passes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef copy_func(x):\n return x\n\ndef stack(\n tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0\n) -> torch.Tensor:\n if len(tensors) == 0:\n raise RuntimeError(\"stack expected a non-empty TensorList\")\n\n inp_shapes = [list(_.shape) for _ in tensors]\n inp0_shape = inp_shapes[0]\n for i, s in enumerate(inp_shapes[1:]):\n if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()):\n raise IndexError(\n \"Dimension out of range (expected to be in range of [{}, {}], but got {})\".format(\n -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim\n )\n )\n if s != inp0_shape:\n raise RuntimeError(\n f\"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i+1}\"\n )\n\n if dim < 0:\n dim = dim + len(inp0_shape) + 1\n\n in0_shape = inp0_shape[:dim] + [1] + inp0_shape[dim:]\n out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]\n out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device)\n out0_strides = out0.stride()\n out0_offsets = list(\n itertools.accumulate([out0_strides[dim] for _ in inp_shapes[:-1]], initial=0)\n )\n\n for a, out0_offset in zip(tensors, out0_offsets):\n a = a.reshape(in0_shape)\n in_view = StridedBuffer(a, in0_shape, a.stride())\n out_view = StridedBuffer(out0, in0_shape, out0.stride(), offset=out0_offset)\n copy_func.instantiate(a.ndim)(in_view, out0=out_view)\n\n return out0\n", - "description_1": "Use triton language to create a kernel function 'copy_func' that takes a single argument 'x' and returns it unchanged. Then, implement a function 'stack' that stacks a list or tuple of PyTorch tensors along a specified dimension. It checks the shapes of input tensors for compatibility and reshapes them before using 'copy_func' to copy data into an output tensor.", - "description_2": "Use triton language to create a basic copy kernel and a stack function for tensors.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef sub_func(x, y, alpha):\n return x - y * alpha\n\n@triton.jit\ndef sub_func_tensor_scalar(x, y, alpha):\n return x - y * alpha\n\n@triton.jit\ndef sub_func_scalar_tensor(x, y, alpha):\n return x - y * alpha\n\ndef sub(A, B, *, alpha=1):\n if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n return sub_func(A, B, alpha)\n elif isinstance(A, torch.Tensor):\n return sub_func_tensor_scalar(A, B, alpha)\n elif isinstance(B, torch.Tensor):\n return sub_func_scalar_tensor(A, B, alpha)\n else:\n return torch.tensor(A - B * alpha)\n", - "description_1": "Use triton language to define three kernels: sub_func, sub_func_tensor_scalar, and sub_func_scalar_tensor. Each kernel takes three parameters: x, y, and alpha. The kernels perform element-wise subtraction of y multiplied by alpha from x. The sub function determines which kernel to call based on whether A and B are tensors or scalars.", - "description_2": "Use triton language to create kernels for element-wise subtraction with scalar and tensor inputs. Implement a function to select the appropriate kernel based on input types.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef sum_kernel_1(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=0.0).to(tl.float32)\n sum_val = tl.sum(inp_val)\n mid_ptr = mid + pid\n tl.store(mid_ptr, sum_val)\n\n@triton.jit\ndef sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=0.0).to(tl.float32)\n sum_val = tl.sum(mid_val)\n tl.store(out, sum_val)\n\ndef cfggen():\n block_m = [1, 2, 4, 8]\n configs = [\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": 1024}, num_warps=4) for m in block_m\n ]\n return configs\n\n@triton.jit\ndef sum_kernel(\n inp,\n out,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n inp = inp + pid * N\n out = out + pid\n row_mask = pid < M\n\n _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n\n a = tl.load(inp + cols, mask, other=0.0).to(tl.float32)\n _sum += a\n sum = tl.sum(_sum, axis=1)[:, None]\n tl.store(out, sum, row_mask)\n\ndef sum(inp, *, dtype=None):\n M = inp.numel()\n if dtype is None:\n dtype = inp.dtype\n if dtype is torch.bool:\n inp = inp.to(torch.int64)\n dtype = torch.int64\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n out = torch.empty([], dtype=dtype, device=inp.device)\n\n with torch.cuda.device(inp.device):\n sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)\n sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)\n return out\n\ndef sum_dim(inp, dim=None, keepdim=False, *, dtype=None):\n if dtype is None:\n dtype = inp.dtype\n if dtype is torch.bool:\n dtype = torch.int64\n\n shape = list(inp.shape)\n dim = [d % inp.ndim for d in dim]\n inp = dim_compress(inp, dim)\n N = 1\n for i in dim:\n N *= shape[i]\n shape[i] = 1\n M = inp.numel() // N\n\n out = torch.empty(shape, dtype=dtype, device=inp.device)\n\n grid = lambda meta: (triton.cdiv(M, meta[\"BLOCK_M\"]),)\n with torch.cuda.device(inp.device):\n sum_kernel[grid](inp, out, M, N)\n if not keepdim:\n out = out.squeeze(dim=dim)\n return out\n", - "description_1": "Use triton language to define three kernels for summing up elements in an input tensor. The first kernel, sum_kernel_1, takes four arguments: inp (the input tensor), mid (an intermediate tensor), M (the size of the input), and BLOCK_SIZE (a constexpr). The second kernel, sum_kernel_2, takes four arguments: mid (the intermediate tensor), out (the output tensor), mid_size (the size of the intermediate tensor), and BLOCK_MID (a constexpr). The third kernel, sum_kernel, is configured for autotuning and takes six arguments: inp (input tensor), out (output tensor), M (number of rows), N (number of columns), BLOCK_M (a constexpr), and BLOCK_N (a constexpr). The function sum calls these kernels to compute the sum of the input tensor elements in two stages and returns the output tensor. The function sum_dim calls the third kernel to compute the sum of input tensor elements along specified dimensions.", - "description_2": "Use triton language to implement kernels for tensor summation with two-stage and dimension-specific summation capabilities.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.math import tanh as _tanh\nfrom triton.language.math import pow\n\n@triton.jit\ndef tanh_forward(x):\n # x: input tensor\n return _tanh(x.to(tl.float32))\n\n@triton.jit\ndef tanh_backward(y, dy):\n # y: output tensor from forward pass\n # dy: gradient of the loss with respect to y\n return dy * (1.0 - pow(y.to(tl.float32), 2))\n\nclass Tanh(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A):\n # ctx: context object to save tensors for backward\n # A: input tensor\n if A.requires_grad is True:\n out = tanh_forward(A.to(torch.float32))\n ctx.save_for_backward(out)\n return out.to(A.dtype)\n else:\n out = tanh_forward(A)\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n # out_grad: gradient of the loss with respect to the output\n (out,) = ctx.saved_tensors\n in_grad = tanh_backward(out, out_grad)\n return in_grad\n\ndef tanh(A):\n # A: input tensor\n return Tanh.apply(A)\n", - "description_1": "Use triton language to implement two kernels, `tanh_forward` with 1 parameter (x: input tensor), and `tanh_backward` with 2 parameters (y: output tensor from forward pass, dy: gradient of the loss with respect to y). The `tanh_forward` kernel computes the hyperbolic tangent of the input tensor, while `tanh_backward` computes the gradient of the hyperbolic tangent function for backpropagation. A custom autograd function `Tanh` is also implemented, with `forward` and `backward` methods to utilize these kernels, having parameters ctx (context object), A (input tensor), and out_grad (gradient of the loss with respect to the output).", - "description_2": "Use triton language to create a custom activation function utilizing two kernels: one to compute the hyperbolic tangent of an input tensor, and another to compute the corresponding gradient for backpropagation. Integrate these kernels with PyTorch's autograd functionality.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n# Kernel implementation\n@triton.jit\ndef _jit_function(\n in0_ptr: tl.tensor, # of tl.pointer_type\n out0_ptr: tl.tensor, # of tl.pointer_type\n in0_stride0: int, in0_stride1: int, # strides for in0\n out0_stride0: int, out0_stride1: int, # strides for out0\n s0: int, s1: int, # task_space\n in_s0: int, in_s1: int, # task_space2 used when input and output tensor has different shape\n num_tasks: int,\n tiles_per_cta,\n tile_size: tl.constexpr,\n one_tile_per_cta: tl.constexpr,\n):\n # task id & masking\n pid = tl.program_id(0)\n num_ctas = tl.num_programs(0)\n init_tid = pid * tile_size + tl.arange(0, tile_size)\n\n if one_tile_per_cta: # monolithic kernel style\n tid = init_tid\n mask = tid < num_tasks\n\n # multi index reconstruction\n i1 = tid % s1\n tid //= s1\n i0 = tid\n\n # loads\n in0 = tl.load(in0_ptr + (i0 % in_s0) * in0_stride0 + (i1 % in_s1) * in0_stride1, mask=mask)\n\n # compute\n out0 = in0\n\n # stores\n tl.store(out0_ptr + i0 * out0_stride0 + i1 * out0_stride1, out0, mask=mask)\n\n else: # grid-stride-loop style kernel\n for j in range(0, tiles_per_cta):\n tid = init_tid + j * tile_size * num_ctas\n mask = tid < num_tasks\n\n # multi index reconstruction\n i1 = tid % s1\n tid //= s1\n i0 = tid\n\n # loads\n in0 = tl.load(in0_ptr + (i0 % in_s0) * in0_stride0 + (i1 % in_s1) * in0_stride1, mask=mask)\n\n # compute\n out0 = in0\n\n # stores\n tl.store(out0_ptr + i0 * out0_stride0 + i1 * out0_stride1, out0, mask=mask)\n\n# Tile function invocation\ndef _wrapper(in0, dims):\n in0_rank = in0.dim()\n dims_rank = len(dims)\n in0_shape = list(in0.shape)\n dims_shape = list(dims)\n\n if dims_rank < in0_rank:\n diff = in0_rank - dims_rank\n ones = [1 for _ in range(diff)]\n dims_shape = ones + dims_shape\n elif dims_rank > in0_rank:\n diff = dims_rank - in0_rank\n ones = [1 for _ in range(diff)]\n in0_shape = ones + in0_shape\n\n is_empty = False\n out_shape = []\n for i in range(len(in0_shape)):\n assert dims_shape[i] >= 0, 'the number of repetitions per dimension out of range (expected to >= 0) but got {}'.format(dims_shape[i])\n if dims_shape[i] == 0:\n is_empty = True\n out_shape.append(in0_shape[i] * dims_shape[i])\n\n out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)\n in0 = in0.reshape(in0_shape)\n\n if not is_empty:\n out0 = _wrapper_out(in0, out0)\n\n return out0\n\n\ndef _wrapper_out(in0, out0):\n shape = out0.shape\n num_tasks = shape[0] * shape[1] # volume(shape)\n tile_size = min(512, triton.next_power_of_2(num_tasks))\n num_warps = 4\n num_ctas = min(65535, (num_tasks + tile_size - 1) // tile_size) # cdiv\n tiles_per_cta = (num_tasks + tile_size * num_ctas - 1) // (tile_size * num_ctas)\n grid = (num_ctas, 1, 1)\n\n # strides of each tensor argument w.r.t the task space\n in0_strides = in0.stride()\n out0_strides = out0.stride()\n\n with torch.cuda.device(in0.device.index):\n _jit_function[grid](\n in0, out0,\n in0_strides[0], in0_strides[1], # stride for in0\n out0_strides[0], out0_strides[1], # stride for out0\n shape[0], shape[1], # task indexing space\n in0.shape[0], in0.shape[1], # task indexing space used when input and ouput tensor has different shape\n num_tasks, # num tasks\n tiles_per_cta=tiles_per_cta, # tiles_per_cta\n tile_size=tile_size,\n one_tile_per_cta=(tiles_per_cta == 1),\n num_warps=num_warps,\n )\n return out0\n", - "description_1": "Use triton language to implement a kernel for tensor tiling. The kernel takes two tensor pointers (input and output), strides for input and output tensors, task space dimensions, tile size, number of tasks, and other parameters. It computes a tiled version of the input tensor into the output tensor using either a monolithic kernel style or grid-stride-loop style based on the tiling configuration.", - "description_2": "Use triton language to perform tensor tiling by computing a tiled version of the input tensor into the output tensor, configurable with task space and tile size parameters.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport triton.language.core as core\n\n\n@triton.jit\ndef _get_finfo_val(\n dtype,\n return_max,\n):\n if dtype is tl.float32:\n if return_max:\n return torch.finfo(torch.float32).max\n else:\n return torch.finfo(torch.float32).min\n elif dtype is tl.float16:\n if return_max:\n return torch.finfo(torch.float16).max\n else:\n return torch.finfo(torch.float16).min\n elif dtype is tl.bfloat16:\n if return_max:\n return torch.finfo(torch.bfloat16).max\n else:\n return torch.finfo(torch.bfloat16).min\n\n\n@triton.jit\ndef topk_stage1_kernel(\n y_ptr,\n index_ptr,\n x_ptr,\n k,\n N: tl.constexpr,\n CHUNK_SIZE: tl.constexpr,\n DESCENDING: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_chunk_idx = tl.program_id(1)\n chunk_num = tl.num_programs(1)\n\n y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k\n index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k\n\n chunk_offset = cur_chunk_idx * CHUNK_SIZE\n x_ptr += cur_batch * N + chunk_offset\n\n cols = tl.arange(0, CHUNK_SIZE)\n mask = (chunk_offset + cols) < N\n\n mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)\n x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)\n for k_idx in range(k):\n if DESCENDING:\n chunk_select_val = tl.max(x_val)\n chunk_select_idx = tl.argmax(x_val, axis=0)\n else:\n chunk_select_val = tl.min(x_val)\n chunk_select_idx = tl.argmin(x_val, axis=0)\n\n tl.store(y_ptr + k_idx, chunk_select_val)\n tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)\n\n if DESCENDING:\n x_val = tl.where(\n cols == chunk_select_idx,\n _get_finfo_val(tl.float32, return_max=False),\n x_val,\n )\n else:\n x_val = tl.where(\n cols == chunk_select_idx,\n _get_finfo_val(tl.float32, return_max=True),\n x_val,\n )\n\n\n@triton.jit\ndef argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):\n _dim: core.constexpr = dim\n n_dims: core.constexpr = (x.shape[_dim]).bit_length() - 1\n for i in range(1, n_dims + 1):\n x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)\n return x, ids\n\n\n@triton.jit\ndef topk_stage2_kernel(\n y_ptr,\n index_ptr,\n chunk_x,\n chunk_index,\n sort_dim: tl.constexpr,\n k: tl.constexpr,\n N: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n DESCENDING: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n chunk_x += cur_batch * N\n chunk_index += cur_batch * N\n y_ptr += cur_batch * k\n index_ptr += cur_batch * k\n\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n\n mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)\n mask_index_val = torch.iinfo(torch.int32).min if DESCENDING else torch.iinfo(torch.int32).max\n\n chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)\n chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(\n tl.int32\n )\n\n sorted_chunk_x, sorted_chunk_index = argsort(\n chunk_x_val, chunk_index_val, 0, descending=DESCENDING\n )\n tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)\n tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)\n\n\ndef topk(x, k, dim=-1, largest=True, sorted=True):\n if dim < 0:\n dim = dim + x.ndim\n\n descending = True\n if not largest:\n descending = False\n\n topk_elem_cnt = x.shape[dim]\n batch_size = (x.numel() // topk_elem_cnt)\n\n if topk_elem_cnt < 1024:\n chunk_size = 256\n else:\n chunk_size = 1024\n\n if chunk_size < k:\n chunk_size = triton.next_power_of_2(k)\n\n chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)\n\n stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)\n stage1_out_idx = torch.empty(\n batch_size * chunk_num * k, device=x.device, dtype=torch.int64\n )\n\n out_shape = x.shape[:-1] + (k,)\n stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)\n stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)\n\n with torch.cuda.device(x.device):\n topk_stage1_kernel[\n batch_size,\n chunk_num,\n ](\n stage1_out,\n stage1_out_idx,\n x,\n k,\n topk_elem_cnt,\n chunk_size,\n descending,\n )\n stage2_elem_cnt = chunk_num * k\n BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)\n\n with torch.cuda.device(x.device):\n topk_stage2_kernel[batch_size,](\n stage2_out,\n stage2_out_idx,\n stage1_out,\n stage1_out_idx,\n dim,\n k,\n stage2_elem_cnt,\n BLOCK_SIZE,\n descending,\n )\n\n return (stage2_out, stage2_out_idx)\n", - "description_1": "Use triton language to implement a top-k operation on a batch of vectors. The top-k operation consists of two stages: The first stage processes the input in chunks and retrieves local top-k values and their indices using topk_stage1_kernel, which uses parameters y_ptr (output pointer for values), index_ptr (output pointer for indices), x_ptr (input pointer), k (number of top elements), N (input length), CHUNK_SIZE, and DESCENDING (order type). The second stage uses topk_stage2_kernel to combine local top-k results into final top-k results, using argsort for sorting, with parameters y_ptr (output pointer for values), index_ptr (output pointer for indices), chunk_x (pointer to chunk values), chunk_index (pointer to chunk indices), sort_dim, k, N (combined chunk size), BLOCK_SIZE, and DESCENDING.", - "description_2": "Use triton language to implement two-stage top-k sorting on input data by processing chunks and then merging results for final top-k values and indices.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom ..utils.shape_utils import can_use_int32_index\n\ndef cfggen():\n warps = [1, 2, 4, 8, 16, 32]\n configs = [\n triton.Config({\"M_BLOCK_SIZE\": 1, \"N_BLOCK_SIZE\": 2048}, num_warps=w)\n for w in warps\n ]\n return configs\n\ndef cfggen_batch():\n warps = [1, 2, 4, 8, 16, 32]\n configs = [\n triton.Config({\"BATCH_BLOCK_SIZE\": 1, \"MN_BLOCK_SIZE\": 512}, num_warps=w)\n for w in warps\n ]\n return configs\n\n@triton.jit(do_not_specialize=[\"diagonal\"])\ndef triu_kernel(\n X,\n Y,\n M,\n N,\n diagonal,\n M_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n INT64_INDEX: tl.constexpr = False,\n):\n pid = tl.program_id(0)\n if INT64_INDEX:\n pid = pid.to(tl.int64)\n row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None]\n m_mask = row < M\n X += row * N\n Y += row * N\n\n for n_offset in range(0, N, N_BLOCK_SIZE):\n cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :]\n n_mask = cols < N\n mask = m_mask and n_mask\n\n x = tl.load(X + cols, mask, other=0.0)\n y = tl.where(row + diagonal <= cols, x, 0.0)\n tl.store(Y + cols, y, mask=mask)\n\n@triton.jit(do_not_specialize=[\"diagonal\"])\ndef triu_batch_kernel(\n X,\n Y,\n batch,\n MN,\n N,\n diagonal,\n BATCH_BLOCK_SIZE: tl.constexpr,\n MN_BLOCK_SIZE: tl.constexpr,\n INT64_INDEX: tl.constexpr = False,\n):\n batch_id = tl.program_id(0)\n mn_id = tl.program_id(1)\n if INT64_INDEX:\n batch_id = batch_id.to(tl.int64)\n mn_id = mn_id.to(tl.int64)\n row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None]\n batch_mask = row < batch\n X += row * MN\n Y += row * MN\n\n cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :]\n mn_mask = cols < MN\n mask = batch_mask and mn_mask\n x = tl.load(X + cols, mask, other=0.0)\n m = cols // N\n n = cols % N\n y = tl.where(m + diagonal <= n, x, 0.0)\n tl.store(Y + cols, y, mask=mask)\n\ndef triu(A, diagonal=0):\n A = A.contiguous()\n out = torch.empty_like(A)\n assert len(A.shape) > 1, \"Input tensor must have at least 2 dimensions\"\n use_int64_index = not can_use_int32_index(A)\n M, N = A.shape[-2:]\n with torch.cuda.device(A.device):\n if len(A.shape) == 2:\n grid = lambda meta: (triton.cdiv(M, meta[\"M_BLOCK_SIZE\"]),)\n triu_kernel[grid](A, out, M, N, diagonal, INT64_INDEX=use_int64_index)\n else:\n batch = int(torch.numel(A) / M / N)\n B = A.view(batch, -1)\n grid = lambda meta: (\n triton.cdiv(batch, meta[\"BATCH_BLOCK_SIZE\"]),\n triton.cdiv(M * N, meta[\"MN_BLOCK_SIZE\"]),\n )\n triu_batch_kernel[grid](\n B, out, batch, M * N, N, diagonal, INT64_INDEX=use_int64_index\n )\n out = out.view(A.shape)\n return out\n", - "description_1": "Use triton language to implement two kernels: 'triu_kernel' and 'triu_batch_kernel'. 'triu_kernel' takes 8 parameters: X (input tensor), Y (output tensor), M (number of rows), N (number of columns), diagonal (offset for diagonal), M_BLOCK_SIZE (block size for rows), N_BLOCK_SIZE (block size for columns), and INT64_INDEX (flag for index type). It computes the upper triangular part of a matrix. 'triu_batch_kernel' takes 9 parameters: X (input tensor), Y (output tensor), batch (number of batches), MN (total elements in a batch), N (number of columns), diagonal (offset for diagonal), BATCH_BLOCK_SIZE (block size for batches), MN_BLOCK_SIZE (block size for elements), and INT64_INDEX (flag for index type). It computes the upper triangular part of a batch of matrices. The 'triu' function calls these kernels based on the input tensor's dimensions.", - "description_2": "Use triton language to create kernels for computing the upper triangular part of matrices and batches of matrices, with support for configurable block sizes and index types.", - "difficulty": 3 - }, - { - "code": "import logging\nimport torch\nimport triton\nimport triton.language as tl\nfrom flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float\nfrom flag_gems.utils.shape_utils import volume\n\ndef heur_block(args):\n if args[\"N\"] <= 512:\n return 512\n else:\n return 1024\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 512:\n return 4\n elif args[\"N\"] <= 1024:\n return 8\n else:\n return 16\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"philox_seed\", \"philox_offset\"])\ndef uniform_kernel(\n out_ptr,\n N,\n philox_seed,\n philox_offset,\n from_,\n to,\n BLOCK: tl.constexpr,\n):\n # Convert philox_seed and philox_offset to 64-bit integers\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n # Calculate counter values for Philox RNG\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n # Generate random numbers using Philox RNG\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n # Convert random numbers to uniform floats in the range [from_, to]\n r0 = uint_to_uniform_float(r0) * (to - from_) + from_\n r1 = uint_to_uniform_float(r1) * (to - from_) + from_\n r2 = uint_to_uniform_float(r2) * (to - from_) + from_\n r3 = uint_to_uniform_float(r3) * (to - from_) + from_\n # Calculate offsets for storing results\n off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n # Store the results in the output pointer\n tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\nUNROLL = 4\n\ndef uniform_(self, from_=0.0, to=1.0, *, generator=None):\n logging.debug(\"GEMS UNIFORM\")\n N = volume(self.shape)\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n increment = triton.cdiv(N, UNROLL)\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n with torch.cuda.device(self.device):\n uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to)\n return self\n", - "description_1": "Use triton language to implement a kernel that generates uniform random numbers using the Philox algorithm. The kernel takes 6 parameters: out_ptr (output pointer), N (number of elements), philox_seed (seed for RNG), philox_offset (offset for RNG), from_ (lower bound of uniform distribution), and to (upper bound of uniform distribution). The kernel uses heuristics to determine the block size and number of warps, and stores the generated random numbers in the output pointer.", - "description_2": "Use triton language to create a function that fills a tensor with uniform random numbers in a specified range using the Philox RNG algorithm.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef simple_unique_flat_kernel(\n sorted_data_ptr: tl.tensor,\n sorted_indices_ptr: tl.tensor, # in\n data_out_ptr: tl.tensor,\n inverse_indices_ptr: tl.tensor,\n idx_ptr: tl.tensor,\n unique_size_ptr: tl.tensor, # out\n return_inverse: tl.constexpr,\n return_counts: tl.constexpr,\n num_tasks: int,\n tile_size: tl.constexpr,\n):\n i0 = tl.arange(0, tile_size)\n mask = i0 < num_tasks\n\n # load\n a = tl.load(sorted_data_ptr + i0, mask=mask)\n i0_prev = tl.where(i0 > 0, i0 - 1, 0)\n b = tl.load(sorted_data_ptr + i0_prev, mask=mask)\n\n # ne & cumsum\n ne_result = tl.where(i0 > 0, a != b, 0)\n cumsum = tl.cumsum(ne_result)\n\n # unique_size\n unique_size_mask = i0 == tile_size - 1\n tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask)\n\n # data_out: scatter_(to=cumsum, sorted_data)\n tl.store(data_out_ptr + cumsum, a, mask=mask)\n\n # inverse_indices: scatter_(to=sorted_indices, cumsum)\n if return_inverse:\n sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask)\n tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask)\n\n # idx\n if return_counts:\n idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask\n tl.store(idx_ptr + cumsum, i0, mask=idx_mask)\n\n\n@triton.jit\ndef output_counts_flat_impl(\n global_pid,\n idx_ptr: tl.tensor,\n origin_num_tasks: int, # in\n counts_ptr: tl.tensor, # out\n num_tasks: int,\n tile_size: tl.constexpr,\n):\n r = tl.arange(0, tile_size)\n\n # load idx\n i0 = global_pid * tile_size + r\n mask = i0 < num_tasks\n idx = tl.load(idx_ptr + i0, mask=mask)\n\n # load idx_next\n i0_next = i0 + 1\n next_mask = i0_next < num_tasks\n idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)\n\n # diff\n counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)\n\n # store counts\n tl.store(counts_ptr + i0, counts, mask=mask)\n\n\n@triton.jit\ndef local_ne_flat_impl(\n global_pid,\n sorted_data_ptr: tl.tensor, # in\n ne_result_ptr: tl.tensor,\n tile_sum_ptr: tl.tensor, # out\n global_ctas_num: int,\n num_tasks: int,\n tile_size: tl.constexpr,\n):\n r = tl.arange(0, tile_size)\n i0 = global_pid * tile_size + r\n mask = i0 < num_tasks\n i0_prev = tl.where(i0 > 0, i0 - 1, 0)\n\n # load\n a = tl.load(sorted_data_ptr + i0, mask=mask)\n b = tl.load(sorted_data_ptr + i0_prev, mask=mask)\n\n # compute\n ne_result = tl.where(i0 > 0, a != b, 0)\n\n # store ne_result\n tl.store(ne_result_ptr + i0, ne_result, mask=mask)\n\n # store tile_sum\n tile_sum = tl.sum(ne_result)\n tile_sum_mask = global_pid < global_ctas_num\n tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask)\n\n\ndef simple_unique_flat(\n sorted_data: torch.Tensor,\n sorted_indices: torch.Tensor,\n return_inverse: bool,\n return_counts: bool,\n):\n num_tasks = sorted_data.numel()\n grid = (1, 1, 1)\n\n # allocate tensor\n data_out = torch.empty_like(sorted_data)\n if return_inverse:\n inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)\n else:\n inverse_indices = None\n if return_counts:\n idx = torch.empty_like(sorted_data, dtype=torch.int64)\n else:\n idx = None\n unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device)\n\n # launch kernel\n with torch.cuda.device(sorted_data.device.index):\n simple_unique_flat_kernel[grid](\n sorted_data,\n sorted_indices, # in\n data_out,\n inverse_indices,\n idx,\n unique_size, # out\n return_inverse,\n return_counts,\n num_tasks,\n tile_size=triton.next_power_of_2(num_tasks),\n num_warps=8,\n )\n out_size = unique_size.item() + 1\n counts = None\n if return_counts:\n idx = idx[:out_size]\n counts = torch.empty_like(idx)\n with torch.cuda.device(sorted_data.device.index):\n output_counts_flat_kernel[grid](\n idx,\n num_tasks, # in\n counts, # out\n num_tasks=out_size,\n tiles_per_cta=1,\n tile_size=triton.next_power_of_2(out_size),\n num_warps=8,\n )\n return data_out[:out_size], inverse_indices, counts\n", - "description_1": "Use triton language to implement unique and counting operations on sorted tensor data. The kernel functions manage tensor pointers and constants, load data, compute unique elements and cumulative sums, scatter results to output tensors, and store unique sizes. These functions handle tasks according to mask conditions based on input tensor sizes, return conditions for inverse indices and counts, and launch kernels on specified grids.", - "description_2": "Use triton language to compute unique elements and counts from sorted data. Functions handle tensor loading, processing, and storing to manage tasks based on conditions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef cfggen():\n block_m = [1, 2, 4, 8]\n block_n = [1024, 2048]\n warps = [4, 8, 16]\n configs = [\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": n}, num_warps=w)\n for m in block_m\n for n in block_n\n for w in warps\n ]\n return configs\n\n@triton.jit\ndef welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):\n count = count_x + count_y\n _count = tl.maximum(count, 1)\n mc_x = mean_x * count_x\n mc_y = mean_y * count_y\n mean = (mc_x + mc_y) / _count\n M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean\n return mean, count, M\n\n@triton.jit(do_not_specialize=[\"correction\"])\ndef var_mean_welford_kernel(\n X,\n Var,\n Mean,\n M,\n N,\n correction,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Var = Var + pid\n Mean = Mean + pid\n row_mask = pid < M\n\n _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n\n x = tl.load(X + cols, mask, other=0.0).to(tl.float32)\n\n count = _count + mask\n cnt = tl.maximum(count, 1)\n cur_mean = (_mean * _count + x) / cnt\n _acc += (x - cur_mean) * (x - _mean) * mask\n _mean = cur_mean\n _count = count\n\n mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func)\n var = acc / (N - correction)\n mean = mean[:, None]\n var = var[:, None]\n tl.store(Mean, mean, row_mask)\n tl.store(Var, var, row_mask)\n\n@triton.jit\ndef var_mean_kernel_1(\n X,\n Acc,\n Average,\n Count,\n N,\n BLOCK_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_N + tl.arange(0, BLOCK_N)\n\n X = X + offset\n Acc = Acc + pid\n Average = Average + pid\n Count = Count + pid\n mask = offset < N\n\n x = tl.load(X, mask, other=0.0).to(tl.float32)\n\n count = tl.sum(mask.to(tl.float32))\n average = tl.sum(x) / count\n acc = tl.sum(x * x) - count * average * average\n\n tl.store(Average, average)\n tl.store(Acc, acc)\n tl.store(Count, count)\n\n@triton.jit(do_not_specialize=[\"correction\"])\ndef var_mean_kernel_2(\n Acc,\n Average,\n Count,\n Var,\n Mean,\n N,\n correction,\n BLOCK_NUM,\n BLOCK_N: tl.constexpr,\n):\n offset = tl.arange(0, BLOCK_N)\n mask = offset < BLOCK_NUM\n Acc = Acc + offset\n Average = Average + offset\n Count = Count + offset\n acc = tl.load(Acc, mask, other=0.0).to(tl.float32)\n average = tl.load(Average, mask, other=0.0).to(tl.float32)\n count = tl.load(Count, mask, other=0.0).to(tl.float32)\n\n mean, _, nvar = tl.reduce((average, count, acc), axis=0, combine_fn=welford_func)\n\n var = nvar / (N - correction)\n tl.store(Mean, mean)\n tl.store(Var, var)\n\ndef var_mean(x, dim=None, *, correction=None, keepdim=False):\n if correction is None:\n correction = 1.0\n\n if dim is None or len(dim) == x.ndim:\n dim = list(range(x.ndim))\n shape = [1] * x.ndim\n N = x.numel()\n var = torch.empty(shape, dtype=x.dtype, device=x.device)\n mean = torch.empty(shape, dtype=x.dtype, device=x.device)\n BLOCK_N = 1024\n BLOCK_NUM = triton.cdiv(N, BLOCK_N)\n acc = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)\n average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)\n count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device)\n\n with torch.cuda.device(x.device):\n var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N)\n var_mean_kernel_2[(1,)](\n acc, average, count, var, mean, N, correction, BLOCK_NUM\n )\n else:\n shape = list(x.shape)\n dim = [d % x.ndim for d in dim]\n x = dim_compress(x, dim)\n N = 1\n for i in dim:\n N *= shape[i]\n shape[i] = 1\n M = x.numel() // N\n var = torch.empty(shape, dtype=x.dtype, device=x.device)\n mean = torch.empty(shape, dtype=x.dtype, device=x.device)\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]),)\n with torch.cuda.device(x.device):\n var_mean_welford_kernel[grid](x, var, mean, M, N, correction)\n\n if not keepdim:\n var = var.squeeze(dim=dim)\n mean = mean.squeeze(dim=dim)\n return var, mean\n", - "description_1": "Use triton language to implement three kernels: welford_func, var_mean_welford_kernel, and var_mean_kernel_1. The welford_func kernel takes six parameters: mean_x, count_x, M_x, mean_y, count_y, and M_y, and computes the combined mean, count, and M. The var_mean_welford_kernel takes nine parameters: X, Var, Mean, M, N, correction, BLOCK_M, and BLOCK_N, and computes the variance and mean of input X using a Welford's method. The var_mean_kernel_1 takes six parameters: X, Acc, Average, Count, N, and BLOCK_N, and computes the sum of squares, average, and count of input X.", - "description_2": "Use triton language to implement a kernel var_mean_kernel_2 that takes eight parameters: Acc, Average, Count, Var, Mean, N, correction, BLOCK_NUM, and BLOCK_N, and computes the variance and mean using a reduction operation. Implement a Python function var_mean that calls these kernels to compute variance and mean of a tensor x along specified dimensions.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\ndef cfggen():\n block_m = [1, 2, 4, 8]\n configs = [\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": 1024}, num_warps=4) for m in block_m\n ]\n return configs\n\n# L2 norm kernel\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit\ndef l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Out = Out + pid\n row_mask = pid < M\n\n _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask & col_mask\n\n a = tl.load(X + cols, mask, other=0.0).to(tl.float32)\n _sum += a * a\n sum = tl.sum(_sum, axis=1)\n\n out = tl.sqrt(sum)[:, None]\n tl.store(Out, out, row_mask)\n\n# Max norm kernel\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit\ndef max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Out = Out + pid\n row_mask = pid < M\n\n _max = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask & col_mask\n\n a = tl.load(X + cols, mask, other=0.0).to(tl.float32)\n _max = tl.maximum(tl.abs(a), _max)\n\n max = tl.max(_max, axis=1)\n out = max[:, None]\n tl.store(Out, out, row_mask)\n\n# Min norm kernel\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit\ndef min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Out = Out + pid\n row_mask = pid < M\n\n _min = tl.full([BLOCK_M, BLOCK_N], value=float(\"inf\"), dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask & col_mask\n\n a = tl.load(X + cols, mask, other=float(\"inf\")).to(tl.float32)\n _min = tl.minimum(tl.abs(a), _min)\n\n min = tl.min(_min, axis=1)\n out = min[:, None]\n tl.store(Out, out, row_mask)\n\n# L0 norm kernel\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit\ndef l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Out = Out + pid\n row_mask = pid < M\n\n _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask & col_mask\n\n a = tl.load(X + cols, mask, other=0).to(tl.float32)\n _sum += tl.where(a != 0, 1, 0)\n sum = tl.sum(_sum, axis=1)\n out = sum[:, None]\n tl.store(Out, out, row_mask)\n\n# V norm kernel\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit(do_not_specialize=[\"ord\"])\ndef v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n X = X + pid * N\n Out = Out + pid\n row_mask = pid < M\n\n _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask & col_mask\n\n a = tl.load(X + cols, mask, other=0.0).to(tl.float32)\n _sum += pow(tl.abs(a), ord)\n sum = tl.sum(_sum, axis=1)\n out = pow(sum, 1 / ord)[:, None]\n tl.store(Out, out, row_mask)\n\ndef vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):\n if dtype is not None:\n dtype = torch.dtype(dtype)\n else:\n dtype = x.dtype\n if dtype not in [torch.float16, torch.float32, torch.bfloat16]:\n raise NotImplementedError(f\"vector_norm not implemented for {dtype}\")\n\n with torch.cuda.device(x.device):\n if dim is None or len(dim) == x.ndim:\n dim = list(range(x.ndim))\n shape = [1] * x.ndim\n x = dim_compress(x, dim)\n M = x.numel()\n BLOCK_SIZE = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n MID_SIZE = triton.cdiv(M, BLOCK_SIZE)\n BLOCK_MID = triton.next_power_of_2(MID_SIZE)\n\n mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device)\n out = torch.empty(shape, dtype=dtype, device=x.device)\n if ord == 2:\n l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)\n l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)\n elif ord == float(\"inf\"):\n max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)\n max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)\n elif ord == -float(\"inf\"):\n min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)\n min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)\n elif ord == 0:\n l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)\n l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)\n else:\n l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE)\n l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID)\n else:\n shape = list(x.shape)\n dim = [d % x.ndim for d in dim]\n x = dim_compress(x, dim)\n N = 1\n for i in dim:\n N *= shape[i]\n shape[i] = 1\n M = x.numel() // N\n out = torch.empty(shape, dtype=dtype, device=x.device)\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]),)\n if ord == 2:\n l2_norm_kernel[grid](x, out, M, N)\n elif ord == float(\"inf\"):\n max_norm_kernel[grid](x, out, M, N)\n elif ord == -float(\"inf\"):\n min_norm_kernel[grid](x, out, M, N)\n elif ord == 0:\n l0_norm_kernel[grid](x, out, M, N)\n else:\n v_norm_kernel[grid](x, out, M, N, ord)\n if not keepdim:\n out = out.squeeze(dim=dim)\n return out\n", - "description_1": "Use triton language to implement a series of norm kernels including L2, L0, L1, max, and min norms. Each kernel has specific parameters: input X, output Out, dimensions M and N, and block sizes BLOCK_M and BLOCK_N. The kernels use triton's parallel computing capabilities to efficiently calculate norms on matrices.", - "description_2": "Use triton language to create and use parallelizable kernels for computing various matrix norms (L2, L0, L1, max, and min) with specified dimensions and block sizes.", - "difficulty": 3 - }, - { - "code": "import logging\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef vstack_kernel(\n itensor_ptr0,\n itensor_ptr1,\n itensor_ptr2,\n itensor_ptr3,\n output_ptr,\n local_row0,\n local_row1,\n local_row2,\n local_row3,\n exc_row_offset0,\n exc_row_offset1,\n exc_row_offset2,\n exc_row_offset3,\n total_row_offset,\n row_stride,\n max_tile_elems,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_x = tl.program_id(axis=0)\n tensor_idx = tl.program_id(axis=1)\n col_idx = tl.arange(0, BLOCK_SIZE)\n\n intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1)\n intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr)\n intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr)\n base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1)\n base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx)\n base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx)\n local_row = tl.where(tensor_idx == 0, local_row0, local_row1)\n local_row = tl.where(tensor_idx == 2, local_row2, local_row)\n local_row = tl.where(tensor_idx == 3, local_row3, local_row)\n\n end_idx = local_row * row_stride.to(tl.int64)\n idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64)\n offset_mask = idx < end_idx\n in_offset = intensor_ptr + idx\n row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64)\n out_offset = output_ptr + row_stride_offset + idx\n out = tl.load(in_offset, mask=offset_mask)\n tl.store(out_offset, out, mask=offset_mask)\n\n\ndef vstack(tensors: list):\n logging.debug(\"GEMS VSTACK\")\n\n tensors = torch.atleast_2d(tensors)\n num_tensors = len(tensors)\n assert num_tensors > 0\n\n device = tensors[0].device\n dtype = tensors[0].dtype\n for tensor in tensors:\n assert (\n tensor.device == device\n and tensor.dtype == dtype\n and tensors[0].shape[1:] == tensor.shape[1:]\n )\n\n c_tensors = [t.contiguous() for t in tensors]\n total_rows = sum(tensor.shape[0] for tensor in c_tensors)\n output_shape = list(c_tensors[0].shape)\n output_shape[0] = total_rows\n output = torch.empty(output_shape, device=device, dtype=dtype)\n row_stride = c_tensors[0].stride(0)\n\n outer_iters = triton.cdiv(num_tensors, 4)\n total_row_offset = 0\n for i in range(outer_iters):\n max_rows = 1\n itensors = []\n exclusive_row = []\n local_row = []\n array_row_offset = 0\n scheduled_num_tensors = 0\n for j in range(4):\n tensor_idx = i * 4 + j\n if tensor_idx < num_tensors:\n scheduled_num_tensors += 1\n itensors.append(c_tensors[tensor_idx])\n local_row.append(c_tensors[tensor_idx].shape[0])\n exclusive_row.append(array_row_offset)\n array_row_offset += c_tensors[tensor_idx].shape[0]\n max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])\n else:\n empty_tensor = torch.empty(\n 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device\n )\n itensors.append(empty_tensor)\n local_row.append(local_row[-1])\n exclusive_row.append(exclusive_row[-1])\n max_tile_elems = max_rows * row_stride\n grid = lambda META: (\n triton.cdiv(max_tile_elems, META[\"BLOCK_SIZE\"]),\n scheduled_num_tensors,\n )\n with torch.cuda.device(c_tensors[0].device):\n vstack_kernel[grid](\n itensors[0],\n itensors[1],\n itensors[2],\n itensors[3],\n output,\n local_row[0],\n local_row[1],\n local_row[2],\n local_row[3],\n exclusive_row[0],\n exclusive_row[1],\n exclusive_row[2],\n exclusive_row[3],\n total_row_offset,\n row_stride,\n max_tile_elems,\n )\n total_row_offset += array_row_offset\n return output\n", - "description_1": "Use triton language to define a kernel 'vstack_kernel' that vertically stacks multiple tensors into an output tensor. It takes 17 parameters: four input tensor pointers, one output pointer, four local row values, four exclusive row offsets, a total row offset, a row stride, a maximum tile elements integer, and a block size constant. The function uses Triton's parallel programming features to load elements from input tensors and store them into the output tensor, considering provided row offsets and strides. Additionally, use a wrapper function 'vstack' to prepare and launch this kernel with a given list of tensors, making sure they all have the same device and dtype.", - "description_2": "Use triton language to create a kernel that efficiently stacks multiple 2D tensors vertically into a single output tensor, handling tensors on the same device with the same dtype, and considering row offsets and strides.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\ndef cfggen_first():\n block_m = [1, 2, 4, 8, 32]\n block_n = [512, 1024, 2048]\n warps = [4, 8, 16]\n configs = [\n triton.Config({\"BLOCK_ROW_SIZE\": m, \"BLOCK_COL_SIZE\": n}, num_warps=w)\n for m in block_m\n for n in block_n\n for w in warps\n ]\n return configs\n\ndef cfggen_last():\n block_m = [512, 1024, 2048]\n block_n = [1, 2, 4, 8, 32]\n warps = [4, 8, 16]\n configs = [\n triton.Config({\"BLOCK_ROW_SIZE\": m, \"BLOCK_COL_SIZE\": n}, num_warps=w)\n for m in block_m\n for n in block_n\n for w in warps\n ]\n return configs\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef weight_norm_kernel_last(\n output,\n norm,\n v,\n g,\n M,\n N,\n eps,\n BLOCK_ROW_SIZE: tl.constexpr,\n BLOCK_COL_SIZE: tl.constexpr,\n):\n tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]\n bx = tl.program_id(axis=0) * BLOCK_COL_SIZE\n col_offset = bx + tx\n col_mask = col_offset < N\n\n ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]\n v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)\n for base in range(0, M, BLOCK_ROW_SIZE):\n row_offset = base + ty\n mask = row_offset < M and col_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_block += v_value * v_value\n\n normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)\n tl.store(norm + col_offset, normalized[:, None], mask=col_mask)\n g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)\n\n for base in range(0, M, BLOCK_ROW_SIZE):\n row_offset = base + ty\n mask = row_offset < M and col_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_vec = v_value / normalized[:, None]\n out = v_vec * g_value\n tl.store(output + row_offset * N + col_offset, out, mask=mask)\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef weight_norm_kernel_first(\n output,\n norm,\n v,\n g,\n M,\n N,\n eps,\n BLOCK_ROW_SIZE: tl.constexpr,\n BLOCK_COL_SIZE: tl.constexpr,\n):\n ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]\n by = tl.program_id(axis=0) * BLOCK_ROW_SIZE\n row_offset = by + ty\n row_mask = row_offset < M\n\n tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]\n v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)\n for base in range(0, N, BLOCK_COL_SIZE):\n col_offset = base + tx\n mask = col_offset < N and row_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_block += v_value * v_value\n\n normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps)\n tl.store(norm + row_offset, normalized[:, None], mask=row_mask)\n g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)\n\n for base in range(0, N, BLOCK_COL_SIZE):\n col_offset = base + tx\n mask = col_offset < N and row_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_vec = v_value / normalized[:, None]\n out = v_vec * g_value\n tl.store(output + row_offset * N + col_offset, out, mask=mask)\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef weight_norm_bwd_kernel_last(\n v_grad,\n g_grad,\n w,\n v,\n g,\n norm,\n M,\n N,\n eps,\n BLOCK_ROW_SIZE: tl.constexpr,\n BLOCK_COL_SIZE: tl.constexpr,\n):\n tx = tl.arange(0, BLOCK_COL_SIZE)[:, None]\n bx = tl.program_id(axis=0) * BLOCK_COL_SIZE\n col_offset = tx + bx\n col_mask = col_offset < N\n\n g_value = tl.load(g + col_offset, mask=col_mask).to(tl.float32)\n norm_value = tl.load(norm + col_offset, mask=col_mask).to(tl.float32)\n\n ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]\n\n vw_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)\n for base in range(0, M, BLOCK_ROW_SIZE):\n row_offset = base + ty\n mask = row_offset < M and col_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)\n vw_block += v_value * w_value\n vw_sum = tl.sum(vw_block, 1)[:, None]\n\n for base in range(0, M, BLOCK_ROW_SIZE):\n row_offset = base + ty\n mask = row_offset < M and col_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_grad_value = g_value * (\n w_value / (norm_value + eps)\n - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum\n )\n tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)\n\n g_grad_value = vw_sum / (norm_value + eps)\n tl.store(g_grad + col_offset, g_grad_value, mask=col_mask)\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef weight_norm_bwd_kernel_first(\n v_grad,\n g_grad,\n w,\n v,\n g,\n norm,\n M,\n N,\n eps,\n BLOCK_ROW_SIZE: tl.constexpr,\n BLOCK_COL_SIZE: tl.constexpr,\n):\n ty = tl.arange(0, BLOCK_ROW_SIZE)[:, None]\n by = tl.program_id(axis=0) * BLOCK_ROW_SIZE\n row_offset = by + ty\n row_mask = row_offset < M\n\n g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32)\n norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32)\n\n tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]\n\n v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)\n for base in range(0, N, BLOCK_COL_SIZE):\n col_offset = base + tx\n mask = col_offset < N and row_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_block += v_value * w_value\n vw_sum = tl.sum(v_block, 1)[:, None]\n\n for base in range(0, N, BLOCK_COL_SIZE):\n col_offset = base + tx\n mask = col_offset < N and row_mask\n v_value = tl.load(v + row_offset * N + col_offset, mask=mask).to(tl.float32)\n w_value = tl.load(w + row_offset * N + col_offset, mask=mask).to(tl.float32)\n v_grad_value = g_value * (\n w_value / (norm_value + eps)\n - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum\n )\n tl.store(v_grad + row_offset * N + col_offset, v_grad_value, mask=mask)\n\n g_grad_value = vw_sum / (norm_value + eps)\n tl.store(g_grad + row_offset, g_grad_value, mask=row_mask)\n\nclass WeightNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, v, g, dim):\n v = v.contiguous()\n g = g.contiguous()\n output = torch.empty_like(v)\n norm = torch.empty_like(g, dtype=torch.float32)\n if dim == 0:\n M = v.shape[0]\n N = math.prod(v.shape[1:])\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_ROW_SIZE\"]),)\n with torch.cuda.device(v.device):\n weight_norm_kernel_first[grid](\n output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny\n )\n elif dim == len(v.shape) - 1:\n M = math.prod(v.shape[:-1])\n N = v.shape[dim]\n grid = lambda META: (triton.cdiv(N, META[\"BLOCK_COL_SIZE\"]),)\n with torch.cuda.device(v.device):\n weight_norm_kernel_last[grid](\n output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny\n )\n ctx.save_for_backward(v, g, norm)\n ctx.DIM = dim\n return output, norm\n\n @staticmethod\n def backward(ctx, w_grad, norm_grad):\n v, g, norm = ctx.saved_tensors\n dim = ctx.DIM\n v_grad = torch.empty_like(v)\n g_grad = torch.empty_like(g)\n\n if dim == 0:\n M = v.shape[0]\n N = math.prod(v.shape[1:])\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_ROW_SIZE\"]),)\n with torch.cuda.device(v.device):\n weight_norm_bwd_kernel_first[grid](\n v_grad,\n g_grad,\n w_grad,\n v,\n g,\n norm,\n M,\n N,\n eps=torch.finfo(torch.float32).tiny,\n )\n elif dim == len(v.shape) - 1:\n M = math.prod(v.shape[:dim])\n N = v.shape[dim]\n grid = lambda META: (triton.cdiv(N, META[\"BLOCK_COL_SIZE\"]),)\n with torch.cuda.device(v.device):\n weight_norm_bwd_kernel_last[grid](\n v_grad,\n g_grad,\n w_grad,\n v,\n g,\n norm,\n M,\n N,\n eps=torch.finfo(torch.float32).tiny,\n )\n return v_grad, g_grad, None\n\ndef weight_norm(v, g, dim=0):\n return WeightNorm.apply(v, g, dim)\n", - "description_1": "Use triton language to implement weight normalization kernels and backward kernels with parameters: output, norm, v, g, M, N, eps, BLOCK_ROW_SIZE, BLOCK_COL_SIZE for both last and first dimension. The forward function computes the normalized output and norm using the kernels based on the dimension, while the backward function computes gradients for v and g.", - "description_2": "Use triton language to perform weight normalization with gradient computation support.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef where_self_func(condition, self, other):\n # This kernel uses triton's where function to select elements from 'self' or 'other' based on 'condition'.\n return tl.where(condition, self, other)\n\ndef where_self(condition, self, other):\n # Calls the where_self_func kernel with the given condition, self, and other tensors.\n return where_self_func(condition, self, other)\n\n@triton.jit\ndef where_scalar_self_func(condition, self, other):\n # This kernel uses triton's where function to select elements from 'self' or 'other' based on 'condition'.\n return tl.where(condition, self, other)\n\ndef where_scalar_self(condition, self, other):\n # Calls the where_scalar_self_func kernel with the given condition, self, and other tensors.\n return where_scalar_self_func(condition, self, other)\n\n@triton.jit\ndef where_scalar_other_func(condition, self, other):\n # This kernel uses triton's where function to select elements from 'self' or 'other' based on 'condition'.\n return tl.where(condition, self, other)\n\ndef where_scalar_other(condition, self, other):\n # Calls the where_scalar_other_func kernel with the given condition, self, and other tensors.\n return where_scalar_other_func(condition, self, other)\n", - "description_1": "Use triton language to define three kernels: where_self_func, where_scalar_self_func, and where_scalar_other_func. Each kernel takes three parameters: 'condition', 'self', and 'other'. The kernels use triton's tl.where function to select elements from 'self' or 'other' based on the 'condition'. Corresponding wrapper functions where_self, where_scalar_self, and where_scalar_other call these kernels with the same parameters.", - "description_2": "Use triton language to create kernels that perform element selection based on a condition. Implement wrapper functions to call these kernels.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flag_gems.utils.shape_utils import volume\n\n# Triton kernel to set elements to zero\n@triton.jit\ndef zeros_kernel(\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n tl.store(output_ptr + offsets, 0.0, mask=mask)\n\n# Function to call the Triton kernel\ndef zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None):\n if dtype is None:\n dtype = torch.get_default_dtype()\n if device is None:\n device = torch.device(\"cuda\")\n\n out = torch.empty(size, device=device, dtype=dtype)\n N = volume(size)\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK_SIZE\"]),)\n with torch.cuda.device(device):\n zeros_kernel[grid_fn](out, N, BLOCK_SIZE=1024)\n return out\n", - "description_1": "Use triton language to implement a kernel function 'zeros_kernel' that sets elements of an output tensor to zero. The kernel takes three parameters: 'output_ptr' (pointer to the output tensor), 'n_elements' (total number of elements to process), and 'BLOCK_SIZE' (a compile-time constant defining the block size for processing). The function 'zeros' is a wrapper that prepares the output tensor and launches the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel that initializes a tensor to zero and a wrapper function to manage tensor creation and kernel execution.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef example_kernel(x_ptr, y_ptr, n_elements, **meta):\n idx = tl.program_id(axis=0)\n for i in range(idx, n_elements, tl.num_programs(axis=0)):\n x_val = tl.load(x_ptr + i)\n y_val = x_val * 2 # Example operation\n tl.store(y_ptr + i, y_val)\n\ndef call_example_kernel(x, y):\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n example_kernel[grid](x, y, n_elements, BLOCK_SIZE=1024)\n", - "description_1": "Use triton language to implement a kernel function that doubles each element of an input tensor. The kernel processes elements in a parallel manner using a given block size and grid configuration, and stores the results in an output tensor.", - "description_2": "Use triton language to define a kernel that multiplies elements of an input tensor by 2, and a corresponding function to launch this kernel.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel to convert a random uint into a random float uniformly sampled in [0, 1).\n@triton.jit\ndef uint_to_uniform_float(x):\n \"\"\"\n Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).\n \"\"\"\n if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):\n x = x.to(tl.int32, bitcast=True)\n scale = 4.6566127342e-10\n else:\n tl.static_assert(\n tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)\n )\n x = x.to(tl.int64, bitcast=True)\n scale = 1.0842020432385337e-19\n x = tl.where(x < 0, -x - 1, x)\n return x * scale\n\n# Kernel to generate uniform random numbers using Philox RNG.\n@triton.jit\ndef uniform(seed, philox_offset, offset):\n seed = seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = offset\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(seed, c0, c1, _O, _O)\n r0 = uint_to_uniform_float(r0)\n r1 = uint_to_uniform_float(r1)\n r2 = uint_to_uniform_float(r2)\n r3 = uint_to_uniform_float(r3)\n return r0, r1, r2, r3\n", - "description_1": "Use triton language to implement two kernels: one for converting a random uint to a float uniformly sampled in [0, 1) with 1 parameter (x: the input tensor of type uint32/int32 or uint64/int64), and another for generating uniform random numbers using Philox RNG with 3 parameters (seed: the seed for RNG, philox_offset: the offset for Philox RNG, offset: additional offset for RNG).", - "description_2": "Use triton language to create a kernel for converting uint to uniform float and another kernel for generating uniform random numbers using Philox RNG.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef cfggen():\n block_m = [1, 2, 4]\n block_n = [256, 1024, 2048, 4096]\n configs = [\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": n}, num_warps=4)\n for m in block_m\n for n in block_n\n ]\n return configs\n\n@triton.autotune(configs=cfggen(), key=[\"M\", \"N\"])\n@triton.jit\ndef add_on_kernel(\n idx,\n add_on,\n cur_shape,\n cur_strides,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid_x = tl.program_id(axis=0)\n pid_y = tl.program_id(axis=1)\n rows_offset = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]\n rows_mask = rows_offset < M\n\n cols_offset = pid_y + tl.arange(0, BLOCK_N)[None, :]\n cols_mask = cols_offset < N\n block_mask = rows_mask and cols_mask\n\n offsets = rows_offset * N + cols_offset\n cur_idx = tl.load(idx + offsets, mask=block_mask, other=1)\n mod = cur_idx % cur_shape\n res = mod * cur_strides\n tl.store(add_on + offsets, res, mask=block_mask)\n\n\ndef offset_calculator(inp, idx, strides, dim, isInp):\n ndim = inp.ndim\n shape = list(inp.shape)\n offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)\n idx_dim = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)\n for d in range(0, ndim):\n add_on = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)\n N = idx.size(idx.ndim - 1)\n M = idx.numel() // N\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n triton.cdiv(N, meta[\"BLOCK_N\"]),\n )\n add_on_kernel[grid](idx, add_on, shape[d], strides[d], M, N)\n\n offsets = torch.add(offsets, add_on)\n if d == dim:\n idx_dim = add_on\n idx = idx // shape[d]\n return offsets if not isInp else (offsets - idx_dim)\n", - "description_1": "Use triton language to implement a kernel 'add_on_kernel' that computes offsets based on input indices and strides. The kernel takes seven arguments: idx, add_on, cur_shape, cur_strides, M, N, BLOCK_M, and BLOCK_N. It calculates offsets in a 2D grid using triton's program_id and performs element-wise operations, storing results using tl.store. The calling function 'offset_calculator' computes total offsets for an input tensor based on its shape and dimension, employing the kernel for offset computation.", - "description_2": "Use triton language to create an offset calculation kernel that processes indices and strides within a 2D grid layout.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"TILE_N\": 32}),\n triton.Config({\"TILE_N\": 64}),\n triton.Config({\"TILE_N\": 128}),\n triton.Config({\"TILE_N\": 256}),\n triton.Config({\"TILE_N\": 512}),\n triton.Config({\"TILE_N\": 1024}),\n ],\n key=[\"N\"],\n)\n@triton.heuristics(\n values={\n \"TILE_M\": lambda args: 1024 // args[\"TILE_N\"],\n \"ONE_TILE_PER_CTA\": lambda args: args[\"TILE_N\"] >= args[\"N\"],\n },\n)\n@triton.jit\ndef softmax_kernel_inner(\n output_ptr,\n input_ptr,\n M,\n N,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n ONE_TILE_PER_CTA: tl.constexpr,\n DUMMY=42,\n):\n _ = DUMMY\n pid_m = tl.program_id(0)\n m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)\n if ONE_TILE_PER_CTA:\n n_offsets = tl.arange(0, TILE_N)\n offset = m_offsets[:, None] * N + n_offsets\n input_ptrs = input_ptr + offset\n mask = (m_offsets[:, None] < M) & (n_offsets < N)\n inp = tl.load(input_ptrs, mask=mask, other=-float(\"inf\"))\n m = tl.max(inp, 1)\n e = tl.exp(inp - m[:, None])\n z = tl.sum(e, 1)\n out = e / z[:, None]\n output_ptrs = output_ptr + offset\n tl.store(output_ptrs, out, mask=mask)\n else:\n m = tl.full([TILE_M], value=float(\"-inf\"), dtype=tl.float32)\n z = tl.full([TILE_M], value=0.0, dtype=tl.float32)\n\n n_offsets = tl.arange(0, TILE_N)\n offset = m_offsets[:, None] * N + n_offsets\n for _ in range(0, N, TILE_N):\n mask = (m_offsets[:, None] < M) & (n_offsets < N)\n input_ptrs = input_ptr + offset\n inp = tl.load(input_ptrs, mask=mask, other=-float(\"inf\"))\n m_new = tl.maximum(m, tl.max(inp, 1))\n alpha = m - m_new\n z = z * tl.exp(alpha) + tl.sum(tl.exp(inp - m_new[:, None]), axis=1)\n m = m_new\n n_offsets += TILE_N\n offset += TILE_N\n\n n_offsets = tl.arange(0, TILE_N)\n offset = m_offsets[:, None] * N + n_offsets\n for _ in range(0, N, TILE_N):\n mask = (m_offsets[:, None] < M) & (n_offsets < N)\n input_ptrs = input_ptr + offset\n inp = tl.load(input_ptrs, mask=mask, other=-float(\"inf\"))\n o = tl.exp(inp - m[:, None]) / z[:, None]\n output_ptrs = output_ptr + offset\n tl.store(output_ptrs, o, mask=mask)\n n_offsets += TILE_N\n offset += TILE_N\n\n\ndef softmax_inner_decorator_cascade(x, dim, dtype=None):\n assert dim >= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i] # pre_dim\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n\n out = torch.empty_like(inp, dtype=dtype)\n\n with torch.cuda.device(out.device):\n grid = lambda meta: (triton.cdiv(M, meta[\"TILE_M\"]), 1, 1)\n softmax_kernel_inner[grid](\n out,\n inp,\n M,\n N,\n DUMMY=60,\n )\n return out\n\n\ndef softmax_inner_pass_kernel_arg_via_kw(x, dim, dtype=None):\n assert dim >= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i] # pre_dim\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n out = torch.empty_like(inp, dtype=dtype)\n\n grid = lambda meta: (triton.cdiv(M, meta[\"TILE_M\"]), 1, 1)\n softmax_kernel_inner[grid](\n out,\n inp,\n M,\n N=N,\n DUMMY=60,\n )\n return out\n\n\ndef softmax_inner_kernel_arg_apply_default(x, dim, dtype=None):\n assert dim >= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i] # pre_dim\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n out = torch.empty_like(inp, dtype=dtype)\n\n grid = lambda meta: (triton.cdiv(M, meta[\"TILE_M\"]), 1, 1)\n softmax_kernel_inner[grid](\n out,\n inp,\n M,\n N,\n )\n return out\n", - "description_1": "Use triton language to implement a softmax kernel function 'softmax_kernel_inner' with 8 parameters: output_ptr, input_ptr, M, N, TILE_M, TILE_N, ONE_TILE_PER_CTA, and DUMMY. The kernel computes the softmax of input data in a tiled manner, using triton's parallel programming capabilities. The function 'softmax_inner_decorator_cascade' calls this kernel with 3 parameters: x, dim, and dtype, setting up the grid and passing necessary arguments to the kernel. Similarly, 'softmax_inner_pass_kernel_arg_via_kw' and 'softmax_inner_kernel_arg_apply_default' are wrapper functions that call the kernel with different argument configurations.", - "description_2": "Use triton language to create a softmax kernel that processes input data in tiles, optimizing for parallel execution. Implement wrapper functions to call this kernel with various argument setups.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport pytest\nfrom flag_gems.utils.pointwise_dynamic import pointwise_dynamic\nfrom flag_gems.utils.tensor_wrapper import StridedBuffer\n\nUSE_BLOCK_POINTER = [True, False]\n\n@pytest.mark.parametrize(\"use_block_pointer\", USE_BLOCK_POINTER)\ndef test_dynamic_function_without_non_tensor_args(use_block_pointer):\n config = CodeGenConfig(\n max_tile_size=1024,\n max_grid_size=(65536, 65536, 65536),\n max_num_warps_per_cta=32,\n prefer_block_pointer=use_block_pointer,\n prefer_1d_tile=False,\n )\n\n @pointwise_dynamic(\n num_inputs=2, promotion_methods=[(0, 1, \"DEFAULT\")], config=config\n )\n @triton.jit\n def add(x, y):\n return x + y\n\n SIZE = 2\n for ndim in range(8):\n shape = [SIZE] * ndim\n x = torch.randn(shape, device=\"cuda\")\n y = torch.randn_like(x)\n out = add(x, y)\n torch.testing.assert_close(out, x + y)\n\n@pytest.mark.parametrize(\"use_block_pointer\", USE_BLOCK_POINTER)\ndef test_dynamic_function_with_non_tensor_args(use_block_pointer):\n config = CodeGenConfig(\n max_tile_size=1024,\n max_grid_size=(65536, 65536, 65536),\n max_num_warps_per_cta=32,\n prefer_block_pointer=use_block_pointer,\n prefer_1d_tile=False,\n )\n\n @pointwise_dynamic(\n num_inputs=3,\n is_tensor=[True, True, False],\n promotion_methods=[(0, 1, \"DEFAULT\")],\n config=config,\n )\n @triton.jit\n def axpy(x, y, alpha):\n return alpha * x + y\n\n SIZE = 2\n for ndim in range(8):\n shape = [SIZE] * ndim\n x = torch.randn(shape, device=\"cuda\")\n y = torch.randn_like(x)\n alpha = 2.0\n out = axpy(x, y, alpha)\n torch.testing.assert_close(out, alpha * x + y)\n\n@pytest.mark.parametrize(\"use_block_pointer\", USE_BLOCK_POINTER)\ndef test_dynamic_function_with_multiple_outputs(use_block_pointer):\n config = CodeGenConfig(\n max_tile_size=1024,\n max_grid_size=(65536, 65536, 65536),\n max_num_warps_per_cta=32,\n prefer_block_pointer=use_block_pointer,\n prefer_1d_tile=False,\n )\n\n @pointwise_dynamic(\n num_inputs=3,\n is_tensor=[True, True, False],\n num_outputs=2,\n promotion_methods=[(0, 1, \"DEFAULT\"), (0, 1, \"DEFAULT\")],\n config=config,\n )\n @triton.jit\n def multiple_out(x, y, alpha):\n return alpha * x + y, alpha * x - y\n\n SIZE = 2\n for ndim in range(8):\n shape = [SIZE] * ndim\n x = torch.randn(shape, device=\"cuda\")\n y = torch.randn_like(x)\n alpha = 2.0\n out0, out1 = multiple_out(x, y, alpha)\n torch.testing.assert_close(out0, alpha * x + y)\n torch.testing.assert_close(out1, alpha * x - y)\n\n@pytest.mark.parametrize(\"use_block_pointer\", USE_BLOCK_POINTER)\ndef test_dynamic_function_with_broadcasting(use_block_pointer):\n config = CodeGenConfig(\n max_tile_size=1024,\n max_grid_size=(65536, 65536, 65536),\n max_num_warps_per_cta=32,\n prefer_block_pointer=use_block_pointer,\n prefer_1d_tile=True, # [misaligned address]\n )\n\n @pointwise_dynamic(\n num_inputs=3,\n is_tensor=[True, True, False],\n promotion_methods=[(0, 1, \"DEFAULT\")],\n config=config,\n )\n @triton.jit\n def axpy(x, y, alpha):\n return alpha * x + y\n\n SIZE = 10\n x = torch.randn([SIZE, 1, SIZE], device=\"cuda\")\n y = torch.randn([1, SIZE, 1], device=\"cuda\")\n alpha = 2.0\n out = axpy(x, y, alpha)\n torch.testing.assert_close(out, alpha * x + y)\n\n@pytest.mark.parametrize(\"use_block_pointer\", USE_BLOCK_POINTER)\ndef test_dynamic_function_with_predefined_out(use_block_pointer):\n config = CodeGenConfig(\n max_tile_size=1024,\n max_grid_size=(65536, 65536, 65536),\n max_num_warps_per_cta=32,\n prefer_block_pointer=use_block_pointer,\n prefer_1d_tile=False,\n )\n\n @pointwise_dynamic(\n num_inputs=3,\n is_tensor=[True, True, False],\n promotion_methods=[(0, 1, \"DEFAULT\")],\n config=config,\n )\n @triton.jit\n def axpy(x, y, alpha):\n return alpha * x + y\n\n SIZE = 10\n x = torch.randn([SIZE, SIZE, SIZE], device=\"cuda\")\n y = torch.randn([], device=\"cuda\")\n alpha = 2.0\n o = torch.empty([SIZE, SIZE, SIZE], device=\"cuda\")\n out = axpy(x, y, alpha, out0=o)\n torch.testing.assert_close(out, alpha * x + y)\n\n@pytest.mark.parametrize(\"use_block_pointer\", USE_BLOCK_POINTER)\ndef test_dynamic_function_manual_instantiation_mixing_strided_buffer_and_tensor(\n use_block_pointer,\n):\n config = CodeGenConfig(\n max_tile_size=1024,\n max_grid_size=(65536, 65536, 65536),\n max_num_warps_per_cta=32,\n prefer_block_pointer=use_block_pointer,\n prefer_1d_tile=False,\n )\n\n @pointwise_dynamic(\n num_inputs=3,\n is_tensor=[True, True, False],\n promotion_methods=[(0, 1, \"DEFAULT\"), (0, 1, \"DEFAULT\")],\n config=config,\n )\n @triton.jit\n def axpyaxmy(x, y, alpha):\n return alpha * x + y, alpha * x - y\n\n SIZE = 10\n x = torch.randn([SIZE, SIZE, SIZE], device=\"cuda\")\n y = torch.randn([SIZE, SIZE, SIZE], device=\"cuda\")\n alpha = 2.0\n _out0 = torch.empty([SIZE, SIZE, SIZE], device=\"cuda\")\n _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=\"cuda\"))\n out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1)\n\n assert isinstance(out0, torch.Tensor)\n assert isinstance(out1, StridedBuffer)\n", - "description_1": "Use triton language to implement several kernel functions for dynamic element-wise operations. The `add` kernel takes two tensors `x` and `y` as input and returns their element-wise sum. The `axpy` kernel takes two tensors `x` and `y` and a scalar `alpha`, computing the operation `alpha * x + y`. The `multiple_out` kernel similarly takes two tensors and a scalar and returns two outputs, `alpha * x + y` and `alpha * x - y`. These functions use a `pointwise_dynamic` decorator for code generation, allowing different configurations for execution based on parameters like tile size and grid size.", - "description_2": "Use triton language to create pointwise dynamic operations including element-wise addition and scalar-tensor multiplication and addition with triton.jit.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom flag_gems.utils import tensor_wrapper\n\n@triton.jit\ndef double(in_ptr, out_ptr, n, TILE_SIZE: tl.constexpr):\n # Triton kernel to double the values in the input pointer\n pid = tl.program_id(0)\n offsets = pid * TILE_SIZE + tl.arange(0, TILE_SIZE)\n mask = offsets < n\n x = tl.load(in_ptr + offsets, mask=mask)\n out = x * 2.0\n tl.store(out_ptr + offsets, out, mask=mask)\n\ndef test_typed_pointer():\n # Test the double kernel with complex tensor input\n real = torch.randn(10, 10, device=\"cuda\")\n imag = torch.randn(10, 10, device=\"cuda\")\n x = torch.complex(real, imag)\n\n out = torch.empty_like(x)\n TILE_SIZE = 128\n n = x.numel() * 2\n grid = (\n triton.cdiv(n, TILE_SIZE),\n 1,\n )\n in_ptr = tensor_wrapper.TypedPtr(x.data_ptr(), dtype=x.dtype.to_real())\n out_ptr = tensor_wrapper.TypedPtr(out.data_ptr(), dtype=out.dtype.to_real())\n double[grid](in_ptr, out_ptr, n, TILE_SIZE)\n torch.testing.assert_close(out, x * 2.0)\n\ndef test_typed_pointer_reinterpret_with_offset():\n # Test the double kernel with complex tensor and offset\n real = torch.randn(100, device=\"cuda\")\n imag = torch.randn(100, device=\"cuda\")\n x = torch.complex(real, imag)\n\n out = torch.empty_like(x)\n TILE_SIZE = 128\n k = 10\n n = (x.numel() - k) * 2\n grid = (\n triton.cdiv(n, TILE_SIZE),\n 1,\n )\n in_ptr = tensor_wrapper.TypedPtr.reinterpret_tensor(x, x.dtype.to_real(), 2 * k)\n out_ptr = tensor_wrapper.TypedPtr.reinterpret_tensor(\n out, out.dtype.to_real(), 2 * k\n )\n double[grid](in_ptr, out_ptr, n, TILE_SIZE)\n torch.testing.assert_close(out[k:], x[k:] * 2.0)\n\ndef test_typed_pointer_as_is():\n # Test the double kernel with regular tensor\n x = torch.randn(100, device=\"cuda\")\n out = torch.empty_like(x)\n TILE_SIZE = 128\n k = 10\n n = x.numel() - k\n grid = (\n triton.cdiv(n, TILE_SIZE),\n 1,\n )\n in_ptr = tensor_wrapper.TypedPtr.from_tensor(x, k)\n out_ptr = tensor_wrapper.TypedPtr.from_tensor(out, k)\n double[grid](in_ptr, out_ptr, n, TILE_SIZE)\n torch.testing.assert_close(out[k:], x[k:] * 2.0)\n", - "description_1": "Use triton language to implement a kernel function 'double' that takes four parameters: two pointers ('in_ptr', 'out_ptr'), an integer 'n', and a constant expression 'TILE_SIZE'. The kernel reads values from the 'in_ptr', doubles them, and stores the result in 'out_ptr'. This functionality is executed over a grid defined by a single-dimensional index space. The kernel is used in three different test functions that handle various input data structures, such as complex tensors and strided buffers.", - "description_2": "Use triton language to create a kernel that doubles elements in a tensor with grid-based parallel execution.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_stages=1, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=4, num_warps=8),\n triton.Config({}, num_stages=8, num_warps=8),\n triton.Config({}, num_stages=1),\n triton.Config({}, num_stages=2),\n triton.Config({}, num_stages=4),\n triton.Config({}, num_stages=8),\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"n_elements\"],\n)\n@triton.jit\ndef _dequantize_rowwise(\n x_ptr,\n state_x,\n output_ptr,\n inv_127,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n arange = tl.arange(0, P2)\n offsets = block_start + arange\n row_mask = arange < BLOCK_SIZE\n x = tl.load(x_ptr + offsets, mask=row_mask)\n max_val = tl.load(state_x + pid)\n output = max_val * x * inv_127\n tl.store(output_ptr + offsets, output, mask=row_mask)\n\ndef dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):\n output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)\n P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (x.shape[0],)\n _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)\n return output\n", - "description_1": "Use triton language to implement a row-wise dequantization kernel '_dequantize_rowwise', which takes 7 parameters. The parameters are: 1) x_ptr: pointer to the input tensor, 2) state_x: tensor containing max values for each row, 3) output_ptr: pointer to the output tensor, 4) inv_127: constant to normalize values, 5) n_elements: total number of elements, 6) BLOCK_SIZE: block size for triton kernel execution, 7) P2: power-of-two value for the block size. The function calculates normalized output by multiplying max value, input value, and inv_127. The function 'dequantize_rowwise' is a wrapper around this kernel, receiving a torch tensor, and executing the kernel to fill an output tensor with the dequantized values.", - "description_2": "Use triton language to implement a kernel for row-wise dequantization with given pointers and constants. Use torch to prepare data and invoke this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n ],\n key=[\"M\", \"N\", \"K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n },\n)\n@triton.jit\ndef _int8_matmul_mixed_dequantize(\n A,\n B,\n C,\n bias,\n state_x_ptr,\n state_w_ptr,\n M,\n N,\n K,\n divfactor: tl.constexpr,\n has_bias: tl.constexpr,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n w_factor = tl.load(state_w_ptr)\n x_factor = tl.load(state_x_ptr + ram)[:, None]\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = w_factor * (x_factor * (acc * divfactor))\n acc = acc.to(C.dtype.element_ty)\n if has_bias:\n bias = tl.load(bias + rn).to(C.dtype.element_ty)\n acc = acc + bias[None, :]\n C = C + (rm[:, None] * stride_cm + rn[None, :])\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):\n device = a.device\n divfactor = 1.0 / (127.0 * 127.0)\n has_bias = 0 if bias is None else 1\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n c = torch.empty((M, N), device=device, dtype=torch.float16)\n ACC_TYPE = tl.float32\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]), META[\"SPLIT_K\"])\n _int8_matmul_mixed_dequantize[grid](\n a,\n b,\n c,\n bias,\n state_x,\n state_w,\n M,\n N,\n K,\n divfactor,\n has_bias,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n GROUP_M=8,\n ACC_TYPE=ACC_TYPE,\n )\n return c\n", - "description_1": "Use triton language to implement a kernel for mixed-precision int8 matrix multiplication with dequantization and optional bias addition. The kernel _int8_matmul_mixed_dequantize accepts 22 parameters: two input matrices A and B, an output matrix C, an optional bias, state pointers for x and w, dimensions M, N, and K, a divisor factor, a flag for bias presence, stride values for input matrices, and several compile-time constants for block sizes and types. The int8_matmul_mixed_dequantize function serves as a wrapper to prepare inputs and invoke the Triton kernel.", - "description_2": "Use triton language to perform int8 matrix multiplication with optional bias and dequantization, allowing custom grid and block configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _int8_matmul_rowwise_dequantize(\n A,\n B,\n C,\n bias,\n state_x_ptr,\n state_w_ptr,\n M,\n N,\n K,\n divfactor,\n has_bias: tl.constexpr,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr,\n):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n w_factor = tl.load(state_w_ptr + rbn)[None, :]\n x_factor = tl.load(state_x_ptr + ram)[:, None]\n\n # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n\n acc = w_factor * (x_factor * (acc * divfactor))\n acc = acc.to(C.dtype.element_ty)\n\n if has_bias:\n bias = tl.load(bias + rn).to(C.dtype.element_ty)\n acc = acc + bias[None, :]\n\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):\n divfactor = 1.0 / (127.0 * 127.0)\n has_bias = 0 if bias is None else 1\n\n device = a.device\n # handle non-contiguous inputs if necessary\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c = torch.empty((M, N), device=device, dtype=torch.float16)\n # accumulator types\n ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32\n # launch int8_matmul_rowwise_dequantize kernel\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]), META[\"SPLIT_K\"])\n _int8_matmul_rowwise_dequantize[grid](\n a,\n b,\n c,\n bias,\n state_x,\n state_w,\n M,\n N,\n K,\n divfactor,\n has_bias,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n GROUP_M=8,\n ACC_TYPE=ACC_TYPE,\n )\n return c\n", - "description_1": "Use triton language to create a kernel function for int8 matrix multiplication with row-wise dequantization. The kernel takes 24 parameters, including matrices A, B, C, bias terms, state pointers, dimensions M, N, K, a divfactor for scaling, boolean flag for bias presence, strides for each dimension in A, B, C, constant parameters like block sizes and group size, and accumulator type. The function performs matrix multiplication, applies quantization factors, optionally adds a bias, and writes back the result with potential reduction-splitting using atomic addition.", - "description_2": "Use triton language to define and invoke a kernel for int8 matrix multiplication and row-wise dequantization with adjustable block size and quantization factors.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n# This kernel does fused columnwise quantization and transpose.\n@triton.jit\ndef _quantize_columnwise_and_transpose(\n x_ptr,\n output_ptr,\n output_maxs,\n n_elements,\n M: tl.constexpr,\n N: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid\n p2_arange = tl.arange(0, P2)\n p2_arange_mask = p2_arange < M\n arange = p2_arange * N\n offsets = block_start + arange\n x = tl.load(x_ptr + offsets, mask=p2_arange_mask)\n abs_x = tl.abs(x)\n max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)\n output = tl.libdevice.llrint(127.0 * (x / max_val))\n\n new_start = pid * M\n new_offsets = new_start + p2_arange\n tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)\n tl.store(output_maxs + pid, max_val)\n\ndef quantize_columnwise_and_transpose(x: torch.Tensor):\n M, N = x.shape\n output = torch.empty(N, M, device=x.device, dtype=torch.int8)\n output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(M))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)\n return output, output_maxs\n", - "description_1": "Use triton language to implement a kernel that performs fused columnwise quantization and transpose on a 2D tensor. The kernel takes pointers to input and output tensors, the number of elements, and several compile-time constants. It computes the maximum absolute value per column, scales the input values, and stores the quantized results and maximum values. The wrapper function prepares the input, output tensors, and grid configuration for the kernel launch.", - "description_2": "Use triton language to create a kernel for columnwise quantization and transpose of a tensor, and a wrapper to set up and launch the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Global quantize kernel\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 1024}, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 2048}, num_stages=1),\n ],\n key=[\"n_elements\"],\n)\n@triton.jit\ndef _quantize_global(\n x_ptr,\n absmax_inv_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n absmax_inv = tl.load(absmax_inv_ptr)\n output = tl.libdevice.llrint(127.0 * (x * absmax_inv))\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef quantize_global(x: torch.Tensor):\n absmax = x.abs().max().unsqueeze(0)\n absmax_inv = 1.0 / absmax\n output = torch.empty(*x.shape, device=\"cuda\", dtype=torch.int8)\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n _quantize_global[grid](x, absmax_inv, output, n_elements)\n return output, absmax\n\n# Global quantize and transpose kernel\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"GROUP_M\": 8}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"GROUP_M\": 8}, num_warps=4),\n ],\n key=[\"M\", \"N\"],\n)\n@triton.jit\ndef _quantize_global_transpose(\n A,\n absmax_inv_ptr,\n B,\n stride_am,\n stride_an,\n stride_bn,\n stride_bm,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n GROUP_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // group_size\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n a = tl.load(A, mask=mask)\n absmax_inv = tl.load(absmax_inv_ptr)\n\n # rematerialize to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n output = tl.libdevice.llrint(127.0 * (a * absmax_inv))\n\n tl.store(B, output, mask=mask)\n\ndef quantize_global_transpose(input):\n absmax = input.abs().max().unsqueeze(0)\n absmax_inv = 1.0 / absmax\n M, N = input.shape\n out = torch.empty(N, M, device=\"cuda\", dtype=torch.int8)\n\n assert out.size(0) == N and out.size(1) == M\n assert input.stride(0) == 1 or input.stride(1) == 1\n assert out.stride(0) == 1 or out.stride(1) == 1\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n _quantize_global_transpose[grid](\n input,\n absmax_inv,\n out,\n input.stride(0),\n input.stride(1),\n out.stride(0),\n out.stride(1),\n M,\n N,\n )\n return out, absmax\n", - "description_1": "Use triton language to implement two kernels: '_quantize_global' which quantizes an input tensor globally, and '_quantize_global_transpose' which quantizes and transposes an input tensor. The first kernel takes 5 parameters: x_ptr (input data pointer), absmax_inv_ptr (inverse of absolute max value pointer), output_ptr (output data pointer), n_elements (number of elements to process), BLOCK_SIZE (block size). The second kernel takes 11 parameters: A (input data pointer), absmax_inv_ptr (inverse of absolute max value pointer), B (output data pointer), stride_am (stride of input matrix along M), stride_an (stride of input matrix along N), stride_bn (stride of output matrix along N), stride_bm (stride of output matrix along M), M (number of rows in input matrix), N (number of columns in input matrix), BLOCK_M (block size along M), BLOCK_N (block size along N), GROUP_M (group size along M).", - "description_2": "Use triton language to create two functions: 'quantize_global' which calls '_quantize_global' kernel, and 'quantize_global_transpose' which calls '_quantize_global_transpose' kernel. These functions handle data preparation and grid configuration, quantizing and optionally transposing the input tensor.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for rowwise quantization\n@triton.jit\ndef _quantize_rowwise(\n x_ptr,\n output_ptr,\n output_maxs,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n arange = tl.arange(0, P2)\n offsets = block_start + arange\n row_mask = arange < BLOCK_SIZE\n x = tl.load(x_ptr + offsets, mask=row_mask)\n\n abs_x = tl.abs(x)\n max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)\n output = tl.libdevice.llrint(127.0 * (x / max_val))\n tl.store(output_ptr + offsets, output, mask=row_mask)\n tl.store(output_maxs + pid, max_val)\n\n# Function to call the Triton kernel\ndef quantize_rowwise(x: torch.Tensor):\n output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)\n output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (x.shape[0],)\n _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)\n return output, output_maxs\n", - "description_1": "Use triton language to implement a rowwise quantization kernel. The kernel '_quantize_rowwise' takes 6 parameters: 'x_ptr' (pointer to input tensor), 'output_ptr' (pointer to output tensor), 'output_maxs' (pointer to store max values for each row), 'n_elements' (total number of elements), 'BLOCK_SIZE' (size of each block), and 'P2' (power of 2 greater than or equal to the number of columns). The function 'quantize_rowwise' prepares the input and output tensors, calculates 'P2', and launches the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel for rowwise quantization of a tensor, and a function to prepare data and launch this kernel.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Bias,\n Out,\n Lse,\n TMP,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_ob,\n stride_oh,\n stride_om,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation for forward pass of FlashAttention\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out,\n DO,\n Delta,\n stride_ob,\n stride_oh,\n stride_om,\n stride_dob,\n stride_doh,\n stride_dom,\n nheads,\n seqlen_q,\n seqlen_q_rounded,\n headdim,\n BLOCK_M: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n):\n # Triton kernel for preprocessing in backward pass\n\n@triton.jit\ndef _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):\n # Triton kernel for storing gradients of K and V\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n,\n Q,\n K,\n V,\n Bias,\n DO,\n DQ,\n DK,\n DV,\n LSE,\n D,\n softmax_scale,\n stride_qm,\n stride_kn,\n stride_vn,\n stride_bm,\n stride_dom,\n stride_dqm,\n stride_dkn,\n stride_dvn,\n seqlen_q,\n seqlen_k,\n headdim,\n ATOMIC_ADD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for backward pass processing of one column block\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(\"DQ\")),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(\"DQ\")),\n ],\n key=[\"CACHE_KEY_SEQLEN_Q\", \"CACHE_KEY_SEQLEN_K\", \"BIAS_TYPE\", \"IS_CAUSAL\", \"BLOCK_HEADDIM\"],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q,\n K,\n V,\n Bias,\n DO,\n DQ,\n DK,\n DV,\n LSE,\n D,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_dob,\n stride_doh,\n stride_dom,\n stride_dqb,\n stride_dqh,\n stride_dqm,\n stride_dkb,\n stride_dkh,\n stride_dkn,\n stride_dvb,\n stride_dvh,\n stride_dvn,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n SEQUENCE_PARALLEL: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for backward pass of FlashAttention\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Function to call the forward Triton kernel\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n # Function to call the backward Triton kernel\n", - "description_1": "Use triton language to implement forward and backward kernels for FlashAttention, handling inputs Q, K, V, and optional Bias. The forward kernel computes the attention output and log-sum-exp values, while the backward kernel computes gradients for Q, K, V. Parameters include sequence lengths, head dimensions, and block sizes.", - "description_2": "Use triton language to implement FlashAttention forward and backward kernels with support for causal masking and attention bias.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef triton_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # Compute program ID\n pid = tl.program_id(0)\n # Compute the start index for this program\n start = pid * BLOCK_SIZE\n # Create a range of indices for this program\n offsets = start + tl.arange(0, BLOCK_SIZE)\n # Load input data\n input_data = tl.load(input_ptr + offsets, mask=offsets < n_elements, other=0.0)\n # Perform computation (e.g., element-wise addition)\n output_data = input_data + 1.0\n # Store the result\n tl.store(output_ptr + offsets, output_data, mask=offsets < n_elements)\n\ndef call_triton_kernel(input_tensor, output_tensor):\n # Define the number of elements and block size\n n_elements = input_tensor.numel()\n BLOCK_SIZE = 1024\n # Launch the Triton kernel\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n triton_kernel[grid](input_tensor, output_tensor, n_elements, BLOCK_SIZE)\n\n# Example usage\ninput_tensor = torch.randn(10240, device='cuda')\noutput_tensor = torch.empty_like(input_tensor)\ncall_triton_kernel(input_tensor, output_tensor)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition on an input tensor. The kernel is decorated with @triton.jit and takes four parameters: input_ptr, output_ptr, n_elements, and BLOCK_SIZE. The kernel computes the program ID, calculates the start index, creates a range of indices, loads input data, performs the addition, and stores the result. A separate function, call_triton_kernel, is used to launch the kernel with specified grid and block size.", - "description_2": "Use triton language to define a kernel for element-wise addition on a tensor, and a function to launch this kernel with specified grid and block size.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_add_kernel(X, Y, Z, N):\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']),)\n BLOCK_SIZE = 1024\n add_kernel[grid](X, Y, Z, N, num_warps=1)\n\n# Example usage\nx = torch.tensor([1.0, 2.0, 3.0], device='cuda')\ny = torch.tensor([4.0, 5.0, 6.0], device='cuda')\nz = torch.empty_like(x)\nN = x.numel()\ncall_add_kernel(x, y, z, N)\n", - "description_1": "Use triton language to define a kernel function named add_kernel that performs element-wise addition of two input tensors X and Y and stores the result in tensor Z. The kernel takes four parameters: X, Y, Z (all pointers to the tensor data in GPU memory) and N, the total number of elements in the tensors. The call_add_kernel function is a wrapper that configures and launches the add_kernel with specific parameters such as grid size and block size, using triton's meta programming features to dynamically determine the grid size based on N.", - "description_2": "Use triton language to create a kernel function that adds two GPU tensors, with a wrapper function for kernel configuration and launch.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# A sample Triton kernel function decorated with @triton.jit\n@triton.jit\ndef example_kernel(x_ptr, y_ptr, BLOCK_SIZE: int):\n \"\"\"\n This is a sample kernel function that takes two pointers and a block size as input.\n It performs an element-wise addition of two vectors in parallel.\n \"\"\"\n pid = tl.program_id(axis=0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n x = tl.load(x_ptr + offset)\n y = tl.load(y_ptr + offset)\n tl.store(y_ptr + offset, x + y)\n\n# Wrapper function to launch the Triton kernel\ndef launch_example_kernel(x, y):\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(x.size(0), meta['BLOCK_SIZE']),)\n example_kernel[grid](x, y, BLOCK_SIZE)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition of two vectors. The kernel takes two pointers `x_ptr` and `y_ptr` and an integer `BLOCK_SIZE` to perform operations in parallel on blocks of data. It uses Triton primitives to load data from global memory, perform addition, and store the result back to global memory. A separate wrapper function `launch_example_kernel` is provided to configure grid dimensions and launch the kernel with two PyTorch tensors `x` and `y`.", - "description_2": "Use triton language to create a kernel for element-wise vector addition with input pointers and block size, and a wrapper for launching the kernel with grid configuration.", - "difficulty": 1 - }, - { - "code": "import triton\nimport torch\n\n@triton.jit\ndef my_kernel(X, output, stride_x, stride_y, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < X.shape[0]\n x_values = tl.load(X + offsets * stride_x, mask=mask)\n y_values = x_values * 2\n tl.store(output + offsets * stride_y, y_values, mask=mask)\n\ndef call_my_kernel(X):\n BLOCK_SIZE = 1024\n output = torch.empty_like(X)\n grid = lambda meta: (triton.cdiv(X.numel(), BLOCK_SIZE),)\n my_kernel[grid](X, output, X.stride(0), output.stride(0), BLOCK_SIZE=BLOCK_SIZE)\n return output\n\n# Example Usage\nx = torch.arange(1024, device='cuda')\nresult = call_my_kernel(x)\n", - "description_1": "Use triton language to define a kernel `my_kernel` that performs element-wise multiplication of a tensor with 2. The kernel takes four parameters: X (input tensor), output (output tensor), stride_x (stride for X), and stride_y (stride for output), and a constexpr parameter BLOCK_SIZE indicating the block size for execution. The kernel is invoked using the `call_my_kernel` function, which manages the configuration and execution of the kernel on a given input tensor X.", - "description_2": "Use triton language to create and invoke a kernel that multiplies each element of a given tensor by 2, with configurable block size for efficient execution.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n init,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``\n DTYPE_PACK: Unsigned type twice the width of block_value\n\n NOTE: This function is limited to values which are 32-bits or less.\n \"\"\"\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n # tl.atomic_load\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(\n scratch_base, block_value, index, combine_fn, init\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block, must be 64-bits wide\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n \"\"\"\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n\n return exclusive_prefix\n\n@triton.jit\ndef frexp(x):\n # TODO(isuruf): use inline_asm_elementwise here\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n", - "description_1": "Use triton language to define a set of utility functions for element-wise operations, reductions, and handling data formats including accumulation, minimum/maximum comparison, welford statistics computation, random number generation, and parallel prefix scan techniques for GPU tensors.", - "description_2": "Use triton language to implement various arithmetic and reduction operations on tensors and facilitate advanced data processing with GPU acceleration.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n\ndef _run_sampled_addmm_kernel(\n alpha, beta, is_beta_zero,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n):\n n_batches = values.size(0)\n n_block_rows = crow_indices.size(-1) - 1\n\n full_grid = (n_batches, n_block_rows)\n if max_grid is not None:\n grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))\n else:\n grid_blocks = None\n tensor_dims_map = {\n values: (0, None),\n crow_indices: (0, -1),\n col_indices: (0, None),\n mat1: (0, -4),\n mat2: (0, None),\n }\n if values.dtype in (torch.half, torch.bfloat16):\n acc_dtype = tl.float32\n allow_tf32 = True\n else:\n acc_dtype = tl.float64\n allow_tf32 = False\n\n def kernel(grid, *sliced_tensors):\n _sampled_addmm_kernel[grid](\n alpha, beta, is_beta_zero,\n *blocksize, k, tile_k,\n *ptr_stride_extractor(*sliced_tensors),\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n num_stages=1,\n num_warps=4\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n", - "description_1": "Use triton language to implement a sampled matrix multiplication kernel that multiplies a sparse BSR tensor with a dense matrix and accumulates the results using provided beta and alpha coefficients. The kernel needs to be invoked with grid launch parameters based on input tensor dimensions.", - "description_2": "Use triton language to create a kernel for sparse-dense matrix multiplication and execute it using grid-stride loop method.", - "difficulty": 5 - }, - { - "code": "import triton\nfrom triton import language as tl\nfrom triton.language import load, store\n\n# Kernel to add two arrays element-wise\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Kernel to add two arrays element-wise with optional parameter\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# 2D Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=4, num_warps=4\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n# Kernel to multiply an array by 2\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# In-place kernel to multiply an array by 2\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n# Kernel to apply an activation function\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n elif ACTIVATION == \"add_kernel\":\n add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n\n# Kernel to add two arrays element-wise with block pointers\n@triton.jit\ndef add_kernel_with_block_ptr(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n y = tl.load(\n tl.make_block_ptr(\n base=y_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n output = x + y\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n output,\n boundary_check=[0],\n )\n\n# Kernel to add two arrays element-wise with import\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to implement various kernels for element-wise operations on arrays. These include addition, multiplication by 2, and conditional operations. The kernels utilize block pointers and support optional parameters and autotuning for performance optimization.", - "description_2": "Use triton language to create kernels for element-wise addition and multiplication of arrays, with support for block pointers and autotuning.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.ir import ReductionHint\nfrom torch._inductor.triton_heuristics import reduction\nfrom torch._inductor.utils import instance_descriptor\n\n@reduction(\n size_hints=[4096, 256],\n reduction_hint=ReductionHint.DEFAULT,\n filename=__file__,\n meta={\n \"signature\": {\n 0: (tl.pointer_type(tl.float32), 1),\n 1: (tl.pointer_type(tl.float32), 1),\n 2: (tl.float32, 1),\n 3: (tl.float32, 1)\n },\n \"device\": 0,\n \"device_type\": \"cuda\",\n \"constants\": {},\n \"mutated_arg_names\": [\"out\"],\n \"autotune_hints\": set(),\n \"kernel_name\": \"example_kernel\",\n \"configs\": [instance_descriptor()]\n }\n)\n@triton.jit\ndef example_kernel(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xoffset = tl.program_id(0) * XBLOCK\n roffset = tl.program_id(1) * RBLOCK\n rindex = roffset + tl.arange(0, RBLOCK)\n xmask = xoffset < xnumel\n rmask = rindex < rnumel\n \n xbase = xoffset + tl.arange(0, XBLOCK)\n \n out_ptr = out_ptr2 + (xbase, )\n \n if xmask:\n for _ in range(RBLOCK):\n if rmask:\n tl.store(out_ptr, tl.load(in_ptr0 + (xbase, )) + tl.load(in_ptr1 + (rindex, )))\n\ndef call_kernel_example(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel):\n grid = (xnumel // XBLOCK, rnumel // RBLOCK)\n stream = torch.cuda.current_stream(0)\n example_kernel[(grid, stream)](in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK=128, RBLOCK=32)\n\n", - "description_1": "Use triton language to implement a kernel called 'example_kernel'. The kernel is decorated with @triton.jit and takes six parameters. in_ptr0 and in_ptr1 are input pointers, out_ptr2 is an output pointer. xnumel and rnumel are integer arguments representing dimensions. XBLOCK and RBLOCK are compile-time constant expressions determining the block sizes. The kernel uses triton language constructs to perform element-wise addition of inputs and stores results in out_ptr2, with masking based on xnumel and rnumel dimensions.", - "description_2": "Use triton language to define a kernel with element-wise addition using input pointers, and output results with dimension-based masking.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import foreach\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\nclass ForeachKernel:\n MAX_NUM_ARGS = 250\n\n def __init__(self):\n self.blocking_2d = False\n self.block_size_1d = 1024\n self.block_size_2d = 32\n self.num_warps = 8\n self.sub_kernels = []\n self.iter_vars_count = itertools.count()\n self.x_block_count = 0\n\n def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):\n sub_kernel = TritonKernel(\n *groups,\n index_dtype=index_dtype,\n mutations=mutations,\n pid_cache={\n \"tl.program_id(0)\": \"xpid_offset\",\n \"tl.program_id(1)\": \"ypid\",\n },\n reduction_hint=reduction_hint,\n )\n self.blocking_2d |= groups[1] != 1 and len(groups) == 3\n self.sub_kernels.append(sub_kernel)\n return sub_kernel\n\n def jit_line(self):\n can_use_32bit = all(k.index_dtype == \"tl.int32\" for k in self.sub_kernels)\n index_dtype = \"tl.int32\" if can_use_32bit else \"tl.int64\"\n _, _, signature = self.args.python_argdefs()\n triton_meta = {\n \"signature\": signature_to_meta(signature, size_dtype=can_use_32bit),\n \"device\": V.graph.scheduler.current_device.index,\n \"device_type\": V.graph.scheduler.current_device.type,\n \"constants\": {},\n }\n triton_meta[\"configs\"] = [config_of(signature)]\n return (\n f\"@foreach(num_warps={self.num_warps}, meta={triton_meta!r})\\n\"\n + \"@triton.jit\"\n )\n\n def grid(self):\n return (\n self.x_block_count,\n ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)\n if self.blocking_2d\n else 1,\n 1,\n )\n\n def codegen_kernel(self, name=None):\n code = IndentedBuffer()\n code.splice(\n \"\"\"\n import triton\n import triton.language as tl\n from torch._inductor.triton_heuristics import foreach\n from torch._inductor.utils import instance_descriptor\n from torch._inductor import triton_helpers\n \"\"\"\n )\n argdefs, _, _ = self.args.python_argdefs()\n code.writeline(self.jit_line())\n code.writeline(f\"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):\")\n with code.indent():\n code.splice(\"xpid = tl.program_id(0)\")\n if self.blocking_2d:\n code.splice(\"ypid = tl.program_id(1)\")\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_2d}\")\n code.splice(f\"YBLOCK: tl.constexpr = {self.block_size_2d}\")\n else:\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_1d}\")\n\n for sub_kernel in self.sub_kernels:\n assert len(sub_kernel.numels) <= 3\n numel_ind = 0 if not self.blocking_2d else 1\n self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))\n with code.indent():\n if self.blocking_2d:\n code.splice(f\"ynumel = {sub_kernel.numels[0]}\")\n code.splice(f\"xnumel = {sub_kernel.numels[1]}\")\n else:\n code.splice(f\"xnumel = {sub_kernel.numels[0]}\")\n\n sub_kernel.codegen_body()\n code.splice(sub_kernel.body)\n\n code.splice(\"else:\")\n with code.indent():\n code.splice(\"pass\")\n return code.getvalue()\n\n def call_kernel(self, code, name: str):\n _, call_args, _ = self.args.python_argdefs()\n for i in range(len(call_args)):\n if V.graph.is_unspec_arg(call_args[i]):\n call_args[i] = call_args[i] + \".item()\"\n if V.graph.cpp_wrapper:\n V.graph.wrapper_code.generate_kernel_call(\n name, call_args, device_index=V.graph.scheduler.current_device.index\n )\n else:\n call_args_str = \", \".join(call_args)\n stream_name = code.write_get_cuda_stream(\n V.graph.scheduler.current_device.index\n )\n code.writeline(\n f\"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})\"\n )\n", - "description_1": "Use triton language to define a kernel class 'ForeachKernel'. The class initializes several parameters such as block size and sub kernels. It includes methods to create sub-kernels, generate JIT lines, define grid size, generate kernel code, and call the kernel. The kernel function 'codegen_kernel' is decorated with '@triton.jit' and requires several input arguments.", - "description_2": "Use triton language to define a kernel class 'ForeachKernel' with methods for creating and managing Triton kernels, including initialization, kernel code generation, and kernel execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel for a specific operation\n@triton.jit\ndef triton_kernel_example(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < X.shape[0]\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\n# Function to call the Triton kernel\ndef call_triton_kernel_example(X, Y, Z, BLOCK_SIZE):\n grid = lambda meta: (triton.cdiv(X.shape[0], meta['BLOCK_SIZE']),)\n triton_kernel_example[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = torch.empty(1024, device='cuda')\ncall_triton_kernel_example(X, Y, Z, BLOCK_SIZE=128)\n", - "description_1": "Use triton language to define a kernel 'triton_kernel_example' that performs element-wise addition of two input tensors X and Y, storing the result in tensor Z. The kernel uses a block size defined by BLOCK_SIZE and handles out-of-bounds accesses with a mask. The function 'call_triton_kernel_example' sets up the grid and launches the kernel with the specified block size.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors with a specified block size, and a function to launch this kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight):\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n return (\n new_mean,\n m2 + delta * (value - new_mean),\n new_weight,\n )\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n", - "description_1": "Use triton language to implement various reduction and comparison operations, including tensor promotion, floating point checks, product accumulation, minimum and maximum value calculations with and without indices, Welford reduction for variance calculation, device assertions, random integer generation, and binary search bucketization.", - "description_2": "Use triton language to perform reduction operations and comparisons, including min/max calculations and Welford variance reduction.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional, Tuple\nfrom torch._inductor.cuda_properties import get_device_capability\n\ndef _has_triton():\n if not torch.cuda.is_available():\n return False\n try:\n import triton\n return triton is not None and get_device_capability() >= (7, 0)\n except ImportError:\n return False\n\nif _has_triton():\n @triton.jit\n def _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n ):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n @triton.jit\n def _bsr_strided_dense_rowspace_kernel(\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n # values prologue\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n # values epilogue\n # crow_indices prologue\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n # crow_indices epilogue\n # col_indices prologue\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n # col_indices epilogue\n # dense prologue\n dense_ptr,\n dense_batch_stride,\n dense_tiled_row_stride,\n dense_tiled_col_stride,\n dense_row_block_stride,\n dense_col_block_stride,\n # dense epilogue\n # output prologue\n output_ptr,\n output_batch_stride,\n output_tiled_row_stride,\n output_tiled_col_stride,\n output_row_block_stride,\n output_col_block_stride,\n # output epilogue\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n GROUP_SIZE_ROW: tl.constexpr,\n ):\n batch_pid = tl.program_id(axis=2)\n row_block_pid = tl.program_id(axis=0)\n col_block_pid = tl.program_id(axis=1)\n n_block_rows = tl.num_programs(axis=0)\n n_block_cols = tl.num_programs(axis=1)\n\n row_block_pid, col_block_pid = tl.swizzle2d(\n row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW\n )\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n # NOTE: dense is advanced into all dimensions but the tiled row one.\n # That will be advanced in the loop according to values in col_indices.\n dense_block_ptrs = (\n dense_ptr\n + dense_batch_stride * batch_pid\n + dense_tiled_col_stride * col_block_pid\n + dense_row_block_stride * col_block_arange[:, None]\n + dense_col_block_stride * row_block_arange[None, :]\n )\n\n # Pointers are set to exact write-to locations\n output_ptrs = (\n output_ptr\n + output_batch_stride * batch_pid\n + output_tiled_row_stride * row_block_pid\n + output_tiled_col_stride * col_block_pid\n + output_row_block_stride * row_block_arange[:, None]\n + output_col_block_stride * row_block_arange[None, :]\n )\n\n # Set pointer to the first nonzero element in the current row\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), dtype=acc_dtype)\n for _ in range(row_nnz):\n values_block = tl.load(values_block_ptrs)\n\n # find which row of dense needs to get loaded\n # for multiplication with values_block.\n dense_row_idx = tl.load(col_index_nnz_ptr)\n dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)\n\n # do block mm\n output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32)\n\n # move val/col_index ptrs to the next block in the row\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n # write back the result\n tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))\n\n\n def _run_dense_rowspace_kernel(\n blocksize, values, crow_indices, col_indices, dense, output, max_grid\n ):\n n_batches = dense.size(0)\n n_block_rows = crow_indices.size(-1) - 1\n n_block_cols = dense.size(-3)\n\n full_grid = (n_batches, n_block_cols, n_block_rows)\n if max_grid is not None:\n grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))\n else:\n grid_blocks = None\n tensor_dims_map = {\n values: (0, None, None),\n crow_indices: (0, None, -1),\n col_indices: (0, None, None),\n dense: (0, -3, None),\n output: (0, -3, -4)\n }\n if values.dtype in (torch.half, torch.bfloat16):\n acc_dtype = tl.float32\n allow_tf32 = True\n else:\n acc_dtype = tl.float64\n allow_tf32 = False\n\n def kernel(grid, *sliced_tensors):\n _bsr_strided_dense_rowspace_kernel[grid](\n *blocksize,\n *ptr_stride_extractor(*sliced_tensors),\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n GROUP_SIZE_ROW=4,\n num_stages=1,\n num_warps=4\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n\n def _run_sampled_addmm_kernel(\n alpha, beta, is_beta_zero,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n ):\n n_batches = values.size(0)\n n_block_rows = crow_indices.size(-1) - 1\n\n full_grid = (n_batches, n_block_rows)\n if max_grid is not None:\n grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))\n else:\n grid_blocks = None\n tensor_dims_map = {\n values: (0, None),\n crow_indices: (0, -1),\n col_indices: (0, None),\n mat1: (0, -4),\n mat2: (0, None),\n }\n if values.dtype in (torch.half, torch.bfloat16):\n acc_dtype = tl.float32\n allow_tf32 = True\n else:\n acc_dtype = tl.float64\n allow_tf32 = False\n\n def kernel(grid, *sliced_tensors):\n _sampled_addmm_kernel[grid](\n alpha, beta, is_beta_zero,\n *blocksize, k, tile_k,\n *ptr_stride_extractor(*sliced_tensors),\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n num_stages=1,\n num_warps=4\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n\n def sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n ):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\n\n def bsr_dense_mm(\n bsr: torch.Tensor,\n dense: torch.Tensor,\n *,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n ):\n f_name = \"bsr_dense_mm\"\n if not skip_checks:\n check_bsr_layout(f_name, bsr)\n check_device(f_name, bsr, dense.device)\n check_dtype(f_name, bsr, dense.dtype)\n check_mm_compatible_shapes(f_name, bsr, dense)\n\n m = bsr.size(-2)\n n = dense.size(-1)\n row_block, col_block = bsr.values().shape[-2:]\n check(\n not n % row_block,\n f\"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by \"\n f\"blocksize[0] == {row_block}.\",\n )\n check_blocksize(f_name, (row_block, col_block))\n else:\n m, kl = bsr.shape[-2:]\n kr, n = dense.shape[-2:]\n\n original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)\n\n if out is not None and not skip_checks:\n expected_out_shape = original_batch_dims_broadcasted + (m, n)\n check(\n out.shape == expected_out_shape,\n \"bsr_dense_mm(): `out` argument has wrong shape, \"\n f\"expected {expected_out_shape}, but got {out.shape}.\",\n )\n check(\n out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),\n \"bsr_dense_mm(): only row-major/col-major `out` arguments are supported, \"\n \"i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) \"\n \"should be True.\",\n )\n\n # Allocate out\n if out is None:\n out = dense.new_empty(original_batch_dims_broadcasted + (m, n))\n\n # Short circuit if lhs is zero\n if bsr._nnz() == 0:\n return out.zero_()\n\n blocksize = bsr.values().shape[-2:]\n\n # NOTE: out is contiguous, so prepare_inputs will create a view.\n # out gets modified in-place, so we store a backup copy.\n out_backup = out\n\n # prepare inputs by reshaping them to be kernel-compatible.\n crow_indices, col_indices, values, dense, out = prepare_inputs(bsr, dense, out)\n\n # \"Blockify\" the row dimension of dense with blocksize[1]\n # since dense is on the rhs of matmul\n dense = tile_to_blocksize(dense, blocksize[::-1])\n # \"Blockify\" the row dimension of out with blocksize[0]\n # which is inherited from the bsr input.\n # NOTE: tile_to_blocksize will create a view.\n # NOTE: out.blocksize[-1] == dense.blocksize[-1],\n # so it could be any value in [1, dense.shape[-1]).\n # We need to probably use the largest possible blocksize\n # so that it fits into SRAM.\n out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))\n\n # Launch kernel\n _run_dense_rowspace_kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)\n\n return out_backup\n\n @triton.jit\n def _bsr_softmax_kernel(\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n values_ptr,\n values_batch_stride,\n values_row_block_stride,\n values_nnz_col_block_stride,\n row_block, col_block,\n MAX_ROW_NNZ: tl.constexpr,\n TILE: tl.constexpr\n ):\n batch_pid = tl.program_id(axis=2)\n row_block_offset_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_arange = tl.arange(0, TILE)\n mask = row_arange < row_nnz * col_block\n\n curr_row_values_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_row_block_stride * row_block_offset_pid\n + nnz_offset * col_block\n )\n\n # find max in the row\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n max_row_value = tl.max(row_tile, axis=0)\n for _ in range(TILE, MAX_ROW_NNZ, TILE):\n row_arange += TILE\n mask = row_arange < row_nnz * col_block\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n curr_max_row_value = tl.max(row_tile, axis=0)\n max_row_value = tl.where(max_row_value > curr_max_row_value, max_row_value, curr_max_row_value)\n\n # find denominator for stable softmax\n num = tl.exp(row_tile - max_row_value)\n denom = tl.sum(num, axis=0)\n for _ in range(TILE, MAX_ROW_NNZ, TILE):\n row_arange -= TILE\n mask = row_arange < row_nnz * col_block\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n num = tl.exp(row_tile - max_row_value)\n denom += tl.sum(num, axis=0)\n\n # populate output\n tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask)\n for _ in range(TILE, MAX_ROW_NNZ, TILE):\n row_arange += TILE\n mask = row_arange < row_nnz * col_block\n row_tile = tl.load(curr_row_values_ptrs + row_arange, mask=mask, other=-float('inf')).to(tl.float32)\n num = tl.exp(row_tile - max_row_value)\n tl.store(curr_row_values_ptrs + row_arange, (num / denom).to(values_ptr.dtype.element_ty), mask=mask)\n\n\n def bsr_softmax(input, max_row_nnz=None):\n f_name = \"bsr_softmax\"\n\n check_bsr_layout(f_name, input)\n check_dtype(f_name, input, input.dtype)\n\n if input._nnz() == 0 or input.numel() == 0:\n return input.clone()\n\n m, n = input.shape[-2:]\n nnz = input._nnz()\n row_block, col_block = input.values().shape[-2:]\n\n if max_row_nnz is None:\n max_row_nnz = triton.next_power_of_2(n)\n else:\n max_row_nnz = triton.next_power_of_2(max_row_nnz)\n\n crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2)\n # reshape values from\n # (b1, ..., bn, nnz, row_block, col_block) to\n # (b1 * ... * bn, row_block, nnz * col_block).\n # This simplifies batch dim manipulation and unlocks\n # the possibility to access all nnzs in any given row.\n if input.values().transpose(-3, -2).is_contiguous():\n # Need to clone to avoid `contiguous` returning a view.\n values = input.values().clone()\n else:\n values = input.values()\n values = values.transpose(-3, -2).contiguous().unsqueeze(0).flatten(0, -4).reshape(-1, row_block, nnz * col_block)\n full_grid = (values.shape[0], row_block, m // row_block)\n grid_blocks = None\n tensor_dims_map = {\n # We span nnz number of blocks, not nnz + 1,\n # hence crow_indices[..., :-1]\n crow_indices[..., :-1]: (0, None, -1),\n values: (0, None, None),\n }\n\n def kernel(grid, *sliced_tensors):\n _bsr_softmax_kernel[grid](\n *ptr_stride_extractor(*sliced_tensors),\n row_block, col_block,\n max_row_nnz,\n # Triton's max numel is bounded by 2 ** 17.\n min(2 ** 17, max_row_nnz)\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n values = values.reshape(-1, row_block, nnz, col_block).transpose(-3, -2).reshape(*input.values().shape)\n\n return torch.sparse_compressed_tensor(\n input.crow_indices().clone(),\n input.col_indices().clone(),\n values,\n size=input.shape,\n layout=input.layout\n )\n\n def _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None\n ):\n f_name = \"_scaled_dot_product_attention\"\n check(\n not is_causal,\n f\"{f_name}(): is_causal == True is not supported.\"\n )\n check(\n attn_mask is not None,\n f\"{f_name}(): attn_mask == None is not supported.\"\n )\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\"\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\"\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\nelse:\n bsr_softmax = None # type: ignore[assignment]\n bsr_dense_mm = None # type: ignore[assignment]\n sampled_addmm = None # type: ignore[assignment]\n _scaled_dot_product_attention = None # type: ignore[assignment]\n", - "description_1": "Use triton language to implement three kernels: (1) A sampled matrix multiplication kernel, '_sampled_addmm_kernel', that takes 32 arguments, including scalar factors, block sizes, matrix pointers, and configuration constants. (2) A BSR strided dense multiplication kernel, '_bsr_strided_dense_rowspace_kernel', which multiplies sparse BSR matrices with dense matrices using 30 arguments that specify matrix data and layout details. (3) A BSR softmax kernel, '_bsr_softmax_kernel', that computes the softmax of sparse matrices stored in block row format, using 11 arguments including pointers to data and block size constants. Three wrapper functions, 'sampled_addmm', 'bsr_dense_mm', and 'bsr_softmax', each use these kernels to perform high-level matrix operations by managing input validation, data preparation, and kernel launch configurations with up to 12 parameters for managing tensor data and execution options.", - "description_2": "Use triton language to create an attention mechanism with sparse matrices. Implement three primary kernels: a sampled matrix multiplication, a BSR strided multiplication, and a BSR softmax. Provide high-level functions to manage input validation, data preparation, and kernel execution for advanced sparse-dense computations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc,\n stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta):\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n #------------#\n #- Prologue -#\n #------------#\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n # output offset\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n # dense input offset\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8) # compiler hint\n offpb = 0\n # sparse input offset\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8) # compiler hint\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n # output offset\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n # dense input offset\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8) # compiler hint\n offpa = 0\n # sparse input offset\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8) # compiler hint\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n\n # initialize a, b pointers\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n\n ## ---------------- ##\n ## Inner Loop ##\n ## ---------------- ##\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n # pre-fetch\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n # write-back directly\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n # accumulate partial results using spin-locks\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n", - "description_1": "Use triton language to create a kernel function with 22 positional arguments and variadic keyword arguments to perform advanced block sparse matrix multiplication with consideration for different blocking schemes and sparse/dense/dense modes, supporting locking and conditional loading. Specifically, it utilizes 3 program ids for parallel execution over different dimensions, and complex control flow to handle different data organizations in memory.", - "description_2": "Use triton language to perform block sparse matrix multiplication with advanced prologue, inner loop, and write-back stages, supporting multiple block configurations and sparse/dense combinations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\ndef num_warps(n):\n if n < 512:\n return 4\n if n < 2048:\n return 8\n return 16\n\n@triton.jit\ndef _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm,\n stride_zattnm, **meta):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n columnid = tl.load(LUT + offset + rbmn * 4 + 1)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # apply scale\n if meta['APPLY_SCALE']:\n x = x * scale\n # apply RPE\n if meta['APPLY_RPE']:\n prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn\n rpe = tl.load(prpe, mask=check, other=0)\n x = x + rpe\n # apply key-padding mask\n if meta['APPLY_KP_MASK']:\n pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn\n kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))\n if meta['KP_MASK_MUL']:\n kp_m = tl.where(kp_m == 0, -float('inf'), 0.)\n x = x + kp_m\n # apply attention mask\n if meta['APPLY_ATTN_MASK']:\n pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn\n attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))\n if meta['ATTN_MASK_MUL']:\n attn_m = tl.where(attn_m == 0, -float('inf'), 0.)\n x = x + attn_m\n # computation\n x = tl.softmax(x)\n tl.store(px, x, mask=check)\n\n@triton.jit\ndef _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n # compute fused softmax backward\n x = tl.load(X, mask=check, other=0)\n dx = tl.load(DX, mask=check, other=0)\n x = x.to(tl.float32)\n dx = dx.to(tl.float32)\n y = x * (dx - tl.sum(x * dx, 0)) * scale\n tl.store(DX, y, mask=check)\n\nclass _sparse_softmax(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,\n num_blocks, maxlut, bench, time):\n\n apply_scale = False if scale == 1.0 else True\n\n # handle None rpe\n if rpe is None:\n apply_rpe = False\n stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0\n rpe = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_rpe = True\n stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)\n\n # handle None key_padding_mask\n if key_padding_mask is None:\n apply_kp_mask = False\n stride_zkpm = 0\n key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_kp_mask = True\n stride_zkpm = key_padding_mask.stride(0)\n\n # handle None attention_mask\n if attn_mask is None:\n apply_attn_mask = False\n stride_zattnm = 0\n attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_attn_mask = True\n stride_zattnm = attn_mask.stride(0)\n\n # run kernel\n M = x.shape[0]\n meta = {\n 'BLOCK': block,\n 'APPLY_SCALE': apply_scale,\n 'APPLY_RPE': apply_rpe,\n 'APPLY_KP_MASK': apply_kp_mask,\n 'APPLY_ATTN_MASK': apply_attn_mask,\n 'KP_MASK_MUL': kp_mask_mode == 'mul',\n 'ATTN_MASK_MUL': attn_mask_mode == 'mul',\n }\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\\\n stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)\n\n # save to context\n ctx.mark_dirty(x)\n ctx.save_for_backward(x, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.apply_scale = apply_scale\n ctx.apply_rpe = apply_rpe\n ctx.apply_kp_mask = apply_kp_mask\n ctx.apply_attn_mask = apply_attn_mask\n ctx.kp_mask_mode = kp_mask_mode\n ctx.attn_mask_mode = attn_mask_mode\n return x\n\n @staticmethod\n def backward(ctx, dx):\n\n # retrieve from context\n x, lut = ctx.saved_tensors\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\nclass Softmax:\n\n def sparse_softmax(*args, **kwargs):\n return _sparse_softmax.apply(*args, **kwargs)\n\n def __init__(self, layout, block, bench=False):\n\n self.num_blocks = layout.sum().item()\n self.spdims = layout.shape\n self.layout = layout\n self.block = block\n self.bench = bench\n self.lut_cache = dict()\n\n def make_lut(self, device):\n\n key = (device, )\n if key not in self.lut_cache:\n self.lut_cache[key] = _sparse_softmax.make_lut(self.layout, self.block, device)\n return self.lut_cache[key]\n\n def __call__(self,\n x,\n scale=1.,\n rpe=None,\n key_padding_mask=None,\n attn_mask=None,\n key_padding_mask_mode='add',\n attn_mask_mode='add'):\n\n time_y = [None]\n if rpe is not None and rpe.dtype != x.dtype:\n raise ValueError('relative position embedding must be %s' % x.dtype)\n if attn_mask is not None and attn_mask.dtype != x.dtype:\n raise ValueError('Attention mask must be %s' % x.dtype)\n if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:\n raise ValueError('Key padding mask must be %s' % x.dtype)\n lut, maxlut = self.make_lut(x.device)\n x = Softmax.sparse_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,\n self.spdims, self.block, lut, self.num_blocks, maxlut, self.bench, time_y)\n self.time_y = time_y[0]\n return x\n", - "description_1": "Use triton language to implement block-sparse softmax with forward and backward kernels. The _forward kernel takes 13 parameters: X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, and applies scaling, relative position embedding, key-padding mask, and attention mask before computing softmax on a block-sparse matrix. The _backward kernel computes the backward pass for softmax with 7 parameters: X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, computing gradients in a block-sparse format.", - "description_2": "Use triton language to implement block-sparse softmax. The implementation includes both forward and backward passes, handling sparse data layout and optional scaling and masking operations, by processing data in blocks as defined by a lookup table.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom deepspeed.ops.transformer.inference.triton import score_4d_matmul, context_4d_matmul\n\n\n@triton.jit\ndef _flash_packed_kernel(\n QKV,\n mask,\n ADD_MASK: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n sm_scale,\n Out,\n stride_qz,\n stride_qn,\n stride_qm,\n stride_mz,\n stride_oz,\n stride_on,\n Z,\n H,\n N_CTX,\n P_SEQ,\n hidden_size,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n batch = off_hz // H\n head = off_hz % H\n\n q_offset = batch * stride_qz + head * BLOCK_DMODEL\n k_offset = q_offset + hidden_size\n v_offset = k_offset + hidden_size\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qn + offs_d[None, :]\n k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]\n v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]\n\n # mask\n off_mask = batch * stride_mz + offs_n[None, :]\n mask_ptrs = mask + off_mask\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # scale sm_scale by log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)\n q = (q * qk_scale).to(tl.float16)\n # loop over k, v and update accumulator\n lo = 0\n hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n k = tl.load(k_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)\n v = tl.load(v_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)\n\n if ADD_MASK:\n mask_val = tl.load(mask_ptrs)\n mask_ptrs += BLOCK_N\n qk = qk + mask_val.to(tl.float32)\n\n if IS_CAUSAL:\n qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16)\n qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf)\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc_scale = l_i * 0 + alpha # workaround some compiler bug\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.float16), v.to(tl.float16))\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n # write back l and m\n acc = acc / l_i[:, None]\n o_offset = batch * stride_oz + head * BLOCK_DMODEL\n out_ptrs = Out + o_offset + (offs_m[:, None] * stride_on + offs_d[None, :])\n tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX)\n\n\ndef _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True):\n heads = qkv.shape[-1] // 3 // head_size\n hidden_size = qkv.shape[-1] // 3\n\n BLOCK_M = 128\n BLOCK_N = 64 if head_size <= 64 else 32\n\n o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half)\n if mask is None:\n mask = torch.empty(0)\n add_mask = False\n\n grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1)\n num_stages = 4 if head_size <= 64 else 3\n num_warps = 4\n P_SEQ = 0\n\n _flash_packed_kernel[grid](qkv,\n mask,\n add_mask,\n causal,\n sm_scale,\n o,\n qkv.stride(0),\n qkv.stride(1),\n qkv.stride(2),\n mask.stride(1) if add_mask else 0,\n o.stride(0),\n o.stride(1),\n qkv.shape[0],\n heads,\n qkv.shape[1],\n P_SEQ,\n hidden_size,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL=head_size,\n num_warps=num_warps,\n num_stages=num_stages)\n\n return o\n\n\n###NULL!###\n", - "description_1": "Use triton language to implement a packed flash attention kernel that computes scaled dot-product attention on input queries, keys, and values with optional causal and masking mechanisms. The kernel loads QKV tensors and performs computation in blocks, optimizing the process with memory hierarchy and parallelism using custom grid and block configurations.", - "description_2": "Use triton language to implement a packed flash attention kernel that performs attention computations using QKV input tensors and scales by a specified factor. It handles both causal and masked attention modes, using optimized grid and block structures for parallel computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom deepspeed.accelerator import get_accelerator\n\n@triton.jit\ndef gelu_functor(x):\n # Using approximation introduces greater parity errors.\n # return tl.sigmoid(1.702 * x) * x\n return x * 0.5 * (1.0 + tl.math.erf(x / 1.41421356237))\n\n@triton.jit\ndef gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = gelu_functor(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef gelu(activations: torch.Tensor) -> torch.Tensor:\n assert activations.is_contiguous()\n assert get_accelerator().on_accelerator(activations)\n\n output = torch.empty_like(activations)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024)\n return output\n", - "description_1": "Use triton language to implement a GELU activation function. The `gelu_functor` takes one parameter `x` (a tensor element) and returns the GELU activation using the error function. The `gelu_kernel` takes four parameters: `x_ptr` (pointer to input tensor), `output_ptr` (pointer to output tensor), `n_elements` (number of elements in the tensor), and `BLOCK_SIZE` (block size for parallel execution). It computes the GELU activation for each block of the input tensor and stores the result in the output tensor. The `gelu` function is a wrapper that prepares the input tensor, sets up the grid for kernel execution, and calls the `gelu_kernel`.", - "description_2": "Use triton language to create a kernel for computing the GELU activation function on a tensor using parallel execution. Implement a functor for the GELU computation and a kernel to apply this functor across the tensor. Provide a wrapper function to handle tensor preparation and kernel invocation.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef layer_norm_kernel(\n Out,\n A,\n Weight,\n Bias,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.0)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0.0).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\n@triton.jit\ndef layer_norm_residual_kernel(\n Out,\n A,\n Residual,\n ln_input,\n Weight,\n Bias,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n Residual += row * stride\n ln_input += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = a + res\n tl.store(ln_input + cols, a, mask=cols < N)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.0)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\n@triton.jit\ndef layer_norm_residual_bias_kernel(\n Out,\n A,\n Residual,\n InputBias,\n ln_input,\n Weight,\n Bias,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n Residual += row * stride\n ln_input += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)\n b = tl.load(InputBias + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = a + b + res\n tl.store(ln_input + cols, a, mask=cols < N)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.0)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\ndef layer_norm(a, weight, bias, eps):\n assert a.is_contiguous()\n assert weight.is_contiguous()\n assert bias.is_contiguous()\n\n out = torch.empty_like(a)\n a_arg = a.view(-1, a.shape[-1])\n M, N = a_arg.shape\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n layer_norm_kernel[(M, )](\n out,\n a_arg,\n weight,\n bias,\n a_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return out\n\ndef layer_norm_residual(a, input_bias, residual, weight, bias, eps):\n assert a.is_contiguous()\n assert weight.is_contiguous()\n assert bias.is_contiguous()\n assert residual.is_contiguous()\n\n out = torch.empty_like(a)\n ln_input = torch.empty_like(a)\n a_arg = a.view(-1, a.shape[-1])\n residual = residual.view(-1, residual.shape[-1])\n M, N = a_arg.shape\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n if input_bias is None:\n layer_norm_residual_kernel[(M, )](\n out,\n a_arg,\n residual,\n ln_input,\n weight,\n bias,\n a_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n else:\n layer_norm_residual_bias_kernel[(M, )](\n out,\n a_arg,\n residual,\n input_bias,\n ln_input,\n weight,\n bias,\n a_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement three types of layer normalization kernels: simple layer normalization, layer normalization with residual, and layer normalization with residual and input bias. The kernels compute mean and variance for normalization, adjust weights and bias, and are configured with BLOCK_SIZE for optimal parallel execution.", - "description_2": "Use triton language to create kernels for layer normalization, including variants with residual and bias, optimizing parallelism with BLOCK_SIZE.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n K_block_ptr = tl.make_block_ptr(base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1))\n V_block_ptr = tl.make_block_ptr(base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0))\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.float16)\n lo = 0\n hi = N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.float16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n acc = acc / l_i[:, None]\n O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n tl.store(O_block_ptr, acc.to(tl.float16))\n\n\nclass triton_flash_attn(torch.nn.Module):\n\n def __init__(self, ):\n super(triton_flash_attn, self).__init__()\n\n def forward(self, q, k, v, sm_scale, block_128=True):\n BLOCK = 128 if block_128 else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n k.shape[0],\n k.shape[1],\n k.shape[2],\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=1,\n )\n return o\n", - "description_1": "Use triton language to implement a forward kernel for flash attention. The kernel takes 25 parameters: Q, K, V, sm_scale, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, and three constexpr parameters BLOCK_M, BLOCK_DMODEL, BLOCK_N. The kernel computes the attention output by iterating over blocks of the input matrices.", - "description_2": "Use triton language to create a PyTorch module 'triton_flash_attn' that uses the forward kernel to compute attention. The module's forward method takes 4 parameters: q, k, v, sm_scale, and an optional block_128. It sets up the grid and block size, and calls the kernel to compute the output.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom deepspeed.accelerator import get_accelerator\n\n@triton.jit\ndef residual_add_bias_kernel(\n hidden_state_ptr,\n residual_ptr,\n attn_output_ptr,\n hidden_state_size,\n attn_bias_ptr,\n final_bias_ptr,\n bias_size,\n output_ptr,\n mp_size: tl.constexpr,\n mlp_after_attn: tl.constexpr,\n pre_attn_norm: tl.constexpr,\n add_attn_bias: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n\n block_start = pid * BLOCK_SIZE\n\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < hidden_state_size\n\n bias_offsets = offsets % bias_size\n bias_mask = bias_offsets < bias_size\n\n tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask)\n tl_residual = tl.load(residual_ptr + offsets, mask=mask)\n tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask)\n tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask)\n tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask)\n\n if mlp_after_attn:\n if pre_attn_norm:\n output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size\n else:\n output = tl_hidden_state + tl_residual + tl_final_bias\n else:\n output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size\n if add_attn_bias:\n output += tl_attn_bias / mp_size\n\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor,\n attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,\n add_attn_bias: bool, pre_attn_norm: bool):\n # check that all tensors are on the same device\n assert get_accelerator().on_accelerator(hidden_state) \\\n and get_accelerator().on_accelerator(residual) \\\n and get_accelerator().on_accelerator(attn_output) \\\n and get_accelerator().on_accelerator(attn_bias) \\\n and get_accelerator().on_accelerator(final_bias)\n\n # check that all tensors have the same dtype\n assert hidden_state.dtype == residual.dtype == attn_output.dtype \\\n == attn_bias.dtype == final_bias.dtype\n\n # check that all tensors have the right shape\n assert hidden_state.shape == residual.shape == attn_output.shape\n assert attn_bias.shape == final_bias.shape\n assert attn_bias.shape[0] == hidden_state.shape[2]\n\n output = torch.empty_like(hidden_state)\n\n hidden_state_size = output.numel()\n bias_size = attn_bias.numel()\n\n grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), )\n\n residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\\\n attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \\\n add_attn_bias, \\\n BLOCK_SIZE=1024)\n\n return output\n", - "description_1": "Use triton language to implement a kernel that performs a residual addition with bias on tensors. The kernel takes 13 parameters: pointers to hidden state, residual, attention output, attention bias, final bias, and output tensors, sizes of hidden state and bias, and several compile-time constants for configuration. The kernel computes the output by conditionally adding the residual, attention output, and biases to the hidden state based on the provided flags. The function 'residual_add_bias' prepares the input tensors, checks their properties, and launches the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel for residual addition with bias, and a function to set up and launch this kernel with given tensor inputs and configuration flags.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr):\n row_idx = tl.program_id(0)\n row_start_ptr = input_ptr + row_idx * stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\n@triton.jit\ndef masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n row_idx = tl.program_id(0)\n row_start_ptr = input_ptr + row_idx * stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)\n mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)\n row_minus_max = row - tl.max(row, axis=0)\n row_minus_max = row_minus_max + mask\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:\n assert input.is_contiguous()\n assert (dim == -1) or (dim == len(input.shape) - 1), \"Only dim=-1 is supported\"\n\n use_mask = False if mask is None else True\n input_arg = input.view(-1, input.shape[-1])\n n_rows, n_cols = input_arg.shape\n BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2)\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n output = torch.empty_like(input)\n if use_mask:\n assert mask.is_contiguous()\n mask = mask.view(-1, mask.shape[-1])\n mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0\n masked_softmax_kernel[(n_rows, )](\n output,\n input,\n input_arg.stride(0),\n mask,\n mask_stride,\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n else:\n softmax_kernel[(n_rows, )](\n output,\n input,\n input_arg.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return output\n", - "description_1": "Use triton language to implement a softmax operation with optional masking. The softmax_kernel function takes 5 parameters: output_ptr (output tensor pointer), input_ptr (input tensor pointer), stride (stride of the input tensor), n_cols (number of columns in the input tensor), and BLOCK_SIZE (block size for parallel execution). The masked_softmax_kernel function takes 7 parameters: output_ptr, input_ptr, stride, mask_ptr (mask tensor pointer), mask_stride (stride of the mask tensor), n_cols, and BLOCK_SIZE. The softmax function is a wrapper that prepares the input and mask tensors, determines the block size and number of warps, and calls the appropriate kernel function.", - "description_2": "Use triton language to create a softmax operation with optional mask support, utilizing parallel execution with configurable block size and warp count.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .gelu import gelu_functor\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 256,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=3, num_warps=8),\n triton.Config({\n 'BLOCK_M': 256,\n 'BLOCK_N': 128,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=3, num_warps=8),\n triton.Config({\n 'BLOCK_M': 256,\n 'BLOCK_N': 64,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 64,\n 'BLOCK_N': 256,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 128,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 64,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 64,\n 'BLOCK_N': 128,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 32,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 64,\n 'BLOCK_N': 32,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=5, num_warps=2),\n ],\n key=['CACHE_M', 'CACHE_N', 'CACHE_K'],\n prune_configs_by={\n 'early_config_prune': _fp16_matmul_prune_config,\n 'perf_model': None,\n 'top_k': AUTOTUNE_TOP_K\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _fp_matmul(\n A,\n B,\n C,\n M,\n N,\n K,\n bias,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n CACHE_M,\n CACHE_N,\n CACHE_K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr,\n BIAS_ADD: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(K, 0, -BLOCK_K * SPLIT_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.)\n b = tl.load(B, mask=rk[:, None] < k, other=0.)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n # bias addition\n if BIAS_ADD:\n bias_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n bias_ptr = bias + bias_offset\n b = tl.load(bias_ptr, mask=bias_offset < N)\n acc = acc + b[None, :]\n # activation\n if ACTIVATION == \"relu\":\n acc = tl.where(acc >= 0, acc, 0)\n elif ACTIVATION == \"leaky_relu\":\n acc = tl.where(acc >= 0, acc, 0.01 * acc)\n elif ACTIVATION == \"gelu\":\n #acc = tl.sigmoid(1.702 * acc) * acc\n acc = gelu_functor(acc)\n elif ACTIVATION == \"sigmoid\":\n acc = tl.sigmoid(acc) # sigmoid\n acc = acc.to(C.dtype.element_ty)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8\n },\n num_stages=1, # this is mainly for unit test, to minimize the share memory usage\n num_warps=8),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=['CACHE_M', 'CACHE_N', 'CACHE_K'],\n prune_configs_by={\n 'early_config_prune': matmul_4d_prune_config,\n 'perf_model': None,\n 'top_k': AUTOTUNE_TOP_K\n },\n)\n@triton.jit\ndef matmul_4d_kernel(\n # Pointers to matrices\n a_ptr,\n b_ptr,\n c_ptr,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_M,\n CACHE_N,\n CACHE_K,\n stride_ab,\n stride_ah,\n stride_am,\n stride_ak,\n stride_bb,\n stride_bh,\n stride_bk,\n stride_bn,\n stride_cb,\n stride_ch,\n stride_cm,\n stride_cn,\n scale,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MASK: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n head = tl.program_id(axis=1)\n batch = tl.program_id(axis=2)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n if MASK:\n if (pid_m + 1) * BLOCK_SIZE_M - 1 < pid_n * BLOCK_SIZE_N:\n c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.dtype.element_ty) - float(\"inf\")\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +\n stride_cn * offs_cn[None, :])\n tl.store(c_ptrs, c)\n return\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +\n (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))\n b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +\n (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)\n b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)\n a = tl.load(a_ptrs, mask=a_mask, other=0.)\n b = tl.load(b_ptrs, mask=b_mask, other=0.)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n c = accumulator.to(c_ptr.dtype.element_ty)\n if scale > 0:\n c = c * scale.to(c_ptr.dtype.element_ty)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if MASK:\n c += tl.where(offs_cm[:, None] >= offs_cn[None, :], 0, float(\"-inf\"))\n c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +\n stride_cn * offs_cn[None, :])\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n", - "description_1": "Use triton language to implement two matrix multiplication kernels. The first kernel, _fp_matmul, takes 22 parameters: A, B, C (matrices), M, N, K (dimensions), bias, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn (strides), CACHE_M, CACHE_N, CACHE_K (cache sizes), BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, EVEN_K, ACC_TYPE, BIAS_ADD, ACTIVATION (meta-parameters). It performs matrix multiplication with optional bias addition and activation functions. The second kernel, matmul_4d_kernel, takes 23 parameters: a_ptr, b_ptr, c_ptr (pointers to matrices), M, N, K (dimensions), CACHE_M, CACHE_N, CACHE_K (cache sizes), stride_ab, stride_ah, stride_am, stride_ak, stride_bb, stride_bh, stride_bk, stride_bn, stride_cb, stride_ch, stride_cm, stride_cn (strides), scale, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, MASK (meta-parameters). It computes the matrix multiplication C = A x B with optional scaling and masking.", - "description_2": "Use triton language to create two matrix multiplication kernels with configurable block sizes and optional features like bias addition, activation, scaling, and masking.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test that we can convert uniform to exponential without div by 0.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel that converts uniform random numbers to exponential random numbers. The kernel '_uniform_to_exponential_kernel' takes three parameters: 'input' (a pointer to the input tensor), 'output' (a pointer to the output tensor), and 'n' (a compile-time constant representing the number of elements to process). The kernel uses Triton's parallel programming model to load elements from the input tensor, apply the '_uniform_to_exponential' transformation, and store the results in the output tensor. The function 'test_uniform_to_exponential' tests this kernel by creating a CUDA tensor with specific values, invoking the kernel, and asserting that the output is finite and greater than zero.", - "description_2": "Use triton language to create a kernel that transforms uniform random numbers to exponential random numbers and test it using CUDA tensors.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n # # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # attn_bias[]\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n\n # cur_batch_seq_len: the length of prompts\n # cur_batch_ctx_len: the length of prefix\n # cur_batch_in_all_start_index: the start id of the dim=0\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n # # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = 0\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n # load alibi\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n # -- update output accumulator --\n # scale p\n # scale acc\n acc_scale = alpha\n # acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n # init alibi\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = cur_batch_ctx_len\n # # init debugger\n # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc\n # offset_db_k = tl.arange(0, BLOCK_N)\n # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, allow_tf32=False)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n # load alibi\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n # -- update output accumulator --\n # scale p\n # scale acc\n acc_scale = alpha\n # acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n acc = acc / l_i[:, None]\n\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n alibi_slopes=None):\n\n cap = torch.cuda.get_device_capability()\n BLOCK = 128 if cap[0] >= 8 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 8 if Lk <= 64 else 8\n if alibi_slopes is not None:\n _fwd_kernel_alibi[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n alibi_slopes,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4\n ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3), #[num_blocks, num_kv_heads, head_size, block_size]\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3), #[num_blocks, num_kv_heads, head_size, block_size]\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement multiple forward kernels for context attention. The kernels use triton's JIT compilation for parallel execution on CUDA devices. The _fwd_kernel function accepts 45 parameters, performing scaled dot-product attention with input tensors Q, K, and V, handling caching with K_cache and V_cache, and adjusting for sequence lengths, block sizes, and strides. It calculates attention scores, updates, and writes the results to Out tensor. The function allows configuration of constants BLOCK_M, BLOCK_DMODEL, and BLOCK_N. _fwd_kernel_alibi enhances this operation by applying alibi biases to attention calculations, making use of Alibi_slopes among 47 input parameters. Finally, the context_attention_fwd wrapper function, decorated with torch's inference_mode, orchestrates these kernel launches based on input arguments and CUDA capabilities, selecting the appropriate kernel and configuring its execution grid.", - "description_2": "Use triton language to implement forward kernels for context attention computation with dot-product and alibi biasing, compiled for parallel execution on CUDA, and utilize a wrapper function to manage kernel selection and execution configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n N,\n K,\n EM,\n num_valid_tokens,\n stride_am,\n stride_ak,\n stride_be,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n compute_type: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(compute_type)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any]) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel takes pointers to input matrices, dimensions, and meta-parameters to perform block matrix multiplication. It computes the product of a token matrix and an expert matrix, using sorted token IDs and expert IDs to determine the correct expert for each token. The kernel supports optional multiplication by routed weights and writes the result back to an output matrix. The invoke function sets up the grid and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to implement a fused MoE kernel for block matrix multiplication with optional routed weight multiplication, and provide a function to invoke this kernel with grid setup.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n \"\"\"Similar to torch.rand, but allows for seeds to be set per row.\"\"\"\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output\n tensor. The random numbers in a row generated using the seed for that row.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a seeded uniform random number generator. The function `seeded_uniform` takes parameters for tensor size, seeds, output tensor, data type, device, and pin memory. It calculates the necessary strides and block sizes, then calls the Triton kernel `_seeded_uniform_triton`. The kernel generates random float32 numbers in [0, 1) for each element in the output tensor using per-row seeds. It handles up to 3D tensors and uses the Philox PRNG to generate random numbers efficiently.", - "description_2": "Use triton language to create a random number generator that produces float32 numbers in [0, 1) for each element in a tensor, using per-row seeds and the Philox PRNG for efficiency.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS = 1e-6\nMAX_TRITON_N_COLS = 131072\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n # The rows are independent, so we parallelize across those\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n\n # Load the row index from DRAM\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n\n # The stride represents how much we need to increase the\n # pointer to advance 1 row\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n\n # The block size is the next power of two greater than n_cols,\n # so we can fit each row in a single block\n col_offsets = tl.arange(0, block_size)\n\n # Load the row into SRAM, using a mask since block_size may be > than n_cols\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n # clamp sampled token to n_cols - 1\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n # Write back output to DRAM\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n\n if modify_greedy_probs:\n if not uses_random_sampling:\n # Set the probability of the sampled token to 1, all other\n # tokens to zero.\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to implement a token sampling operator with two kernels: one to convert uniform noise into exponential noise, and another to sample tokens from given probabilities. The main sampling kernel takes 17 arguments: sample indices, output samples, output logprobs, output modified probs, probabilities, logprobabilities, seeds, uniform noise, output row stride, probs row stride, uniform noise row stride, uniform noise best stride, number of samples, number of columns, number of best tokens, block size, modify greedy probs flag, save logprobs flag, and save modified probs flag. The first kernel (_uniform_to_exponential) converts uniform noise to exponential noise using logarithmic transformation to assist in sampling. The second kernel (_sample_triton) performs the sampling using the Gumbel-max trick with options to modify probabilities for speculative decoding and save various results.", - "description_2": "Use triton language to implement a token sampling operator by creating a main sampling kernel that performs Gumbel-max trick sampling and auxiliary kernel to convert uniform noise into exponential noise.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n# Triton kernel for forward pass\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Kernel implementation\n ...\n\n# Triton kernel for backward pass\n@triton.jit\ndef _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr):\n # Kernel implementation\n ...\n\n@triton.jit\ndef _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):\n # Kernel implementation\n ...\n\n@triton.jit\ndef _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Kernel implementation\n ...\n\n@triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'])\n@triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})\n@triton.jit\ndef _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Kernel implementation\n ...\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Function implementation\n ...\n\ndef _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n # Function implementation\n ...\n\nclass FlashAttnQKVPackedFunc(torch.autograd.Function):\n # Autograd function implementation\n ...\n\nflash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply\n\nclass FlashAttnKVPackedFunc(torch.autograd.Function):\n # Autograd function implementation\n ...\n\nflash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply\n\nclass FlashAttnFunc(torch.autograd.Function):\n # Autograd function implementation\n ...\n\nflash_attn_func = FlashAttnFunc.apply\n", - "description_1": "Use triton language to implement FlashAttention forward and backward kernels for processing queries, keys, and values with optional bias and causal masking. Implement forward kernels (`_fwd_kernel`) with parameters for query (Q), key (K), value (V), Bias, and output tensor (Out). Backward kernels (`_bwd_kernel`, `_bwd_preprocess_do_o_dot`, etc.) compute gradients with respect to inputs by processing deltas (DO), LSE, and other tensors. Includes classes `FlashAttnQKVPackedFunc`, `FlashAttnKVPackedFunc`, and `FlashAttnFunc` for torch.autograd.Function applications.", - "description_2": "Implement Triton kernels to perform FlashAttention operations with forward and backward passes, incorporating inputs such as queries, keys, values, biases, and causal flags.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\nimport math\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, Lse, TMP, softmax_scale,\n stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k,\n seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation\n ...\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out, DO, Delta, stride_ob, stride_oh, stride_om,\n stride_dob, stride_doh, stride_dom, nheads, seqlen_q,\n seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n):\n # Triton kernel implementation\n ...\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D,\n softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm,\n stride_dom, stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr,\n BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation\n ...\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False},\n num_warps=8,\n num_stages=1,\n pre_hook=init_to_zero(\"DQ\"),\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True},\n num_warps=8,\n num_stages=1,\n pre_hook=init_to_zero(\"DQ\"),\n ),\n ],\n key=[\n \"CACHE_KEY_SEQLEN_Q\", \"CACHE_KEY_SEQLEN_K\", \"BIAS_TYPE\",\n \"IS_CAUSAL\", \"BLOCK_HEADDIM\",\n ],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale,\n stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm,\n stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm,\n stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation\n ...\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Function description\n ...\n\ndef _flash_attn_backward(\n do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None\n):\n # Function description\n ...\n", - "description_1": "Use triton language to implement forward and backward pass for FlashAttention kernels. The _fwd_kernel function has 35 parameters including input tensors, strides, dimensions, constants, and configurations. It computes scaled dot-product attention with optional bias and causal masking. The _bwd_preprocess_do_o_dot function has 12 parameters and computes an intermediate delta for backward pass. The _bwd_kernel_one_col_block has 36 parameters and computes gradients with respect to Q, K, V, and optional bias. Finally, the _bwd_kernel orchestrates the backward pass with 42 parameters. Additionally, _flash_attn_forward wraps the forward kernel and configures it for specific input conditions while _flash_attn_backward wraps backward kernels for computing gradients.", - "description_2": "Use triton language to implement FlashAttention forward kernel with support for causal and non-causal attention, handling up to 128 head dimensions. Use triton to implement backward kernel to compute gradients for Q, K, V using pre-processed intermediate results from forward pass.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\ndef sum_row_blocked(A: torch.Tensor) -> torch.Tensor:\n M, N = A.shape\n outputs = torch.empty((M,), dtype=A.dtype, device=A.device)\n\n dynamic_launch_grid = lambda params: (triton.cdiv(M, params[\"BLOCK_M\"]), )\n sum_row_blocked_kernel[dynamic_launch_grid](\n A_ptr=A, outputs_ptr=outputs,\n M=M, N=N,\n A_strides_x=A.stride(0), A_strides_y=A.stride(1),\n BLOCK_M=2,\n )\n\n return outputs\n\n@triton.jit\ndef sum_row_blocked_kernel(\n A_ptr, outputs_ptr,\n M, N,\n BLOCK_M,\n A_strides_x, A_strides_y,\n):\n program_id = tl.program_id(axis=0)\n input_block_ptr = tl.make_block_ptr(\n base=A_ptr,\n shape=(M, N),\n strides=(A_strides_x, A_strides_y),\n offsets=(program_id * BLOCK_M, 0),\n block_shape=(BLOCK_M, N),\n order=(1, 0),\n )\n", - "description_1": "Use triton language to create a kernel 'sum_row_blocked_kernel' that processes a tensor by dividing it into blocks of rows, specified by 'BLOCK_M'. The kernel takes pointers to input tensor 'A_ptr' and output 'outputs_ptr', along with their dimensions 'M' and 'N', and strides 'A_strides_x', 'A_strides_y'. The launching function 'sum_row_blocked' sets up the grid using the dynamic launch grid calculation.", - "description_2": "Use triton language to implement a row-wise block processing kernel that divides the tensor into row blocks and assigns a block to each program, using dynamic grid launch.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_aligned(\n Q, K, V, B0, sm_scale,\n Out,\n stride_qh, stride_qm, stride_qk,\n stride_kh, stride_kn, stride_kk,\n stride_vh, stride_vk, stride_vn,\n stride_oh, stride_om, stride_on,\n stride_b0h, stride_b0m,\n Z,\n H,\n N_CTX,\n P_SEQ,\n OUT_DTYPE: tl.constexpr,\n BIAS_LAST_SIZE: tl.constexpr,\n B0_NUMEL: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for computing attention with alignment considerations\n\ndef _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,\n BLOCK_M,\n BLOCK_N,\n num_warps,\n num_stages):\n # Function to prepare and launch the Triton kernel on the device\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]\n assert P_SEQ == 0\n _fwd_kernel_aligned[grid](\n q, k, v,\n rel_h_w,\n sm_scale,\n o,\n q.stride(1), q.stride(2), q.stride(3),\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n o.stride(1), o.stride(2), o.stride(3),\n rel_h_w.stride(1), rel_h_w.stride(2),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n P_SEQ,\n OUT_DTYPE=tl.float16 if q.dtype == torch.float16 else tl.bfloat16,\n BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2),\n B0_NUMEL=rel_h_w.size(-1),\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL=q.shape[-1],\n num_warps=num_warps,\n num_stages=num_stages)\n\n@torch.library.impl(lib, \"custom_flash_aligned\", \"CUDA\")\ndef _attention_rel_h_rel_w_kernel_aligned(q, k, v, rel_h_w, sm_scale):\n # Custom CUDA implementation for Flash Attention using Triton\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n o = torch.empty_like(q, memory_format=torch.contiguous_format)\n\n global BEST_CONFIGS\n if BEST_CONFIGS is None:\n BEST_CONFIGS = _load_best_configs()\n if BEST_CONFIGS is None:\n BEST_CONFIGS = {}\n key = _create_best_configs_key(q, k, v, rel_h_w, o)\n if key not in BEST_CONFIGS:\n import functools\n import itertools\n configs = []\n for (BLOCK_M, BLOCK_N, num_warps) in itertools.product([64, 128], [64, 128], [1, 2, 4, 8]):\n for num_stages in range(1, num_warps + 1):\n configs.append((BLOCK_M, BLOCK_N, num_warps, num_stages))\n best, best_config = _autotune(configs, functools.partial(_attention_rel_h_rel_w_kernel_aligned_device,\n q, k, v, rel_h_w, sm_scale, o))\n BEST_CONFIGS[key] = best_config\n _save_best_configs(BEST_CONFIGS)\n best_config = BEST_CONFIGS[key]\n if best_config is None:\n return torch.tensor([])\n\n _attention_rel_h_rel_w_kernel_aligned_device(q,\n k,\n v,\n rel_h_w,\n sm_scale,\n o,\n best_config[0],\n best_config[1],\n best_config[2],\n best_config[3])\n\n return o\n\n", - "description_1": "Use triton language to implement a forward kernel for aligned attention computation with parameters for query, key, value tensors, bias, and output. The kernel accounts for block sizes and strides to efficiently compute scaled dot-product attention with bias adjustments.", - "description_2": "Use triton language to compute aligned attention on CUDA devices, leveraging custom block configurations to optimize for specific device architectures and tensor shapes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef update_fn_kernel(\n p_ptr,\n grad_ptr,\n exp_avg_ptr,\n lr,\n wd,\n beta1,\n beta2,\n n_elements,\n BLOCK_SIZE, # tl.constexpr\n):\n pid = tl.program_id(axis = 0)\n\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n\n mask = offsets < n_elements\n\n # offsetted pointers\n offset_p_ptr = p_ptr + offsets\n offset_grad_ptr = grad_ptr + offsets\n offset_exp_avg_ptr = exp_avg_ptr + offsets\n\n # load\n p = tl.load(offset_p_ptr, mask = mask)\n grad = tl.load(offset_grad_ptr, mask = mask)\n exp_avg = tl.load(offset_exp_avg_ptr, mask = mask)\n\n # stepweight decay\n p = p * (1 - lr * wd)\n\n # diff between momentum running average and grad\n diff = exp_avg - grad\n\n # weight update\n update = diff * beta1 + grad\n\n # torch.sign\n can_update = update != 0\n update_sign = tl.where(update > 0, -lr, lr)\n\n p = p + update_sign * can_update\n\n # decay the momentum running average coefficient\n exp_avg = diff * beta2 + grad\n\n # store new params and momentum running average coefficient\n tl.store(offset_p_ptr, p, mask = mask)\n tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)\n\ndef update_fn_triton(\n p: torch.Tensor,\n grad: torch.Tensor,\n exp_avg: torch.Tensor,\n lr: float,\n wd: float,\n beta1: float,\n beta2: float\n):\n assert all([t.is_cuda for t in (p, grad, exp_avg)])\n n_elements = p.numel()\n\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n\n update_fn_kernel[grid](\n p,\n grad,\n exp_avg,\n lr,\n wd,\n beta1,\n beta2,\n n_elements\n )\n", - "description_1": "Use triton language to create a kernel 'update_fn_kernel' and a wrapper function 'update_fn_triton'. The kernel performs the update of model parameters using gradient descent with momentum and weight decay, where each parameter is updated individually using input pointers 'p_ptr', 'grad_ptr', and 'exp_avg_ptr' for parameters, gradients, and exponential moving averages, respectively. The learning rate 'lr', weight decay 'wd', beta coefficients 'beta1', 'beta2', and the total number of elements 'n_elements' determine the update process. The 'BLOCK_SIZE' specifies the number of threads per block. The wrapper function 'update_fn_triton' prepares the grid and launches this kernel.", - "description_2": "Use triton language to implement a kernel that applies weight decay, momentum-based gradient updates, and manages the moving average of gradients for optimization tasks, wrapping it in a Python function to set up and execute the computation on GPU.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom functools import partial\nfrom torch.distributed._tensor.experimental import local_map\nfrom torch.distributed._tensor import Partial, Shard, Replicate\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\"],\n)\n@triton.jit\ndef _rms_norm_fwd_kernel(\n X,\n stride_x,\n Y,\n stride_y,\n W,\n Rstd,\n eps,\n M, # num rows\n N, # num cols\n block_N: tl.constexpr,\n):\n row = tl.program_id(0)\n cols = tl.arange(0, block_N)\n\n # Load input data and weights\n mask = cols < N\n x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n # Compute mean and variance\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n # Store the reciprocal standard deviation\n tl.store(Rstd + row, rstd)\n\n # Normalize and apply linear transformation\n x_hat = x * rstd\n y = x_hat * w\n\n # Write output\n tl.store(Y + row * stride_y + cols, y, mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\"],\n)\n@triton.jit\ndef _rms_norm_bwd_kernel_sm(\n X,\n stride_x,\n W,\n DY,\n stride_dy,\n DX,\n stride_dx,\n Rstd,\n DW,\n eps,\n M, # num rows\n N, # num cols\n rows_per_program,\n block_N: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, block_N)\n mask = cols < N\n\n # Load weights\n w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n # Accumulate gradients for weights\n dw = tl.zeros((block_N,), dtype=tl.float32)\n\n row_end = min(row_start + rows_per_program, M)\n for row in range(row_start, row_end):\n # Load input, output gradient, and reciprocal standard deviation\n x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)\n rstd = tl.load(Rstd + row)\n\n # Compute normalized input and gradients\n x_hat = x * rstd\n wdy = w * dy\n dw += dy * x_hat\n c1 = tl.sum(x_hat * wdy, axis=0) / N\n dx = (wdy - x_hat * c1) * rstd\n\n # Store input gradient\n tl.store(DX + row * stride_dx + cols, dx, mask=mask)\n\n # Store weight gradients\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n\n\nclass TritonFusedRMSNorm(torch.autograd.Function):\n @partial(\n local_map,\n out_placements=[Shard(1)],\n in_placements=(None, [Shard(1)], [Replicate()], None),\n )\n @staticmethod\n def forward(ctx, x, weight, eps):\n x_shape_start = x.shape\n\n # Flatten input\n x = x.view(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n if weight.stride(-1) != 1:\n weight = weight.contiguous()\n\n M, N = x.shape\n y = torch.empty_like(x)\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n\n max_size = 65536 // x.element_size()\n block_N = min(max_size, triton.next_power_of_2(N))\n\n if N > block_N:\n raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n grid = lambda meta: (M,)\n _rms_norm_fwd_kernel[grid](\n x,\n x.stride(0),\n y,\n y.stride(0),\n weight,\n rstd,\n eps,\n M,\n N,\n block_N,\n )\n\n ctx.eps = eps\n ctx.save_for_backward(x, weight, rstd)\n ctx.x_shape_start = x_shape_start\n\n y = y.reshape(x_shape_start)\n return y\n\n @partial(\n local_map,\n out_placements=([Shard(1)], [Partial()], None),\n in_placements=(None, [Shard(1)]),\n )\n @staticmethod\n def backward(ctx, dy):\n x, weight, rstd = ctx.saved_tensors\n eps = ctx.eps\n x_shape_start = ctx.x_shape_start\n\n # Flatten input and output gradients\n dy = dy.view(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n\n M, N = dy.shape\n dx = torch.empty_like(x)\n\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n\n max_size = 65536 // x.element_size()\n block_N = min(max_size, triton.next_power_of_2(N))\n rows_per_sm = math.ceil(M / sm_count)\n\n if N > block_N:\n raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n grid = lambda meta: (sm_count,)\n _rms_norm_bwd_kernel_sm[grid](\n x,\n x.stride(0),\n weight,\n dy,\n dy.stride(0),\n dx,\n dx.stride(0),\n rstd,\n _dw,\n eps,\n M,\n N,\n rows_per_sm,\n block_N,\n )\n dw = _dw.sum(0).to(weight.dtype)\n dx = dx.view(x_shape_start)\n return dx, dw, None\n\n\n# expose fusedRMSNorm as a function\ndef fused_rms_norm_fn(\n x,\n weight,\n eps=1e-6,\n):\n return TritonFusedRMSNorm.apply(\n x,\n weight,\n eps,\n )\n", - "description_1": "Use triton language to implement a fused RMS normalization operation with forward and backward kernels. The forward kernel (_rms_norm_fwd_kernel) takes 9 parameters: X (input tensor), stride_x (stride of X), Y (output tensor), stride_y (stride of Y), W (weights), Rstd (reciprocal standard deviation), eps (epsilon for numerical stability), M (number of rows), N (number of columns), and block_N (block size for columns). The backward kernel (_rms_norm_bwd_kernel_sm) takes 13 parameters: X (input tensor), stride_x (stride of X), W (weights), DY (gradient of output), stride_dy (stride of DY), DX (gradient of input), stride_dx (stride of DX), Rstd (reciprocal standard deviation), DW (gradient of weights), eps (epsilon for numerical stability), M (number of rows), N (number of columns), rows_per_program (number of rows per program), and block_N (block size for columns). The TritonFusedRMSNorm class provides the forward and backward methods for autograd, and the fused_rms_norm_fn function exposes the operation.", - "description_2": "Use triton language to create a fused RMS normalization operation with both forward and backward passes, utilizing triton.jit decorated kernels for efficient computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Forward kernel function for fused attention mechanism in Triton\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, TMP, L, M, Out, # Inputs and buffers\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr\n):\n # Kernel implementation\n\n# Backward preprocess kernel function for attention gradients in Triton\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L, NewDO, Delta, # Inputs and buffers\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr\n):\n # Kernel implementation\n\n# Backward kernel function for fused attention mechanism in Triton\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO, DQ, DK, DV, L, M, D, # Inputs and buffers\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX, num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr\n):\n # Kernel implementation\n\nclass _TritonFlashAttention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n # Forward pass\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, tmp, L, m, o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk,\n num_warps=num_warps, num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n # Backward pass\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](\n o, do, l, do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n num_warps = 8\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale, o, do_scaled, dq, dk, dv, l, m, delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2], ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL,\n num_warps=num_warps, num_stages=1,\n )\n return dq, dk, dv, None\n\ndef triton_flash_attention(q, k, v, sm_scale):\n \"\"\"\n Arguments:\n q: (batch, nheads, seq, headdim)\n k: (batch, nheads, seq, headdim)\n v: (batch, nheads, seq, headdim)\n sm_scale: float. The scaling of QK^T before applying softmax.\n Return:\n out: (batch, nheads, seq, headdim)\n \"\"\"\n if HAS_TRITON:\n return _TritonFlashAttention.apply(q, k, v, sm_scale)\n else:\n raise RuntimeError(\"Triton kernel requires CUDA 11.4+!\")\n", - "description_1": "Use triton language to implement a fused attention mechanism with a forward kernel, a backward preprocess kernel, and a backward kernel. These kernels handle matrix multiplications and apply softmax operations on attention scores. The kernels operate on tensors Q, K, V for queries, keys, and values, respectively, with a scaling factor applied to the QK^T product before softmax. The system handles the forward and backward propagation of gradients, computing outputs and gradients with respect to the inputs using Triton's JIT compilation.", - "description_2": "Use triton language to build a fused attention mechanism for efficient computation of attention scores and gradients, using JIT-compiled kernels for forward and backward passes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef triton_softmax(X_ptr, Y_ptr, M, N, BLOCK_SIZE):\n pid = tl.program_id(0) # Get the current block ID\n block_start = pid * BLOCK_SIZE # Calculate the start index of the current block\n offsets = tl.arange(0, BLOCK_SIZE) # Generate thread offsets for the current block\n idx = block_start + offsets # Calculate the index each thread is responsible for\n mask = idx < M # Create a mask to prevent out-of-bounds access\n \n # Load row data\n x_row = tl.load(X_ptr + idx*N, mask=mask) # Assume rows are stored contiguously\n x_max = tl.max(x_row)\n x_shifted = x_row - x_max\n exp_x = tl.exp(x_shifted)\n sum_x = tl.sum(exp_x)\n \n softmax_ret = exp_x / sum_x\n tl.store(Y_ptr + idx * N, softmax_ret, mask=mask)\n\ndef softmax_triton(X):\n M, N = X.shape\n Y = torch.empty_like(X[:,])\n \n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE']),)\n triton_softmax[grid](X, Y, M, N, BLOCK_SIZE=1024)\n return Y\n", - "description_1": "Use triton language to implement a softmax function. The kernel 'triton_softmax' takes 5 parameters: X_ptr (pointer to input tensor), Y_ptr (pointer to output tensor), M (number of rows), N (number of columns), and BLOCK_SIZE (size of each block). It calculates the softmax of each row in the input tensor. The function 'softmax_triton' is a wrapper that prepares the input and output tensors and launches the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a row-wise softmax operation for a 2D tensor, optimizing for GPU execution by dividing the work into blocks and threads.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef tanh(x):\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef gelu_new(x):\n pi = math.pi\n a = tl.math.sqrt(2.0 / pi)\n b = x + 0.044715 * x * x * x\n return 0.5 * x * (1.0 + tanh(a * b))\n\n@triton.jit\ndef dropout(x, p, seed, offset):\n random = tl.rand(seed, offset)\n return tl.where(random > p, x / (1 - p), 0.0)\n\n@triton.jit\ndef fused_linear_kernel(\n x_ptr, # Pointer to the first element of input data matrix\n w_ptr, # Pointer to the first element of weight matrix\n z_ptr, # Output result address\n M, N, K, # Matrix dimensions\n b_ptr=None,\n r_ptr=None,\n apply_gelu=False, # gelu activation and dropout\n dropout_prob=0.0,\n seed=1337,\n BLOCK_SIZE_M: tl.constexpr = 128, # Block size\n BLOCK_SIZE_N: tl.constexpr = 128, \n BLOCK_SIZE_K: tl.constexpr = 64,\n):\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n \n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]\n \n z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n x_k = tl.arange(0, BLOCK_SIZE_K)[None,:] + k\n x = tl.load(x_ptr + offs_m * K + x_k, mask=(offs_m < M) & (x_k < K), other=0.0)\n x = x.to(tl.float16)\n \n w_k = tl.arange(0, BLOCK_SIZE_K)[:, None] + k\n w = tl.load(w_ptr + w_k * N + offs_n, mask=(w_k < K) & (offs_n < N), other=0.0)\n w = w.to(tl.float16)\n \n z = tl.dot(x, w, acc=z)\n \n if b_ptr is not None:\n b = tl.load(b_ptr + offs_n, mask=(offs_n < N), other=0.0)\n z += b.to(tl.float32)\n \n z_offset = offs_m * N + offs_n\n z_mask = (offs_m < M) & (offs_n < N)\n \n if apply_gelu:\n z = gelu_new(z)\n if dropout_prob > 0.0:\n z = dropout(z, dropout_prob, seed, z_offset)\n\n if r_ptr is not None:\n r = tl.load(r_ptr + z_offset, mask=z_mask)\n z += r.to(tl.float32)\n\n tl.store(z_ptr + z_offset, z, mask=z_mask)\n\n@torch.no_grad()\ndef fused_ffn(\n x,\n weight,\n bias=None,\n residual=None,\n add_gelu=False,\n dropout_prob=0.0,\n):\n out_shape_0 = x.shape[:-1]\n x = x.view((-1, x.shape[-1]))\n M, K = x.shape\n N = weight.shape[1]\n \n z = torch.empty((M, N), device=x.device, dtype=x.dtype)\n \n assert x.shape[1] == weight.shape[0]\n assert x.is_contiguous()\n assert weight.is_contiguous()\n\n if bias is not None:\n assert bias.is_contiguous()\n assert weight.shape[1] == bias.shape[0]\n if residual is not None:\n residual = residual.view(z.shape)\n assert residual.is_contiguous()\n \n BLOCK_SIZE_M = 64\n BLOCK_SIZE_N = 64\n BLOCK_SIZE_K = 32\n \n grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N), 1)\n fused_linear_kernel[grid](\n x, \n weight, \n z,\n M, N, K,\n apply_gelu=add_gelu,\n dropout_prob=dropout_prob,\n b_ptr=bias,\n r_ptr=residual,\n BLOCK_SIZE_M=BLOCK_SIZE_M,\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n return z.view((*out_shape_0, N))\n", - "description_1": "Use triton language to implement a fused linear kernel with optional GELU activation and dropout. The kernel takes pointers to input data, weights, and output, along with matrix dimensions and optional bias and residual pointers. It performs matrix multiplication in blocks and applies GELU and dropout if specified.", - "description_2": "Use triton language to implement a fused feedforward network function that prepares input data, allocates output, and launches the fused linear kernel with specified parameters.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nimport triton.language as tl\n\n# Triton kernel to perform layer normalization\n@triton.jit\ndef layernorm_kernel(\n x_ptr, # pointer to input data\n weight_ptr, # pointer to weights\n bias_ptr, # pointer to bias\n z_ptr, # pointer to output data\n H, # size of the embedding layer\n eps=1e-5, # epsilon for numerical stability\n BLOCK_SIZE: tl.constexpr = 16, # size of blocks\n):\n row_idx = tl.program_id(0)\n x_row_ptr = x_ptr + row_idx * H # compute the starting pointer for the current row\n z_row_ptr = z_ptr + row_idx * H\n \n # 1. Compute mean\n _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for i in range(0, H, BLOCK_SIZE):\n col_offsets = i + tl.arange(0, BLOCK_SIZE)\n x = tl.load(x_row_ptr + col_offsets, mask=col_offsets < H)\n _sum += x.to(tl.float32)\n \n mean = tl.sum(_sum, axis=0) / H\n \n # 2. Compute variance\n x_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for i in range(0, H, BLOCK_SIZE):\n col_offsets = i + tl.arange(0, BLOCK_SIZE)\n x = tl.load(x_row_ptr + col_offsets, mask=col_offsets < H).to(tl.float32)\n x = tl.where(col_offsets < H, x - mean, 0.)\n x_var += x * x\n \n x_var = tl.sum(x_var, axis=0) / H\n rtsd = tl.sqrt(x_var + eps)\n \n # 3. Normalize and scale\n for i in range(0, H, BLOCK_SIZE):\n col_offsets = i + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < H\n x = tl.load(x_row_ptr + col_offsets, mask=mask)\n w = tl.load(weight_ptr + col_offsets, mask=mask)\n b = tl.load(bias_ptr + col_offsets)\n \n x_hat = (x - mean) / rtsd\n z = x_hat * w + b\n tl.store(z_row_ptr + col_offsets, z, mask=mask)\n\n# Function to call the Triton kernel for layer normalization\n@torch.no_grad()\ndef layernorm(\n x, # input tensor\n weight, # weights for scaling\n bias, # bias for shifting\n eps=1e-5 # epsilon for numerical stability\n):\n # Ensure input tensors are contiguous\n assert x.is_contiguous()\n assert weight.is_contiguous()\n assert bias.is_contiguous()\n \n # Reshape input tensor for processing\n assert x.shape[-1] == weight.shape[0] == bias.shape[0]\n out_shape = x.shape\n x = x.view(-1, x.shape[-1]) # reshape to 2D tensor\n BL, H = x.shape\n z = torch.empty(x.shape, device=x.device, dtype=x.dtype)\n \n # Configure kernel parameters\n MAX_FUSED_SIZE = 4096 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n \n # Launch kernel\n layernorm_kernel[BL,](\n x,\n weight,\n bias,\n z,\n H, \n eps,\n BLOCK_SIZE,\n num_warps=num_warps\n ) \n return z.view(out_shape)\n", - "description_1": "Use triton language to define a layer normalization kernel and its corresponding calling function. The kernel is decorated with @triton.jit and performs operations such as computing mean, variance, and applying normalization on a given tensor using pointers to the input data, weights, bias, and output data. The kernel requires 7 parameters: pointers to input data, weights, bias, output data, the size of the embedding layer (H), a numerical stability parameter (eps), and a block size for processing. The calling function 'layernorm' uses PyTorch, ensures data contiguity, reshapes the input tensor, prepares the output tensor, configures kernel parameters, and then launches the kernel. It requires 4 parameters: input tensor, weights, bias, and epsilon for stability.", - "description_2": "Use triton language to create a customizable layer normalization operator, which includes defining a @triton.jit kernel for computing mean and variance, normalizing inputs, and a PyTorch calling function for managing input/output tensor configurations and launching the kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nimport triton.language as tl\n\n@triton.jit\ndef rmsnorm_kernel(\n x_ptr, # Pointer to input tensor x, shape is [M, N]\n w_ptr, # Pointer to weight tensor w (gamma parameter)\n z_ptr, # Pointer to output tensor z\n K, # Number of elements in the last dimension\n eps=1e-5, # Epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr = 8, # Block size for processing\n):\n # z = (x / (rms + eps)) * w\n\n row_idx = tl.program_id(0)\n x_row_ptr = x_ptr + row_idx * K # Pointer to the start of the row in x\n w_row_ptr = w_ptr + row_idx * K # Pointer to the start of the row in w\n z_row_ptr = z_ptr + row_idx * K # Pointer to the start of the row in z\n\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for col_index in range(0, K, BLOCK_SIZE):\n col_offsets = col_index + tl.arange(0, BLOCK_SIZE)\n x_ptrs = x_row_ptr + col_offsets\n\n x = tl.load(x_ptrs, mask=col_offsets < K, other=0.0).to(tl.float32)\n _var += x * x\n\n var = tl.sum(_var, axis=0) / K\n rms = 1 / tl.sqrt(var + eps)\n\n # Normalize and apply rmsnorm\n for col_index in range(0, K, BLOCK_SIZE):\n col_offsets = col_index + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < K\n\n x = tl.load(x_row_ptr + col_offsets, mask=mask, other=0.0)\n w = tl.load(w_ptr + col_offsets, mask=mask).to(tl.float32)\n\n z = x * rms * w\n tl.store(z_row_ptr + col_offsets, z, mask=mask)\n\n@torch.no_grad()\ndef rmsnorm(\n x, # Input tensor\n weight, # Weight tensor (gamma parameter)\n eps=1e-5 # Epsilon to avoid division by zero\n):\n # Only for NLP layernorm, normalized_shape parameter is omitted\n assert x.is_contiguous()\n assert weight.is_contiguous()\n assert x.shape[-1] == weight.shape[0]\n\n out_shape = x.shape\n # Flatten x to a 2D tensor, [B, L, K] -> [M, K], K is the hidden dimension.\n x = x.view((-1, x.shape[-1]))\n M, K = x.shape\n x = x.view((M, K))\n z = torch.empty(x.shape, device=x.device, dtype=x.dtype)\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 1024 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K))\n\n grid = (triton.cdiv(K, BLOCK_SIZE), 1)\n rmsnorm_kernel[M, ](\n x,\n weight,\n z,\n K,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return z.view(out_shape)\n", - "description_1": "Use triton language to implement a root mean square normalization (RMSNorm) kernel. The kernel takes pointers to input tensor x, weight tensor w, and output tensor z, along with the number of elements K in the last dimension, an epsilon value to avoid division by zero, and a block size for processing. The kernel computes the variance of each row, calculates the root mean square (RMS), and normalizes the input tensor x by dividing it by the RMS and multiplying by the weight tensor w. The rmsnorm function prepares the input and weight tensors, sets up the grid and block size, and calls the kernel to perform the normalization.", - "description_2": "Use triton language to create a kernel for RMS normalization, which normalizes input tensor x using a weight tensor w and an epsilon value to avoid division by zero, and applies it to each row of the input.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_softmax_kernel(\n input_ptr, # pointer to the input data\n stride_input_row, # stride of input rows\n output_ptr, # pointer to the output data\n stride_output_row, # stride of output rows\n num_cols, # number of columns in input\n BLOCK_SIZE: tl.constexpr # block size for triton kernel\n):\n row_id = tl.program_id(axis=0)\n row_start_ptr = input_ptr + row_id * stride_input_row\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_pointers = row_start_ptr + col_offsets\n \n row_data_mask = col_offsets < num_cols\n \n x = tl.load(input_pointers, mask=row_data_mask, other=0.0)\n \n safe_row = x - tl.max(x, axis=0)\n numerator = tl.exp(safe_row)\n denominator = tl.sum(numerator, axis=0)\n softmax_out = numerator / denominator\n \n output_row_ptr = output_ptr + row_id * stride_input_row\n output_pointers = output_row_ptr + col_offsets\n tl.store(output_pointers, softmax_out, mask=row_data_mask)\n\n@torch.no_grad()\ndef softmax(x: torch.Tensor) -> torch.Tensor:\n \"\"\"Triton implementation of Softmax, only supports 2D tensor in forward pass.\"\"\"\n rows, cols = x.shape\n assert x.ndim == 2, f\"only accepts 2D tensor now\"\n BLOCK_SIZE = triton.next_power_of_2(cols)\n num_warps = 4\n if BLOCK_SIZE > 2047:\n num_warps = 8\n elif BLOCK_SIZE > 4095:\n num_warps = 16\n \n grid = (rows, 1)\n \n softmax_out = torch.empty_like(x)\n \n _fwd_softmax_kernel[grid](\n x,\n x.stride(0),\n softmax_out,\n softmax_out.stride(0),\n cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n \n return softmax_out\n", - "description_1": "Use triton language to implement a forward softmax operation on a 2D tensor. The kernel function, _fwd_softmax_kernel, takes 6 parameters: input pointer, input row stride, output pointer, output row stride, number of columns, and block size. It computes the softmax of each row in parallel using Triton, storing results in the output pointer. The softmax function wraps this kernel for use with PyTorch tensors, setting up the grid size and managing memory.", - "description_2": "Use triton language to perform a row-wise softmax operation on 2D tensor data. Create a Triton kernel to compute softmax per row, manage memory and grid size, and provide a wrapper for PyTorch integration.", - "difficulty": 2 - }, - { - "code": "import torch\nimport numpy as np\nimport triton\nimport triton.language as tl\nimport time\n\n@triton.jit\ndef sum_op(a, b):\n return a + b\n\n@triton.jit\ndef kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):\n range_m = tl.arange(0, BLOCK_M)\n range_n = tl.arange(0, BLOCK_N)\n x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])\n z = tl.associative_scan(x, 0, sum_op)\n tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)\n\ndef to_triton(x: np.ndarray, device=\"cuda\", dst_type=None):\n t = x.dtype.name\n if t in [\"uint8\", \"uint16\", \"uint32\", \"uint64\"]:\n signed_type_name = t.lstrip(\"u\")\n x_signed = x.astype(getattr(np, signed_type_name))\n return torch.tensor(x_signed, device=device).contiguous()\n else:\n return torch.tensor(x, device=device).contiguous()\n\ndef to_numpy(x):\n if isinstance(x, torch.Tensor):\n return x.cpu().numpy()\n else:\n raise ValueError(f\"Not a triton-compatible tensor: {x}\")\n\nif __name__ == \"__main__\":\n device = torch.device(\"cuda:0\")\n triton_times = []\n print(\"Initializing\")\n num_warps = 16\n dim = 1\n seq_len = 2048\n batch = 4\n dtype_str = \"float32\"\n axis = 0\n shape = (batch, seq_len, dim)\n n_timings = 10000\n x = np.random.rand(*shape).astype(dtype=np.float32)\n z = np.empty_like(x)\n x_tri = to_triton(x, device=device)\n z_tri = to_triton(z, device=device)\n kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)\n out_triton = to_numpy(z_tri)\n\n for _ in range(n_timings):\n start = time.monotonic_ns()\n kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)\n stop = time.monotonic_ns()\n triton_times.append((stop - start) / (10**9))\n\n print(\"Times triton \" + str(np.array(triton_times).mean()))\n", - "description_1": "Use triton language to implement a kernel that performs an associative scan (cumulative sum) on a 2D tensor. The kernel takes two input tensors X and Z, and three block constants BLOCK_M, BLOCK_N, and AXIS. It loads data from X, performs the scan using a sum operation, and stores the result in Z. The kernel is executed with a specified number of warps.", - "description_2": "Use triton language to perform a cumulative sum on a 2D tensor using a kernel with specified block sizes and axis.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel named _selective_scan_update_kernel and its corresponding call function selective_state_update. The kernel requires 39 parameters which include pointers to various matrices, dimensions, strides, and meta-parameters like DT_SOFTPLUS, BLOCK_SIZE_M, HAS_DT_BIAS, HAS_D, HAS_Z, BLOCK_SIZE_DSTATE. The function selective_state_update takes 9 to 10 parameters including state, x, dt, A, B, C, optional D, optional z, optional dt_bias, and dt_softplus. It ensures correct shapes for these matrices and uses the Triton kernel to perform computation and return an output tensor 'out'.", - "description_2": "Use triton language to create a kernel for updating states with pointers to matrices and calculate the results based on various conditions. The accompanying function sets up parameters and uses the kernel for computation, returning the resultant matrix.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _update_step(\n kv_state_ptr, v_ptr, k_ptr, q_ptr, out_ptr,\n dim, dstate,\n stride_kv_state_batch, stride_kv_state_dim, stride_kv_state_dstate,\n stride_v_batch, stride_v_dim,\n stride_k_batch, stride_k_dstate,\n stride_q_batch, stride_q_dstate,\n stride_out_batch, stride_out_dim,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n kv_state_ptr += pid_b * stride_kv_state_batch\n v_ptr += pid_b * stride_v_batch\n k_ptr += pid_b * stride_k_batch\n q_ptr += pid_b * stride_q_batch\n\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n kv_state_ptrs = kv_state_ptr + (offs_m[:, None] * stride_kv_state_dim + offs_n[None, :] * stride_kv_state_dstate)\n v_ptrs = v_ptr + offs_m * stride_v_dim\n k_ptrs = k_ptr + offs_n * stride_k_dstate\n q_ptrs = q_ptr + offs_n * stride_q_dstate\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n kv_state = tl.load(kv_state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n V = tl.load(v_ptrs, mask=offs_m < dim, other=0.0)\n K = tl.load(k_ptrs, mask=offs_n < dstate, other=0.0)\n Q = tl.load(q_ptrs, mask=offs_n < dstate, other=0.0)\n\n kv_state = kv_state + K[None, :] * V[:, None]\n tl.store(kv_state_ptrs, kv_state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n num = tl.sum(kv_state * Q[None, :], axis=1)\n tl.store(out_ptrs, num, mask=offs_m < dim)\n\n\ndef lin_attn_step(kv_state, v, k, q):\n \"\"\"\n Argument:\n kv state: (batch, dim, dstate)\n v: (batch, dim)\n k: (batch, dstate)\n q: (batch, dstate)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = kv_state.shape\n assert v.shape == (batch, dim)\n assert k.shape == (batch, dstate)\n assert q.shape == k.shape\n\n out = torch.empty_like(v)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n\n BLOCK_SIZE_M, num_warps = (4, 8)\n\n with torch.cuda.device(v.device.index):\n _update_step[grid](\n kv_state, v, k, q, out,\n dim, dstate,\n kv_state.stride(0), kv_state.stride(1), kv_state.stride(2),\n v.stride(0), v.stride(1),\n k.stride(0), k.stride(1),\n q.stride(0), q.stride(1),\n out.stride(0), out.stride(1),\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel `_update_step` that performs matrix updates and transformations with parameters for matrix pointers, dimensions, strides, and meta-parameters like `BLOCK_SIZE_M` and `BLOCK_SIZE_DSTATE`. This kernel is called by the function `lin_attn_step` which is a linear attention step function for processing input tensors `kv_state`, `v`, `k`, and `q`, all with specific batch, dimension, and dstate shapes.", - "description_2": "Use triton language to implement a kernel for matrix operations and create a linear attention function to call this kernel with specific inputs.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n # Write dx\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n # Don't need to compute dresidual_in separately in this case\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement layer normalization kernels with support for optional residuals, RMS norm, and bias. Implement forward and backward passes.", - "description_2": "Use triton language to implement forward and backward pass kernels for layer normalization with optional bias, residuals, and RMS norm.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to define a kernel function that updates state with matrix and vector operations, taking into account parameters such as dt (possibly modified by dt_bias), and applying conditional operations based on the existence of D and z. It is invoked by the selective_state_update function that calculates the output for given input tensors using specified grid and meta parameters.", - "description_2": "Use triton language to implement a kernel for state update with optional bias and scaling, executed by a Python function managing tensor dimensions and grid setup.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Tanh implementation using Triton\n@triton.jit\ndef tanh(x):\n return 2 * tl.sigmoid(2 * x) - 1\n\n# Cosh implementation using Triton\n@triton.jit\ndef cosh(x):\n exp_x = tl.exp(x)\n return (exp_x + 1.0 / exp_x) * 0.5\n\n# ReLU activation function using Triton\n@triton.jit\ndef relu(x):\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n# ReLU gradient computation using Triton\n@triton.jit\ndef relu_grad(x):\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n# Squared ReLU activation function using Triton\n@triton.jit\ndef squared_relu(x):\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n# Squared ReLU gradient computation using Triton\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n# Leaky ReLU activation function using Triton\n@triton.jit\ndef leaky_relu(x):\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n# Leaky ReLU gradient computation using Triton\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n return tl.where(x >= 0, max_grad, min_grad)\n\n# GeLU activation function using Triton\n@triton.jit\ndef gelu(x):\n return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n\n# GeLU gradient computation using Triton\n@triton.jit\ndef gelu_grad(x):\n cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization\n return cdf + x * pdf\n\n# GeLU approximation activation function using Triton\n@triton.jit\ndef gelu_approx(x):\n return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))\n\n# GeLU approximation gradient computation using Triton\n@triton.jit\ndef gelu_approx_grad(x):\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)\n", - "description_1": "Use triton language to implement several activation functions and their gradients: relu, squared_relu, leaky_relu, gelu, and gelu_approx. Each function takes a single argument 'x' which represents input tensor elements, and the kernels apply the respective activation or gradient logic element-wise.", - "description_2": "Use triton language to create element-wise activation functions and gradients: ReLU, Squared ReLU, Leaky ReLU, GELU, and approximate GELU.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flash_attn.ops.triton.k_activations import (\n gelu,\n gelu_approx,\n gelu_grad,\n gelu_approx_grad,\n squared_relu,\n squared_relu_grad,\n)\n\n@triton.jit\ndef kernel_fwd(\n C, ACT_INPUT, A, B, bias, M, N, K, CACHE_KEY_M, CACHE_KEY_N, CACHE_KEY_K,\n stride_cm, stride_am, stride_ak, stride_bn, stride_bk,\n BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,\n A_ROWMAJOR: tl.constexpr, B_COLMAJOR: tl.constexpr,\n BIAS: tl.constexpr, SAVE_ACT_INPUT: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n if A_ROWMAJOR:\n A = A + (ram[:, None] * stride_am + rk[None, :])\n else:\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n if B_COLMAJOR:\n B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n else:\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n if A_ROWMAJOR:\n A += BLOCK_K\n else:\n A += BLOCK_K * stride_ak\n if B_COLMAJOR:\n B += BLOCK_K\n else:\n B += BLOCK_K * stride_bk\n\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n if SAVE_ACT_INPUT:\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n tl.store(act_in_ptrs, acc)\n\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n elif ACTIVATION == \"gelu_approx\":\n acc = gelu_approx(acc)\n elif ACTIVATION == \"squared_relu\":\n acc = squared_relu(acc)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc)\n\n\ndef triton_linear_act(\n x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,\n activation: str = \"id\", save_act_input: bool = False\n) -> torch.Tensor:\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = x.shape[:-1], x.shape[-1]\n batch_dim = batch_shape.numel()\n x_reshaped = x.reshape(batch_dim, n)\n\n if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n x_reshaped = x_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n bias = bias.contiguous() if bias is not None else None\n\n assert (\n x.dtype == weight.dtype\n ), f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert (\n x.dtype == bias.dtype\n ), f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert (\n x_reshaped.shape[1] == weight.shape[1]\n ), f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n\n assert (\n bias is None or bias.shape[0] == weight.shape[0]\n ), \"Incompatible dimensions in between weight and bias\"\n\n M, K = x_reshaped.shape\n N, K = weight.shape\n\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_input = torch.empty_like(output) if save_act_input else None\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_fwd[grid](\n output,\n act_input,\n x_reshaped,\n weight,\n bias if bias is not None else x,\n M,\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=output.stride(0),\n stride_am=x_reshaped.stride(0),\n stride_ak=x_reshaped.stride(1),\n stride_bk=weight.stride(1),\n stride_bn=weight.stride(0),\n BIAS=bias is not None,\n SAVE_ACT_INPUT=save_act_input,\n ACTIVATION=activation,\n A_ROWMAJOR=x_reshaped.stride(1) == 1,\n B_COLMAJOR=weight.stride(1) == 1,\n GROUP_M=8,\n )\n\n if not save_act_input:\n return output.reshape(*batch_shape, output.shape[-1])\n else:\n return (\n output.reshape(*batch_shape, output.shape[-1]),\n act_input.reshape(*batch_shape, act_input.shape[-1]),\n )\n\n\n@triton.jit\ndef kernel_bwd(\n C, ACT_INPUT, A, B, M, N, K, CACHE_KEY_M, CACHE_KEY_N, CACHE_KEY_K,\n stride_cm, stride_am, stride_ak, stride_bk, stride_bn,\n BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n if ACTIVATION != \"id\":\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n act_input = tl.load(act_in_ptrs).to(acc.dtype)\n if ACTIVATION == \"gelu\":\n acc *= gelu_grad(act_input)\n elif ACTIVATION == \"gelu_approx\":\n acc *= gelu_approx_grad(act_input)\n elif ACTIVATION == \"squared_relu\":\n acc *= squared_relu_grad(act_input)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc, mask=mask)\n\n\ndef triton_dgrad_act(\n grad_output: torch.Tensor, weight: torch.Tensor,\n activation: str = \"id\", act_input: Optional[torch.Tensor] = None\n) -> torch.Tensor:\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]\n batch_dim = batch_shape.numel()\n grad_output_reshaped = grad_output.reshape(batch_dim, n)\n\n if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:\n grad_output_reshaped = grad_output_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n\n assert (\n grad_output.dtype == weight.dtype\n ), f\"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}\"\n assert (\n grad_output_reshaped.shape[1] == weight.shape[0]\n ), f\"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}\"\n if activation != \"id\":\n assert act_input is not None, f\"act_input is required for activation {activation}\"\n\n M, K = grad_output_reshaped.shape\n K, N = weight.shape\n\n grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_bwd[grid](\n grad_input,\n act_input,\n grad_output_reshaped,\n weight,\n M,\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=grad_input.stride(0),\n stride_am=grad_output_reshaped.stride(0),\n stride_ak=grad_output_reshaped.stride(1),\n stride_bk=weight.stride(0),\n stride_bn=weight.stride(1),\n ACTIVATION=activation,\n GROUP_M=8,\n )\n\n return grad_input.reshape(*batch_shape, grad_input.shape[-1])\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with optional bias addition and activation function. The function 'kernel_fwd' has 24 parameters: pointers to matrices, matrix dimensions, stride values, and meta-parameters for configuration. The function 'triton_linear_act' has 5 parameters: input tensor, weight matrix, optional bias, activation function, and a boolean for saving activation inputs. Another function 'kernel_bwd' is for back-propagation, with 21 parameters for matrix multiplication, activation, and grad calculation. The 'triton_dgrad_act' function wraps this kernel with 4 parameters: gradient output, weight, activation function, and optional activation inputs.", - "description_2": "Use triton language to create forward and backward matrix multiplication kernels with activation options and optimize tensor operations for performance on GPU.", - "difficulty": 4 - }, - { - "code": "import triton.language as tl\nimport triton\nimport torch\n\n@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3] * meta['BLOCK'])})\n@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[3] * meta['BLOCK'])})\n@triton.jit\ndef _forward(\n X, OUT, LUT, sizemax, stride_zx, stride_zout, stride_hout, **meta\n):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # computation\n c = tl.max(x, axis=0)\n out = tl.log(tl.sum(tl.exp(x - c), axis=0)) + c\n # pointers to OUT\n pout = OUT + pidz * stride_zout + headid * stride_hout + rowid * BLOCK + rxm\n tl.store(pout, out)\n\n@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[5] * meta['BLOCK'])})\n@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[5]) * meta['BLOCK']})\n@triton.jit\ndef _backward(X, OUT, DX, DOUT, LUT, sizemax, stride_zx, stride_zout, stride_hout,\n stride_zdx, stride_zdout, stride_hdout, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n pdx = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n pout = OUT + pidz * stride_zout + headid * stride_hout + rowid * BLOCK + rxm\n pdout = DOUT + pidz * stride_zdout + headid * stride_hdout + rowid * BLOCK + rxm\n # Load\n x = tl.load(px, mask=check, other=-float('inf'))\n out = tl.load(pout)\n dout = tl.load(pdout)\n x = x.to(tl.float32)\n out = out.to(tl.float32)\n dout = dout.to(tl.float32)\n # Computation\n # [2021-09-14] TD: -(out - x) works but x - out segfaults, I think bc of a bug in broadcasting\n dx = dout * tl.exp(-(out - x))\n tl.store(pdx, dx, mask=check)\n\nclass _logsumexp(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, spdims, block, lut, maxlut, n_head, n_row, bench, time):\n out = torch.zeros((x.shape[0], n_head, n_row), dtype=x.dtype, device=x.device)\n # run kernel\n M = x.shape[0]\n meta = {'BLOCK': block}\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, out, lut, maxlut, x.stride(0), out.stride(0), out.stride(1),\n force_nc_cache=True, **meta)\n\n # save to context\n ctx.save_for_backward(x, out, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n return out\n\n @staticmethod\n def backward(ctx, dout):\n # retrieve from context\n x, out, lut = ctx.saved_tensors\n dx = torch.zeros_like(x)\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x, out, dx, dout, lut, ctx.maxlut, x.stride(0), out.stride(0),\n out.stride(1), dx.stride(0), dout.stride(0), dout.stride(1),\n force_nc_cache=True, BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None\n", - "description_1": "Use triton language to implement two kernels: _forward and _backward. _forward takes 7 arguments (X, OUT, LUT, sizemax, stride_zx, stride_zout, stride_hout) and computes a sparse logsumexp operation using a lookup table (LUT) to extract blocks of data from X, compute their max, then perform a logsumexp reduction and store the results in OUT. _backward takes 13 arguments (X, OUT, DX, DOUT, LUT, sizemax, stride_zx, stride_zout, stride_hout, stride_zdx, stride_zdout, stride_hdout) and computes the gradient of the sparse logsumexp operation with respect to the input X, storing the results in DX.", - "description_2": "Use triton language to develop _forward kernel that efficiently performs a block-wise logsumexp using provided metadata and indexing through lookup tables. Additionally, design _backward kernel to compute gradients for the logsumexp operation by utilizing the forward pass outputs and adjusting based on incoming gradients.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _kernel(\n A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc,\n stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta\n):\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8)\n offpb = 0\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8)\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8)\n offpa = 0\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8)\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n\nclass _matmul(torch.autograd.Function):\n @staticmethod\n def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]\n if a_inner != b_inner:\n raise ValueError(f\"Size of tensor A along the {_dim_to_name(a_dim)} dim ({a_inner}) must match size \"\n f\"of tensor B along the {_dim_to_name(b_dim)} dim ({b_inner})\")\n if a_inner % 16 != 0:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n batch_size = a.size(0)\n a_outer = a.size(3 if trans_a else 2)\n dtype = a.dtype\n device = a.device\n total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])\n c = torch.zeros((batch_size, total_width, block, block), dtype=dtype, device=device)\n for lut, width, pack in zip(luts, widths, packs):\n num_lock = 1\n TK = 16 if block == 16 and (a_inner // 16) % 2 == 1 else 32\n meta = {'TM': block * pack, 'TN': block * pack, 'BLOCK': block, 'TK': TK, 'TZ': 1,\n 'SDD': True, 'DSD': False, 'DDS': False}\n locks = _matmul.get_locks(2 * width * batch_size * num_lock, a.device)\n max_width = 49152\n for off_width in range(0, width, max_width):\n grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]\n _kernel[grid](\n a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(0),\n c.stride(2),\n c.stride(3),\n a_outer,\n a_outer,\n a_inner,\n off_width,\n lut,\n locks,\n num_lock,\n num_warps=4,\n **meta\n )\n return c\n\n @staticmethod\n def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):\n AS0 = a.size(0)\n AS1 = a.size(1)\n AS2 = a.size(3 if trans_a else 2)\n BS2 = block * spdims[1 if trans_b else 2]\n dtype = a.dtype\n meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\n 'SDD': False, 'DSD': False, 'DDS': True}\n CS0 = AS0\n CS1 = AS1\n CS2 = BS2 if trans_c else AS2\n CS3 = AS2 if trans_c else BS2\n locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)\n c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)\n grid = lambda meta: [width, triton.cdiv(AS2, meta['TM']), AS0]\n _kernel[grid](\n a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(1),\n c.stride(3 if trans_c else 2),\n c.stride(2 if trans_c else 3),\n AS2,\n BS2,\n 0,\n 0,\n lut,\n locks,\n num_locks,\n num_warps=4,\n **meta\n )\n return c\n\n @staticmethod\n def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):\n AS1 = block * spdims[2 if trans_a else 1]\n BS0 = b.size(0)\n BS1 = b.size(1)\n BS3 = b.size(2 if trans_b else 3)\n dtype = a.dtype\n meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1,\n 'SDD': False, 'DSD': True, 'DDS': False}\n CS0 = BS0\n CS1 = BS1\n CS2 = BS3 if trans_c else AS1\n CS3 = AS1 if trans_c else BS3\n locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)\n c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)\n grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]\n _kernel[grid](\n a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(1),\n c.stride(3 if trans_c else 2),\n c.stride(2 if trans_c else 3),\n BS3,\n AS1,\n 0,\n 0,\n lut,\n locks,\n num_locks,\n num_warps=4,\n **meta\n )\n return c\n\n fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}\n\n @staticmethod\n def forward(\n ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks,\n da_width, da_packs, db_lut, db_num_locks, db_width, db_packs\n ):\n c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)\n ctx.save_for_backward(a, b)\n ctx.da_num_locks = da_num_locks\n ctx.da_lut = da_lut\n ctx.da_width = da_width\n ctx.da_packs = da_packs\n ctx.db_lut = db_lut\n ctx.db_num_locks = db_num_locks\n ctx.db_width = db_width\n ctx.db_packs = db_packs\n ctx.mode = mode\n ctx.spdims = spdims\n ctx.block = block\n ctx.trans_a = trans_a\n ctx.trans_b = trans_b\n return c\n\n @staticmethod\n def backward(ctx, dc):\n a, b = ctx.saved_tensors\n da, db = None, None\n mode = ctx.mode\n if ctx.needs_input_grad[0]:\n mode_da = mode[1] + mode[0] + mode[2]\n da = _matmul.fn[mode_da](\n dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width,\n ctx.da_packs\n )\n if ctx.needs_input_grad[1]:\n mode_db = mode[2] + mode[1] + mode[0]\n db = _matmul.fn[mode_db](\n a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width,\n ctx.db_packs\n )\n return da, db, None, None, None,\\\n None, None, None, None,\\\n None, None, None, None, None, None,\\\n None, None, None, None, None, None,\\\n None, None, None, None, None, None\n\nclass matmul:\n def make_lut(self, dtype, device):\n key = (dtype, device)\n if key in self.lut_cache:\n return self.lut_cache[key]\n layout, block = self.layout, self.block\n step = 16\n if self.mode == 'sdd':\n c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, device)\n elif self.mode == 'dsd':\n c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)\n elif self.mode == 'dds':\n c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)\n if self.mode == 'sdd':\n da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)\n elif self.mode == 'dsd':\n da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, device)\n elif self.mode == 'dds':\n da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)\n if self.mode == 'sdd':\n db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)\n elif self.mode == 'dsd':\n db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)\n elif self.mode == 'dds':\n db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, device)\n self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\n da_lut, da_num_locks, da_width, da_packs,\n db_lut, db_num_locks, db_width, db_packs)\n return self.lut_cache[key]\n\n def __init__(self, layout, block, mode, trans_a=False, trans_b=False):\n if mode not in ['sdd', 'dsd', 'dds']:\n raise NotImplementedError('Supported modes are: sdd, dsd, dds')\n self.lut_cache = dict()\n self.block = block\n self.mode = mode\n self.trans_a = trans_a\n self.trans_b = trans_b\n layout_dim = layout.ndim\n assert layout_dim in (2, 3), \"Layout should be a 2 or 3 dimensional tensor of 0s and 1s\"\n if not mode == 'sdd':\n trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)\n self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner\n sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)\n self.dense_inner_size = layout.shape[sparse_inner] * block\n self.sparse_shape = (layout.sum().item(), block, block)\n if layout_dim == 2:\n layout = layout.unsqueeze(0)\n layout = layout.long()\n self.layout = layout\n self.spdims = layout.shape\n\n def __call__(self, a, b):\n c_lut, c_num_locks, c_width, c_packs,\\\n da_lut, da_num_locks, da_width, da_packs,\\\n db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)\n original_dims = max(a.ndim, b.ndim)\n a, b = self._validate_inputs(a, b)\n c = _matmul.apply(\n a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,\n c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs\n )\n dims_to_trim = c.ndim - original_dims\n for _ in range(dims_to_trim):\n c = c.squeeze(0)\n return c\n\n def _validate_inputs(self, a, b):\n if a.device != b.device:\n raise ValueError(f\"Inputs must be on the same device; got {a.device} for tensor A \"\n f\"and {b.device} for tensor B\")\n if not a.is_cuda:\n raise ValueError(\"Only GPU devices are supported for now\")\n if torch.is_autocast_enabled():\n a, b = a.half(), b.half()\n elif a.dtype != b.dtype:\n raise ValueError(f\"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B\")\n mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b\n if mode != 'sdd':\n dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')\n dense_inner = dense.shape[self.dense_inner_dim]\n if dense_inner != self.dense_inner_size:\n raise ValueError(f\"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim \"\n f\"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.\")\n if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:\n raise ValueError(f\"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument \"\n f\"{sparse_name}, got {sparse.shape}\")\n\n def add_extra_dims(x):\n dims_needed = 4 - x.ndim\n if dims_needed > 0:\n singletons = [1] * dims_needed\n x = x.view(*singletons, *x.shape)\n elif dims_needed < 0:\n raise ValueError(\"Tensors with more than 4 dimensions are not currently supported\")\n return x\n\n a = add_extra_dims(a)\n b = add_extra_dims(b)\n return a, b\n\ndef _dim_to_name(x):\n return \"last\" if x == -1 else \"second to last\"\n", - "description_1": "Use triton language to define a matrix multiplication kernel `_kernel` and encapsulate it in the `_matmul` class with modes for sparse-dense-dense (`sdd`), dense-sparse-dense (`dsd`), and dense-dense-sparse (`dds`) matrix multiplications.", - "description_2": "Use triton language to create a kernel for performing block-sparse matrix multiplications, and a wrapper class to handle different sparsity patterns.", - "difficulty": 4 - }, - { - "code": "import triton.language as tl\nimport triton\nimport torch\n\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef num_warps(n):\n if n < 512:\n return 4\n if n < 2048:\n return 8\n return 16\n\n\n@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3] * meta['BLOCK'])})\n@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[3] * meta['BLOCK'])})\n@triton.jit\ndef _forward(\n X, OUT, LUT, sizemax, stride_zx, stride_zout, stride_hout, **meta\n):\n \"\"\"\n Forward kernel for block-sparse sum.\n Arguments:\n - X: Input tensor of shape (M, H, N)\n - OUT: Output tensor of shape (M, H, N)\n - LUT: Look-up table containing block-sparse information\n - sizemax: Maximum size for LUT lookup\n - stride_zx: Stride for input tensor X\n - stride_zout: Stride for output tensor OUT\n - stride_hout: Stride for the second dimension of output tensor OUT\n \"\"\"\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=0)\n x = x.to(tl.float32)\n # computation\n out = tl.sum(x, axis=0)\n # pointers to OUT\n pout = OUT + pidz * stride_zout + headid * stride_hout + rowid * BLOCK + rxm\n tl.store(pout, out)\n\n\n@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3] * meta['BLOCK'])})\n@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[3]) * meta['BLOCK']})\n@triton.jit\ndef _backward(DX, DOUT, LUT, sizemax, stride_zdx, stride_zdout, stride_hdout, **meta):\n \"\"\"\n Backward kernel for block-sparse sum.\n Arguments:\n - DX: Gradient tensor for input X\n - DOUT: Gradient tensor for output OUT\n - LUT: Look-up table containing block-sparse information\n - sizemax: Maximum size for LUT lookup\n - stride_zdx: Stride for gradient tensor DX\n - stride_zdout: Stride for gradient tensor DOUT\n - stride_hdout: Stride for the second dimension of gradient tensor DOUT\n \"\"\"\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n pdx = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n pdout = DOUT + pidz * stride_zdout + headid * stride_hdout + rowid * BLOCK + rxm\n # Load\n # [2021-09-14] TD: Triton's broadcasting is very buggy, I have to read from dx (which is all\n # zeros) just so that I can broadcast dout (a scalar).\n dx_zeros = tl.load(pdx, mask=check, other=0)\n dout = tl.load(pdout)\n # Computation\n dx = dout - dx_zeros\n tl.store(pdx, dx, mask=check)\n", - "description_1": "Use triton language to implement a block-sparse sum operation in both forward and backward passes.", - "description_2": "Use triton language to compute the block-sparse sum and its gradient for tensors with a given LUT and block configuration.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef get_depth(K):\n return triton.next_power_of_2(K)\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"K\"],\n)\n@triton.heuristics({'DEPTH': lambda nargs: get_depth(nargs['K'])})\n@triton.heuristics({'IS_FP16': lambda nargs: nargs['Y'].dtype == torch.float16})\n@triton.jit\ndef _softmax(\n Y, X, M,\n stride_ym, stride_yn,\n stride_xm, stride_xn,\n stride_m,\n K,\n LOG: tl.constexpr,\n MASK_TYPE: tl.constexpr,\n CAUSAL: tl.constexpr,\n DEPTH: tl.constexpr,\n IS_FP16: tl.constexpr,\n):\n \"\"\"\n Fused softmax kernel over a 3d tensor.\n The softmax is applied over the last dimension, equivalent to torch.softmax(tensor, dim=-1)\n \"\"\"\n m = tl.program_id(0)\n n = tl.program_id(1)\n k = tl.arange(0, DEPTH)\n x_ptrs = X + m * stride_xm + n * stride_xn + k\n io_mask = k < K\n if CAUSAL:\n io_mask = io_mask & (k <= n)\n x = tl.load(x_ptrs, mask=io_mask, other=float(\"-inf\"))\n if CAUSAL:\n off = float(\"-inf\")\n off = off.to(x.dtype)\n x = tl.where(k > n, off, x)\n if MASK_TYPE is not None:\n if MASK_TYPE == 'qk':\n mask_ptrs = M + n * stride_m + k\n elif MASK_TYPE == 'bk':\n mask_ptrs = M + m * stride_m + k\n add_mask = tl.load(mask_ptrs, io_mask, other=float(\"-inf\"))\n x += add_mask\n z = x - tl.max(x, axis=0)\n if IS_FP16:\n z = z.to(tl.float32)\n num = tl.exp(z)\n denom = tl.sum(num, axis=0)\n if LOG:\n y = z - tl.log(denom)\n else:\n y = num / denom\n y_ptrs = Y + m * stride_ym + n * stride_yn + k\n tl.store(y_ptrs, y, mask=k < K)\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"K\"],\n)\n@triton.heuristics({'DEPTH': lambda nargs: get_depth(nargs['K'])})\n@triton.heuristics({'IS_FP16': lambda nargs: nargs['GradIn'].dtype == torch.float16})\n@triton.jit\ndef _softmax_backward(\n GradIn, GradOut, Out,\n stride_bm, stride_bn,\n stride_gm, stride_gn,\n stride_om, stride_on,\n K,\n LOG: tl.constexpr,\n CAUSAL: tl.constexpr,\n DEPTH: tl.constexpr,\n IS_FP16: tl.constexpr,\n):\n \"\"\"\n Compute the softmax gradients.\n \"\"\"\n m = tl.program_id(0)\n n = tl.program_id(1)\n k = tl.arange(0, DEPTH)\n grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k\n out_ptrs = Out + m * stride_om + n * stride_on + k\n io_mask = k < K\n if CAUSAL:\n io_mask = io_mask & (k <= n)\n g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0))\n o = tl.load(out_ptrs, mask=io_mask, other=float(0))\n if CAUSAL:\n zero = float(0)\n zero = zero.to(g.dtype)\n g = tl.where(k > n, zero, g)\n o = tl.where(k > n, zero, o)\n if LOG:\n s = tl.sum(g, 0)\n if IS_FP16:\n o = o.to(tl.float32)\n grad_in = g - tl.exp(o) * s\n else:\n s = tl.sum(g * o, 0)\n grad_in = o * (g - s)\n grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k\n tl.store(grad_in_ptrs, grad_in, mask=k < K)\n", - "description_1": "Use triton language to implement fused softmax and its backward pass. The _softmax kernel computes the fused softmax over a 3D tensor. It has parameters for input/output pointers, strides, and other configurations like LOG, MASK_TYPE, CAUSAL, DEPTH, and IS_FP16. The kernel applies softmax over the last dimension, considering optional masks and handling fp16 values. The _softmax_backward kernel computes the gradients for softmax, taking input/output pointers, strides, and configurations like LOG, CAUSAL, DEPTH, and IS_FP16. It handles both standard and log-softmax cases.", - "description_2": "Use triton language to implement fused softmax operation applied over the last dimension of a 3D tensor with support for masking and handling float16 precision. Additionally, implement the backward pass for computing gradients of the softmax operation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef get_depth(K):\n return triton.next_power_of_2(K)\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"K\"],\n)\n@triton.heuristics({'DEPTH': lambda nargs: get_depth(nargs['K'])})\n@triton.heuristics({'IS_FP16': lambda nargs: nargs['GradIn'].dtype == torch.float16})\n@triton.jit\ndef _softmax_dropout_backward(\n GradIn, GradOut, Out, DropoutMask, dropout_prob,\n stride_bm, stride_bn,\n stride_gm, stride_gn,\n stride_om, stride_on,\n stride_mm, stride_mn,\n K,\n CAUSAL: tl.constexpr,\n DEPTH: tl.constexpr,\n IS_FP16: tl.constexpr,\n):\n \"\"\"\n Compute the softmax gradients.\n ..Note: Not autotuning for now because this would lead to broken accumulated gradients\n \"\"\"\n\n m = tl.program_id(0)\n n = tl.program_id(1)\n\n # col indices\n k = tl.arange(0, DEPTH)\n\n # the memory address of all the elements that we want to load can be computed as follows\n grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k\n out_ptrs = Out + m * stride_om + n * stride_on + k\n dropout_mask_ptrs = DropoutMask + m * stride_mm + n * stride_mn + k\n\n # load input data; pad out-of-bounds elements with 0\n io_mask = k < K\n\n # Causal - 1: skip on the loads directly\n if CAUSAL:\n io_mask = io_mask & (k <= n)\n\n g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0))\n o = tl.load(out_ptrs, mask=io_mask, other=float(0))\n\n zero = float(0)\n zero = zero.to(g.dtype)\n # Causal - 2: enforce correctness over a couple of misloaded values\n if CAUSAL:\n g = tl.where(k > n, zero, g)\n o = tl.where(k > n, zero, o)\n\n dropout_mask = tl.load(dropout_mask_ptrs, mask=io_mask, other=float(0))\n g = tl.where(dropout_mask != 0, g / (1 - dropout_prob), zero)\n\n # Step 1: Compute the intermediate sum used for the gradient\n s = tl.sum(g * o, 0)\n\n # Step 2: Compute the gradients\n grad_in = o * (g - s)\n\n # write back to the input gradients\n # technically we could write only the lower triangular matrix in the causal case\n # but this is deemed to error prone\n grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k\n tl.store(grad_in_ptrs, grad_in, mask=k < K)\n", - "description_1": "Use triton language to implement a softmax dropout backward kernel. The kernel has 15 parameters: GradIn (gradient input tensor), GradOut (gradient output tensor), Out (output tensor), DropoutMask (dropout mask tensor), dropout_prob (dropout probability), stride_bm, stride_bn, stride_gm, stride_gn, stride_om, stride_on, stride_mm, stride_mn (stride values for memory access), K (size of the last dimension), CAUSAL (boolean for causal masking), DEPTH (depth of computation), and IS_FP16 (boolean for half-precision). The kernel computes the gradient of the softmax function with dropout, considering causal masking if specified.", - "description_2": "Use triton language to create a kernel that computes gradients for softmax with dropout, supporting causal masking and half-precision.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' that updates the state of a given matrix using selective scan based on input matrices. The kernel has 10 pointer parameters for input and output matrices, 3 matrix dimension parameters, 18 stride parameters for navigating through input matrices, and 6 meta-parameters to control optional computations. The function is invoked by 'selective_state_update' which computes matrix multiplications and optionally applies non-linear transformations and bias adjustments.", - "description_2": "Use triton language to create a kernel for selective matrix state update with support for bias and non-linear transformations, to be called by a Python function managing input and output tensors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n kv_state_ptr, k_state_ptr, x_ptr, B_ptr, C_ptr, out_ptr,\n dim, dstate,\n stride_kv_state_batch, stride_kv_state_dim, stride_kv_state_dstate,\n stride_k_state_batch, stride_k_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_out_batch, stride_out_dim,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n kv_state_ptr += pid_b * stride_kv_state_batch\n k_state_ptr += pid_b * stride_k_state_batch\n x_ptr += pid_b * stride_x_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n kv_state_ptrs = kv_state_ptr + (offs_m[:, None] * stride_kv_state_dim + offs_n[None, :] * stride_kv_state_dstate)\n k_state_ptrs = k_state_ptr + offs_n * stride_k_state_dstate\n x_ptrs = x_ptr + offs_m * stride_x_dim\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n kv_state = tl.load(kv_state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n k_state = tl.load(k_state_ptrs, mask=offs_n < dstate, other=0.0)\n\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0)\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0)\n\n kv_state = kv_state + B[None, :] * x[:, None]\n tl.store(kv_state_ptrs, kv_state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n \n k_state = k_state + B\n tl.store(k_state_ptrs, k_state, mask=offs_n < dstate)\n \n num = tl.sum(kv_state * C[None, :], axis=1)\n tl.store(out_ptrs, num, mask=offs_m < dim)\n\n\ndef selective_state_update(\n kv_state, \n k_state,\n x, \n B, \n C\n):\n batch, dim, dstate = kv_state.shape\n assert x.shape == (batch, dim)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n BLOCK_SIZE_M, num_warps = (4, 8)\n\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n kv_state, k_state, x, B, C, out,\n dim, dstate,\n kv_state.stride(0), kv_state.stride(1), kv_state.stride(2),\n k_state.stride(0), k_state.stride(1),\n x.stride(0), x.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n out.stride(0), out.stride(1),\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 22 parameters for updating state matrices and a wrapper function 'selective_state_update' with 5 parameters to call the kernel.", - "description_2": "Use triton language to create a kernel for state update and a wrapper to execute it.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S, p1, p2, \n O,\n NUM_BLOCK, \n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n ):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :]\n\n O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V \n\n p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + D_MODEL_K \n\n p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + D_MODEL_V \n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S) \n \n S += D_MODEL_K * D_MODEL_V \n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK-2):\n p_k = tl.load(p1)\n p_v = tl.load(p2)\n S_i = tl.load(S) \n acc = acc * p_k[:, None] * p_v[None, :] + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n p1 += D_MODEL_K\n p2 += D_MODEL_V\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V \n\n@triton.jit\ndef _bwd_recurrence(\n S, p1, p2, \n DS, Dp1, Dp2, \n NUM_BLOCK, NUM_SPLIT_K, NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n ):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n\n DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n\n p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K \n\n p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V \n\n Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V\n\n Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) \n\n for i in range(NUM_BLOCK - 1):\n p_key = tl.load(p1)\n p_value = tl.load(p2)\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n Dacc += DS_i \n dp_i = Dacc * S_i\n dp_key = tl.sum(dp_i * p_value[None, :], axis=1)\n tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty))\n dp_value = tl.sum(dp_i * p_key[:, None], axis=0) \n tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty))\n\n tl.store(S, Dacc.to(S.dtype.element_ty)) \n\n Dacc *= p_key[:, None]\n Dacc *= p_value[None, :]\n\n S -= D_MODEL_K * D_MODEL_V \n DS -= D_MODEL_K * D_MODEL_V \n p1 -= D_MODEL_K \n p2 -= D_MODEL_V \n Dp1 -= D_MODEL_K * NUM_SPLIT_V\n Dp2 -= D_MODEL_V * NUM_SPLIT_K\n\nclass Chunk_memory_update_full(torch.autograd.Function):\n @staticmethod\n def forward(ctx, decay_key_last, decay_value_last, to_add):\n decay_key_last = decay_key_last.contiguous()\n decay_value_last = decay_value_last.contiguous()\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape \n output = torch.empty_like(to_add) \n BLOCK_MODEL = 32\n \n assert D_k % 32 == 0\n assert D_v % 32 == 0\n assert D_k == decay_key_last.shape[-1]\n assert D_v == decay_value_last.shape[-1]\n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n ctx.grid = grid \n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add, \n decay_key_last,\n decay_value_last,\n output,\n D_MODEL_K=D_k, D_MODEL_V=D_v,\n NUM_BLOCK=N, \n BLOCK_MODEL=BLOCK_MODEL\n )\n \n output[:, :, 0] = 0\n ctx.save_for_backward(output, decay_key_last, decay_value_last) \n \n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, decay_key_last, decay_value_last = ctx.saved_tensors \n\n B, H, N, D_k, D_v = output.shape \n\n num_block = N\n \n BLOCK_MODEL = 32 \n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n\n D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, device=DO.device, dtype=torch.float32)\n D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32)\n\n _bwd_recurrence[grid](\n output, decay_key_last, decay_value_last,\n DO, D_p1, D_p2, \n NUM_BLOCK = num_block, NUM_SPLIT_K = D_k // BLOCK_MODEL, NUM_SPLIT_V = D_v // BLOCK_MODEL, \n D_MODEL_K = D_k,\n D_MODEL_V = D_v, \n BLOCK_MODEL = BLOCK_MODEL\n )\n\n output[:, :, -1] = 0\n D_p1[:, :, 0] = 0\n D_p1[:, :, -1] = 0\n D_p2[:, :, 0] = 0\n D_p2[:, :, -1] = 0\n \n return D_p1.sum(-2), D_p2.sum(-2), output\n", - "description_1": "Use triton language to implement two kernels: _fwd_recurrence and _bwd_recurrence. The _fwd_recurrence kernel takes 8 parameters: S, p1, p2, O, NUM_BLOCK, D_MODEL_K, D_MODEL_V, and BLOCK_MODEL. It performs forward recurrence operations on input tensors. The _bwd_recurrence kernel takes 11 parameters: S, p1, p2, DS, Dp1, Dp2, NUM_BLOCK, NUM_SPLIT_K, NUM_SPLIT_V, D_MODEL_K, and D_MODEL_V. It performs backward recurrence operations on input tensors. Both kernels use triton's parallel programming model to handle tensor operations efficiently.", - "description_2": "Use triton language to create a custom autograd function Chunk_memory_update_full with forward and backward methods. The forward method calls _fwd_recurrence kernel with parameters to_add, decay_key_last, decay_value_last, and output. The backward method calls _bwd_recurrence kernel with parameters output, decay_key_last, decay_value_last, DO, D_p1, and D_p2. This function is designed to handle memory updates in a chunked manner for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S, \n O,\n NUM_BLOCK, \n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :]\n\n O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V \n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S) \n \n S += D_MODEL_K * D_MODEL_V \n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK-2):\n S_i = tl.load(S) \n acc = acc + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V \n\n@triton.jit\ndef _bwd_recurrence(\n S, \n DS, \n NUM_BLOCK, NUM_SPLIT_K, NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n\n DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) \n\n for i in range(NUM_BLOCK - 1):\n DS_i = tl.load(DS)\n Dacc += DS_i \n tl.store(S, Dacc.to(S.dtype.element_ty)) \n\n S -= D_MODEL_K * D_MODEL_V \n DS -= D_MODEL_K * D_MODEL_V \n\nclass Chunk_memory_update_no_decay(torch.autograd.Function):\n @staticmethod\n def forward(ctx, to_add):\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape \n output = torch.empty_like(to_add) \n BLOCK_MODEL = 32\n \n assert D_k % 32 == 0\n assert D_v % 32 == 0\n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n ctx.grid = grid \n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add, \n output,\n D_MODEL_K=D_k, D_MODEL_V=D_v,\n NUM_BLOCK=N, \n BLOCK_MODEL=BLOCK_MODEL\n )\n\n output[:, :, 0] = 0\n ctx.save_for_backward(output) \n \n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, = ctx.saved_tensors \n\n B, H, N, D_k, D_v = output.shape \n\n num_block = N\n \n BLOCK_MODEL = 32 \n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n\n _bwd_recurrence[grid](\n output, \n DO, \n NUM_BLOCK = num_block, NUM_SPLIT_K = D_k // BLOCK_MODEL, NUM_SPLIT_V = D_v // BLOCK_MODEL, \n D_MODEL_K = D_k,\n D_MODEL_V = D_v, \n BLOCK_MODEL = BLOCK_MODEL\n )\n\n output[:, :, -1] = 0\n \n return output\n", - "description_1": "Use triton language to implement two kernels: _fwd_recurrence and _bwd_recurrence. The _fwd_recurrence kernel takes 6 parameters: S (input tensor), O (output tensor), NUM_BLOCK (number of blocks), D_MODEL_K (model dimension for K), D_MODEL_V (model dimension for V), and BLOCK_MODEL (block size). It performs a forward recurrence operation on the input tensor S and stores the result in O. The _bwd_recurrence kernel takes 8 parameters: S (input tensor), DS (gradient tensor), NUM_BLOCK (number of blocks), NUM_SPLIT_K (number of splits for K dimension), NUM_SPLIT_V (number of splits for V dimension), D_MODEL_K (model dimension for K), D_MODEL_V (model dimension for V), and BLOCK_MODEL (block size). It performs a backward recurrence operation to compute gradients. The Chunk_memory_update_no_decay class uses these kernels in its forward and backward static methods to perform memory update operations without decay.", - "description_2": "Use triton language to create forward and backward recurrence kernels for memory update operations, handling input and gradient tensors with specified block and model dimensions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S, p1, \n O,\n NUM_BLOCK, \n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :]\n\n O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V \n\n p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + D_MODEL_K \n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S) \n \n S += D_MODEL_K * D_MODEL_V \n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK-2):\n p_k = tl.load(p1)\n S_i = tl.load(S) \n acc = acc * p_k[:, None] + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n p1 += D_MODEL_K\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V \n\n@triton.jit\ndef _bwd_recurrence(\n S, p1, \n DS, Dp1, \n NUM_BLOCK, NUM_SPLIT_K, NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n\n DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n\n p1 = p1 + offset_bh * NUM_BLOCK * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K \n\n Dp1 = Dp1 + offset_bh * NUM_BLOCK * D_MODEL_K * NUM_SPLIT_V + offset_s * D_MODEL_K + tl.arange(0, BLOCK_MODEL) + offset_d * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_K * NUM_SPLIT_V\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) \n\n for i in range(NUM_BLOCK - 1):\n p_key = tl.load(p1)\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n Dacc += DS_i \n dp_i = Dacc * S_i\n dp_key = tl.sum(dp_i, axis=1)\n tl.store(Dp1, dp_key.to(Dp1.dtype.element_ty))\n\n tl.store(S, Dacc.to(S.dtype.element_ty)) \n\n Dacc *= p_key[:, None]\n\n S -= D_MODEL_K * D_MODEL_V \n DS -= D_MODEL_K * D_MODEL_V \n p1 -= D_MODEL_K \n Dp1 -= D_MODEL_K * NUM_SPLIT_V\n\nclass Chunk_memory_update_only_gk(torch.autograd.Function):\n @staticmethod\n def forward(ctx, decay_key_last, to_add):\n decay_key_last = decay_key_last.contiguous()\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape \n output = torch.empty_like(to_add) \n BLOCK_MODEL = 32\n \n assert D_k % 32 == 0\n assert D_v % 32 == 0\n assert D_k == decay_key_last.shape[-1]\n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n ctx.grid = grid \n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add, \n decay_key_last,\n output,\n D_MODEL_K=D_k, D_MODEL_V=D_v,\n NUM_BLOCK=N, \n BLOCK_MODEL=BLOCK_MODEL\n )\n \n output[:, :, 0] = 0\n ctx.save_for_backward(output, decay_key_last) \n \n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, decay_key_last = ctx.saved_tensors \n\n B, H, N, D_k, D_v = output.shape \n\n num_block = N\n \n BLOCK_MODEL = 32 \n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n\n D_p1 = torch.empty(B, H, N, D_v // BLOCK_MODEL, D_k, device=DO.device, dtype=torch.float32)\n\n _bwd_recurrence[grid](\n output, decay_key_last, \n DO, D_p1, \n NUM_BLOCK = num_block, NUM_SPLIT_K = D_k // BLOCK_MODEL, NUM_SPLIT_V = D_v // BLOCK_MODEL, \n D_MODEL_K = D_k,\n D_MODEL_V = D_v, \n BLOCK_MODEL = BLOCK_MODEL\n )\n\n output[:, :, -1] = 0\n D_p1[:, :, 0] = 0\n D_p1[:, :, -1] = 0\n \n return D_p1.sum(-2), output\n", - "description_1": "Use triton language to implement two kernels: _fwd_recurrence and _bwd_recurrence. The _fwd_recurrence kernel takes 7 parameters: S (input tensor), p1 (decay factor), O (output tensor), NUM_BLOCK (number of blocks), D_MODEL_K (key dimension), D_MODEL_V (value dimension), and BLOCK_MODEL (block size). It performs a forward recurrence operation on the input tensor S, updating the output tensor O. The _bwd_recurrence kernel takes 10 parameters: S (input tensor), p1 (decay factor), DS (gradient of output), Dp1 (gradient of decay factor), NUM_BLOCK (number of blocks), NUM_SPLIT_K (number of key splits), NUM_SPLIT_V (number of value splits), D_MODEL_K (key dimension), D_MODEL_V (value dimension), and BLOCK_MODEL (block size). It performs a backward recurrence operation to compute gradients for the input tensor S and decay factor p1.", - "description_2": "Use triton language to create a forward kernel for recurrence operations with 7 parameters and a backward kernel for computing gradients with 10 parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_recurrence(\n S, p2, \n O,\n NUM_BLOCK, \n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :]\n\n O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + D_MODEL_K * D_MODEL_V \n\n p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + D_MODEL_V \n\n acc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32)\n acc += tl.load(S) \n \n S += D_MODEL_K * D_MODEL_V \n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n\n for i in range(NUM_BLOCK-2):\n p_v = tl.load(p2)\n S_i = tl.load(S) \n acc = acc * p_v[None, :] + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n p2 += D_MODEL_V\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V \n\n@triton.jit\ndef _bwd_recurrence(\n S, p2, \n DS, Dp2, \n NUM_BLOCK, NUM_SPLIT_K, NUM_SPLIT_V,\n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL: tl.constexpr\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n\n DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL + tl.arange(0, BLOCK_MODEL)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n\n p2 = p2 + offset_bh * NUM_BLOCK * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V \n\n Dp2 = Dp2 + offset_bh * NUM_BLOCK * D_MODEL_V * NUM_SPLIT_K + offset_d * D_MODEL_V + tl.arange(0, BLOCK_MODEL) + offset_s * BLOCK_MODEL + (NUM_BLOCK - 2) * D_MODEL_V * NUM_SPLIT_K\n\n Dacc = tl.zeros([BLOCK_MODEL, BLOCK_MODEL], dtype=tl.float32) \n\n for i in range(NUM_BLOCK - 1):\n p_value = tl.load(p2)\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n Dacc += DS_i \n dp_i = Dacc * S_i\n dp_value = tl.sum(dp_i, axis=0) \n tl.store(Dp2, dp_value.to(Dp2.dtype.element_ty))\n\n tl.store(S, Dacc.to(S.dtype.element_ty)) \n\n Dacc *= p_value[None, :]\n\n S -= D_MODEL_K * D_MODEL_V \n DS -= D_MODEL_K * D_MODEL_V \n p2 -= D_MODEL_V \n Dp2 -= D_MODEL_V * NUM_SPLIT_K\n\nclass Chunk_memory_update_only_gv(torch.autograd.Function):\n @staticmethod\n def forward(ctx, decay_value_last, to_add):\n decay_value_last = decay_value_last.contiguous()\n to_add = to_add.contiguous()\n\n B, H, N, D_k, D_v = to_add.shape \n output = torch.empty_like(to_add) \n BLOCK_MODEL = 32\n \n assert D_k % 32 == 0\n assert D_v % 32 == 0\n assert D_v == decay_value_last.shape[-1]\n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n ctx.grid = grid \n ctx.BLOCK_MODEL = BLOCK_MODEL\n\n _fwd_recurrence[grid](\n to_add, \n decay_value_last,\n output,\n D_MODEL_K=D_k, D_MODEL_V=D_v,\n NUM_BLOCK=N, \n BLOCK_MODEL=BLOCK_MODEL\n )\n \n output[:, :, 0] = 0\n ctx.save_for_backward(output, decay_value_last) \n \n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, decay_value_last = ctx.saved_tensors \n\n B, H, N, D_k, D_v = output.shape \n\n num_block = N\n \n BLOCK_MODEL = 32 \n\n grid = (B*H, D_k//BLOCK_MODEL, D_v//BLOCK_MODEL)\n\n D_p2 = torch.empty(B, H, N, D_k // BLOCK_MODEL, D_v, device=DO.device, dtype=torch.float32)\n\n _bwd_recurrence[grid](\n output, decay_value_last,\n DO, D_p2, \n NUM_BLOCK=num_block, NUM_SPLIT_K=D_k // BLOCK_MODEL, NUM_SPLIT_V=D_v // BLOCK_MODEL, \n D_MODEL_K=D_k,\n D_MODEL_V=D_v, \n BLOCK_MODEL=BLOCK_MODEL\n )\n\n output[:, :, -1] = 0\n D_p2[:, :, 0] = 0\n D_p2[:, :, -1] = 0\n \n return D_p2.sum(-2), output\n", - "description_1": "Use triton language to implement forward and backward recurrence kernels for a memory update operation. The forward kernel (_fwd_recurrence) takes 7 parameters: S (input tensor), p2 (decay values), O (output tensor), NUM_BLOCK (number of blocks), D_MODEL_K (key dimension), D_MODEL_V (value dimension), and BLOCK_MODEL (block size). It computes a recurrence relation over blocks of the input tensor. The backward kernel (_bwd_recurrence) takes 10 parameters: S (input tensor), p2 (decay values), DS (gradient of S), Dp2 (gradient of p2), NUM_BLOCK (number of blocks), NUM_SPLIT_K (key splits), NUM_SPLIT_V (value splits), D_MODEL_K (key dimension), D_MODEL_V (value dimension), and BLOCK_MODEL (block size). It computes gradients for the recurrence relation. The Chunk_memory_update_only_gv class wraps these kernels for use in PyTorch's autograd system, with forward and backward methods handling the data flow and gradient computation.", - "description_2": "Use triton language to create kernels for a block-wise recurrence operation with forward and backward passes, integrated with PyTorch autograd.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit \ndef stable_log_sigmoid(x):\n # Compute stable log sigmoid\n max_value = tl.where(x < 0, x, 0)\n abs_value = tl.where(x > 0, x, -x)\n return max_value - tl.log(1 + tl.exp(-abs_value))\n\n@triton.jit\ndef _fwd_preprocess_cumsum_gk(\n Q, K, GK, GK_cumsum, \n Q_exp, K_reduce, GK_last_exp, \n NUM_CHUNK, L, normalizer, clamp_min, \n D_MODEL_K: tl.constexpr, \n CHUNK_SIZE: tl.constexpr, \n ):\n # Forward pass for cumulative sum with gating key\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n Q_ptr = Q + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n Q_exp_ptr = Q_exp + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n GK_last_exp_ptr = GK_last_exp + offset_bh * NUM_CHUNK * D_MODEL_K + offset_c * D_MODEL_K + tl.arange(0, D_MODEL_K)\n cumsum = tl.zeros([D_MODEL_K], dtype=tl.float32)\n\n for _ in range(CHUNK_SIZE):\n gk = tl.load(GK_ptr).to(tl.float32) \n gk = stable_log_sigmoid(gk) / normalizer\n gk = tl.where(gk >= clamp_min, gk, clamp_min)\n cumsum += gk \n tl.store(GK_cumsum_ptr, cumsum.to(GK_cumsum_ptr.dtype.element_ty))\n cumsum_exp = tl.exp(cumsum)\n q = tl.load(Q_ptr) \n q_exp = q * cumsum_exp\n tl.store(Q_exp_ptr, q_exp)\n Q_ptr += D_MODEL_K\n Q_exp_ptr += D_MODEL_K\n GK_ptr += D_MODEL_K\n GK_cumsum_ptr += D_MODEL_K\n\n tl.store(GK_last_exp_ptr, tl.exp(cumsum).to(GK_last_exp_ptr.dtype.element_ty))\n tl.debug_barrier()\n \n GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n K_ptr = K + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n K_reduce_ptr = K_reduce + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n\n for _ in range(CHUNK_SIZE):\n gk_cumsum = tl.load(GK_cumsum_ptr)\n k = tl.load(K_ptr)\n k_reduce = k * tl.exp(cumsum - gk_cumsum)\n tl.store(K_reduce_ptr, k_reduce.to(K_reduce_ptr.dtype.element_ty))\n K_ptr += D_MODEL_K\n GK_cumsum_ptr += D_MODEL_K\n K_reduce_ptr += D_MODEL_K\n\n@triton.jit\ndef _bwd_preprocess_cumsum_gk(\n Q, K, GK, GK_cumsum, \n DQ_exp, DK_reduce, DGK_last_exp, DGK_cumsum, \n DQ, DK, DGK, \n NUM_CHUNK, L, normalizer, clamp_min, \n D_MODEL_K: tl.constexpr, \n CHUNK_SIZE: tl.constexpr, \n ):\n # Backward pass for cumulative sum with gating key\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n Q_ptr = Q + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n K_ptr = K + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n GK_cumsum_ptr = GK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n DQ_ptr = DQ + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n DK_ptr = DK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n DQ_exp_ptr = DQ_exp + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n DK_reduce_ptr = DK_reduce + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n DGK_cumsum_ptr = DGK_cumsum + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n DGK_ptr = DGK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K)\n D_GK_last_exp_ptr = DGK_last_exp + offset_bh * NUM_CHUNK * D_MODEL_K + offset_c * D_MODEL_K + tl.arange(0, D_MODEL_K) \n cumsum_gradient = tl.zeros([D_MODEL_K], dtype=tl.float32)\n grad_gk_last = tl.zeros([D_MODEL_K], dtype=tl.float32)\n\n gk_last = tl.load(GK_cumsum_ptr + (CHUNK_SIZE - 1) * D_MODEL_K).to(tl.float32) \n cumsum_gradient += tl.load(D_GK_last_exp_ptr) * tl.exp(gk_last)\n \n GK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n GK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n Q_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n K_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DQ_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DK_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DGK_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DQ_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n DGK_ptr += (CHUNK_SIZE - 1) * D_MODEL_K\n\n for idx in range(CHUNK_SIZE -1, -1, -1):\n gk_cs = tl.load(GK_cumsum_ptr).to(tl.float32)\n k = tl.load(K_ptr).to(tl.float32)\n grad_k = tl.exp(gk_last - gk_cs) * tl.load(DK_reduce_ptr).to(tl.float32)\n tl.store(DK_ptr, grad_k.to(DK_ptr.dtype.element_ty))\n grad_k *= k \n cumsum_gradient -= grad_k\n grad_gk_last += grad_k\n\n q = tl.load(Q_ptr).to(tl.float32)\n grad_q = tl.exp(gk_cs) * tl.load(DQ_exp_ptr) \n tl.store(DQ_ptr, grad_q.to(DK_ptr.dtype.element_ty))\n cumsum_gradient += grad_q * q.to(tl.float32)\n\n cumsum_gradient += tl.load(DGK_cumsum_ptr).to(tl.float32) \n \n tl.store(DGK_ptr, cumsum_gradient.to(DGK_ptr.dtype.element_ty))\n\n Q_ptr -= D_MODEL_K\n DQ_exp_ptr -= D_MODEL_K\n K_ptr -= D_MODEL_K\n DK_reduce_ptr -= D_MODEL_K\n GK_cumsum_ptr -= D_MODEL_K\n DGK_cumsum_ptr -= D_MODEL_K\n DQ_ptr -= D_MODEL_K\n DK_ptr -= D_MODEL_K\n DGK_ptr -= D_MODEL_K\n\n DGK_ptr = DGK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K) + (CHUNK_SIZE - 1) * D_MODEL_K\n GK_ptr = GK + offset_bh * L * D_MODEL_K + offset_c * CHUNK_SIZE * D_MODEL_K + tl.arange(0, D_MODEL_K) + (CHUNK_SIZE - 1) * D_MODEL_K\n\n grad_gk_last = grad_gk_last + 0.\n for idx in range(CHUNK_SIZE -1, -1, -1): \n dgk = tl.load(DGK_ptr).to(tl.float32)\n dgk += grad_gk_last\n \n gk = tl.load(GK_ptr).to(tl.float32) \n gk_logit = stable_log_sigmoid(gk) / normalizer\n dgk = tl.where(gk_logit >= clamp_min, (dgk / normalizer) * (1 - tl.sigmoid(gk)), 0.)\n\n tl.store(DGK_ptr, dgk.to(DGK_ptr.dtype.element_ty))\n DGK_ptr -= D_MODEL_K\n GK_ptr -= D_MODEL_K\n\nclass PreprocessCumSum_GK(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, gk, normalizer_gk=8, clamp_min=-3):\n # Forward function for PreprocessCumSum_GK\n q = q.contiguous()\n k = k.contiguous()\n gk = gk.contiguous()\n \n B, H, NUM_CHUNK, CHUNK_SIZE, D = q.shape\n D_k = k.shape[-1]\n \n grid = (B * H, NUM_CHUNK)\n ctx.grid = grid \n\n k_reduce = torch.empty_like(k)\n q_exp = torch.empty_like(q)\n gk_cumsum = torch.empty_like(gk)\n gk_last_exp = torch.empty_like(gk[:, :, :, 0], dtype=torch.float32)\n\n _fwd_preprocess_cumsum_gk[grid](\n q, k, gk, gk_cumsum, \n q_exp, k_reduce, gk_last_exp, \n CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK = NUM_CHUNK, L = CHUNK_SIZE * NUM_CHUNK, normalizer=normalizer_gk, clamp_min=clamp_min,\n D_MODEL_K=D_k, num_warps=8 if D_k >= 512 else 4\n )\n \n ctx.grid = grid \n ctx.save_for_backward(q, k, gk, gk_cumsum)\n ctx.normalizer_gk = normalizer_gk\n ctx.clamp_min = clamp_min\n\n return gk_cumsum, k_reduce, q_exp, gk_last_exp\n\n @staticmethod\n def backward(ctx, dgk_cumsum, dk_reduce, dq_exp, dgk_last_exp):\n # Backward function for PreprocessCumSum_GK\n dgk_cumsum = dgk_cumsum.contiguous()\n dk_reduce = dk_reduce.contiguous()\n dq_exp = dq_exp.contiguous()\n dgk_last_exp = dgk_last_exp.contiguous()\n\n q, k, gk, gk_cumsum = ctx.saved_tensors\n grid = ctx.grid\n\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dgk = torch.empty_like(gk)\n\n B, H, NUM_CHUNK, CHUNK_SIZE, D_k = q.shape\n\n _bwd_preprocess_cumsum_gk[grid](\n q, k, gk, gk_cumsum, \n dq_exp, dk_reduce, dgk_last_exp, dgk_cumsum,\n dq, dk, dgk,\n CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK = NUM_CHUNK, L = CHUNK_SIZE * NUM_CHUNK, normalizer=ctx.normalizer_gk, clamp_min = ctx.clamp_min,\n D_MODEL_K=D_k, num_warps=8 if D_k >= 512 else 4\n )\n\n return dq, dk, dgk, None, None, None\n", - "description_1": "Use triton language to implement a stable log sigmoid function and a forward and backward pass for cumulative sum with gating key. The stable_log_sigmoid kernel takes 1 argument: x, which is a tensor. The _fwd_preprocess_cumsum_gk kernel takes 13 arguments: Q, K, GK, GK_cumsum, Q_exp, K_reduce, GK_last_exp, NUM_CHUNK, L, normalizer, clamp_min, D_MODEL_K, CHUNK_SIZE. The _bwd_preprocess_cumsum_gk kernel takes 14 arguments: Q, K, GK, GK_cumsum, DQ_exp, DK_reduce, DGK_last_exp, DGK_cumsum, DQ, DK, DGK, NUM_CHUNK, L, normalizer, clamp_min, D_MODEL_K, CHUNK_SIZE.", - "description_2": "Use triton language to create a stable log sigmoid function and implement forward and backward passes for cumulative sum with gating key, handling tensors and constants.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit \ndef stable_log_sigmoid(x):\n # Compute stable log sigmoid\n max_value = tl.where(x < 0, x, 0)\n abs_value = tl.where(x > 0, x, -x)\n return max_value - tl.log(1 + tl.exp(-abs_value))\n\n@triton.jit\ndef _fwd_preprocess_cumsum_gv(\n V, GV, \n GV_cumsum, GV_exp, V_reduce, GV_last_exp, \n NUM_CHUNK, L, normalizer, clamp_min,\n D_MODEL_V: tl.constexpr, \n CHUNK_SIZE: tl.constexpr, \n ):\n # Forward pass for cumulative sum with gradient value processing\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n\n GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n GV_last_exp_ptr = GV_last_exp + offset_bh * NUM_CHUNK * D_MODEL_V + offset_c * D_MODEL_V + tl.arange(0, D_MODEL_V)\n GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n GV_exp_ptr = GV_exp + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n\n cumsum = tl.zeros([D_MODEL_V], dtype=tl.float32)\n \n for _ in range(CHUNK_SIZE):\n gv = tl.load(GV_ptr).to(tl.float32) \n gv = stable_log_sigmoid(gv) / normalizer\n gv = tl.where(gv >= clamp_min, gv, clamp_min)\n cumsum += gv\n\n tl.store(GV_cumsum_ptr, cumsum.to(GV_cumsum_ptr.dtype.element_ty))\n tl.store(GV_exp_ptr, tl.exp(cumsum).to(GV_cumsum_ptr.dtype.element_ty))\n \n GV_cumsum_ptr += D_MODEL_V\n GV_exp_ptr += D_MODEL_V\n GV_ptr += D_MODEL_V\n\n tl.store(GV_last_exp_ptr, tl.exp(cumsum).to(GV_last_exp_ptr.dtype.element_ty))\n \n tl.debug_barrier()\n \n V_ptr = V + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) \n GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n V_reduce_ptr = V_reduce + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) \n\n for _ in range(CHUNK_SIZE):\n v = tl.load(V_ptr) \n gv = tl.load(GV_cumsum_ptr)\n v_reduce = v * tl.exp(cumsum - gv)\n tl.store(V_reduce_ptr, v_reduce.to(V_reduce_ptr.dtype.element_ty))\n \n V_ptr += D_MODEL_V\n V_reduce_ptr += D_MODEL_V\n GV_cumsum_ptr += D_MODEL_V\n \n@triton.jit\ndef _bwd_preprocess_cumsum_gv(\n V, GV, GV_cumsum, \n DGV_cumsum_exp, DV_reduce, DGV_last_exp, DGV_cumsum, \n DV, DGV, \n NUM_CHUNK, L, normalizer, clamp_min, \n D_MODEL_V: tl.constexpr, \n CHUNK_SIZE: tl.constexpr, \n ):\n # Backward pass for cumulative sum with gradient value processing\n offset_bh = tl.program_id(0)\n offset_c = tl.program_id(1)\n V_ptr = V + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n GV_cumsum_ptr = GV_cumsum + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n\n DV_ptr = DV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n DV_reduce_ptr = DV_reduce + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n DGV_cumsum_ptr = DGV_cumsum + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n DGV_cumsum_exp_ptr = DGV_cumsum_exp + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n\n DGV_ptr = DGV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V)\n\n D_GV_last_exp_ptr = DGV_last_exp + offset_bh * NUM_CHUNK * D_MODEL_V + offset_c * D_MODEL_V + tl.arange(0, D_MODEL_V) \n \n cumsum_gradient = tl.zeros([D_MODEL_V], dtype=tl.float32)\n grad_gv_last = tl.zeros([D_MODEL_V], dtype=tl.float32)\n\n gv_last = tl.load(GV_cumsum_ptr + (CHUNK_SIZE - 1) * D_MODEL_V) \n cumsum_gradient += tl.load(D_GV_last_exp_ptr) * tl.exp(gv_last).to(tl.float32)\n \n GV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n GV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n\n V_ptr += (CHUNK_SIZE - 1) * D_MODEL_V \n\n DV_reduce_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DGV_cumsum_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DGV_cumsum_exp_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n DGV_ptr += (CHUNK_SIZE - 1) * D_MODEL_V\n\n for idx in range(CHUNK_SIZE -1, -1, -1):\n gv_cs = tl.load(GV_cumsum_ptr).to(tl.float32)\n v = tl.load(V_ptr).to(tl.float32)\n grad_v = tl.exp(gv_last - gv_cs) * tl.load(DV_reduce_ptr).to(tl.float32)\n tl.store(DV_ptr, grad_v.to(DV_ptr.dtype.element_ty))\n grad_v *= v\n cumsum_gradient -= grad_v\n grad_gv_last += grad_v\n\n grad_v = tl.exp(gv_cs) * tl.load(DGV_cumsum_exp_ptr) \n cumsum_gradient += grad_v\n\n cumsum_gradient += tl.load(DGV_cumsum_ptr).to(tl.float32) \n \n tl.store(DGV_ptr, cumsum_gradient.to(DGV_ptr.dtype.element_ty))\n\n V_ptr -= D_MODEL_V\n DV_reduce_ptr -= D_MODEL_V\n GV_cumsum_ptr -= D_MODEL_V\n DGV_cumsum_ptr -= D_MODEL_V\n DV_ptr -= D_MODEL_V\n DGV_ptr -= D_MODEL_V\n DGV_cumsum_exp_ptr -= D_MODEL_V\n \n DGV_ptr = DGV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + (CHUNK_SIZE - 1) * D_MODEL_V\n GV_ptr = GV + offset_bh * L * D_MODEL_V + offset_c * CHUNK_SIZE * D_MODEL_V + tl.arange(0, D_MODEL_V) + (CHUNK_SIZE - 1) * D_MODEL_V\n \n grad_gv_last = grad_gv_last + 0.\n\n for idx in range(CHUNK_SIZE -1, -1, -1): \n dgv = tl.load(DGV_ptr).to(tl.float32)\n dgv += grad_gv_last\n gv = tl.load(GV_ptr).to(tl.float32) \n\n gv_logit = stable_log_sigmoid(gv) / normalizer\n gv = tl.sigmoid(gv) \n dgv = (dgv / normalizer) * (1 - gv) \n dgv = tl.where(gv_logit >= clamp_min, dgv, 0.)\n\n tl.store(DGV_ptr, dgv.to(DGV_ptr.dtype.element_ty))\n DGV_ptr -= D_MODEL_V\n GV_ptr -= D_MODEL_V\n\nclass PreprocessCumSum_GV(torch.autograd.Function):\n @staticmethod\n def forward(ctx, v, gv, normalizer_gv=8, clamp_min=-3):\n # Forward pass for PreprocessCumSum_GV\n v = v.contiguous()\n gv = gv.contiguous()\n \n B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape\n\n grid = (B * H, NUM_CHUNK)\n ctx.grid = grid \n\n gv_cumsum = torch.empty_like(gv, dtype=torch.float32) \n gv_cumsum_exp = torch.empty_like(gv)\n v_reduce = torch.empty_like(v)\n gv_last_exp = torch.empty_like(gv[:, :, :, 0], dtype=torch.float32)\n _fwd_preprocess_cumsum_gv[grid](\n v, gv, gv_cumsum, gv_cumsum_exp, \n v_reduce, gv_last_exp, \n CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK = NUM_CHUNK, L = CHUNK_SIZE * NUM_CHUNK, normalizer=normalizer_gv, clamp_min=clamp_min,\n D_MODEL_V=D_v, num_warps=8 if D_v >= 512 else 4\n ) \n \n ctx.grid = grid \n ctx.save_for_backward(v, gv, gv_cumsum)\n ctx.normalizer_gv = normalizer_gv\n ctx.clamp_min = clamp_min\n\n return gv_cumsum, v_reduce, gv_cumsum_exp, gv_last_exp\n\n @staticmethod\n def backward(ctx, dgv_cumsum, dv_reduce, dgv_cumsum_exp, dgv_last_exp):\n # Backward pass for PreprocessCumSum_GV\n dgv_cumsum = dgv_cumsum.contiguous()\n dv_reduce = dv_reduce.contiguous()\n dgv_cumsum_exp = dgv_cumsum_exp.contiguous()\n dgv_last_exp = dgv_last_exp.contiguous()\n v, gv, gv_cumsum = ctx.saved_tensors\n grid = ctx.grid\n\n B, H, NUM_CHUNK, CHUNK_SIZE, D_v = v.shape\n\n dv = torch.empty_like(v)\n dgv = torch.empty_like(gv) \n _bwd_preprocess_cumsum_gv[grid](\n v, gv, gv_cumsum, dgv_cumsum_exp, dv_reduce, dgv_last_exp, dgv_cumsum, \n dv, dgv, \n CHUNK_SIZE=CHUNK_SIZE, NUM_CHUNK = NUM_CHUNK, L = CHUNK_SIZE * NUM_CHUNK, normalizer=ctx.normalizer_gv, clamp_min = ctx.clamp_min,\n D_MODEL_V=D_v, num_warps=8 if D_v >= 512 else 4 \n ) \n return dv, dgv, None, None, None\n", - "description_1": "Use triton language to implement a stable log sigmoid function and a forward and backward pass for cumulative sum with gradient value processing. The stable_log_sigmoid kernel takes 1 argument: x, which is a tensor. The _fwd_preprocess_cumsum_gv kernel takes 11 arguments: V, GV, GV_cumsum, GV_exp, V_reduce, GV_last_exp, NUM_CHUNK, L, normalizer, clamp_min, and D_MODEL_V, CHUNK_SIZE as constexpr. The _bwd_preprocess_cumsum_gv kernel takes 13 arguments: V, GV, GV_cumsum, DGV_cumsum_exp, DV_reduce, DGV_last_exp, DGV_cumsum, DV, DGV, NUM_CHUNK, L, normalizer, clamp_min, and D_MODEL_V, CHUNK_SIZE as constexpr. The PreprocessCumSum_GV class has a forward method with 4 arguments: v, gv, normalizer_gv, and clamp_min, and a backward method with 4 arguments: dgv_cumsum, dv_reduce, dgv_cumsum_exp, and dgv_last_exp.", - "description_2": "Use triton language to create a stable log sigmoid function and implement forward and backward passes for cumulative sum with gradient value processing, handling tensors and constants efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward computation of A\n@triton.jit\ndef _fwd_kernel_compute_A(\n Q, K, GK, \n A, \n stride_q1, stride_q2, stride_q3, stride_q4,\n stride_a1, stride_a2, stride_a3, stride_a4,\n Z, H, N_CTX, D,\n BLOCK_DMODEL_QK: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_k = tl.program_id(2)\n\n qk_offset = off_hz * stride_q2 + off_k * BLOCK_DMODEL_QK\n a_offset = (off_k * Z*H + off_hz) * stride_a2 \n\n lo = 0\n hi = BLOCK_N \n\n Q_ptr = Q + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n K_ptr = K + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[:, None] + tl.arange(0, 16)[None, :] * stride_q4 \n\n GK_K_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[:, None] + tl.arange(0, 16)[None, :] * stride_q4 \n\n GK_Q_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 \n\n A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 \n\n for q_high in range(16, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + q_high * stride_q4 + tl.arange(0,BLOCK_DMODEL_QK)).to(tl.float32)\n q_gk2 = tl.exp(q_gk - q_normalizer[None, :])\n q = q * q_gk2.to(q.dtype)\n\n #inter-chunk bf16\n for k_high in range(0, q_high, 16):\n k = tl.load(K_ptr + k_high * stride_q4)\n k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32) \n k_gk = tl.exp(q_normalizer[:, None] - k_gk)\n k = k * k_gk.to(k.dtype)\n qk = tl.dot(q, k, allow_tf32=False) \n tl.store(A_ptr + q_high * stride_a4 + k_high, qk.to(A_ptr.dtype.element_ty)) \n\n\n ## intra chunk fp32\n for q_high in range(lo, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + q_high * stride_q4 + tl.arange(0,BLOCK_DMODEL_QK)).to(tl.float32)\n q_gk2 = tl.exp(q_gk - q_normalizer[None, :])\n q = q * q_gk2\n q_gk3 = tl.exp(q_normalizer[None, :] - q_gk)\n k = tl.load(K_ptr + q_high * stride_q4)\n k = k * tl.trans(q_gk3)\n\n qk = tl.dot(q, k, allow_tf32=False)\n qk = tl.where(tl.arange(0, 16)[:, None]>=tl.arange(0, 16)[None, :], qk, 0.)\n tl.store(A_ptr + q_high * stride_a4 + q_high, qk.to(A_ptr.dtype.element_ty)) \n\n# Triton kernel for backward computation of dqk\n@triton.jit\ndef _bwd_kernel_dqk(Q, K, GK, DA, \n DQ, \n DK, DGK,\n stride_q1, stride_q2, stride_q3, stride_q4,\n stride_a1, stride_a2, stride_a3, stride_a4,\n Z, H, N_CTX, D,\n BLOCK_DMODEL_QK: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr\n ):\n\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_k = tl.program_id(2)\n\n qk_offset = off_hz * stride_q2 + BLOCK_DMODEL_QK * off_k\n a_offset = off_hz * stride_a2\n\n lo = 0\n hi = BLOCK_N \n\n Q_ptr = Q + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n K_ptr = K + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n GK_K_ptr = GK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n GK_Q_ptr = GK + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n DA_ptr = DA + a_offset + (start_m) * stride_a3 + tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 \n\n # inter chunk dq. bf16\n for q_high in range(lo+16, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4) \n\n q_normalizer = tl.load(GK + qk_offset + (start_m * stride_q3)+ q_high * stride_q4 + tl.arange(0,BLOCK_DMODEL_QK)).to(tl.float32)\n\n dq2 = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32)\n\n for k_high in range(0, q_high, 16):\n k = tl.load(K_ptr + k_high * stride_q4)\n k_gk = tl.load(GK_K_ptr + k_high * stride_q4).to(tl.float32) \n dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(k.dtype)\n k_gk = tl.exp(q_normalizer[None, :] - k_gk)\n k = k * k_gk.to(k.dtype)\n dq2 += tl.dot(dqk, k, allow_tf32=False)\n\n dq2 = dq2.to(q.dtype)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_gk = tl.exp(q_gk - q_normalizer[None, :])\n dq = dq2 * q_gk.to(q.dtype) \n dq_gk = dq * q\n\n DQ_ptr = DQ + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + q_high * stride_q4\n tl.store(DQ_ptr, dq.to(DQ_ptr.dtype.element_ty))\n\n DGK_Q_ptr = DGK + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + q_high * stride_q4\n tl.store(DGK_Q_ptr, dq_gk.to(DGK_Q_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n for k_high in range(lo, hi-16, 16):\n k = tl.load(K_ptr + k_high * stride_q4)\n k_gk = tl.load(GK_K_ptr + k_high * stride_q4)\n dk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32)\n dgk = tl.zeros([16, BLOCK_DMODEL_QK], dtype=tl.float32)\n\n for q_high in range(k_high+16, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4) \n q_normalizer = tl.load(GK + qk_offset + (start_m * stride_q3)+ q_high * stride_q4 + tl.arange(0,\n BLOCK_DMODEL_QK)).to(tl.float32)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_gk = tl.exp(q_gk - q_normalizer[None, :]).to(q.dtype)\n q = q * q_gk\n dqk = tl.load(DA_ptr + q_high * stride_a4 + k_high).to(q.dtype)\n\n k_gk2 = tl.exp(q_normalizer[None, :] - k_gk)\n\n dk2 = tl.dot(tl.trans(dqk), q, allow_tf32=False)\n dk += dk2 * k_gk2\n dgk -= dk2 * k * k_gk2\n\n DK_ptr = DK + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + k_high * stride_q4\n tl.store(DK_ptr, dk.to(DK_ptr.dtype.element_ty))\n\n DGK_K_ptr = DGK + qk_offset + (start_m) * stride_q3+ tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4 + k_high * stride_q4\n prev = tl.load(DGK_K_ptr)\n tl.store(DGK_K_ptr, (prev + dgk).to(DGK_K_ptr.dtype.element_ty))\n\n tl.debug_barrier()\n\n DK_ptr = DK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n DGK_K_ptr = DGK + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n DQ_ptr = DQ + qk_offset + (start_m) * stride_q3 + tl.arange(0, BLOCK_DMODEL_QK)[None, :] + tl.arange(0, 16)[:, None] * stride_q4\n\n ## intra chunk, fp32.\n for q_high in range(lo, hi, 16):\n q = tl.load(Q_ptr + q_high * stride_q4)\n q_gk = tl.load(GK_Q_ptr + q_high * stride_q4).to(tl.float32)\n q_normalizer = tl.load(GK + qk_offset + start_m * stride_q3 + q_high * stride_q4 + tl.arange(0,BLOCK_DMODEL_QK)).to(tl.float32)\n q_gk2 = tl.exp(q_gk - q_normalizer[None, :])\n q2 = q * q_gk2\n q_gk3 = tl.exp(q_normalizer[None, :] - q_gk)\n\n k = tl.load(K_ptr + q_high * stride_q4)\n k2 = k * q_gk3\n\n dqk = tl.load(DA_ptr + q_high * stride_a4 + q_high)\n dqk = tl.where(tl.arange(0, 16)[:, None]>=tl.arange(0, 16)[None, :], dqk, 0.)\n\n dk2 = tl.dot(tl.trans(dqk), q2, allow_tf32=False) \n dk = dk2 * q_gk3\n prev_dk = tl.load(DK_ptr + q_high * stride_q4)\n tl.store(DK_ptr + q_high * stride_q4, (dk + prev_dk).to(DK_ptr.dtype.element_ty))\n\n dgk = - dk * k\n dq2 = tl.dot(dqk, k2, allow_tf32=False)\n dq = dq2 * q_gk2\n\n prev_dq = tl.load(DQ_ptr + q_high * stride_q4)\n tl.store(DQ_ptr + q_high * stride_q4, (dq + prev_dq).to(DQ_ptr.dtype.element_ty))\n\n dgk += dq * q\n prev_dq_gk = tl.load(DGK_K_ptr + q_high * stride_q4)\n tl.store(DGK_K_ptr + q_high * stride_q4, (dgk + prev_dq_gk).to(DGK_K_ptr.dtype.element_ty))\n\n# Class wrapping the forward and backward computation using the Triton kernels\nclass FlashGRet(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, gk):\n q = q.contiguous()\n k = k.contiguous()\n gk = gk.contiguous()\n\n capability = torch.cuda.get_device_capability()\n if capability[0] < 8:\n raise RuntimeError(\"Flash attention currently only supported for compute capability >= 80\")\n\n BLOCK_M = BLOCK_N = q.shape[-2]\n\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk \n if Lk > 128:\n assert Lk % 128 == 0\n\n BLOCK_DMODEL_QK = min(Lk, 128)\n ctx.BLOCK_DMODEL_QK = BLOCK_DMODEL_QK\n\n A = torch.zeros(max(1, Lk//128) , q.shape[0], q.shape[1], q.shape[2], BLOCK_N, BLOCK_N, device=q.device, dtype=q.dtype)\n\n grid = (q.shape[2], q.shape[0] * q.shape[1], max(1, Lk//128))\n\n _fwd_kernel_compute_A[grid](\n q, k, gk, A,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n A.stride(1), A.stride(2), A.stride(3), A.stride(4),\n q.shape[0], q.shape[1], q.shape[2], q.shape[3], \n BLOCK_N=BLOCK_N, BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, BLOCK_M=BLOCK_M, num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4, num_stages=8\n )\n\n ctx.save_for_backward(q, k, gk)\n ctx.grid = grid\n ctx.BLOCK_N = BLOCK_N\n ctx.BLOCK_N = BLOCK_N\n ctx.head = q.shape[1]\n return A.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, dA):\n dA = dA.contiguous()\n q, k, gk = ctx.saved_tensors\n\n dq = torch.zeros_like(q)\n dk = torch.zeros_like(k)\n dgk = torch.zeros_like(gk)\n\n BLOCK_N = ctx.BLOCK_N\n BLOCK_M = BLOCK_N\n\n _bwd_kernel_dqk[ctx.grid](\n q, k, gk, dA,\n dq, \n dk, dgk,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n dA.stride(0), dA.stride(1), dA.stride(2), dA.stride(3),\n q.shape[0], q.shape[1], q.shape[2], q.shape[3],\n BLOCK_N=BLOCK_N, BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK, BLOCK_M=BLOCK_M, num_warps=8 if ctx.BLOCK_DMODEL_QK == 128 else 4, num_stages=5\n )\n \n return dq, dk, dgk, None\n", - "description_1": "Use triton language to implement a forward and backward computation kernel for processing tensor operations involving query, key, and decay_key tensors, and their gradients. The forward kernel computes a matrix A by performing element-wise operations and matrix multiplication on input tensors. The backward kernel computes the gradients of the input tensors. Both kernels support configurable block sizes for the operations.", - "description_2": "Use triton language to implement a forward and backward computation kernel for matrix operations involving three tensors and their gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_compute_O(\n A, V, GV, O, \n stride_a1, stride_a2, stride_a3, stride_a4,\n stride_v1, stride_v2, stride_v3, stride_v4,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_v = tl.program_id(2)\n\n a_offset = off_hz * stride_a2\n v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V\n\n lo = 0\n hi = BLOCK_N \n\n V_ptr = V + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n O_ptr = O + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4 \n\n for q_high in range(lo+16, hi, 16):\n q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32)\n acc = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32)\n \n for k_high in range(0, q_high, 16): \n qk = tl.load(A_ptr + q_high * stride_a4 + k_high) \n v = tl.load(V_ptr + k_high * stride_v4)\n k_gv = tl.load(GV_ptr + k_high * stride_v4)\n k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv)\n v = v * k_gv.to(v.dtype) \n output = tl.dot(qk.to(v.dtype), v, allow_tf32=False) \n acc += output\n \n tl.store(O_ptr + q_high * stride_v4, acc.to(O.dtype.element_ty)) \n \n tl.store(O_ptr, tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32).to(O.dtype.element_ty))\n \n tl.debug_barrier()\n \n for q_high in range(lo, hi, 16):\n q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32)\n\n qk = tl.load(A_ptr + q_high * stride_a4 + q_high) \n v = tl.load(V_ptr + q_high * stride_v4)\n k_gv = tl.load(GV_ptr + q_high * stride_v4)\n k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv)\n \n v = v * k_gv2\n output = tl.dot(qk.to(tl.float32), v, allow_tf32=False)\n \n q_gv = tl.exp(k_gv - q_gv_normalizer[None, :])\n\n prev = tl.load(O_ptr + q_high * stride_v4)\n output += prev \n output = output * q_gv\n\n tl.store(O_ptr + q_high * stride_v4, output.to(O.dtype.element_ty))\n\n@triton.jit\ndef _bwd_kernel_dav(V, GV, A, O, \n DO, DA,\n DV, DGV, \n Z, H, \n stride_a1, stride_a2, stride_a3, stride_a4,\n stride_v1, stride_v2, stride_v3, stride_v4,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr\n ):\n \n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_v = tl.program_id(2)\n\n a_offset = off_hz * stride_a2\n da_offset = (off_v * Z * H + off_hz) * stride_a2 \n v_offset = off_hz * stride_v2 + off_v * BLOCK_DMODEL_V \n\n lo = 0\n hi = BLOCK_N \n \n DO_ptr = DO + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n O_ptr = O + v_offset + (start_m ) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n \n DV_ptr = DV + v_offset + (start_m ) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n GV_ptr = GV + v_offset + (start_m ) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n DGV_ptr = DGV + v_offset + (start_m ) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[:, None] * stride_v4\n\n A_ptr = A + a_offset + (start_m ) * stride_a3 + tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4\n\n DA_ptr = DA + da_offset + (start_m ) * stride_a3 + tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None] * stride_a4\n\n for q_high in range(lo, hi, 16):\n do = tl.load(DO_ptr + q_high * stride_v4) \n o = tl.load(O_ptr + q_high * stride_v4)\n tl.store(DGV_ptr + q_high * stride_v4, (do * o)) \n \n q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32)\n q_gv = tl.load(GV_ptr + q_high * stride_v4)\n q_gv = tl.exp(q_gv - q_gv_normalizer[None, :])\n do = do * q_gv\n\n tl.store(DO_ptr + q_high * stride_v4, do.to(DO_ptr.dtype.element_ty))\n \n tl.debug_barrier()\n\n V_ptr = V + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[:, None] + tl.arange(0, 16)[None, :] * stride_v4\n GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[:, None] + tl.arange(0, 16)[None, :] * stride_v4\n\n for q_high in range(lo+16, hi, 16):\n do = tl.load(DO_ptr + q_high * stride_v4) \n q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * stride_v4 + tl.arange(0, \n BLOCK_DMODEL_V)).to(tl.float32)\n \n for k_high in range(0, q_high, 16):\n v = tl.load(V_ptr + k_high * stride_v4) \n k_gv = tl.load(GV_ptr + k_high * stride_v4)\n k_gv = tl.exp(q_gv_normalizer[:, None] - k_gv)\n \n v2 = v * k_gv.to(v.dtype) \n dqk = tl.dot(do, v2, allow_tf32=False) \n tl.store(DA_ptr + q_high * stride_a4 + k_high, dqk.to(DA.dtype.element_ty)) \n \n tl.debug_barrier()\n\n A_ptr = A + a_offset + (start_m) * stride_a3 + tl.arange(0, 16)[:, None] + tl.arange(0, 16)[ None, :] * stride_a4\n\n V_ptr = V + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[ :, None] * stride_v4\n GV_ptr = GV + v_offset + (start_m) * stride_v3 + tl.arange(0, BLOCK_DMODEL_V)[None, :] + tl.arange(0, 16)[ :, None] * stride_v4\n\n for k_high in range(0, hi, 16): \n dv = tl.zeros([16, BLOCK_DMODEL_V], dtype=tl.float32)\n\n k_gv = tl.load(GV_ptr + k_high * stride_v4)\n\n for q_high in range(k_high + 16, BLOCK_N, 16):\n do = tl.load(DO_ptr + q_high * stride_v4) \n\n kq = tl.load(A_ptr + q_high * stride_a4 + k_high).to(do.dtype) \n\n q_gv_normalizer = tl.load(GV + v_offset + (start_m) * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32)\n k_gv2 = tl.exp(q_gv_normalizer[None, :] - k_gv) \n\n dv2 = tl.dot(kq, do, allow_tf32=False) \n dv += dv2 * k_gv2\n\n v = tl.load(V_ptr + k_high * stride_v4)\n tl.store(DV_ptr + k_high * stride_v4, dv.to(v.dtype))\n \n prev_dv = tl.load(DGV_ptr + k_high * stride_v4)\n tl.store(DGV_ptr + k_high * stride_v4, prev_dv - dv*v)\n \n tl.debug_barrier()\n\n A_ptr = A + a_offset + (start_m ) * stride_a3 + tl.arange(0, 16)[:, None] + tl.arange(0, 16)[ None, :] * stride_a4 \n\n for q_high in range(lo, hi, 16):\n do = tl.load(DO_ptr + q_high * stride_v4) \n\n q_gv_normalizer = tl.load(GV + v_offset + start_m * stride_v3 + q_high * stride_v4 + tl.arange(0, BLOCK_DMODEL_V)).to(tl.float32)\n\n v = tl.load(V_ptr + q_high * stride_v4)\n k_gv = tl.load(GV_ptr + q_high * stride_v4)\n k_gv = tl.exp(q_gv_normalizer[None, :] - k_gv)\n v2 = v * k_gv\n\n dqk = tl.dot(do.to(v2.dtype), tl.trans(v2), allow_tf32=False)\n dqk = tl.where(tl.arange(0, 16)[:, None] >= tl.arange(0, 16)[None, :], dqk, 0.)\n tl.store(DA_ptr + q_high * stride_a4 + q_high, dqk.to(DA_ptr.dtype.element_ty))\n\n kq = tl.load(A_ptr + q_high * stride_a4 + q_high).to(do.dtype)\n dv2 = tl.dot(kq, do, allow_tf32=False)\n \n dv = dv2 * k_gv\n prev_dv = tl.load(DV_ptr + q_high * stride_v4)\n tl.store(DV_ptr + q_high * stride_v4, (prev_dv + dv).to(DV.dtype.element_ty))\n\n prev_gdv = tl.load(DGV_ptr + q_high * stride_v4)\n prev_gdv -= dv * v \n tl.store(DGV_ptr + q_high * stride_v4, prev_gdv.to(DGV.dtype.element_ty))\n\nclass FlashGRet_O(torch.autograd.Function):\n @staticmethod\n def forward(ctx, A, v, gv, chunk_size=16):\n assert gv.dtype == torch.float32\n\n capability = torch.cuda.get_device_capability()\n if capability[0] < 8:\n raise RuntimeError(\"Flash attention currently only supported for compute capability >= 80\")\n \n BLOCK_M = BLOCK_N = v.shape[-2]\n\n Lv = v.shape[-1]\n BLOCK_V = min(128, Lv)\n ctx.BLOCK_V = BLOCK_V \n\n assert v.shape[-1] % BLOCK_V == 0\n \n grid = (v.shape[2] , v.shape[0] * v.shape[1], max(1, v.shape[-1] // BLOCK_V))\n \n o = torch.empty_like(v) \n\n _fwd_compute_O[grid](A, v, gv, o,\n A.stride(0), A.stride(1), A.stride(2), A.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M,\n BLOCK_DMODEL_V=BLOCK_V, num_warps= 8 if BLOCK_V==128 else 4, num_stages=5\n )\n\n ctx.save_for_backward(A, v,gv, o)\n ctx.grid = grid \n ctx.chunk_size = chunk_size\n return o\n\n @staticmethod\n def backward(ctx, do):\n do = do.contiguous()\n A, v, gv, o = ctx.saved_tensors\n BLOCK_V = ctx.BLOCK_V\n assert v.shape[-1] % BLOCK_V == 0\n\n dv = torch.zeros_like(v)\n dgv = torch.zeros_like(gv)\n \n BLOCK_M = BLOCK_N = v.shape[-2]\n \n grid = ctx.grid \n\n dA = torch.empty(v.shape[-1] // BLOCK_V if BLOCK_V == 128 else 1, A.shape[0], A.shape[1], A.shape[2], A.shape[3], A.shape[3], device=A.device, dtype=A.dtype)\n\n _bwd_kernel_dav[grid](\n v, gv, A, o, \n do, dA,\n dv, dgv,\n v.shape[0], v.shape[1],\n A.stride(0), A.stride(1), A.stride(2), A.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, \n BLOCK_DMODEL_V=ctx.BLOCK_V, num_warps=8, num_stages=4\n ) \n\n return dA.sum(0).to(A), dv.to(v), dgv.to(gv), None\n", - "description_1": "Use triton language to implement two kernels: _fwd_compute_O and _bwd_kernel_dav. The _fwd_compute_O kernel computes the forward pass of a matrix operation with parameters A, V, GV, O, and strides for A and V. It uses BLOCK_M, BLOCK_N, and BLOCK_DMODEL_V as block sizes. The _bwd_kernel_dav kernel computes the backward pass with parameters V, GV, A, O, DO, DA, DV, DGV, Z, H, and strides for A and V. It also uses BLOCK_M, BLOCK_N, and BLOCK_DMODEL_V as block sizes. The FlashGRet_O class wraps these kernels for use in PyTorch's autograd system, with forward and backward methods that call the respective kernels.", - "description_2": "Use triton language to create a forward kernel _fwd_compute_O for matrix operations with parameters A, V, GV, O, and strides, and a backward kernel _bwd_kernel_dav for gradients with parameters V, GV, A, O, DO, DA, DV, DGV, Z, H, and strides. Implement a PyTorch autograd function FlashGRet_O to utilize these kernels.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Tanh is just a scaled sigmoid\n@triton.jit\ndef tanh(x):\n return 2 * tl.sigmoid(2 * x) - 1\n\n# ReLU activation function\n@triton.jit\ndef relu(x):\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n# ReLU gradient\n@triton.jit\ndef relu_grad(x):\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n# Squared ReLU activation\n@triton.jit\ndef squared_relu(x):\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n# Squared ReLU gradient\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n# Leaky ReLU activation\n@triton.jit\ndef leaky_relu(x):\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n# Leaky ReLU gradient\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n return tl.where(x >= 0, max_grad, min_grad)\n\n# Gaussian Error Linear Unit (GELU)\n@triton.jit\ndef gelu(x):\n _sqrt1_2 = 0.70710678118 # precomputed sqrt(1/2)\n return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n\n# GELU gradient\n@triton.jit\ndef gelu_grad(x):\n _gaussian_pdf_normalization = 0.3989422804014337 # precomputed 1/sqrt(2*pi)\n _sqrt1_2 = 0.70710678118 # precomputed sqrt(1/2)\n cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization\n return cdf + x * pdf\n\n# GeLU activation with tanh approximation\n@triton.jit\ndef gelu_approx(x):\n _sqrt2pi = 0.7978845608 # precomputed sqrt(2/pi)\n return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))\n\n# GeLU approximation gradient\n@triton.jit\ndef gelu_approx_grad(x):\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (\n 1 + tanh_out\n )\n", - "description_1": "Use triton language to implement various activation functions and their gradients, including ReLU, Squared ReLU, Leaky ReLU, GELU, and GELU with tanh approximation. Each function takes a single tensor input and applies the respective activation or gradient computation.", - "description_2": "Use triton language to create activation functions and their gradients for neural networks, such as ReLU and GELU.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flash_attn.ops.triton.k_activations import (\n gelu,\n gelu_approx,\n squared_relu,\n)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n # good for int8\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n ),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n \"perf_model\": estimate_matmul_time,\n \"top_k\": 10,\n },\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_fwd(\n C, # Pointers to matrices\n ACT_INPUT,\n A,\n B,\n bias,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n stride_cm,\n # stride_cn, # Assume that stride_cn == 1\n stride_am,\n stride_ak,\n stride_bn,\n stride_bk,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n # split k not used, not performant with activation, kept because early_config_prune is expecting it\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n A_ROWMAJOR: tl.constexpr,\n B_COLMAJOR: tl.constexpr,\n BIAS: tl.constexpr,\n SAVE_ACT_INPUT: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Bias has shape (N,)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n This kernel will consolidate over K\n \"\"\"\n\n pid = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n # now compute the block that each program will go through\n # rm (resp. rn) denotes a range of indices\n # for rows (resp. col) of C\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n # trick to avoid masking on M and N axis\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n if A_ROWMAJOR:\n A = A + (ram[:, None] * stride_am + rk[None, :])\n else:\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n if B_COLMAJOR:\n B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n else:\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n if A_ROWMAJOR:\n A += BLOCK_K\n else:\n A += BLOCK_K * stride_ak\n if B_COLMAJOR:\n B += BLOCK_K\n else:\n B += BLOCK_K * stride_bk\n\n # Putting bias after the matmul (instead of before) is faster, idk why\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n # optional: save the activation inputs\n if SAVE_ACT_INPUT:\n # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n tl.store(act_in_ptrs, acc)\n\n # optional: fused activation (while the data is in shared memory)\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n elif ACTIVATION == \"gelu_approx\":\n acc = gelu_approx(acc)\n elif ACTIVATION == \"squared_relu\":\n acc = squared_relu(acc)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # write back result\n # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc)\n\n\ndef triton_linear_act(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor] = None,\n activation: str = \"id\",\n save_act_input: bool = False,\n) -> torch.Tensor:\n \"\"\"\n Compute e = activation(x @ weight.T + bias).\n This wrapper kicks the `kernel_fwd` Triton kernel\n :param x: input tensor\n :param weight: weight matrix\n :param bias: an optional bias tensor\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_input: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n # if torch.is_autocast_enabled():\n # dtype = torch.get_autocast_gpu_dtype()\n # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]\n\n assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n batch_shape, n = x.shape[:-1], x.shape[-1]\n batch_dim = batch_shape.numel()\n x_reshaped = x.reshape(batch_dim, n)\n\n if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n x_reshaped = x_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n bias = bias.contiguous() if bias is not None else None\n\n assert (\n x.dtype == weight.dtype\n ), f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert (\n x.dtype == bias.dtype\n ), f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert (\n x_reshaped.shape[1] == weight.shape[1]\n ), f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n\n assert (\n bias is None or bias.shape[0] == weight.shape[0]\n ), \"Incompatible dimensions in between weight and bias\"\n\n M, K = x_reshaped.shape\n N, K = weight.shape\n\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_input = torch.empty_like(output) if save_act_input else None\n\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_fwd[grid](\n output,\n act_input,\n x_reshaped,\n weight, # data ptrs\n bias if bias is not None else x, # auto skip bias if not present\n M, # shapes\n N,\n K,\n M // 32, # key for triton cache (limit number of compilations)\n N // 32,\n K // 32,\n stride_cm=output.stride(0), # strides\n # stride_cn=output.stride(1),\n stride_am=x_reshaped.stride(0),\n stride_ak=x_reshaped.stride(1),\n stride_bk=weight.stride(1),\n stride_bn=weight.stride(0),\n BIAS=bias is not None, # optional fused bias\n SAVE_ACT_INPUT=save_act_input, # optional save activation inputs\n ACTIVATION=activation, # optional fused activation\n A_ROWMAJOR=x_reshaped.stride(1) == 1,\n B_COLMAJOR=weight.stride(1) == 1,\n GROUP_M=8, # speed optimization: group the programs\n )\n\n if not save_act_input:\n return output.reshape(*batch_shape, output.shape[-1])\n else:\n return (\n output.reshape(*batch_shape, output.shape[-1]),\n act_input.reshape(*batch_shape, act_input.shape[-1]),\n )\n", - "description_1": "Use triton language to implement a forward kernel for matrix multiplication with optional activation and bias. The kernel takes pointers to matrices, dimensions, strides, and meta-parameters as inputs. It computes the output matrix by performing a dot product of input matrices A and B, adds bias if provided, and applies an activation function if specified. The kernel is optimized for performance using autotuning and heuristics.", - "description_2": "Use triton language to implement a forward kernel for matrix multiplication with optional activation and bias, optimized with autotuning.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a layer normalization forward pass kernel. The kernel function '_layer_norm_fwd_1pass_kernel' takes 18 parameters: pointers to input, output, weights, biases, residuals, mean, and 1/std, strides for input, output, and residuals, number of columns, epsilon for numerical stability, and several compile-time constants. The function computes the mean and variance of the input, normalizes it, applies a linear transformation using weights and biases, and stores the result. The wrapper function '_layer_norm_fwd' prepares the input data, allocates output tensors, and launches the kernel with appropriate configurations.", - "description_2": "Use triton language to create a kernel for layer normalization that computes mean and variance, normalizes input, applies weights and biases, and stores the result. Implement a wrapper to handle input preparation and kernel launch.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 34 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to manage data preparation and kernel invocation.", - "description_2": "Use triton language to create a kernel for matrix state updates with optional bias and scaling, and a wrapper to handle data and call the kernel.", - "difficulty": 4 - }, - { - "code": "import math\n\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement two kernels: (1) _layer_norm_fwd_1pass_kernel, which performs layer normalization on input data with configurable features such as residual connections, bias application, and RMS norm option; (2) _layer_norm_bwd_kernel, which computes gradients for layer normalization parameters, taking similar features into account. Both functions are configured with parameters such as block size and presence of bias to optimize computation.", - "description_2": "Use triton language to implement layer normalization and its backward pass with support for optional features like residuals and RMS norm using optimized kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, dim, dstate,\n # Strides\n stride_state_batch, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_dim,\n stride_dt_batch, stride_dt_dim,\n stride_dt_bias_dim,\n stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_dstate,\n stride_C_batch, stride_C_dstate,\n stride_D_dim,\n stride_z_batch, stride_z_dim,\n stride_out_batch, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n state_ptr += pid_b * stride_state_batch\n x_ptr += pid_b * stride_x_batch\n dt_ptr += pid_b * stride_dt_batch\n B_ptr += pid_b * stride_B_batch\n C_ptr += pid_b * stride_C_batch\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch\n out_ptr += pid_b * stride_out_batch\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.log(1.0 + tl.exp(dt))\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None]\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate)\n x: (batch, dim)\n dt: (batch, dim)\n A: (dim, dstate)\n B: (batch, dstate)\n C: (batch, dstate)\n D: (dim,)\n z: (batch, dim)\n dt_bias: (dim,)\n Return:\n out: (batch, dim)\n \"\"\"\n batch, dim, dstate = state.shape\n assert x.shape == (batch, dim)\n assert dt.shape == x.shape\n assert A.shape == (dim, dstate)\n assert B.shape == (batch, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (dim,)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (dim,)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)\n z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, dim, dstate,\n state.stride(0), state.stride(1), state.stride(2),\n x.stride(0), x.stride(1),\n dt.stride(0), dt.stride(1),\n dt_bias.stride(0) if dt_bias is not None else 0,\n A.stride(0), A.stride(1),\n B.stride(0), B.stride(1),\n C.stride(0), C.stride(1),\n D.stride(0) if D is not None else 0,\n z_strides[0], z_strides[1],\n out.stride(0), out.stride(1),\n dt_softplus,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 30 parameters for matrix operations and a wrapper function 'selective_state_update' with 9 parameters to manage state updates in a batch processing context.", - "description_2": "Use triton language to create a kernel for selective state updates with matrix operations and a wrapper function to handle batch processing.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_apply_penalty(\n Logits, presence_penalty, freqency_penalty,\n p_token_ids, p_token_counts, p_cumsum_seq_len, \n stride_logit_b, stride_logit_s,\n BLOCK_P: tl.constexpr\n):\n # Determine the current batch index and load penalties\n cur_batch = tl.program_id(0)\n cur_freqency = tl.load(freqency_penalty + cur_batch)\n cur_presence = tl.load(presence_penalty + cur_batch)\n\n # Load the start and end indices for the current batch\n cur_batch_start_index = tl.load(p_cumsum_seq_len + cur_batch)\n cur_batch_end_index = tl.load(p_cumsum_seq_len + cur_batch + 1)\n\n # Compute the offsets and load token ids and their counts\n cur_batch_id_offset = cur_batch_start_index + tl.arange(0, BLOCK_P)\n batch_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset> 28) + 16) & 15\n int4_zp = ((zp_b << (28 - i * 4) >> 28) + 16) & 15\n bs_offs = (offs_k * 8 + i)[:, None] * stride_bsk + (offs_n // group_size)[None, :] * stride_bsn\n fpb_offs = (offs_k * 8 + i)[:, None] * stride_fpbk + offs_n[None, :] * stride_fpbn\n k8_mask = (offs_k * 8 + i)[:, None] < K * 8\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=n_mask & k8_mask, other=0.0)\n fp_weight = (int4_b - int4_zp) * scale_b\n tl.store(fpb_ptr + fpb_offs, fp_weight, mask=n_mask & k8_mask)\n\ndef dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size):\n Kw, N = b.shape\n fp_b = torch.empty((b_scale.shape[0], b.shape[1]), device=device, dtype=dtype)\n grid = lambda META: (\n triton.cdiv(Kw, META['BLOCK_SIZE_K']),\n triton.cdiv(N, META['BLOCK_SIZE_N']), \n )\n dequantize_kernel[grid](\n b, b_scale, b_zero_point, fp_b,\n Kw, N, group_size,\n b.stride(0), b.stride(1),\n b_scale.stride(0), b_scale.stride(1),\n b_zero_point.stride(0), b_zero_point.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n return fp_b\n\ndef matmul_dequantize_int4(a, b, b_scale, b_zero_point, group_size=128, out=None):\n # Check constraints.\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n Kw, N = b.shape\n if out is None:\n # Allocates output.\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n else:\n c = out\n fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, group_size)\n torch.mm(a, fp_b, out=c)\n fp_b = None\n return c\n", - "description_1": "Use triton language to create a kernel called dequantize_kernel that dequantizes integer matrices to floating-point matrices. This kernel takes pointers to matrices (b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr), matrix dimensions (K, N, group_size), and strides for each dimension (stride_bk, stride_bn, stride_bsk, stride_bsn, stride_bzpk, stride_bzpn, stride_fpbk, stride_fpbn) as inputs. It also uses meta-parameters BLOCK_SIZE_K and BLOCK_SIZE_N. The kernel dequantizes int4 weights using scale and zero point matrices and stores the resulting floating-point weights in fpb_ptr. The corresponding function dequantize_int4 sets up a triton grid for launching this kernel and returns the dequantized matrix. Another function matmul_dequantize_int4 uses the dequantize_int4 function to first dequantize matrix b and then performs a matrix multiplication of a with the dequantized b, returning the result.", - "description_2": "Use triton language to implement a kernel to dequantize int4 matrices and perform matrix multiplication using the dequantized matrices.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef dequantize_kernel(\n # Pointers to matrices\n b_ptr, b_scale_ptr, fpb_ptr,\n # Matrix dimensions\n K, N,\n stride_bk, stride_bn,\n stride_fpbk, stride_fpbn,\n # Meta-parameters\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n k_block_idx = tl.program_id(axis=0)\n n_block_idx = tl.program_id(axis=1)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + \\\n (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_bn\n fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + \\\n (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_fpbn\n bs_offs = n_block_idx * BLOCK_SIZE_N + offs_n[None, :]\n n_mask = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] < N\n mask = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None] < K) & n_mask\n int_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0)\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=n_mask, other=0.0)\n tl.store(fpb_ptr + fpb_offs, int_b * scale_b, mask=mask)\n\ndef matmul_dequantize_int8(a, b, b_scale, out=None):\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n if out == None:\n # Allocates output.\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n else:\n c = out\n fp_b = torch.empty((K, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n dequantize_kernel[grid](\n b, b_scale, fp_b,\n K, N,\n b.stride(0), b.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n torch.mm(a, fp_b, out=c)\n return c\n", - "description_1": "Use triton language to implement a kernel function 'dequantize_kernel' that dequantizes an int8 matrix 'b' using a scale matrix 'b_scale' and stores the result in 'fpb'. The kernel takes 10 parameters: 3 pointers to matrices (b_ptr, b_scale_ptr, fpb_ptr), 2 matrix dimensions (K, N), 4 strides (stride_bk, stride_bn, stride_fpbk, stride_fpbn), and 2 meta-parameters (BLOCK_SIZE_N, BLOCK_SIZE_K). The function 'matmul_dequantize_int8' calls this kernel to perform matrix multiplication with dequantization, taking 4 parameters: matrices 'a', 'b', 'b_scale', and an optional output matrix 'out'.", - "description_2": "Use triton language to create a kernel for dequantizing an int8 matrix with a scale matrix and perform matrix multiplication with the dequantized result.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n K, Dest_loc,\n Out,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n\n k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)\n tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)\n return\n\n\n@torch.no_grad()\ndef destindex_copy_kv(K, DestLoc, Out):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n K, DestLoc, Out,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K, Dest_loc, Out, Out_scale,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n stride_os_bs, stride_os_h, stride_os_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], \n mask=offs_h[:, None] < head_num, other=0.0)\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)[:, None]\n q_src_data = (src_data / data_scale).to(tl.int8)\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]\n tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num)\n tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num)\n\n\n@torch.no_grad()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K, DestLoc, Out, Out_scale,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two kernels and their callers. The first kernel '_fwd_kernel_destindex_copy_kv' takes 11 parameters: K (key tensor), Dest_loc (destination index tensor), Out (output tensor), stride_k_bs, stride_k_h, stride_k_d (strides for key tensor), stride_o_bs, stride_o_h, stride_o_d (strides for output tensor), head_num (number of heads), BLOCK_DMODEL (constant for block size in model dimensions), and BLOCK_HEAD (constant for block size in head dimension). It copies data from K to Out based on indices from Dest_loc. The caller 'destindex_copy_kv' calculates required parameters and launches the kernel. The second kernel '_fwd_kernel_destindex_copy_quantize_kv' takes 15 parameters: similar parameters as the first one, with additional Out_scale (output scale tensor), stride_os_bs, stride_os_h, stride_os_d (strides for output scale tensor). It performs quantization on the data before storing it to Out and also stores scale factors to Out_scale. Its caller 'destindex_copy_quantize_kv' also calculates parameters and launches the kernel.", - "description_2": "Use triton language to copy and quantize data from source tensors to destination tensors with specified block sizes and head numbers.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_m = tl.load(Alibi + cur_head)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])\n qk -= alibi_loc * alibi_m\n\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\nelif triton.__version__ == \"2.0.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,\n TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_tmp_b, stride_tmp_h, stride_tmp_s,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_m = tl.load(Alibi + cur_head)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])\n qk -= alibi_loc * alibi_m\n\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs)\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n\n tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,\n tmp,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n tmp.stride(0), tmp.stride(1), tmp.stride(2),\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\nelse:\n raise Exception(\"error triton version!\")\n", - "description_1": "Use triton language to implement a forward kernel for context attention. The kernel takes 22 parameters: Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, Out, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, BLOCK_M, BLOCK_DMODEL, BLOCK_N. It computes the attention scores and updates the output accumulator using a loop over the sequence length. The context_attention_fwd function calls this kernel with 8 parameters: q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_len, and sets up the grid and block dimensions for the kernel execution.", - "description_2": "Use triton language to implement a forward kernel for context attention with 23 parameters, including a temporary buffer TMP for version 2.0.0. The kernel computes attention scores and updates the output accumulator. The context_attention_fwd function calls this kernel with 8 parameters and sets up the grid and block dimensions for execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n # Write output\n tl.store(Y + cols, y.to(tl.float16), mask=mask)\n\ndef layernorm_forward(x, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n return y\n", - "description_1": "Use triton language to create a fused layer normalization kernel. The kernel '_layer_norm_fwd_fused' takes 8 arguments: 1) X, a pointer to the input data; 2) Y, a pointer to the output data; 3) W, a pointer to the weights; 4) B, a pointer to the biases; 5) stride, an integer indicating the row stride in memory; 6) N, the number of columns in X; 7) eps, a small epsilon value to avoid division by zero; 8) BLOCK_SIZE, a compile-time constant indicating the block size for operations. The kernel computes the mean and variance of X, applies normalization, multiplies by weights, adds biases, and writes the result to Y. The 'layernorm_forward' function calls this kernel, preparing and validating input dimensions and setting the number of warps based on BLOCK_SIZE.", - "description_2": "Use triton language to implement and execute a layer normalization operation, normalizing input using compute mean and variance, and applying weights and biases.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_token_att1(\n Q, K, sm_scale, Alibi, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n Att_Out,\n stride_b_loc_b, stride_b_loc_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n att_stride_h, att_stride_bs,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = max_input_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n alibi_m = tl.load(Alibi + cur_head)\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)\n off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n att_value -= alibi_m * (cur_batch_seq_len - 1 - offs_n)\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n@torch.no_grad()\ndef token_att_fwd(q, k, att_out, alibi, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n BLOCK = 32\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_Loc.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n num_warps = 2\n\n _fwd_kernel_token_att1[grid](\n q, k, sm_scale, alibi, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n att_out,\n B_Loc.stride(0), B_Loc.stride(1),\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n att_out.stride(0), att_out.stride(1),\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for token attention. The kernel function '_fwd_kernel_token_att1' takes 18 parameters: Q, K, sm_scale, Alibi, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, Att_Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, att_stride_h, att_stride_bs, and two constexpr parameters BLOCK_DMODEL and BLOCK_N. It computes the attention values using the provided query and key tensors, scaling factor, and alibi, and stores the result in Att_Out. The function 'token_att_fwd' is a wrapper that sets up the grid and block dimensions and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a token attention forward kernel that computes attention scores using query and key tensors, scaling factor, and alibi, and stores the results in an output tensor. The kernel is executed with a grid configuration based on batch size, number of heads, and maximum input length.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob, V, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n stride_b_loc_b, stride_b_loc_s,\n stride_ph, stride_pbs,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = cur_batch_seq_len\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(tl.float16)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n@torch.no_grad()\ndef token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n if triton.__version__ >= \"2.1.0\":\n BLOCK = 128\n else:\n BLOCK = 64\n batch, head = B_Loc.shape[0], v.shape[1]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n _fwd_kernel_token_att2[grid](\n prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n B_Loc.stride(0), B_Loc.stride(1),\n prob.stride(0), prob.stride(1),\n v.stride(0), v.stride(1), v.stride(2),\n out.stride(0), out.stride(1), out.stride(2),\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for token attention. The kernel '_fwd_kernel_token_att2' takes 18 parameters: Prob, V, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, stride_b_loc_b, stride_b_loc_s, stride_ph, stride_pbs, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, and two constexpr parameters BLOCK_DMODEL and BLOCK_N. It computes the attention output by iterating over the sequence length in blocks and accumulating the weighted sum of values. The function 'token_att_fwd2' is a wrapper that sets up the grid and block dimensions and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a token attention forward kernel that processes input tensors in blocks, computes weighted sums, and stores the results. The kernel is invoked with a wrapper function that configures execution parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, Alibi, B_Loc, B_Seqlen, max_input_len,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_b_loc_b, stride_b_loc_s,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n off_k = cur_head * stride_kh + offs_d[None, :] * stride_kd\n off_v = cur_head * stride_vh + offs_d[None, :] * stride_vd\n off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s\n\n q = tl.load(Q + off_q)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n m_i = -float(\"inf\")\n l_i = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_m = tl.load(Alibi + cur_head)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0)\n k = tl.load(k_ptrs + k_index[:, None] * stride_kbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_N,], dtype=tl.float32)\n qk += tl.sum(q[None, :] * k, 1)\n qk *= sm_scale\n\n alibi_loc = cur_batch_seq_len - 1 - (start_n + offs_n)\n qk -= alibi_loc * alibi_m\n\n qk = tl.where(cur_batch_seq_len > (start_n + offs_n), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 0)\n p = tl.exp(qk - m_ij)\n l_ij = tl.sum(p, 0)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale\n # update acc\n v_index = k_index\n v = tl.load(v_ptrs + v_index[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n # print(p)\n acc += tl.sum(p[:, None] * v, 0)\n\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n@torch.no_grad()\ndef token_attention_fwd(q, k, v, o, alibi, b_loc, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5) # 计算scale系数\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head) # batch, head,\n\n num_warps = 4 if Lk <= 64 else 8\n num_warps = 4\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, alibi, b_loc, b_seq_len, max_input_len,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n b_loc.stride(0), b_loc.stride(1),\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for token attention. The kernel function '_fwd_kernel' takes 22 parameters: Q, K, V, sm_scale, Alibi, B_Loc, B_Seqlen, max_input_len, Out, and 12 stride parameters, along with two block size constants. It computes the attention scores and updates the output accumulator. The function 'token_attention_fwd' is a wrapper that prepares the input data and launches the '_fwd_kernel' with the appropriate grid and block configurations.", - "description_2": "Use triton language to implement a token attention forward kernel with 22 parameters, including input tensors, scaling factors, and block sizes. The kernel computes attention scores and updates outputs, with a wrapper function to handle input preparation and kernel launch.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rotary_kernel(\n Q, Cos, Sin,\n stride_qbs, stride_qh, stride_qd,\n stride_cosbs, stride_cosd,\n stride_sinbs, stride_sind,\n max_total_len,\n H, # N_CTX represents the context length to compute\n BLOCK_HEAD: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2\n dim_range1 = dim_range0 + 1\n off_q0 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range0[None, None, :] * stride_qd\n off_q1 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range1[None, None, :] * stride_qd\n\n cos_range = tl.arange(0, BLOCK_DMODEL // 2)\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd\n\n q0 = tl.load(Q + off_q0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0)\n q1 = tl.load(Q + off_q1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0)\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos - q1 * sin\n out1 = q0 * sin + q1 * cos\n\n tl.store(Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H))\n tl.store(Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H))\n\n return\n\n@torch.no_grad()\ndef rotary_emb_fwd(q, cos, sin):\n total_len = q.shape[0]\n head_num = q.shape[1]\n head_dim = q.shape[2] // 2\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n BLOCK_HEAD = 4\n BLOCK_SEQ = 32\n grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n _rotary_kernel[grid](\n q, cos, sin,\n q.stride(0), q.stride(1), q.stride(2),\n cos.stride(0), cos.stride(1),\n sin.stride(0), sin.stride(1),\n total_len, head_num,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=head_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a rotary embedding kernel function (_rotary_kernel) for tensor Q, cos, and sin with grid-strided loop. It requires 11 parameters: Q (query tensor), Cos (cosine values), Sin (sine values), stride_qbs, stride_qh, stride_qd (strides for query tensor), stride_cosbs, stride_cosd (strides for cosine tensor), stride_sinbs, stride_sind (strides for sine tensor), max_total_len, and H (head count). The kernel computes the rotary embedding by applying the cosine and sine transformations on sub-parts of Q, specifically on blocks specified by BLOCK_HEAD, BLOCK_SEQ, and BLOCK_DMODEL. The output is stored back into Q.", - "description_2": "Use triton language to compute rotary embeddings for a query tensor using cosine and sine matrices. The function rotary_emb_fwd launches this kernel using grid-strided logic where grid is defined by head number and sequence length, and sets the number of warps based on head dimension, then invokes the triton kernel with parameters including tensor strides and grid dimensions.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n kv_group_num,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n \n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale, b_start_loc, b_seq_len,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\nelif triton.__version__ == \"2.0.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen,\n TMP,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_tmp_b, stride_tmp_h, stride_tmp_s,\n kv_group_num,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs)\n acc = acc * acc_scale[:, None]\n\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n \n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale, b_start_loc, b_seq_len,\n tmp,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n tmp.stride(0), tmp.stride(1), tmp.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for attention mechanism with dynamic batch and sequence length handling. The function '_fwd_kernel' has 21+3 constexpr arguments, where Q, K, V are tensors representing query, key, and value respectively. 'sm_scale' is a scalar for scaling attention scores. 'B_Start_Loc' and 'B_Seqlen' are arrays indicating the start location and sequence length for each batch. 'Out' is the output tensor. 'stride_*' are the strides for accessing elements in Q, K, V, Out. 'kv_group_num' divides heads into groups for shared K, V. The BLOCK_M, BLOCK_DMODEL, BLOCK_N are compile-time constants defining block sizes.", - "description_2": "Use triton language to define and execute an efficient forward attention kernel supporting variable batch and sequence lengths with optional scaling, utilizing multiple warps.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward token attention\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob, V, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, # B_Start_Loc stores cumulative input sum if stored continuously\n stride_b_loc_b, stride_b_loc_s,\n stride_ph, stride_pbs,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n kv_group_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n \n cur_kv_head = cur_head // kv_group_num\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = cur_batch_seq_len\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(tl.float16)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n# Function to call the Triton kernel\n@torch.no_grad()\ndef token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n if triton.__version__ >= \"2.1.0\":\n BLOCK = 128\n else:\n BLOCK = 64\n batch, head = B_Loc.shape[0], prob.shape[0]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n \n kv_group_num = prob.shape[0] // v.shape[1]\n\n _fwd_kernel_token_att2[grid](\n prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n B_Loc.stride(0), B_Loc.stride(1),\n prob.stride(0), prob.stride(1),\n v.stride(0), v.stride(1), v.stride(2),\n out.stride(0), out.stride(1), out.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward token attention kernel. The kernel takes 16 parameters: Prob, V, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, stride_b_loc_b, stride_b_loc_s, stride_ph, stride_pbs, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, kv_group_num, BLOCK_DMODEL, and BLOCK_N. It computes the attention output by iterating over the sequence length in blocks, loading probability and value tensors, and accumulating the results. The function token_att_fwd2 calls this kernel with appropriate grid and block settings based on input tensor dimensions and strides.", - "description_2": "Use triton language to create a kernel for forward token attention, processing input tensors in blocks and accumulating results. Implement a function to configure and launch this kernel with specific grid and block parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward pass of token softmax.\n@triton.jit\ndef _fwd_kernel_token_softmax(\n Logics, B_Start_Loc, B_Seqlen,\n Prob_Out,\n stride_logic_h, stride_logic_bs,\n stride_prob_h, stride_prob_bs,\n BLOCK_SIZE: tl.constexpr\n):\n # Determine the current batch and head being processed.\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n # Calculate the column offsets and load batch sequence length and start index.\n col_offsets = tl.arange(0, BLOCK_SIZE)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n # Load logic values for the current head and batch.\n row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs,\n mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32)\n\n # Compute softmax.\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n\n # Store the softmax output.\n tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets)\n * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len)\n return\n\n\n# Python function to invoke the Triton kernel for token softmax forward pass.\n@torch.no_grad()\ndef token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len):\n BLOCK_SIZE = triton.next_power_of_2(max_input_len)\n batch, head_num = B_Start_Loc.shape[0], Logics.shape[0]\n\n # Determine number of warps based on block size.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n\n # Launch the Triton kernel with calculated configurations.\n _fwd_kernel_token_softmax[(batch, head_num)](\n Logics, B_Start_Loc, B_Seqlen,\n Prob_Out,\n Logics.stride(0), Logics.stride(1),\n Prob_Out.stride(0), Prob_Out.stride(1),\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return\n", - "description_1": "Use triton language to implement a softmax function for batched and multi-headed logic data. The _fwd_kernel_token_softmax function has 9 parameters: Logics, B_Start_Loc, B_Seqlen, Prob_Out, stride_logic_h, stride_logic_bs, stride_prob_h, stride_prob_bs, BLOCK_SIZE. Logics is a matrix containing logic data, B_Start_Loc indicates the start position of each batch, B_Seqlen provides sequence lengths, and Prob_Out is the output buffer for storing softmax probabilities. The stride parameters define memory strides, and BLOCK_SIZE is a constant defining the maximum block size. The function calculates softmax over each logic row in a parallelized manner using Triton's parallel computing capabilities. The token_softmax_fwd function is a wrapper to configure the kernel launch, with 5 parameters: Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len. It determines optimal execution parameters based on input lengths, setting num_warps and BLOCK_SIZE, before invoking the Triton kernel.", - "description_2": "Use triton language to implement and execute a parallel softmax function on batched, multi-headed data using kernel and warp configurations for optimal performance.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K, Dest_loc, Out, Out_scale,\n stride_k_bs, stride_k_h, stride_k_g, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_g, stride_o_d,\n stride_os_bs, stride_os_h, stride_os_g,\n group_size,\n BLOCK_GROUP_NUM: tl.constexpr,\n BLOCK_GROUP_DIM: tl.constexpr \n):\n cur_index = tl.program_id(0)\n cur_head = tl.program_id(1)\n \n offs_g = tl.arange(0, BLOCK_GROUP_NUM)\n offs_d = tl.arange(0, BLOCK_GROUP_DIM)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n src_data = tl.load(K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], \n mask=offs_g[:, None] < group_size, other=0.0)\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)\n q_src_data = (src_data / data_scale[:, None]).to(tl.int8)\n \n o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g\n tl.store(o_ptrs, q_src_data, mask=offs_g[:, None]= \"2.1.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale, b_start_loc, b_seq_len,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\nelif triton.__version__ == \"2.0.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen,\n TMP,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_tmp_b, stride_tmp_h, stride_tmp_s,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s\n # t_ptrs = TMP + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale, b_start_loc, b_seq_len,\n tmp,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n tmp.stride(0), tmp.stride(1), tmp.stride(2),\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to define a forward kernel (_fwd_kernel) for computing attention mechanism. The kernel takes 20 parameters, including input tensors Q, K, V, a scaling factor (sm_scale), start and length of batches (B_Start_Loc, B_Seqlen), output tensor (Out), strides for all input and output tensors, and constants for block sizes (BLOCK_M, BLOCK_DMODEL, BLOCK_N). It performs matrix operations (dot products, scaling, and accumulation) within a loop to compute attention scores and updates outputs based on these scores. A context_attention_fwd function is defined to configure and launch this kernel with 7 parameters: q, k, v, o, b_start_loc, b_seq_len, and max_input_len.", - "description_2": "Use triton language to implement an attention mechanism kernel and a corresponding forward pass function that launches the kernel with appropriate configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rms_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n # Write output\n tl.store(Y + cols, y.to(tl.float16), mask=mask)\n\ndef rmsnorm_forward(x, weight, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n # print(\"BLOCK_SIZE:\", BLOCK_SIZE)\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # print(BLOCK_SIZE, num_warps, \"block_size, numwarps\")\n BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2\n num_warps = 8\n # enqueue kernel\n _rms_norm_fwd_fused[(M,)](x_arg, y, weight,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n return y\n", - "description_1": "Use triton language to implement a fused RMS normalization kernel. The kernel '_rms_norm_fwd_fused' takes 7 arguments: X (input tensor pointer), Y (output tensor pointer), W (weights pointer), stride (integer indicating row stride), N (number of columns in X), eps (epsilon for numerical stability), and BLOCK_SIZE (constant for block size). The 'rmsnorm_forward' function prepares the inputs and invokes the kernel.", - "description_2": "Use triton language to perform a forward pass RMS normalization using a custom kernel with fused operations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rotary_kernel(\n Q, Cos, Sin,\n stride_qbs, stride_qh, stride_qd,\n stride_cosbs, stride_cosd,\n stride_sinbs, stride_sind,\n max_total_len,\n H, # N_CTX represents the context length to compute\n BLOCK_HEAD: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)\n dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL)\n\n off_q0 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range0[None, None, :] * stride_qd\n off_q1 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range1[None, None, :] * stride_qd\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n q0 = tl.load(Q + off_q0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0)\n q1 = tl.load(Q + off_q1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), other=0.0)\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos - q1 * sin\n out1 = q0 * sin + q1 * cos\n\n tl.store(Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H))\n tl.store(Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H))\n\n return\n\n@torch.no_grad()\ndef rotary_emb_fwd(q, cos, sin):\n total_len = q.shape[0]\n head_num = q.shape[1]\n head_dim = q.shape[2]\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n BLOCK_HEAD = 4\n BLOCK_SEQ = 32\n grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n\n _rotary_kernel[grid](\n q, cos, sin,\n q.stride(0), q.stride(1), q.stride(2),\n cos.stride(0), cos.stride(1),\n sin.stride(0), sin.stride(1),\n total_len, head_num,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=head_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a rotary kernel function that performs element-wise operations on input tensors Q, Cos, and Sin. The kernel uses block-based indexing to load, compute, and store results in a parallelized manner. The rotary_emb_fwd function sets up the grid and block dimensions and calls the kernel with appropriate parameters.", - "description_2": "Use triton language to create a rotary kernel for element-wise tensor operations and a wrapper function to configure and launch the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_token_att1(\n Q, K, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n Att_Out,\n stride_b_loc_b, stride_b_loc_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n att_stride_h, att_stride_bs,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = max_input_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)\n off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n@torch.no_grad()\ndef token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n BLOCK = 32\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_Loc.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4\n \n _fwd_kernel_token_att1[grid](\n q, k, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n att_out,\n B_Loc.stride(0), B_Loc.stride(1),\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n att_out.stride(0), att_out.stride(1),\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n@triton.jit\ndef _fwd_kernel_token_att1_int8(\n Q, K, K_scale, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n Att_Out,\n stride_b_loc_b, stride_b_loc_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_ksbs, stride_ksh, stride_ksd,\n att_stride_h, att_stride_bs,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = max_input_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)\n off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n off_ks = k_loc[:, None] * stride_ksbs + cur_head * stride_ksh\n k_scale = tl.load(K_scale + off_ks, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k * k_scale, 1)\n att_value *= sm_scale\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n@torch.no_grad()\ndef token_att_fwd_int8k(q, k, k_scale, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n BLOCK = 32\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_Loc.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n num_warps = 2\n\n _fwd_kernel_token_att1_int8[grid](\n q, k, k_scale, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n att_out,\n B_Loc.stride(0), B_Loc.stride(1),\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n k_scale.stride(0), k_scale.stride(1), k_scale.stride(2),\n att_out.stride(0), att_out.stride(1),\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two kernels for token attention. The first kernel, _fwd_kernel_token_att1, takes 18 parameters: Q, K, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, Att_Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, att_stride_h, att_stride_bs, and two constexpr parameters BLOCK_DMODEL and BLOCK_N. It computes attention values by loading and processing blocks of Q and K matrices. The second kernel, _fwd_kernel_token_att1_int8, is similar but includes an additional parameter K_scale for int8 quantized K matrices. It also takes 21 parameters and computes attention values with scaling for int8 quantization. Both kernels are called by their respective wrapper functions, token_att_fwd and token_att_fwd_int8k, which set up the grid and block dimensions and pass the necessary parameters.", - "description_2": "Use triton language to create two kernels for token attention computation, one for standard floating-point and another for int8 quantized inputs, each with specific parameters for matrix dimensions and strides.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob, V, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n stride_b_loc_b, stride_b_loc_s,\n stride_ph, stride_pbs,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = cur_batch_seq_len\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(tl.float16)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n@torch.no_grad()\ndef token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n if triton.__version__ >= \"2.1.0\":\n BLOCK = 128\n else:\n BLOCK = 64\n batch, head = B_Loc.shape[0], v.shape[1]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n _fwd_kernel_token_att2[grid](\n prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n B_Loc.stride(0), B_Loc.stride(1),\n prob.stride(0), prob.stride(1),\n v.stride(0), v.stride(1), v.stride(2),\n out.stride(0), out.stride(1), out.stride(2),\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n@triton.jit\ndef _fwd_kernel_token_att2_int8v(\n Prob, V, V_scale, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n stride_b_loc_b, stride_b_loc_s,\n stride_ph, stride_pbs,\n stride_vbs, stride_vh, stride_vd,\n stride_vsbs, stride_vsh, stride_vsd,\n stride_obs, stride_oh, stride_od,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = cur_batch_seq_len\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_head * stride_vh + offs_d[None, :] * stride_vd\n vs_offs = cur_head * stride_vsh\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n vs_value = tl.load(V_scale + vs_offs + v_loc[:, None] * stride_vsbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n acc += tl.sum(p_value[:, None] * v_value * vs_value, 0)\n\n acc = acc.to(tl.float16)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n@torch.no_grad()\ndef token_att_fwd2_int8v(prob, v, v_scale, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n if max_input_len < 512:\n BLOCK = triton.next_power_of_2(max_input_len)\n else:\n BLOCK = 512\n batch, head = B_Loc.shape[0], v.shape[1]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n _fwd_kernel_token_att2_int8v[grid](\n prob, v, v_scale, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n B_Loc.stride(0), B_Loc.stride(1),\n prob.stride(0), prob.stride(1),\n v.stride(0), v.stride(1), v.stride(2),\n v_scale.stride(0), v_scale.stride(1), v_scale.stride(2),\n out.stride(0), out.stride(1), out.stride(2),\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two kernels for token attention. The first kernel, _fwd_kernel_token_att2, takes 15 parameters: Prob, V, Out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, and strides for various dimensions, along with BLOCK_DMODEL and BLOCK_N as constexpr. It computes the attention output by iterating over the sequence length and accumulating results. The second kernel, _fwd_kernel_token_att2_int8v, is similar but includes V_scale for int8 operations. It takes 18 parameters, including V_scale and its strides. Both kernels are called by their respective functions, token_att_fwd2 and token_att_fwd2_int8v, which set up the grid and block dimensions and invoke the kernels with appropriate arguments.", - "description_2": "Use triton language to create two token attention kernels. The first kernel processes float inputs, while the second handles int8 inputs with scaling. Both kernels iterate over sequence lengths to compute attention outputs, using grid and block dimensions set by their calling functions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import grid\nfrom torch import empty_strided\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.codecache import AsyncCompile\n\nasync_compile = AsyncCompile()\n\n# Triton kernel for pointwise operation with a single input pointer and single output pointer\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 32\n x1 = (xindex // 32) % 256\n x2 = (xindex // 8192) % 16\n x3 = (xindex // 131072)\n x4 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (32 * x2) + (1536 * x1) + (393216 * x3)), None).to(tl.float32)\n tl.store(out_ptr0 + x4, tmp0, None)\n\n# Triton kernel for fused softmax and division\n@triton.jit\ndef triton_(in_ptr0, out_ptr2, xnumel, rnumel):\n XBLOCK: tl.constexpr = 1\n RBLOCK: tl.constexpr = 256\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (256 * x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = 5.656854249492381\n tmp2 = tmp0 / tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tl.broadcast_to(tmp3, [RBLOCK])\n tmp6 = tl.where(rmask, tmp4, float(\"-inf\"))\n tmp7 = tl.max(tmp6, axis=0)\n tmp8 = tmp3 - tmp7\n tmp9 = tl.exp(tmp8)\n tmp10 = tl.broadcast_to(tmp9, [RBLOCK])\n tmp12 = tl.where(rmask, tmp10, 0)\n tmp13 = tl.sum(tmp12, axis=0)\n tmp14 = tmp9 / tmp13\n tmp15 = tmp14.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (256 * x0)), tmp15, rmask)\n\nasync_compile.wait(globals())\ndel async_compile\n\ndef call(args):\n arg0_1, arg1_1, arg2_1 = args\n args.clear()\n buf0 = empty_strided((16, 16, 256, 32), (131072, 8192, 32, 1), torch.float16, device='cuda')\n stream0 = get_raw_stream(0)\n triton_poi_fused_clone_0.run(arg1_1, buf0, 2097152, grid=grid(2097152), stream=stream0)\n buf1 = empty_strided((16, 16, 32, 256), (131072, 8192, 256, 1), torch.float16, device='cuda')\n triton_poi_fused_clone_1.run(arg0_1, buf1, 8192, 256, grid=grid(8192, 256), stream=stream0)\n buf2 = empty_strided((256, 256, 256), (65536, 256, 1), torch.float16, device='cuda')\n buf5 = empty_strided((16, 16, 256, 256), (1048576, 65536, 256, 1), torch.float16, device='cuda')\n triton_per_fused__softmax_div_2.run(buf2, buf5, 65536, 256, grid=grid(65536), stream=stream0)\n buf6 = empty_strided((16, 16, 256, 32), (131072, 8192, 32, 1), torch.float16, device='cuda')\n triton_poi_fused_clone_0.run(arg2_1, buf6, 2097152, grid=grid(2097152), stream=stream0)\n buf7 = empty_strided((256, 256, 32), (8192, 32, 1), torch.float16, device='cuda')\n return (buf7, )\n\n", - "description_1": "Use triton language to define three kernels for pointwise operations and fused softmax and division, each with specific argument counts and tensor manipulations.", - "description_2": "Use triton language to execute CUDA kernels for tensor cloning and softmax operation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch import empty_strided\n\n# Kernel 1: triton_poi_fused_clone_0\n@triton.jit\ndef triton_poi_fused_clone_0(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 32\n x1 = (xindex // 32) % 256\n x2 = (xindex // 8192) % 16\n x3 = (xindex // 131072)\n x4 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1536*x1) + (393216*x3)), None).to(tl.float32)\n tl.store(out_ptr0 + (x4), tmp0, None)\n\n# Kernel 2: triton_poi_fused_clone_1\n@triton.jit\ndef triton_poi_fused_clone_1(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK: tl.constexpr, XBLOCK: tl.constexpr):\n ynumel = 8192\n xnumel = 256\n yoffset = tl.program_id(1) * YBLOCK\n yindex = yoffset + tl.arange(0, YBLOCK)[None, :]\n ymask = yindex < ynumel\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n x2 = xindex\n y0 = yindex % 512\n y1 = (yindex // 512)\n y3 = yindex\n tmp0 = tl.load(in_ptr0 + (y0 + (1536*x2) + (393216*y1)), xmask, eviction_policy='evict_last').to(tl.float32)\n tl.store(out_ptr0 + (x2 + (256*y3)), tmp0, xmask)\n\n# Kernel 3: triton_per_fused__softmax_div_2\n@triton.jit\ndef triton_per_fused__softmax_div_2(in_ptr0, out_ptr2, xnumel, rnumel):\n xnumel = 65536\n XBLOCK: tl.constexpr = 1\n rnumel = 256\n RBLOCK: tl.constexpr = 256\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = 5.656854249492381\n tmp2 = tmp0 / tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tl.broadcast_to(tmp3, [RBLOCK])\n tmp6 = tl.where(rmask, tmp4, float(\"-inf\"))\n tmp7 = triton_helpers.promote_to_tensor(triton_helpers.max2(tmp6, 0))\n tmp8 = tmp3 - tmp7\n tmp9 = tl.exp(tmp8)\n tmp10 = tl.broadcast_to(tmp9, [RBLOCK])\n tmp12 = tl.where(rmask, tmp10, 0)\n tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp12, 0))\n tmp14 = tmp9 / tmp13\n tmp15 = tmp14.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (256*x0)), tmp15, rmask)\n\ndef call(args):\n arg0_1, arg1_1, arg2_1 = args\n args.clear()\n assert_size_stride(arg0_1, (16, 16, 256, 32), (393216, 32, 1536, 1))\n assert_size_stride(arg1_1, (16, 16, 256, 32), (393216, 32, 1536, 1))\n assert_size_stride(arg2_1, (16, 16, 256, 32), (393216, 32, 1536, 1))\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided((16, 16, 256, 32), (131072, 8192, 32, 1), torch.float16, device='cuda')\n stream0 = get_raw_stream(0)\n triton_poi_fused_clone_0.run(arg1_1, buf0, 2097152, grid=grid(2097152), stream=stream0)\n del arg1_1\n buf1 = empty_strided((16, 16, 32, 256), (131072, 8192, 256, 1), torch.float16, device='cuda')\n triton_poi_fused_clone_1.run(arg0_1, buf1, 8192, 256, grid=grid(8192, 256), stream=stream0)\n del arg0_1\n buf2 = empty_strided((256, 256, 256), (65536, 256, 1), torch.float16, device='cuda')\n extern_kernels.bmm(reinterpret_tensor(buf0, (256, 256, 32), (8192, 32, 1), 0), reinterpret_tensor(buf1, (256, 32, 256), (8192, 256, 1), 0), out=buf2)\n buf5 = empty_strided((16, 16, 256, 256), (1048576, 65536, 256, 1), torch.float16, device='cuda')\n triton_per_fused__softmax_div_2.run(buf2, buf5, 65536, 256, grid=grid(65536), stream=stream0)\n del buf2\n buf6 = reinterpret_tensor(buf1, (16, 16, 256, 32), (131072, 8192, 32, 1), 0); del buf1\n triton_poi_fused_clone_0.run(arg2_1, buf6, 2097152, grid=grid(2097152), stream=stream0)\n del arg2_1\n buf7 = reinterpret_tensor(buf0, (256, 256, 32), (8192, 32, 1), 0); del buf0\n extern_kernels.bmm(reinterpret_tensor(buf5, (256, 256, 256), (65536, 256, 1), 0), reinterpret_tensor(buf6, (256, 256, 32), (8192, 32, 1), 0), out=buf7)\n del buf5\n del buf6\n return (reinterpret_tensor(buf7, (16, 16, 256, 32), (131072, 8192, 32, 1), 0), )\n", - "description_1": "Use triton language to implement three kernels: triton_poi_fused_clone_0, triton_poi_fused_clone_1, and triton_per_fused__softmax_div_2. The first kernel takes three arguments: in_ptr0 (input pointer), out_ptr0 (output pointer), and xnumel (number of elements), and performs a pointwise operation. The second kernel takes five arguments: in_ptr0, out_ptr0, ynumel, xnumel, and two block sizes, performing a pointwise operation with tiling. The third kernel takes four arguments: in_ptr0, out_ptr2, xnumel, and rnumel, and performs a persistent reduction operation. The call function orchestrates the execution of these kernels, managing memory and device settings.", - "description_2": "Use triton language to create three CUDA kernels for pointwise and reduction operations, and manage their execution with a call function that handles memory and device settings.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import grid\nfrom torch import empty_strided_cuda\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.select_algorithm import extern_kernels\nfrom torch._inductor.utils import reinterpret_tensor\n\n# Kernel 1: triton_poi_fused_clone_0\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 32\n x1 = (xindex // 32) % 256\n x2 = (xindex // 8192) % 16\n x3 = (xindex // 131072)\n x4 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (32 * x2) + (1536 * x1) + (393216 * x3)), None).to(tl.float32)\n tl.store(out_ptr0 + (x4), tmp0, None)\n\n# Kernel 2: triton_poi_fused_clone_1\n@triton.jit\ndef triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK: tl.constexpr, XBLOCK: tl.constexpr):\n ynumel = 8192\n xnumel = 256\n yoffset = tl.program_id(1) * YBLOCK\n yindex = yoffset + tl.arange(0, YBLOCK)[None, :]\n ymask = yindex < ynumel\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n x2 = xindex\n y0 = yindex % 512\n y1 = (yindex // 512)\n y3 = yindex\n tmp0 = tl.load(in_ptr0 + (y0 + (1536 * x2) + (393216 * y1)), xmask, eviction_policy='evict_last').to(tl.float32)\n tl.store(out_ptr0 + (x2 + (256 * y3)), tmp0, xmask)\n\n# Kernel 3: triton_per_fused__softmax_div_2\n@triton.jit\ndef triton_(in_ptr0, out_ptr2, xnumel, rnumel):\n xnumel = 65536\n XBLOCK: tl.constexpr = 1\n rnumel = 256\n RBLOCK: tl.constexpr = 256\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n roffset = 0\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (256 * x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = 5.656854249492381\n tmp2 = tmp0 / tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = tl.broadcast_to(tmp3, [RBLOCK])\n tmp6 = tl.where(rmask, tmp4, float(\"-inf\"))\n tmp7 = triton_helpers.promote_to_tensor(triton_helpers.max2(tmp6, 0))\n tmp8 = tmp3 - tmp7\n tmp9 = tl.exp(tmp8)\n tmp10 = tl.broadcast_to(tmp9, [RBLOCK])\n tmp12 = tl.where(rmask, tmp10, 0)\n tmp13 = triton_helpers.promote_to_tensor(tl.sum(tmp12, 0))\n tmp14 = tmp9 / tmp13\n tmp15 = tmp14.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (256 * x0)), tmp15, rmask)\n\ndef call(args):\n arg0_1, arg1_1, arg2_1 = args\n args.clear()\n buf0 = empty_strided_cuda((16, 16, 256, 32), (131072, 8192, 32, 1), torch.float16)\n stream0 = get_raw_stream(0)\n triton_poi_fused_clone_0.run(arg1_1, buf0, 2097152, grid=grid(2097152), stream=stream0)\n del arg1_1\n buf1 = empty_strided_cuda((16, 16, 32, 256), (131072, 8192, 256, 1), torch.float16)\n triton_poi_fused_clone_1.run(arg0_1, buf1, 8192, 256, grid=grid(8192, 256), stream=stream0)\n del arg0_1\n buf2 = empty_strided_cuda((256, 256, 256), (65536, 256, 1), torch.float16)\n extern_kernels.bmm(reinterpret_tensor(buf0, (256, 256, 32), (8192, 32, 1), 0), reinterpret_tensor(buf1, (256, 32, 256), (8192, 256, 1), 0), out=buf2)\n buf5 = empty_strided_cuda((16, 16, 256, 256), (1048576, 65536, 256, 1), torch.float16)\n triton_per_fused__softmax_div_2.run(buf2, buf5, 65536, 256, grid=grid(65536), stream=stream0)\n del buf2\n buf6 = reinterpret_tensor(buf1, (16, 16, 256, 32), (131072, 8192, 32, 1), 0)\n del buf1\n triton_poi_fused_clone_0.run(arg2_1, buf6, 2097152, grid=grid(2097152), stream=stream0)\n del arg2_1\n buf7 = reinterpret_tensor(buf0, (256, 256, 32), (8192, 32, 1), 0)\n del buf0\n extern_kernels.bmm(reinterpret_tensor(buf5, (256, 256, 256), (65536, 256, 1), 0), reinterpret_tensor(buf6, (256, 256, 32), (8192, 32, 1), 0), out=buf7)\n del buf5\n del buf6\n return (reinterpret_tensor(buf7, (16, 16, 256, 32), (131072, 8192, 32, 1), 0), )\n", - "description_1": "Use triton language to implement three kernels: (1) triton_poi_fused_clone_0 which performs a pointwise clone operation for a tensor of size 2097152 with 3 input pointers and a constant block size parameter; (2) triton_poi_fused_clone_1 which performs another clone operation with a 4-pointer input, for 8192 x 256 elements, taking block sizes for both dimensions as parameters; (3) triton_per_fused__softmax_div_2 which performs a softmax and division operation on input elements with a reduction along the inner dimension, requiring 4 input pointers. The call function then orchestrates these kernels using CUDA streams and grid settings.", - "description_2": "Use triton language to implement multiple kernels for tensor operations including clone and softmax, along with their orchestration via CUDA.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Compute offsets for this program instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@torch.compile(fullgraph=True)\ndef add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)\n return output\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n # Compute offsets for this program instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@torch.compile(fullgraph=True)\ndef add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel_autotuned[grid](x, y, output, n_elements)\n return output\n\nx = torch.randn(4, device=\"cuda\")\ny = torch.randn(4, device=\"cuda\")\nout = add_fn(x, y)\nprint(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")\n", - "description_1": "Use triton language to define two kernels 'add_kernel' and 'add_kernel_autotuned'. Both kernels take five parameters: two input pointers 'in_ptr0', 'in_ptr1', an output pointer 'out_ptr', the number of elements 'n_elements', and a block size 'BLOCK_SIZE' which is a compile-time constant for 'add_kernel'. The kernels compute element-wise addition of two vectors. The 'add_kernel_autotuned' includes multiple configurations for auto-tuning different block sizes, number of stages, and warps. 'add_fn' is a torch function which compiles the given tensors using these kernels and calculates the grid size to execute the kernels accordingly.", - "description_2": "Use triton language to create two vector addition kernels with both static and autotuned block sizes, integrating with torch.compile for computation on CUDA.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import grid\nfrom torch import empty_strided, device\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\n\n# Kernel to perform pointwise addition and store results\n@triton.jit\ndef triton_poi_fused_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n# Kernel to perform pointwise addition and store results\n@triton.jit\ndef triton_poi_fused_1(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (512 + x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (512 + x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n# Kernel to perform pointwise addition and store results\n@triton.jit\ndef triton_poi_fused_2(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (1024 + x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (1024 + x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n# Kernel for fused add and native layer normalization\n@triton.jit\ndef triton_per_fused_add_native_layer_norm_3(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):\n xnumel = 4096\n XBLOCK: tl.constexpr = 1\n rnumel = 512\n RBLOCK: tl.constexpr = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n roffset = 0\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (512 * x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp3 = tl.load(in_ptr2 + (r1 + (512 * x0)), rmask, other=0.0).to(tl.float32)\n tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp4 = tmp2 + tmp3\n tmp5 = tmp4.to(tl.float32)\n tmp6 = tl.broadcast_to(tmp5, [RBLOCK])\n tmp8 = tl.where(rmask, tmp6, 0)\n tmp9 = tl.broadcast_to(tmp6, [RBLOCK])\n tmp11 = tl.where(rmask, tmp9, 0)\n tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))\n tmp13 = tl.full([1], 512, tl.int32)\n tmp14 = tmp13.to(tl.float32)\n tmp15 = tmp12 / tmp14\n tmp16 = tmp6 - tmp15\n tmp17 = tmp16 * tmp16\n tmp18 = tl.broadcast_to(tmp17, [RBLOCK])\n tmp20 = tl.where(rmask, tmp18, 0)\n tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))\n tmp22 = tmp5 - tmp15\n tmp23 = 512.0\n tmp24 = tmp21 / tmp23\n tmp25 = 1e-05\n tmp26 = tmp24 + tmp25\n tmp27 = tl.math.rsqrt(tmp26)\n tmp28 = tmp22 * tmp27\n tmp30 = tmp29.to(tl.float32)\n tmp31 = tmp28 * tmp30\n tmp33 = tmp32.to(tl.float32)\n tmp34 = tmp31 + tmp33\n tmp35 = tmp34.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (512 * x0)), tmp35, rmask)\n\n# Kernel for fused add and GELU activation\n@triton.jit\ndef triton_poi_fused_add_gelu_4(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x2 = xindex\n x0 = xindex % 2048\n tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)\n tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = 0.5\n tmp5 = tmp3 * tmp4\n tmp6 = 0.7071067811865476\n tmp7 = tmp3 * tmp6\n tmp8 = tl.math.erf(tmp7)\n tmp9 = 1.0\n tmp10 = tmp8 + tmp9\n tmp11 = tmp5 * tmp10\n tmp12 = tmp11.to(tl.float32)\n tl.store(in_out_ptr0 + (x2), tmp12, None)\n\n# Function to invoke the Triton kernels\ndef call(args):\n arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1 = args\n args.clear()\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided((4096, 1536), (1536, 1), device='cuda', dtype=torch.float16)\n extern_kernels.mm(reinterpret_tensor(arg0_1, (4096, 512), (512, 1), 0), arg1_1, out=buf0)\n del arg1_1\n buf1 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), device='cuda', dtype=torch.float16)\n stream0 = get_raw_stream(0)\n triton_poi_fused_0.run(buf0, arg2_1, buf1, 2097152, grid=grid(2097152), stream=stream0)\n buf2 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), device='cuda', dtype=torch.float16)\n triton_poi_fused_1.run(buf0, arg2_1, buf2, 2097152, grid=grid(2097152), stream=stream0)\n buf3 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), device='cuda', dtype=torch.float16)\n triton_poi_fused_2.run(buf0, arg2_1, buf3, 2097152, grid=grid(2097152), stream=stream0)\n del arg2_1\n del buf0\n buf4 = aten._scaled_dot_product_flash_attention.default(buf1, buf2, buf3, scale=0.17677669529663687)\n del buf1\n buf5 = buf4[0]\n del buf4\n buf10 = reinterpret_tensor(buf3, (4096, 512), (512, 1), 0); del buf3\n extern_kernels.mm(reinterpret_tensor(buf5, (4096, 512), (512, 1), 0), arg3_1, out=buf10)\n del arg3_1\n buf14 = reinterpret_tensor(buf5, (16, 256, 512), (131072, 512, 1), 0); del buf5\n triton_per_fused_add_native_layer_norm_3.run(buf10, arg4_1, arg0_1, arg5_1, arg6_1, buf14, 4096, 512, grid=grid(4096), stream=stream0)\n del arg0_1\n del arg4_1\n del arg5_1\n del arg6_1\n buf15 = empty_strided((4096, 2048), (2048, 1), device='cuda', dtype=torch.float16)\n extern_kernels.mm(reinterpret_tensor(buf14, (4096, 512), (512, 1), 0), arg7_1, out=buf15)\n del arg7_1\n buf16 = reinterpret_tensor(buf15, (16, 256, 2048), (524288, 2048, 1), 0); del buf15\n triton_poi_fused_add_gelu_4.run(buf16, arg8_1, 8388608, grid=grid(8388608), stream=stream0)\n del arg8_1\n buf17 = buf10; del buf10\n extern_kernels.mm(reinterpret_tensor(buf16, (4096, 2048), (2048, 1), 0), arg9_1, out=buf17)\n del arg9_1\n del buf16\n buf21 = reinterpret_tensor(buf2, (16, 256, 512), (131072, 512, 1), 0); del buf2\n triton_per_fused_add_native_layer_norm_3.run(buf17, arg10_1, buf14, arg11_1, arg12_1, buf21, 4096, 512, grid=grid(4096), stream=stream0)\n del arg10_1\n del arg11_1\n del arg12_1\n del buf14\n del buf17\n return (buf21, )\n\n", - "description_1": "Use triton language to implement multiple kernels for pointwise addition, layer normalization, and GELU activation. Each kernel uses triton.jit and is configured to execute on a CUDA device. Parameters include input/output pointers, grid configuration, and execution constraints.", - "description_2": "Use triton language to implement and run kernels for element-wise operations and normalization on GPU.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.triton_heuristics import grid\nfrom torch._inductor.triton_heuristics import template, pointwise, persistent_reduction\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor import triton_helpers\n\n@template(\n num_stages=3,\n num_warps=8,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), divisible_by_8=())]},\n inductor_meta={'kernel_name': 'triton_tem_fused_mm_0', 'backend_hash': '7e9a460acc4bd8827e2448ca0e8a42787e1dddb62b2cb1089d7ca1dcc9b86db3'},\n)\n@triton.jit\ndef triton_(arg_A, arg_B, out_ptr0):\n GROUP_M : tl.constexpr = 8\n EVEN_K : tl.constexpr = True\n ALLOW_TF32 : tl.constexpr = True\n ACC_TYPE : tl.constexpr = tl.float32\n B_PROLOGUE_CAST_TYPE : tl.constexpr = None\n BLOCK_M : tl.constexpr = 64\n BLOCK_N : tl.constexpr = 64\n BLOCK_K : tl.constexpr = 64\n\n A = arg_A\n B = arg_B\n\n M = 4096\n N = 1536\n K = 512\n if M * N == 0:\n return\n stride_am = 512\n stride_ak = 1\n stride_bk = 1536\n stride_bn = 1\n\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.)\n b = tl.load(B, mask=rk[:, None] < k, other=0.)\n if B_PROLOGUE_CAST_TYPE is not None:\n b = b.to(B_PROLOGUE_CAST_TYPE)\n acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n idx_m = rm[:, None]\n idx_n = rn[None, :]\n mask = (idx_m < M) & (idx_n < N)\n\n xindex = idx_n + (1536*idx_m)\n tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)\n\n\n@pointwise(\n size_hints=[2097152], \n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), divisible_by_8=(3,))]},\n inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '7e9a460acc4bd8827e2448ca0e8a42787e1dddb62b2cb1089d7ca1dcc9b86db3'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (1536*x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n\n@persistent_reduction(\n size_hints=[4096, 512],\n reduction_hint=ReductionHint.INNER,\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=(), divisible_by_8=(6, 7))]},\n inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_4', 'mutated_arg_names': [], 'no_x_dim': True, 'backend_hash': '7e9a460acc4bd8827e2448ca0e8a42787e1dddb62b2cb1089d7ca1dcc9b86db3'}\n)\n@triton.jit\ndef triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):\n xnumel = 4096\n XBLOCK: tl.constexpr = 1\n rnumel = 512\n RBLOCK: tl.constexpr = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n roffset = 0\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp3 = tl.load(in_ptr2 + (r1 + (512*x0)), rmask, other=0.0).to(tl.float32)\n tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp4 = tmp2 + tmp3\n tmp5 = tmp4.to(tl.float32)\n tmp6 = tl.broadcast_to(tmp5, [RBLOCK])\n tmp8 = tl.where(rmask, tmp6, 0)\n tmp9 = tl.broadcast_to(tmp6, [RBLOCK])\n tmp11 = tl.where(rmask, tmp9, 0)\n tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))\n tmp13 = tl.full([1], 512, tl.int32)\n tmp14 = tmp13.to(tl.float32)\n tmp15 = tmp12 / tmp14\n tmp16 = tmp6 - tmp15\n tmp17 = tmp16 * tmp16\n tmp18 = tl.broadcast_to(tmp17, [RBLOCK])\n tmp20 = tl.where(rmask, tmp18, 0)\n tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))\n tmp22 = tmp5 - tmp15\n tmp23 = 512.0\n tmp24 = tmp21 / tmp23\n tmp25 = 1e-05\n tmp26 = tmp24 + tmp25\n tmp27 = tl.math.rsqrt(tmp26)\n tmp28 = tmp22 * tmp27\n tmp30 = tmp29.to(tl.float32)\n tmp31 = tmp28 * tmp30\n tmp33 = tmp32.to(tl.float32)\n tmp34 = tmp31 + tmp33\n tmp35 = tmp34.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (512*x0)), tmp35, rmask)\n\n\n@pointwise(\n size_hints=[8388608], \n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), divisible_by_8=(2,))]},\n inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_gelu_5', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'backend_hash': '7e9a460acc4bd8827e2448ca0e8a42787e1dddb62b2cb1089d7ca1dcc9b86db3'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x2 = xindex\n x0 = xindex % 2048\n tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)\n tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = 0.5\n tmp5 = tmp3 * tmp4\n tmp6 = 0.7071067811865476\n tmp7 = tmp3 * tmp6\n tmp8 = tl.math.erf(tmp7)\n tmp9 = 1.0\n tmp10 = tmp8 + tmp9\n tmp11 = tmp5 * tmp10\n tmp12 = tmp11.to(tl.float32)\n tl.store(in_out_ptr0 + (x2), tmp12, None)\n\n\ndef call(args):\n arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1 = args\n args.clear()\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = torch.empty((4096, 1536), dtype=torch.float16, device='cuda')\n stream0 = get_raw_stream(0)\n triton_tem_fused_mm_0.run(arg0_1, arg1_1, buf0, grid=torch._inductor.kernel.mm_common.mm_grid(4096, 1536, meta0), stream=stream0)\n buf1 = torch.empty((16, 16, 256, 32), dtype=torch.float16, device='cuda')\n triton_poi_fused_1.run(buf0, arg2_1, buf1, 2097152, grid=grid(2097152), stream=stream0)\n buf2 = torch.empty((16, 16, 256, 32), dtype=torch.float16, device='cuda')\n triton_poi_fused_2.run(buf0, arg2_1, buf2, 2097152, grid=grid(2097152), stream=stream0)\n buf3 = torch.empty((16, 16, 256, 32), dtype=torch.float16, device='cuda')\n triton_poi_fused_3.run(buf0, arg2_1, buf3, 2097152, grid=grid(2097152), stream=stream0)\n buf4 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf1, buf2, buf3, scale=0.17677669529663687)\n buf5 = buf4[0]\n buf10 = torch.ops.inductor._reinterpret_tensor(buf3, (4096, 512), (512, 1), 0)\n torch.ops.inductor.extern_kernels.mm(torch.ops.inductor._reinterpret_tensor(buf5, (4096, 512), (512, 1), 0), arg3_1, out=buf10)\n buf14 = torch.ops.inductor._reinterpret_tensor(buf5, (16, 256, 512), (131072, 512, 1), 0)\n triton_per_fused_add_native_layer_norm_4.run(buf10, arg4_1, arg0_1, arg5_1, arg6_1, buf14, 4096, 512, grid=grid(4096), stream=stream0)\n buf15 = torch.empty((4096, 2048), dtype=torch.float16, device='cuda')\n torch.ops.inductor.extern_kernels.mm(torch.ops.inductor._reinterpret_tensor(buf14, (4096, 512), (512, 1), 0), arg7_1, out=buf15)\n buf16 = torch.ops.inductor._reinterpret_tensor(buf15, (16, 256, 2048), (524288, 2048, 1), 0)\n triton_poi_fused_add_gelu_5.run(buf16, arg8_1, 8388608, grid=grid(8388608), stream=stream0)\n buf17 = buf10\n torch.ops.inductor.extern_kernels.mm(torch.ops.inductor._reinterpret_tensor(buf16, (4096, 2048), (2048, 1), 0), arg9_1, out=buf17)\n buf21 = torch.ops.inductor._reinterpret_tensor(buf2, (16, 256, 512), (131072, 512, 1), 0)\n triton_per_fused_add_native_layer_norm_4.run(buf17, arg10_1, buf14, arg11_1, arg12_1, buf21, 4096, 512, grid=grid(4096), stream=stream0)\n return (buf21, )\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (triton_tem_fused_mm_0) with 3 input parameters (arg_A, arg_B, out_ptr0) for performing batched matrix multiplication with constants defining block sizes and accumulation precision. A second kernel (triton_poi_fused_1) adds two half-precision floating-point inputs and outputs the result. Additional kernels (triton_per_fused_add_native_layer_norm_4, triton_poi_fused_add_gelu_5) are implemented for layer normalization and GELU activation. The kernels are executed via a call function which manages CUDA streams and tensor memory allocations for a sequence of operations, including matrix multiplications and element-wise computations.", - "description_2": "Use triton language to define and run a sequence of CUDA kernels for optimized matrix multiplication, addition, layer normalization, and GELU activation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch import empty_strided\n\n# Kernel: Matmul (Matrix Multiplication)\n@triton.jit\ndef triton_(arg_A, arg_B, out_ptr0):\n # Constants\n GROUP_M : tl.constexpr = 8\n EVEN_K : tl.constexpr = True\n ALLOW_TF32 : tl.constexpr = True\n ACC_TYPE : tl.constexpr = tl.float32\n BLOCK_M : tl.constexpr = 64\n BLOCK_N : tl.constexpr = 128\n BLOCK_K : tl.constexpr = 32\n\n A = arg_A\n B = arg_B\n\n M = 16384\n N = 1536\n K = 512\n\n if M * N == 0:\n return\n\n stride_am = 512\n stride_ak = 1\n stride_bk = 1536\n stride_bn = 1\n\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.)\n b = tl.load(B, mask=rk[:, None] < k, other=0.)\n acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n idx_m = rm[:, None]\n idx_n = rn[None, :]\n mask = (idx_m < M) & (idx_n < N)\n xindex = idx_n + (1536 * idx_m)\n tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)\n\ndef call(args):\n arg0_1, arg1_1, *_ = args\n buf0 = empty_strided((16384, 1536), (1536, 1), torch.float16)\n stream0 = get_raw_stream(0)\n triton_.run(arg0_1, arg1_1, buf0, grid=torch._inductor.kernel.mm_common.mm_grid(16384, 1536, {}), stream=stream0)\n return buf0\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (triton_) with parameters arg_A, arg_B, out_ptr0. The grid and block dimensions, along with specific constants such as BLOCK_M, BLOCK_N, BLOCK_K, are defined. The kernel computes the matrix product of inputs A and B and stores the result in out_ptr0, with optimizations for memory access and performance tuning. The function call executes this kernel with input arguments and a CUDA stream.", - "description_2": "Use triton language to perform optimized matrix multiplication on GPU using custom grid and block settings with memory access optimizations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch import empty_strided\n\n# Kernel 1\n@triton.jit\ndef triton_kernel_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n x0 = xindex % 512\n x1 = (xindex // 512)\n tmp0 = tl.load(in_ptr0 + (x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + xindex, tmp2, None)\n\n# Kernel 2\n@triton.jit\ndef triton_kernel_1(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n x0 = xindex % 512\n x1 = (xindex // 512)\n tmp0 = tl.load(in_ptr0 + (512 + x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (512 + x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + xindex, tmp2, None)\n\n# Kernel 3\n@triton.jit\ndef triton_kernel_2(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n x0 = xindex % 512\n x1 = (xindex // 512)\n tmp0 = tl.load(in_ptr0 + (1024 + x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (1024 + x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + xindex, tmp2, None)\n\n# Call function\ndef call(args):\n arg0, arg1, arg2 = args\n args.clear()\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided((16384, 1536), (1536, 1), torch.float16)\n stream0 = get_raw_stream(0)\n # Run Kernel 1\n triton_kernel_0.run(arg0, arg1, buf0, 8388608, grid=grid(8388608), stream=stream0)\n # Run Kernel 2\n triton_kernel_1.run(arg0, arg1, buf0, 8388608, grid=grid(8388608), stream=stream0)\n # Run Kernel 3\n triton_kernel_2.run(arg0, arg1, buf0, 8388608, grid=grid(8388608), stream=stream0)\n return buf0\n", - "description_1": "Use triton language to define three separate pointwise kernels each performing element-wise addition of two input tensors and storing the result in an output tensor. Each kernel handles a specific offset for its inputs and uses a grid-stride loop for processing large data sizes efficiently. The call function manages the execution of these kernels in a CUDA environment, leveraging streams for efficient GPU resource usage.", - "description_2": "Use triton language to define multiple kernels that perform vectorized addition of input tensors on a CUDA device, and execute these kernels sequentially using a high-level Python call function.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import pointwise, persistent_reduction\nfrom torch._inductor.utils import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch import empty_strided\n\n# Triton kernel for pointwise addition\n@pointwise(\n size_hints=[2097152],\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_poi_fused_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n# Triton kernel for persistent reduction and layer normalization\n@persistent_reduction(\n size_hints=[4096, 512],\n reduction_hint=ReductionHint.INNER,\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n)\n@triton.jit\ndef triton_per_fused_add_native_layer_norm_3(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):\n xnumel = 4096\n XBLOCK: tl.constexpr = 1\n rnumel = 512\n RBLOCK: tl.constexpr = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (512 * x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp3 = tl.load(in_ptr2 + (r1 + (512 * x0)), rmask, other=0.0).to(tl.float32)\n tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp4 = tmp2 + tmp3\n tmp5 = tmp4.to(tl.float32)\n tmp6 = tl.broadcast_to(tmp5, [RBLOCK])\n tmp8 = tl.where(rmask, tmp6, 0)\n tmp9 = tl.broadcast_to(tmp6, [RBLOCK])\n tmp11 = tl.where(rmask, tmp9, 0)\n tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))\n tmp13 = tl.full([1], 512, tl.int32)\n tmp14 = tmp13.to(tl.float32)\n tmp15 = tmp12 / tmp14\n tmp16 = tmp6 - tmp15\n tmp17 = tmp16 * tmp16\n tmp18 = tl.broadcast_to(tmp17, [RBLOCK])\n tmp20 = tl.where(rmask, tmp18, 0)\n tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))\n tmp22 = tmp5 - tmp15\n tmp23 = 512.0\n tmp24 = tmp21 / tmp23\n tmp25 = 1e-05\n tmp26 = tmp24 + tmp25\n tmp27 = tl.math.rsqrt(tmp26)\n tmp28 = tmp22 * tmp27\n tmp30 = tmp29.to(tl.float32)\n tmp31 = tmp28 * tmp30\n tmp33 = tmp32.to(tl.float32)\n tmp34 = tmp31 + tmp33\n tmp35 = tmp34.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (512 * x0)), tmp35, rmask)\n\n# Triton kernel for pointwise addition and GELU activation\n@pointwise(\n size_hints=[8388608],\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_poi_fused_add_gelu_4(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x2 = xindex\n x0 = xindex % 2048\n tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)\n tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = 0.5\n tmp5 = tmp3 * tmp4\n tmp6 = 0.7071067811865476\n tmp7 = tmp3 * tmp6\n tmp8 = tl.math.erf(tmp7)\n tmp9 = 1.0\n tmp10 = tmp8 + tmp9\n tmp11 = tmp5 * tmp10\n tmp12 = tmp11.to(tl.float32)\n tl.store(in_out_ptr0 + (x2), tmp12, None)\n\ndef call(args):\n arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1 = args\n args.clear()\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided((4096, 1536), (1536, 1), torch.float16, device='cuda')\n buf1 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), torch.float16, device='cuda')\n stream0 = get_raw_stream(0)\n triton_poi_fused_0.run(buf0, arg2_1, buf1, 2097152, grid=grid(2097152), stream=stream0)\n buf2 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), torch.float16, device='cuda')\n triton_poi_fused_0.run(buf0, arg2_1, buf2, 2097152, grid=grid(2097152), stream=stream0)\n buf3 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), torch.float16, device='cuda')\n triton_poi_fused_0.run(buf0, arg2_1, buf3, 2097152, grid=grid(2097152), stream=stream0)\n buf4 = aten._scaled_dot_product_flash_attention.default(buf1, buf2, buf3, scale=0.17677669529663687)\n buf5 = buf4[0]\n buf10 = reinterpret_tensor(buf3, (4096, 512), (512, 1), 0)\n extern_kernels.mm(reinterpret_tensor(buf5, (4096, 512), (512, 1), 0), arg3_1, out=buf10)\n buf14 = reinterpret_tensor(buf5, (16, 256, 512), (131072, 512, 1), 0)\n triton_per_fused_add_native_layer_norm_3.run(buf10, arg4_1, arg0_1, arg5_1, arg6_1, buf14, 4096, 512, grid=grid(4096), stream=stream0)\n buf15 = empty_strided((4096, 2048), (2048, 1), torch.float16, device='cuda')\n extern_kernels.mm(reinterpret_tensor(buf14, (4096, 512), (512, 1), 0), arg7_1, out=buf15)\n buf16 = reinterpret_tensor(buf15, (16, 256, 2048), (524288, 2048, 1), 0)\n triton_poi_fused_add_gelu_4.run(buf16, arg8_1, 8388608, grid=grid(8388608), stream=stream0)\n buf17 = buf10\n extern_kernels.mm(reinterpret_tensor(buf16, (4096, 2048), (2048, 1), 0), arg9_1, out=buf17)\n buf21 = reinterpret_tensor(buf2, (16, 256, 512), (131072, 512, 1), 0)\n triton_per_fused_add_native_layer_norm_3.run(buf17, arg10_1, buf14, arg11_1, arg12_1, buf21, 4096, 512, grid=grid(4096), stream=stream0)\n return (buf21, )\n", - "description_1": "Use triton language to implement multiple kernels for pointwise addition, persistent reduction with layer normalization, and GELU activation. The kernels handle operations on tensors with specific shapes and strides, utilizing CUDA for parallel execution.", - "description_2": "Use triton language to create kernels for tensor operations including addition, layer normalization, and GELU activation, optimized for CUDA execution.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom torch._inductor.triton_heuristics import pointwise, persistent_reduction\nfrom torch._inductor.utils import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch import empty_strided\n\n@pointwise(\n size_hints=[2097152],\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_poi_fused_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n@pointwise(\n size_hints=[2097152],\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_poi_fused_1(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (512 + x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (512 + x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n@pointwise(\n size_hints=[2097152],\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_poi_fused_2(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 2097152\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex % 512\n x1 = (xindex // 512)\n x2 = xindex\n tmp0 = tl.load(in_ptr0 + (1024 + x0 + (1536 * x1)), None).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (1024 + x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr0 + (x2), tmp2, None)\n\n@persistent_reduction(\n size_hints=[4096, 512],\n reduction_hint=ReductionHint.INNER,\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda'}\n)\n@triton.jit\ndef triton_per_fused_add_native_layer_norm_3(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel):\n xnumel = 4096\n XBLOCK: tl.constexpr = 1\n rnumel = 512\n RBLOCK: tl.constexpr = 512\n xoffset = tl.program_id(0) * XBLOCK\n xindex = tl.full([1], xoffset, tl.int32)\n xmask = xindex < xnumel\n rindex = tl.arange(0, RBLOCK)[:]\n rmask = rindex < rnumel\n r1 = rindex\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (r1 + (512 * x0)), rmask, other=0.0).to(tl.float32)\n tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp3 = tl.load(in_ptr2 + (r1 + (512 * x0)), rmask, other=0.0).to(tl.float32)\n tmp29 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp32 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp4 = tmp2 + tmp3\n tmp5 = tmp4.to(tl.float32)\n tmp6 = tl.broadcast_to(tmp5, [RBLOCK])\n tmp8 = tl.where(rmask, tmp6, 0)\n tmp9 = tl.broadcast_to(tmp6, [RBLOCK])\n tmp11 = tl.where(rmask, tmp9, 0)\n tmp12 = triton_helpers.promote_to_tensor(tl.sum(tmp11, 0))\n tmp13 = tl.full([1], 512, tl.int32)\n tmp14 = tmp13.to(tl.float32)\n tmp15 = tmp12 / tmp14\n tmp16 = tmp6 - tmp15\n tmp17 = tmp16 * tmp16\n tmp18 = tl.broadcast_to(tmp17, [RBLOCK])\n tmp20 = tl.where(rmask, tmp18, 0)\n tmp21 = triton_helpers.promote_to_tensor(tl.sum(tmp20, 0))\n tmp22 = tmp5 - tmp15\n tmp23 = 512.0\n tmp24 = tmp21 / tmp23\n tmp25 = 1e-05\n tmp26 = tmp24 + tmp25\n tmp27 = tl.math.rsqrt(tmp26)\n tmp28 = tmp22 * tmp27\n tmp30 = tmp29.to(tl.float32)\n tmp31 = tmp28 * tmp30\n tmp33 = tmp32.to(tl.float32)\n tmp34 = tmp31 + tmp33\n tmp35 = tmp34.to(tl.float32)\n tl.store(out_ptr2 + (r1 + (512 * x0)), tmp35, rmask)\n\n@pointwise(\n size_hints=[8388608],\n filename=__file__,\n triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': 0, 'device_type': 'cuda'},\n min_elem_per_thread=0\n)\n@triton.jit\ndef triton_poi_fused_add_gelu_4(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 8388608\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x2 = xindex\n x0 = xindex % 2048\n tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)\n tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)\n tmp2 = tmp0 + tmp1\n tmp3 = tmp2.to(tl.float32)\n tmp4 = 0.5\n tmp5 = tmp3 * tmp4\n tmp6 = 0.7071067811865476\n tmp7 = tmp3 * tmp6\n tmp8 = tl.math.erf(tmp7)\n tmp9 = 1.0\n tmp10 = tmp8 + tmp9\n tmp11 = tmp5 * tmp10\n tmp12 = tmp11.to(tl.float32)\n tl.store(in_out_ptr0 + (x2), tmp12, None)\n\ndef call(args):\n arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1 = args\n args.clear()\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided((4096, 1536), (1536, 1), torch.float16, device='cuda')\n extern_kernels.mm(reinterpret_tensor(arg0_1, (4096, 512), (512, 1), 0), arg1_1, out=buf0)\n del arg1_1\n buf1 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), torch.float16, device='cuda')\n stream0 = get_raw_stream(0)\n triton_poi_fused_0.run(buf0, arg2_1, buf1, 2097152, grid=grid(2097152), stream=stream0)\n buf2 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), torch.float16, device='cuda')\n triton_poi_fused_1.run(buf0, arg2_1, buf2, 2097152, grid=grid(2097152), stream=stream0)\n buf3 = empty_strided((16, 16, 256, 32), (131072, 32, 512, 1), torch.float16, device='cuda')\n triton_poi_fused_2.run(buf0, arg2_1, buf3, 2097152, grid=grid(2097152), stream=stream0)\n del arg2_1\n del buf0\n buf4 = aten._scaled_dot_product_flash_attention.default(buf1, buf2, buf3, scale=0.17677669529663687)\n del buf1\n buf5 = buf4[0]\n del buf4\n buf10 = reinterpret_tensor(buf3, (4096, 512), (512, 1), 0); del buf3\n extern_kernels.mm(reinterpret_tensor(buf5, (4096, 512), (512, 1), 0), arg3_1, out=buf10)\n del arg3_1\n buf14 = reinterpret_tensor(buf5, (16, 256, 512), (131072, 512, 1), 0); del buf5\n triton_per_fused_add_native_layer_norm_3.run(buf10, arg4_1, arg0_1, arg5_1, arg6_1, buf14, 4096, 512, grid=grid(4096), stream=stream0)\n del arg0_1\n del arg4_1\n del arg5_1\n del arg6_1\n buf15 = empty_strided((4096, 2048), (2048, 1), torch.float16, device='cuda')\n extern_kernels.mm(reinterpret_tensor(buf14, (4096, 512), (512, 1), 0), arg7_1, out=buf15)\n del arg7_1\n buf16 = reinterpret_tensor(buf15, (16, 256, 2048), (524288, 2048, 1), 0); del buf15\n triton_poi_fused_add_gelu_4.run(buf16, arg8_1, 8388608, grid=grid(8388608), stream=stream0)\n del arg8_1\n buf17 = buf10; del buf10\n extern_kernels.mm(reinterpret_tensor(buf16, (4096, 2048), (2048, 1), 0), arg9_1, out=buf17)\n del arg9_1\n del buf16\n buf21 = reinterpret_tensor(buf2, (16, 256, 512), (131072, 512, 1), 0); del buf2\n triton_per_fused_add_native_layer_norm_3.run(buf17, arg10_1, buf14, arg11_1, arg12_1, buf21, 4096, 512, grid=grid(4096), stream=stream0)\n del arg10_1\n del arg11_1\n del arg12_1\n del buf14\n del buf17\n return (buf21, )\n", - "description_1": "Use triton language to define multiple kernels for pointwise and persistent reduction operations. The kernels perform operations such as element-wise addition, layer normalization, and GELU activation on input tensors. Each kernel is decorated with @triton.jit and uses triton.language for tensor operations. The call function orchestrates the execution of these kernels on CUDA devices, managing input and output buffers.", - "description_2": "Use triton language to implement CUDA kernels for tensor operations including addition, layer normalization, and GELU activation, and manage their execution on GPU.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef diag_ssm_forward_kernel(s_ptr, x_ptr, lambda_ptr, y_ptr, length,\n batch_size, dim, BLOCK_SIZE: tl.constexpr):\n \"\"\"\n Args:\n s_ptr: [batch_size, dim]\n x_ptr: [length, batch_size, dim]\n lambda_ptr: [dim]\n y_ptr: [length, batch_size, dim]\n \"\"\"\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n s = tl.load(s_ptr + col_offsets, mask=mask, other=0)\n Lambda = tl.load(lambda_ptr + col_offsets % dim, mask=mask, other=0)\n for t in range(length):\n offsets = t * batch_size * dim + col_offsets\n x = tl.load(x_ptr + offsets, mask=mask, other=0)\n s = s * Lambda + x\n tl.store(y_ptr + offsets, s, mask=mask)\n\n@triton.jit\ndef diag_ssm_backward_kernel(\n s_ptr, lambda_ptr, y_ptr, grad_s_ptr, grad_x_ptr, grad_lambda_ptr,\n grad_y_ptr, length, batch_size, dim, BLOCK_SIZE: tl.constexpr):\n \"\"\"\n Args:\n s_ptr: [batch_size, dim]\n lambda_ptr: [dim]\n y_ptr: [length, batch_size, dim]\n grad_s_ptr: [batch_size, dim]\n grad_x_ptr: [length, batch_size, dim]\n grad_lambda_ptr: [batch_size, dim]. The shape is different from ``grad_s_ptr``\n because we need the caller to sum the gradients after the kernel finish.\n It's more complicated to sum the gradients inside the kernel.\n grad_y_ptr: [length, batch_size, dim]\n \"\"\"\n\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n\n Lambda = tl.load(lambda_ptr + col_offsets % dim, mask=mask, other=0)\n\n # Initialize gradients to zero\n grad_s = tl.zeros_like(Lambda)\n grad_Lambda = tl.zeros_like(Lambda)\n\n for i in range(length):\n # range(length - 1, -1, -1) is not correctly implemented by Triton\n t = length - 1 - i\n offsets = t * batch_size * dim + col_offsets\n\n grad_y = tl.load(grad_y_ptr + offsets, mask=mask, other=0)\n if t > 0:\n s = tl.load(\n y_ptr + offsets - batch_size * dim, mask=mask, other=0)\n else:\n s = tl.load(s_ptr + col_offsets, mask=mask, other=0)\n\n grad_s = grad_y + grad_s\n grad_x = grad_s\n grad_Lambda += grad_s * s\n grad_s = grad_s * Lambda\n\n tl.store(grad_x_ptr + offsets, grad_x, mask=mask)\n\n tl.store(grad_s_ptr + col_offsets, grad_s, mask=mask)\n tl.store(grad_lambda_ptr + col_offsets, grad_Lambda, mask=mask)\n\n@triton.jit\ndef diag_ssm_forward_kernel_complex(s_ptr, x_ptr, y_ptr, lambda_ptr,\n length, batch_size, dim,\n BLOCK_SIZE: tl.constexpr):\n \"\"\"\n Args:\n s_ptr: [batch_size, dim, 2]\n x_ptr: [length, batch_size, dim, 2]\n lambda_ptr: [dim, 2]\n y_ptr: [length, batch_size, dim, 2]\n \"\"\"\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n\n # Load real and imaginary parts of 's' and 'Lambda'\n s_real = tl.load(s_ptr + col_offsets * 2, mask=mask, other=0)\n s_imag = tl.load(s_ptr + col_offsets * 2 + 1, mask=mask, other=0)\n lambda_real = tl.load(\n lambda_ptr + (col_offsets % dim) * 2, mask=mask, other=0)\n lambda_imag = tl.load(\n lambda_ptr + (col_offsets % dim) * 2 + 1, mask=mask, other=0)\n\n for t in range(length):\n offsets = (t * batch_size * dim + col_offsets) * 2\n # Load real and imaginary parts of 'x'\n x_real = tl.load(x_ptr + offsets, mask=mask, other=0)\n x_imag = tl.load(x_ptr + offsets + 1, mask=mask, other=0)\n\n # Complex multiplication and addition\n new_s_real = s_real * lambda_real - s_imag * lambda_imag + x_real\n new_s_imag = s_real * lambda_imag + s_imag * lambda_real + x_imag\n\n # Store the updated real and imaginary parts\n tl.store(y_ptr + offsets, new_s_real, mask=mask)\n tl.store(y_ptr + offsets + 1, new_s_imag, mask=mask)\n\n # Update s for the next iteration\n s_real, s_imag = new_s_real, new_s_imag\n\n@triton.jit\ndef diag_ssm_backward_kernel_complex(\n s_ptr, lambda_ptr, y_ptr, grad_s_ptr, grad_x_ptr, grad_lambda_ptr,\n grad_y_ptr, length, batch_size, dim, BLOCK_SIZE: tl.constexpr):\n \"\"\"\n Args:\n s_ptr: [batch_size, dim, 2]\n lambda_ptr: [dim, 2]\n y_ptr: [length, batch_size, dim, 2]\n grad_s_ptr: [batch_size, dim, 2]\n grad_x_ptr: [length, batch_size, dim, 2]\n grad_lambda_ptr: [batch_size, dim, 2]. The shape is different from ``grad_s_ptr``\n because we need the caller to sum the gradients after the kernel finish.\n It's more complicated to sum the gradients inside the kernel.\n grad_y_ptr: [length, batch_size, dim, 2]\n \"\"\"\n\n # autograd for complex numbers calculates \\partial f / \\partial z^*\n # so we need to take conjugate during the calculation.\n # https://pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers\n # So in the following code, when we load/store the imaginary part of a gradient,\n # we need to negate it.\n\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n\n # Load real and imaginary parts of 's' and 'Lambda'\n lambda_real = tl.load(\n lambda_ptr + (col_offsets % dim) * 2, mask=mask, other=0)\n lambda_imag = tl.load(\n lambda_ptr + (col_offsets % dim) * 2 + 1, mask=mask, other=0)\n\n # Initialize gradients to zero\n grad_s_real = tl.zeros_like(lambda_real)\n grad_s_imag = tl.zeros_like(lambda_imag)\n grad_lambda_real = tl.zeros_like(lambda_real)\n grad_lambda_imag = tl.zeros_like(lambda_imag)\n\n for i in range(length):\n # range(length - 1, -1, -1) is not correctly implemented by Triton\n t = length - 1 - i\n offsets = (t * batch_size * dim + col_offsets) * 2\n\n grad_y_real = tl.load(grad_y_ptr + offsets, mask=mask, other=0)\n grad_y_imag = -tl.load(\n grad_y_ptr + offsets + 1, mask=mask, other=0)\n if t > 0:\n s_real = tl.load(\n y_ptr + offsets - 2 * batch_size * dim, mask=mask, other=0)\n s_imag = tl.load(\n y_ptr + offsets - 2 * batch_size * dim + 1,\n mask=mask,\n other=0)\n else:\n s_real = tl.load(s_ptr + 2 * col_offsets, mask=mask, other=0)\n s_imag = tl.load(\n s_ptr + 2 * col_offsets + 1, mask=mask, other=0)\n\n grad_s_real = grad_y_real + grad_s_real\n grad_s_imag = grad_y_imag + grad_s_imag\n grad_x_real = grad_s_real\n grad_x_imag = grad_s_imag\n grad_lambda_real += grad_s_real * s_real - grad_s_imag * s_imag\n grad_lambda_imag += grad_s_real * s_imag + grad_s_imag * s_real\n grad_s_real = grad_x_real * lambda_real - grad_x_imag * lambda_imag\n grad_s_imag = grad_x_real * lambda_imag + grad_x_imag * lambda_real\n\n tl.store(grad_x_ptr + offsets, grad_x_real, mask=mask)\n tl.store(grad_x_ptr + offsets + 1, -grad_x_imag, mask=mask)\n\n # Store the final gradients for s and Lambda\n tl.store(grad_s_ptr + col_offsets * 2, grad_s_real, mask=mask)\n tl.store(grad_s_ptr + col_offsets * 2 + 1, -grad_s_imag, mask=mask)\n tl.store(\n grad_lambda_ptr + col_offsets * 2, grad_lambda_real, mask=mask)\n tl.store(\n grad_lambda_ptr + col_offsets * 2 + 1,\n -grad_lambda_imag,\n mask=mask)\n\nclass _ssm_forward(torch.autograd.Function):\n # TODO use @triton.autotune to choose the best BLOCK_SIZE\n # BLOCK_SIZE = 128 seems work well for 3090\n BLOCK_SIZE = 128\n\n @staticmethod\n def forward(ctx, s, x, Lambda):\n assert s.is_contiguous() and x.is_contiguous(\n ) and Lambda.is_contiguous()\n length, batch_size, dim = x.shape\n n = batch_size * dim\n y = torch.zeros_like(x)\n grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']), )\n\n if Lambda.dtype == torch.complex64:\n diag_ssm_forward_kernel_complex[grid](\n torch.view_as_real(s), torch.view_as_real(x),\n torch.view_as_real(y), torch.view_as_real(Lambda), length,\n batch_size, dim, _ssm_forward.BLOCK_SIZE)\n elif Lambda.dtype.is_floating_point:\n diag_ssm_forward_kernel[grid](s, x, Lambda, y, length,\n batch_size, dim,\n _ssm_forward.BLOCK_SIZE)\n else:\n raise ValueError(\"Unsupported dtype: %s\" % Lambda.dtype)\n ctx.save_for_backward(s, y, Lambda)\n return y\n\n @staticmethod\n def backward(ctx, grad_y):\n s, y, Lambda = ctx.saved_tensors\n length, batch_size, dim = y.shape\n grad_y = grad_y.contiguous()\n n = batch_size * dim\n grad_s = torch.empty_like(s)\n grad_x = torch.empty_like(grad_y)\n # Here grad_lambda stores the gradients of Lambda for each sample\n # in the batch. We will sum them up after the kernel finishes.\n grad_lambda = torch.empty_like(s)\n grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']), )\n if Lambda.dtype == torch.complex64:\n diag_ssm_backward_kernel_complex[grid](\n torch.view_as_real(s), torch.view_as_real(Lambda),\n torch.view_as_real(y), torch.view_as_real(grad_s),\n torch.view_as_real(grad_x),\n torch.view_as_real(grad_lambda),\n torch.view_as_real(grad_y), length, batch_size, dim,\n _ssm_forward.BLOCK_SIZE)\n else:\n diag_ssm_backward_kernel[grid](\n s, Lambda, y, grad_s, grad_x, grad_lambda, grad_y, length,\n batch_size, dim, _ssm_forward.BLOCK_SIZE)\n return grad_s, grad_x, grad_lambda.sum(dim=0)\n\ndiag_ssm_forward_triton = _ssm_forward.apply\n\ndef diag_ssm_forward(s, x, Lambda):\n r\"\"\"Diagonal SSM forward pass\n\n Calculate :math:`y_t = Lambda * y_{t-1} + x_t` for t > 0\n and :math:`y_0 = Lambda * s + x_0`\n\n Args:\n s (torch.Tensor): shape is [batch_size, state_dim]\n x (torch.Tensor): shape is [length, batch_size, state_dim]\n Lambda (torch.Tensor): shape is [state_dim]\n Returns:\n torch.Tensor: y in the above equation. The shape is\n [length, batch_size, state_dim]\n \"\"\"\n if x.is_cuda:\n return diag_ssm_forward_triton(s, x, Lambda)\n else:\n return diag_ssm_forward_slow(s, x, Lambda)\n", - "description_1": "Use triton language to implement diagonal state-space model (SSM) forward and backward kernels for both real and complex numbers. The forward kernel computes the state update y_t = Lambda * y_{t-1} + x_t for a given length, batch size, and dimension. The backward kernel computes gradients for s, x, and Lambda. The complex version handles real and imaginary parts separately. The kernels are wrapped in a PyTorch autograd function for automatic differentiation.", - "description_2": "Use triton language to create kernels for diagonal SSM forward and backward passes, supporting both real and complex data, and integrate with PyTorch autograd.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row,\n stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement forward and backward kernels for the SwiGLU activation function. The forward kernel (_swiglu_fwd_kernel) takes 7 parameters: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, and ncols, with BLOCK_N as a compile-time constant. It computes the element-wise product of X and Y after applying the sigmoid function to X, storing the result in OUT. The backward kernel (_swiglu_bwd_kernel) takes 14 parameters: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, and RECOMPUTE_OUTPUT, with BLOCK_N as a compile-time constant. It computes the gradients of X and Y with respect to the output gradient DOUT, optionally recomputing the output if RECOMPUTE_OUTPUT is true.", - "description_2": "Use triton language to create kernels for computing the forward and backward passes of the SwiGLU activation, handling input and output tensors with specific strides and block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\nconfigs_autotune = [\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n]\n\n\n@triton.autotune(\n configs=configs_autotune,\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n@triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, X1, W1, B1, Y1, RESIDUAL_OUT, ROWSCALE, SEEDS, DROPOUT_MASK,\n Mean, Rstd, stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row,\n stride_y1_row, M, N, eps, dropout_p, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr,\n HAS_DROPOUT: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr,\n HAS_X1: tl.constexpr, HAS_W1: tl.constexpr, HAS_B1: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n if HAS_X1:\n X1 += row * stride_x1_row\n if HAS_W1:\n Y1 += row * stride_y1_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n x *= rowscale\n if HAS_DROPOUT:\n keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n if HAS_X1:\n x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n x1 *= rowscale\n if HAS_DROPOUT:\n keep_mask = (\n tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n )\n x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)\n x += x1\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n if HAS_W1:\n w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n if HAS_B1:\n b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n tl.store(Y1 + cols, y1, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None,\n dropout_p=0.0, rowscale=None, out_dtype=None, residual_dtype=None, is_rms_norm=False,\n return_dropout_mask=False\n):\n M, N = x.shape\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n y1 = torch.empty_like(y) if weight1 is not None else None\n if (\n residual is not None\n or (residual_dtype is not None and residual_dtype != x.dtype)\n or dropout_p > 0.0\n or rowscale is not None\n or x1 is not None\n ):\n residual_out = torch.empty(\n M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype\n )\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n if dropout_p > 0.0:\n seeds = torch.randint(\n 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n )\n else:\n seeds = None\n if return_dropout_mask and dropout_p > 0.0:\n dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)\n else:\n dropout_mask = None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale,\n seeds, dropout_mask, mean, rstd, x.stride(0), y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n x1.stride(0) if x1 is not None else 0,\n y1.stride(0) if y1 is not None else 0,\n M, N, eps, dropout_p, is_rms_norm, BLOCK_N, residual is not None,\n residual_out is not None, bias is not None, dropout_p > 0.0,\n dropout_mask is not None, rowscale is not None, x1 is not None\n )\n if dropout_mask is not None and x1 is not None:\n dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)\n else:\n dropout_mask1 = None\n return (\n y, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1\n )\n", - "description_1": "Use triton language to implement forward pass of layer normalization with support for residual connections, dropout, optional biases, row scaling, and parallel operations using block-wise operations.", - "description_2": "Use triton language to prepare and execute the forward pass layer normalization kernel, handling tensor allocations and configurations for fused operations, supporting dropout and parallelism.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DZ, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_z_row,\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dz_row,\n stride_dw_row,\n stride_db_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row_block_id = tl.program_id(0)\n group = tl.program_id(1)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row + group * N\n if HAS_Z:\n Z += row_start * stride_z_row + group * N\n DZ += row_start * stride_dz_row + group * N\n DY += row_start * stride_dy_row + group * N\n DX += row_start * stride_dx_row + group * N\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:\n B += group * N\n b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)\n x_og = x\n x = x_og * z * tl.sigmoid(z)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.)\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)\n z_sigmoid = tl.sigmoid(z)\n y = xhat * w + b if HAS_BIAS else xhat * w\n if RECOMPUTE_OUTPUT:\n tl.store(Y + cols, y * z * z_sigmoid, mask=mask)\n dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))\n tl.store(DZ + cols, dz, mask=mask)\n dy *= z * z_sigmoid\n else:\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n c1 = tl.sum(xhat * wdy, axis=0) / N\n if not IS_RMS_NORM:\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n dx = (wdy - xhat * c1) * rstd\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if HAS_Z and not NORM_BEFORE_GATE:\n z_sigmoid = tl.sigmoid(z)\n dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))\n tl.store(DZ + cols, dz, mask=mask)\n dx *= z * z_sigmoid\n # Write dx\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_Z:\n Z += stride_z_row\n DZ += stride_dz_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,\n norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n dx = torch.empty_like(x)\n if dz is not None:\n assert z is not None\n assert dz.shape == z.shape\n assert dz.stride(-1) == 1\n else:\n dz = torch.empty_like(z) if z is not None else None\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n assert out.shape == x.shape\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs\n # would limit the occupancy.\n nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)\n _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)\n _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None\n rows_per_program = math.ceil(M / nrow_groups)\n grid = (nrow_groups, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,\n dy, dx, _dw, _db, dz, mean, rstd,\n x.stride(0),\n z.stride(0) if z is not None else 0,\n 0 if not recompute_output else out.stride(0),\n dy.stride(0), dx.stride(0),\n dz.stride(0) if dz is not None else 0,\n _dw.stride(0),\n _db.stride(0) if _db is not None else 0,\n M, group_size, eps,\n rows_per_program,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)\n", - "description_1": "Use triton language to implement a forward pass kernel and a backward pass kernel for layer normalization. The forward pass kernel (_layer_norm_fwd_1pass_kernel) has 16 parameters including pointers to input, output, weights, biases, another branch, mean, and 1/std, various strides, number of rows and columns, epsilon, and several constant expressions. It computes the mean, variance, normalization, and applies linear transformations. The backward pass kernel (_layer_norm_bwd_kernel) has 26 parameters including pointers to input, weights, biases, branches, output, gradients, mean, 1/std, various strides, number of rows and columns, epsilon, and several constant expressions. It computes the gradient of the inputs, weights, biases, and other branches. The calling functions allocate necessary memory, handle the input shapes, and determine the grid configuration for execution on the GPU.", - "description_2": "Use triton language to create kernels for both forward and backward pass of a layer normalization operation, optimizing for dimensions up to a certain size, handling optional biases and other branches, and executing efficiently on the GPU with appropriate grid and block configurations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom ssd.bi.softplus import softplus\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt)\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 44 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to prepare and call the kernel.", - "description_2": "Use triton language to create a kernel for matrix state updates with optional bias and scaling, and a wrapper to manage inputs and call the kernel.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n @triton.jit\n def softplus(dt):\n # Apply the softplus function using Triton 3.0.0 or newer\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n @triton.jit\n def softplus(dt):\n # Apply the softplus function using Triton versions older than 3.0.0\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n", - "description_1": "Use triton language to implement a softplus function kernel that takes one parameter 'dt'. The kernel applies the softplus operation, which is defined differently based on the Triton version. For Triton 3.0.0 or newer, it uses 'tl.math.log(tl.math.exp(dt) + 1)', and for older versions, it uses 'tl.math.log1p(tl.exp(dt))'.", - "description_2": "Use triton language to implement a version-dependent softplus function kernel with one parameter.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement two kernels: _bmm_chunk_fwd_kernel and _bmm_chunk_bwd_kernel. The _bmm_chunk_fwd_kernel performs a batched matrix multiplication with optional sequence index masking and causal masking. It takes 24 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. The _bmm_chunk_bwd_kernel computes the gradient of the batched matrix multiplication with respect to one of the input matrices. It takes 23 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. Both kernels are called by their respective wrapper functions _bmm_chunk_fwd and _bmm_chunk_bwd, which handle input preparation and kernel invocation.", - "description_2": "Use triton language to create forward and backward kernels for batched matrix multiplication with optional sequence and causal masking, and implement wrapper functions to manage input preparation and kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange, repeat\nfrom ssd.bi.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_f_ptr, dA_cumsum_b_ptr,\n C_ptr, prev_states_f_ptr, prev_states_b_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_f_batch, stride_dA_cs_f_chunk, stride_dA_cs_f_head, stride_dA_cs_f_csize,\n stride_dA_cs_b_batch, stride_dA_cs_b_chunk, stride_dA_cs_b_head, stride_dA_cs_b_csize,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_f_batch, stride_states_f_chunk, stride_states_f_head, stride_states_f_hdim, stride_states_f_dstate,\n stride_states_b_batch, stride_states_b_chunk, stride_states_b_head, stride_states_b_hdim, stride_states_b_dstate,\n stride_D_head,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel implementation\n # ...\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum_f, dA_cumsum_b, C, states_f, states_b, D=None, z=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum_f.shape == (batch, nheads, nchunks, chunk_size)\n assert states_f.shape == (batch, nchunks, nheads, headdim, dstate)\n assert dA_cumsum_b.shape == (batch, nheads, nchunks, chunk_size)\n assert states_b.shape == (batch, nchunks, nheads, headdim, dstate)\n # Allocates output.\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum_f, dA_cumsum_b, C, states_f, states_b, D,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum_f.stride(0), dA_cumsum_f.stride(2), dA_cumsum_f.stride(1), dA_cumsum_f.stride(3),\n dA_cumsum_b.stride(0), dA_cumsum_b.stride(2), dA_cumsum_b.stride(1), dA_cumsum_b.stride(3),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states_f.stride(0), states_f.stride(1), states_f.stride(2), states_f.stride(3), states_f.stride(4),\n states_b.stride(0), states_b.stride(1), states_b.stride(2), states_b.stride(3), states_b.stride(4),\n D.stride(0) if D is not None else 0,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n", - "description_1": "Use triton language to implement a forward scan operation on chunks of data, applying transformations and aggregations based on input matrices and configurations.", - "description_2": "Use triton language to perform a forward scan operation on data chunks, utilizing input matrices and configurations for transformations and aggregations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom ssd.bi.softplus import softplus\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}),\n triton.Config({'BLOCK_SIZE_H': 2}),\n triton.Config({'BLOCK_SIZE_H': 4}),\n triton.Config({'BLOCK_SIZE_H': 8}),\n triton.Config({'BLOCK_SIZE_H': 16}),\n triton.Config({'BLOCK_SIZE_H': 32}),\n triton.Config({'BLOCK_SIZE_H': 64}),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n # Pointers to matrices\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_f_ptr, dA_cumsum_b_ptr,\n # Matrix dimension\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n # Strides\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_f_batch, stride_dA_cs_f_chunk, stride_dA_cs_f_head, stride_dA_cs_f_csize,\n stride_dA_cs_b_batch, stride_dA_cs_b_chunk, stride_dA_cs_b_head, stride_dA_cs_b_csize,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_f_ptr += pid_b * stride_dA_cs_f_batch + pid_c * stride_dA_cs_f_chunk\n dA_cumsum_b_ptr += pid_b * stride_dA_cs_b_batch + pid_c * stride_dA_cs_b_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_f_ptrs = dA_cumsum_f_ptr + (offs_h[:, None] * stride_dA_cs_f_head + offs_c[None, :] * stride_dA_cs_f_csize)\n dA_cs_b_ptrs = dA_cumsum_b_ptr + (offs_h[:, None] * stride_dA_cs_b_head + offs_c[None, :] * stride_dA_cs_b_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = softplus(dt)\n # As of Triton 2.2.0, tl.clamp is not available yet\n # dt = tl.clamp(dt, dt_min, dt_max)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs_f = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_f_ptrs, dA_cs_f, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n # Triton reverse cumsum is broken as of 3.0.0 (You can remove this hack when fixed)\n # dA_cs_last = tl.load(dA_cumsum_f_ptr + (offs_h[:, None] * stride_dA_cs_f_head + (chunk_size_limit - 1) * stride_dA_cs_f_csize), mask=(offs_h[:, None] < nheads), other=0.0).to(tl.float32)\n dA_cs_b = tl.flip(tl.cumsum(tl.flip(dA, dim=1), axis=1))\n # dA_cs_b = tl.cumsum(dA, axis=1, reverse=True)\n # print(\"last\", dA_cs_last)\n # dA_cs_b = dA_cs_last - dA_cs_f + dA # Reverse scan is broken thus (dA_cumsum_last - dA_cs_f) + dA will reverse the cumsum\n tl.store(dA_cs_b_ptrs, dA_cs_b, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_f_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_b_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_f_cumsum, dA_b_cumsum,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_f_cumsum.stride(0), dA_f_cumsum.stride(2), dA_f_cumsum.stride(1), dA_f_cumsum.stride(3),\n dA_b_cumsum.stride(0), dA_b_cumsum.stride(2), dA_b_cumsum.stride(1), dA_b_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_f_cumsum, dA_b_cumsum, dt_out\n", - "description_1": "Use triton language to create a kernel that performs cumulative sum operations on chunks of matrix data. Handle optional bias addition and softplus application.", - "description_2": "Use triton language to implement a function that sets up and launches the cumulative sum kernel on tensor data, managing output storage and device setup.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n ],\n key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_f_ptr, dA_cumsum_b_ptr, D_ptr,\n b_ptr, dstates_f_ptr, dstates_b_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_f_batch, stride_dA_cs_f_chunk, stride_dA_cs_f_head, stride_dA_cs_f_csize,\n stride_dA_cs_b_batch, stride_dA_cs_b_chunk, stride_dA_cs_b_head, stride_dA_cs_b_csize,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_f_batch, stride_dstates_f_chunk, stride_dstates_f_head, stride_dstates_f_hdim, stride_dstates_f_dstate,\n stride_dstates_b_batch, stride_dstates_b_chunk, stride_dstates_b_head, stride_dstates_b_hdim, stride_dstates_b_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Kernel implementation\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum_f, dA_cumsum_b, B, CB, dout, dstates_f, dstates_b, D=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum_f.shape == dt.shape\n assert dA_cumsum_b.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates_f.shape == (batch, nchunks, nheads, headdim, dstate)\n assert dstates_b.shape == (batch, nchunks, nheads, headdim, dstate)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum_f, dA_cumsum_b, D, B, dstates_f, dstates_b, dx, ddt, dD,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum_f.stride(0), dA_cumsum_f.stride(2), dA_cumsum_f.stride(1), dA_cumsum_f.stride(3),\n dA_cumsum_b.stride(0), dA_cumsum_b.stride(2), dA_cumsum_b.stride(1), dA_cumsum_b.stride(3),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates_f.stride(0), dstates_f.stride(1), dstates_f.stride(2), dstates_f.stride(3), dstates_f.stride(4),\n dstates_b.stride(0), dstates_b.stride(1), dstates_b.stride(2), dstates_b.stride(3), dstates_b.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a backward kernel for a chunked scan operation. The kernel, _chunk_scan_chunk_state_bwd_dx_kernel, is decorated with @triton.jit and is responsible for computing gradients with respect to inputs x, dt, and optionally D. The kernel takes pointers to input and output tensors, matrix dimensions, strides, and meta-parameters as arguments. The function _chunk_scan_chunk_state_bwd_dx serves as a wrapper to set up the kernel execution, handling input validation, memory allocation, and grid configuration.", - "description_2": "Use triton language to create a backward kernel for a chunked scan operation, computing gradients for inputs.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for state passing forward computation\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n REVERSE: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n # Compute ids\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n # Update pointers\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n # Offsets for matrix elements\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not REVERSE:\n # Forward pass logic\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n else:\n # Reverse pass logic\n states_ptrs += (nchunks - 1) * stride_states_chunk\n dA_cs_ptr += (nchunks - 1) * stride_dA_cs_chunk\n out_ptrs += (nchunks - 1) * stride_out_chunk\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs -= stride_out_chunk\n for c in range(nchunks - 1, -1, -1):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n states = scale * states + new_states\n if c > 0:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs -= stride_states_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n out_ptrs -= stride_out_chunk\n\n\n# Triton kernel for state passing backward computation\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n dout_ptr, out_ptr, dA_cs_ptr, \n dstates_ptr, ddA_cs_ptr, states_converted_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n CONVERT_STATES: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n REVERSE: tl.constexpr,\n):\n # Compute ids\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n # Update pointers\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head \n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head \n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head \n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head \n\n if not REVERSE:\n dstates_ptr += (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += (nchunks - 1) * stride_ddA_cs_chunk\n out_ptr += (nchunks - 1) * stride_out_chunk\n dout_ptr += (nchunks - 1) * stride_dout_chunk\n\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n if not REVERSE:\n states_converted_ptr += (nchunks - 1) * stride_out_chunk\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if not REVERSE:\n dstates_ptrs -= stride_dstates_chunk\n else:\n dstates_ptrs += stride_dstates_chunk\n\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if not REVERSE:\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n else:\n dout_ptrs += stride_dout_chunk\n dstates_ptrs += stride_dstates_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n ddA_cs_ptr += stride_ddA_cs_chunk\n out_ptrs += stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs += stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n tl.store(ddA_cs_ptr, 0.0)\n\n\n# Function to launch the forward Triton kernel\ndef _state_passing_fwd(states, dA_chunk_cumsum, chunk_size=None, out_dtype=None, reverse=False):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum,\n dim, nchunks, 0, 0,\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n REVERSE=reverse,\n )\n return out, final_states\n\n\n# Function to launch the backward Triton kernel\ndef _state_passing_bwd(\n states, dA_chunk_cumsum, dout, dstates_dtype=None, states_dtype=None, chunk_size=None, reverse=False,\n):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum,\n dstates, ddA_chunk_cumsum, states_converted,\n dim, nchunks, 0, 0,\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n CONVERT_STATES=states_converted is not None,\n REVERSE=reverse,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum) if states_dtype is None else (dstates, ddA_chunk_cumsum, states_converted)\n", - "description_1": "Use triton language to implement two kernels: `_state_passing_fwd_kernel` and `_state_passing_bwd_kernel`, both with 21+ parameters for matrix pointers, dimensions, strides, and meta-parameters. `_state_passing_fwd_kernel` computes forward state passing in both forward and reverse directions based on the REVERSE meta-parameter. `_state_passing_bwd_kernel` computes backward state passing, handling state conversions and gradients. Functions `_state_passing_fwd` and `_state_passing_bwd` launch these kernels, respectively, setting up the grid based on the problem size.", - "description_2": "Use triton language to create forward and backward kernels for state passing with flexible block sizes and directions, and use Python functions to manage kernel execution.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X,\n Y,\n OUT,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_out_row,\n ncols,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X,\n Y,\n DOUT,\n OUT,\n DX,\n DY,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dout_row,\n stride_out_row,\n stride_dx_row,\n stride_dy_row,\n ncols,\n BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement the forward and backward kernel operations for the SwiGLU activation function. The forward kernel _swiglu_fwd_kernel accepts 7 parameters: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, and BLOCK_N (a compile-time constant). It computes the element-wise product of X and Y, modified by the sigmoid of X, and stores the result in OUT. The backward kernel _swiglu_bwd_kernel accepts 14 parameters: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N, and RECOMPUTE_OUTPUT (a compile-time constant). It computes gradients for X and Y (stored in DX and DY) using the derivative of the SwiGLU function, optionally recomputing the forward pass output.", - "description_2": "Use triton language to create custom kernels for the forward and backward passes of the SwiGLU function, efficiently utilizing GPU parallelism. The forward pass computes an element-wise operation, storing results in an output tensor, while the backward pass calculates gradients, optionally recalculating the forward output as needed.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, Z, Mean, Rstd, stride_x_row, stride_y_row, stride_z_row, M, N, eps, BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr):\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n ngroups = N // group_size\n out = torch.empty_like(x) if out is None else out\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, W, B, Z, Y, DY, DX, DW, DB, DZ, Mean, Rstd, stride_x_row, stride_z_row, stride_y_row, stride_dy_row,\n stride_dx_row, stride_dz_row, stride_dw_row, stride_db_row, M, N, eps, rows_per_program, \n NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, \n RECOMPUTE_OUTPUT: tl.constexpr, BLOCK_N: tl.constexpr):\n row_block_id = tl.program_id(0)\n group = tl.program_id(1)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row + group * N\n if HAS_Z:\n Z += row_start * stride_z_row + group * N\n DZ += row_start * stride_dz_row + group * N\n DY += row_start * stride_dy_row + group * N\n DX += row_start * stride_dx_row + group * N\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:\n B += group * N\n b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)\n x_og = x\n x = x_og * z * tl.sigmoid(z)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.)\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)\n z_sigmoid = tl.sigmoid(z)\n y = xhat * w + b if HAS_BIAS else xhat * w\n if RECOMPUTE_OUTPUT:\n tl.store(Y + cols, y * z * z_sigmoid, mask=mask)\n dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))\n tl.store(DZ + cols, dz, mask=mask)\n dy *= z * z_sigmoid\n else:\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n c1 = tl.sum(xhat * wdy, axis=0) / N\n if not IS_RMS_NORM:\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n dx = (wdy - xhat * c1) * rstd\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if HAS_Z and not NORM_BEFORE_GATE:\n z_sigmoid = tl.sigmoid(z)\n dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))\n tl.store(DZ + cols, dz, mask=mask)\n dx *= z * z_sigmoid\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_Z:\n Z += stride_z_row\n DZ += stride_dz_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,\n norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):\n M, N = x.shape\n if group_size is None:\n group_size = N\n ngroups = N // group_size\n dx = torch.empty_like(x)\n if dz is not None:\n dz = torch.empty_like(z) if z is not None else None\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n \n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)\n _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)\n _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None\n rows_per_program = math.ceil(M / nrow_groups)\n grid = (nrow_groups, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,\n dy, dx, _dw, _db, dz, mean, rstd,\n x.stride(0),\n z.stride(0) if z is not None else 0,\n 0 if not recompute_output else out.stride(0),\n dy.stride(0), dx.stride(0),\n dz.stride(0) if dz is not None else 0,\n _dw.stride(0),\n _db.stride(0) if _db is not None else 0,\n M, group_size, eps,\n rows_per_program,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)\n\n", - "description_1": "Use triton language to implement two kernels for layer normalization operations. The first kernel (_layer_norm_fwd_1pass_kernel) performs the forward pass of layer normalization on a 2D input tensor X, applying weight and bias transformations and optionally using another tensor Z. It computes mean and variance, handles optional bias and additional branch Z, and stores normalized output Y. The second kernel (_layer_norm_bwd_kernel) handles the backward pass, computing gradients for input, weights, biases, and optional tensor Z using precomputed mean and rstd. Both kernels are configured with grid/block settings based on input dimensions.", - "description_2": "Use triton language to define and execute layer normalization forward and backward kernels on GPU, including mean/variance computation, optional bias and branch handling, and gradient calculations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom ssd.uni.softplus import softplus\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n # Strides\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n # Triton kernel logic here...\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel and a wrapper function for selective state updates in a neural network. The kernel '_selective_scan_update_kernel' has parameters for pointers to input and output matrices, matrix dimensions, strides, and meta-parameters. The wrapper function 'selective_state_update' manages inputs and outputs for the kernel, setting up appropriate configurations and ensuring tensor shapes match expected dimensions. This involves optional bias additions and softplus activations.", - "description_2": "Use triton language to efficiently update neural network states with a custom kernel.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\n# Triton kernel for the softplus function\nif TRITON3:\n @triton.jit\n def softplus(dt):\n # Compute the softplus function: log(exp(dt) + 1) if dt <= 20.0; otherwise, return dt\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n @triton.jit\n def softplus(dt):\n # Compute the softplus function: log1p(exp(dt)) if dt <= 20.0; otherwise, return dt\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n", - "description_1": "Use triton language to implement a softplus function kernel that takes 1 parameter 'dt'. The function computes the softplus operation element-wise on 'dt': it returns log(exp(dt) + 1) if dt <= 20.0; otherwise, it returns dt directly. The behavior changes slightly for Triton versions below 3.0.0, where log1p(exp(dt)) is used instead for numerical stability.", - "description_2": "Use triton language to implement a softplus function with a version-dependent computation for improved numerical stability.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement a block matrix multiplication (BMM) forward and backward kernel for handling chunks of matrices. The forward kernel (_bmm_chunk_fwd_kernel) takes pointers to input matrices a and b, output matrix pointer out, and optional sequence index seq_idx. It calculates the block matrix multiplication for chunks and writes to the output. The backward kernel (_bmm_chunk_bwd_kernel) processes input matrix a, the gradient of the output matrix dout, and an optional residual matrix. It computes the gradient with respect to the input matrices for the block matrix multiplication operation. Both kernels optimize performance using grid and block strategies, and support batching, grouping, and causality constraints. Key arguments include block size parameters for optimization, matrix dimensions, and meta-parameters like IS_CAUSAL and HAS_SEQ_IDX.", - "description_2": "Use triton language to create efficient block matrix multiplication kernels for forward and backward passes, supporting batching, grouping, and causality.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange, repeat\nfrom ssd.uni.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel code...\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n", - "description_1": "Use triton language to implement a chunk-based forward scanning kernel (_chunk_scan_fwd_kernel) and its corresponding PyTorch wrapper function (_chunk_scan_fwd) for parallel processing of matrix blocks. The kernel is optimized for different configurations and supports optional features such as causal masking, additional input Z, and custom dimensional strides.", - "description_2": "Use triton language to create a kernel for efficient parallel processing of matrix blocks, optimized for configurations like block size and causal masking.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom ssd.uni.softplus import softplus\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}),\n triton.Config({'BLOCK_SIZE_H': 2}),\n triton.Config({'BLOCK_SIZE_H': 4}),\n triton.Config({'BLOCK_SIZE_H': 8}),\n triton.Config({'BLOCK_SIZE_H': 16}),\n triton.Config({'BLOCK_SIZE_H': 32}),\n triton.Config({'BLOCK_SIZE_H': 64}),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = softplus(dt)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_bwd_kernel(\n ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,\n ddt_ptr, dA_ptr, ddt_bias_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,\n stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,\n stride_dA_head,\n stride_ddt_bias_head,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk\n ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)\n ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n ddt = ddA * A[:, None] + ddt_out\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt_presoftplus = dt\n dt = softplus(dt)\n clamp_mask = (dt < dt_min) | (dt > dt_max)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)\n ddt = tl.where(clamp_mask, 0.0, ddt)\n if DT_SOFTPLUS:\n ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)\n tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))\n dA = tl.sum(ddA * dt, axis=1)\n tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)\n if HAS_DT_BIAS:\n ddt_bias = tl.sum(ddt, axis=1)\n tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)\n\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n\n\ndef _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), ddt=None):\n batch, seqlen, nheads = dt.shape\n _, _, nchunks, chunk_size = ddA.shape\n assert ddA.shape == (batch, nheads, nchunks, chunk_size)\n assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)\n else:\n ddt_bias = None\n if ddt is not None:\n assert ddt.shape == dt.shape\n else:\n ddt = torch.empty_like(dt)\n dA = torch.empty_like(A, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_bwd_kernel[grid_chunk_cs](\n ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),\n ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n ddt.stride(0), ddt.stride(1), ddt.stride(2),\n dA.stride(0),\n ddt_bias.stride(0) if ddt_bias is not None else 0,\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return ddt, dA, ddt_bias\n", - "description_1": "Use triton language to implement forward and backward kernels for chunked cumulative sum operations. The forward kernel (_chunk_cumsum_fwd_kernel) takes 24 parameters: pointers to matrices (5), matrix dimensions (4), min/max values for clamping (2), strides (12), and meta-parameters (3). The backward kernel (_chunk_cumsum_bwd_kernel) takes 27 parameters: pointers to matrices (8), matrix dimensions (4), min/max values for clamping (2), strides (12), and meta-parameters (3). The forward function _chunk_cumsum_fwd calls the forward kernel with 15 parameters, and the backward function _chunk_cumsum_bwd calls the backward kernel with 16 parameters.", - "description_2": "Use triton language to create kernels for chunked cumulative sum operations with forward and backward passes, handling matrix pointers, dimensions, strides, and meta-parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n ],\n key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n if not HAS_SEQ_IDX:\n scale = tl.exp(dA_cs_last - dA_cs_m)\n else:\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc = tl.dot(b, dstates) * scale[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc += tl.dot(b, dstates)\n b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n acc *= scale[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit\n K_MIN = pid_m * BLOCK_SIZE_M\n cb_ptrs += K_MIN * stride_cb_csize_k\n dout_ptrs += K_MIN * stride_dout_seqlen\n dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n k = tl.multiple_of(k, BLOCK_SIZE_K)\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(dout_ptr.dtype.element_ty)\n acc += tl.dot(cb, dout)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n dx = acc * dt_m[:, None]\n dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n if HAS_D:\n dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n dx += dout_res * D\n tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if HAS_D:\n dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n if D_HAS_HDIM:\n dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n dD = tl.sum(dout_res * x, axis=0)\n tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n else:\n dD = tl.sum(dout_res * x)\n tl.store(dD_ptr, dD)\n ddt = tl.sum(acc * x, axis=1)\n ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a backward pass kernel for a chunked scan operation, handling gradients with respect to input matrices and intermediate states.", - "description_2": "Use triton language to efficiently compute matrix operations and reductions for the backward pass of a chunked scan operation.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n CONVERT_STATES: tl.constexpr,\n HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n else:\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\ndef _state_passing_bwd(\n states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,\n dstates_dtype=None, states_dtype=None, chunk_size=None\n):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = torch.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert dfinal_states.shape == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n if dfinal_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n if dinitstates is not None else (0, 0, 0)),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n", - "description_1": "Use triton language to implement forward and backward state passing kernels. The forward kernel (_state_passing_fwd_kernel) takes 25 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters. It computes the forward pass of state passing with optional initial states and sequence indices. The backward kernel (_state_passing_bwd_kernel) takes 30 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters. It computes the backward pass of state passing, handling gradients and optional sequence indices.", - "description_2": "Use triton language to create kernels for forward and backward state passing operations, handling optional initial states and sequence indices, and computing gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_SUPPORTED_SIZES = {16, 32, 64, 128}\n\ndef _get_configs():\n configs = []\n for block_m in [64, 128, 256]:\n for block_n in [32, 64, 128]:\n for num_stage in [3, 4, 5, 6, 7, 8]:\n for num_warps in [4, 8]:\n configs.append(\n triton.Config(\n {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n},\n num_warps=num_warps,\n num_stages=num_stage,\n )\n )\n return configs\n\n@triton.autotune(\n configs=_get_configs(),\n key=[\"N_CTX\", \"H\", \"Z\"],\n)\n@triton.heuristics({\"EVEN_CTX\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_M\"] == 0})\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n qkv_scale_ptr,\n out_scale_ptr,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n EVEN_CTX: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n qkv_scale = tl.load(qkv_scale_ptr)\n qk_scale = qkv_scale * qkv_scale * sm_scale * 1.44269504\n\n if EVEN_CTX:\n q = tl.load(Q_block_ptr)\n else:\n q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option=\"zero\")\n for start_n in range(0, N_CTX, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_CTX:\n k = tl.load(K_block_ptr)\n else:\n k = tl.load(K_block_ptr, boundary_check=(1,), padding_option=\"zero\")\n qk = tl.dot(q, k, allow_tf32=False, out_dtype=tl.int32)\n qk_fp32 = qk * qk_scale\n\n m_ij = tl.maximum(m_i, tl.max(qk_fp32, 1))\n p = tl.math.exp2(qk_fp32 - m_ij[:, None])\n alpha = tl.math.exp2(m_i - m_ij)\n m_i = m_ij\n if EVEN_CTX:\n v = tl.load(V_block_ptr)\n else:\n v = tl.load(V_block_ptr, boundary_check=(0,), padding_option=\"zero\")\n v = (v * qkv_scale).to(tl.bfloat16)\n acc *= alpha[:, None]\n acc += tl.dot(\n p.to(tl.bfloat16),\n v,\n allow_tf32=True,\n )\n l_i = l_i * alpha + tl.sum(p, 1)\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n\n out_scale = tl.load(out_scale_ptr)\n acc = tl.math.llrint(acc / (l_i[:, None] * out_scale)).to(tl.int8)\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n if EVEN_CTX:\n tl.store(O_block_ptr, acc)\n else:\n tl.store(O_block_ptr, acc, boundary_check=(0,))\n\nclass _attention(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n q,\n k,\n v,\n sm_scale,\n qkv_scale,\n out_scale,\n ):\n capability = torch.cuda.get_device_capability()\n if capability[0] < 8:\n raise RuntimeError(\"Flash attention currently only supported for compute capability >= 80\")\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in _SUPPORTED_SIZES\n o = torch.empty_like(q)\n grid = lambda META: (\n triton.cdiv(q.shape[2], META[\"BLOCK_M\"]),\n q.shape[0] * q.shape[1],\n 1,\n )\n if isinstance(qkv_scale, float):\n qkv_scale = torch.tensor(qkv_scale, device=q.device)\n out_scale = torch.tensor(out_scale, device=q.device)\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n qkv_scale,\n out_scale,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n BLOCK_DMODEL=Lk,\n )\n return o\n", - "description_1": "Use triton language to implement a forward kernel for fused attention. The kernel takes 25 parameters: Q, K, V (input tensors), sm_scale, qkv_scale_ptr, out_scale_ptr (scaling factors), Out (output tensor), 16 stride parameters for tensor dimensions, Z, H, N_CTX (context dimensions), and 3 constexpr parameters (EVEN_CTX, BLOCK_M, BLOCK_DMODEL, BLOCK_N). The kernel computes scaled dot-product attention using block pointers and stores the result in the output tensor.", - "description_2": "Use triton language to implement a fused attention forward function. The function takes 6 parameters: q, k, v (input tensors), sm_scale, qkv_scale, out_scale (scaling factors). It checks device capability, asserts shape constraints, and calls the triton kernel to compute the attention output.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _kernel(\n A,\n B,\n C,\n bias,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n a_scale_ptr,\n b_scale_ptr,\n out_scale_ptr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n EVEN_K: tl.constexpr,\n BIAS_ADD: tl.constexpr,\n A_PER_CHANNEL: tl.constexpr,\n B_PER_CHANNEL: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * BLOCK_K\n _0 = tl.zeros((1, 1), dtype=tl.int8)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n acc += tl.dot(a, b, allow_tf32=True, out_dtype=tl.int32)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n if A_PER_CHANNEL:\n _0 = tl.zeros((1,), dtype=a_scale_ptr.dtype.element_ty)\n mask = ram < M\n a_scale = tl.load(a_scale_ptr + ram, mask=mask, other=_0)\n else:\n a_scale = tl.load(a_scale_ptr)\n if B_PER_CHANNEL:\n _0 = tl.zeros((1,), dtype=b_scale_ptr.dtype.element_ty)\n mask = rbn < N\n b_scale = tl.load(b_scale_ptr + rbn, mask=mask, other=_0)\n else:\n b_scale = tl.load(b_scale_ptr)\n if BIAS_ADD:\n bias = tl.load(bias + rn)\n if A_PER_CHANNEL and B_PER_CHANNEL:\n bias = tl.math.llrint(bias / (a_scale[:, None] * b_scale[None, :])).to(tl.int32)\n acc = acc + bias\n else:\n bias = tl.math.llrint(bias / (a_scale * b_scale)).to(tl.int32)\n acc = acc + bias[None, :]\n\n if A_PER_CHANNEL and B_PER_CHANNEL:\n mask = ram < M\n _0 = tl.zeros((1,), dtype=out_scale_ptr.dtype.element_ty)\n out_scale = tl.load(out_scale_ptr + ram, mask=mask, other=_0)\n acc = tl.math.llrint((acc.to(tl.float32) * a_scale[:, None] * b_scale[None, :] * out_scale[:, None])).to(\n tl.int8\n )\n else:\n out_scale = tl.load(out_scale_ptr)\n acc = tl.math.llrint((acc.to(tl.float32) * (a_scale * b_scale * out_scale))).to(tl.int8)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc, mask=mask)\n\n\nclass _matmul(torch.autograd.Function):\n kernel = _kernel\n\n @staticmethod\n def forward(\n ctx,\n input,\n other,\n input_scale,\n other_scale,\n out_scale,\n bias=None,\n a_per_channel=False,\n b_per_channel=False,\n ) -> torch.Tensor:\n device = input.device\n if input.stride(0) > 1 and input.stride(1) > 1:\n input = input.contiguous()\n if other.stride(0) > 1 and other.stride(1) > 1:\n other = other.contiguous()\n assert input.shape[1] == other.shape[0], \"incompatible dimensions\"\n M, K = input.shape\n _, N = other.shape\n c = torch.empty((M, N), device=device, dtype=torch.int8)\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa: E731\n BIAS_ADD = 0 if bias is None else 1\n _kernel[grid](\n input,\n other,\n c,\n bias,\n M,\n N,\n K,\n input.stride(0),\n input.stride(1),\n other.stride(0),\n other.stride(1),\n c.stride(0),\n c.stride(1),\n a_scale_ptr=input_scale,\n b_scale_ptr=other_scale,\n out_scale_ptr=1.0 / out_scale,\n GROUP_M=8,\n BIAS_ADD=BIAS_ADD,\n A_PER_CHANNEL=a_per_channel,\n B_PER_CHANNEL=b_per_channel,\n )\n return c\n", - "description_1": "Use triton language to implement a quantized matrix multiplication kernel with support for per-channel scaling and optional bias addition. The kernel takes 21 parameters: matrices A, B, C, optional bias, dimensions M, N, K, strides for A, B, C, scaling factors for A, B, output, and several compile-time constants for block sizes and flags. The forward function in the _matmul class prepares inputs, sets up the execution grid, and calls the kernel with 17 parameters including input matrices, output matrix, dimensions, strides, scaling factors, and flags.", - "description_2": "Use triton language to create a quantized matrix multiplication operator with per-channel scaling and optional bias, using a kernel with 21 parameters and a forward function with 17 parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef clamp(x: tl.tensor, min_val, max_val) -> tl.tensor:\n \"\"\"Clamps all elements in `x` into range [min, max].\n\n Args:\n x (tl.tensor): the input tensor.\n min_val (Number): lower bound of the range.\n max_val (Number): upper bound of the range.\n\n Returns:\n tl.tensor: the output tensor.\n \"\"\"\n return tl.math.min(tl.math.max(x, min_val), max_val)\n\n@triton.jit\ndef dequantize(x: tl.tensor, scale: tl.tensor) -> tl.tensor:\n \"\"\"Dequantize quantized tensor to floating point.\n\n Args:\n x (tl.tensor): quantized tensor.\n scale (tl.tensor): quantization scaling factor\n\n Returns:\n tl.tensor: Dequantized floating-point tensor.\n \"\"\"\n return (x * scale).to(tl.float32)\n\n@triton.jit\ndef quantize(x, scale, qmin, qmax) -> tl.tensor:\n \"\"\"Quantize the tensor given quantization scale and data type.\n\n Args:\n x (tl.tensor): floating-point tensor\n scale (tl.tensor): quantization scale factor.\n qmin (Number): quantization minimum range.\n qmax (Number): quantization maximum range\n\n Returns:\n tl.tensor: rounded and clamped tensor.\n Note: this is still in floating point as we can't pass dtype to function\n\n Example:\n \n out = quantize(out, scale, -128, 127).to(tl.int8)\n \"\"\"\n return clamp(tl.math.round(x / scale), qmin, qmax)\n", - "description_1": "Use triton language to implement three functions: 'clamp', 'dequantize', and 'quantize'. The 'clamp' function takes three arguments: a tensor 'x', a minimum value 'min_val', and a maximum value 'max_val', and returns a tensor with all elements clamped within the specified range. The 'dequantize' function takes two arguments: a quantized tensor 'x' and a scaling factor 'scale', and returns a dequantized floating-point tensor. The 'quantize' function takes four arguments: a floating-point tensor 'x', a scaling factor 'scale', a minimum quantization range 'qmin', and a maximum quantization range 'qmax', and returns a rounded and clamped tensor.", - "description_2": "Use triton language to create a 'clamp' function to restrict tensor values within a range, a 'dequantize' function to convert quantized tensors to floating-point, and a 'quantize' function to quantize floating-point tensors with specified scale and range.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _kernel(\n A,\n B,\n C,\n bias,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n a_scale_ptr,\n b_scale_ptr,\n out_dtype: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n BIAS_ADD: tl.constexpr,\n A_PER_CHANNEL: tl.constexpr,\n B_PER_CHANNEL: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B (optional + bias).\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n if A_PER_CHANNEL:\n a_scale = tl.load(a_scale_ptr + ram)\n else:\n a_scale = tl.load(a_scale_ptr)\n\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * BLOCK_K\n _0 = tl.zeros((1, 1), dtype=tl.int8)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n if A_PER_CHANNEL:\n a = tl.math.llrint((a / a_scale[:, None])).to(tl.int8)\n else:\n a = tl.math.llrint((a / a_scale)).to(tl.int8)\n acc += tl.dot(a, b, allow_tf32=True, out_dtype=tl.int32)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n if B_PER_CHANNEL:\n b_scale = tl.load(b_scale_ptr + rbn)\n else:\n b_scale = tl.load(b_scale_ptr)\n if A_PER_CHANNEL and B_PER_CHANNEL:\n acc = (acc.to(tl.float32) * (a_scale[:, None] * b_scale[None, :])).to(out_dtype)\n else:\n acc = (acc.to(tl.float32) * (a_scale * b_scale)).to(out_dtype)\n if BIAS_ADD:\n bias = tl.load(bias + rn)\n acc = acc + bias[None, :]\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\nclass _matmul(torch.autograd.Function):\n kernel = _kernel\n\n @staticmethod\n def forward(\n ctx,\n a,\n b,\n a_scale,\n b_scale,\n bias=None,\n a_per_channel=False,\n b_per_channel=False,\n ):\n device = a.device\n out_dtype = a.dtype\n # handle non-contiguous inputs if necessary\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c = torch.empty((M, N), device=device, dtype=out_dtype)\n tl_outdtype = tl.float32\n if out_dtype == torch.float16:\n tl_outdtype = tl.float16\n elif out_dtype == torch.bfloat16:\n tl_outdtype = tl.bfloat16\n # launch kernel\n grid = lambda META: ( # noqa: E731\n triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n META[\"SPLIT_K\"],\n )\n BIAS_ADD = 0 if bias is None else 1\n _kernel[grid](\n a,\n b,\n c,\n bias,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n GROUP_M=8,\n BIAS_ADD=BIAS_ADD,\n a_scale_ptr=a_scale,\n b_scale_ptr=b_scale,\n out_dtype=tl_outdtype,\n A_PER_CHANNEL=a_per_channel,\n B_PER_CHANNEL=b_per_channel,\n )\n return c\n", - "description_1": "Use triton language to implement a kernel for quantized dynamic matrix multiplication. The kernel function '_kernel' takes 22 parameters including input matrices A and B, output matrix C, optional bias, dimensions M, N, K, strides for A, B, C, scale pointers, output data type, block sizes, group size, split factor, and flags for even K, bias addition, and per-channel scaling. The function performs matrix multiplication with optional bias addition and scaling, handling reduction-splitting if necessary. The '_matmul' class wraps this kernel for use in PyTorch's autograd, with a forward method that prepares inputs, checks constraints, allocates output, and launches the kernel with appropriate grid configuration.", - "description_2": "Use triton language to create a quantized dynamic matrix multiplication kernel with optional bias and scaling, and integrate it with PyTorch's autograd.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef example_kernel(X_ptr, Y_ptr, N, BLOCK_SIZE: tl.constexpr):\n # Triton kernel to perform element-wise addition\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offset < N\n x = tl.load(X_ptr + offset, mask=mask)\n y = tl.load(Y_ptr + offset, mask=mask)\n z = x + y\n tl.store(Y_ptr + offset, z, mask=mask)\n\n\ndef call_example_kernel(X, Y, N):\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)\n example_kernel[grid](X, Y, N, BLOCK_SIZE)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition on two arrays. The kernel is called example_kernel and takes in four parameters: X_ptr (pointer to the first array), Y_ptr (pointer to the second array), N (total number of elements), and BLOCK_SIZE (number of elements to process per block). The kernel loads a block of elements from each array, performs the addition, and stores the result back into the second array. The call_example_kernel function sets the BLOCK_SIZE and grid, then calls the kernel with the input arrays, result array, and element count.", - "description_2": "Use triton language to perform element-wise addition on two arrays using a specified block size for processing.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rmsnorm_triton(\n x_ptr,\n rms_w_ptr,\n output_ptr,\n stride_x_batch,\n stride_x_m,\n stride_x_k,\n stride_rms_w,\n stride_out_batch,\n stride_out_m,\n stride_out_k,\n N_SIZE: tl.constexpr,\n eps: tl.constexpr,\n BLOCK_N_SIZE: tl.constexpr,\n):\n pid_batch = tl.program_id(0)\n pid_m = tl.program_id(1)\n\n offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m\n block_N = tl.arange(0, BLOCK_N_SIZE)\n var = tl.zeros((BLOCK_N_SIZE,), tl.float32)\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n offs_n = block_n_start_idx + block_N\n x_ptr_mask = offs_n < N_SIZE\n x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)\n var += x.to(tl.float32) * x.to(tl.float32)\n\n var = tl.sum(var, axis=0) / N_SIZE\n rstd = 1 / tl.math.sqrt(var + eps)\n\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n offs_n = block_n_start_idx + block_N\n x_ptr_mask = offs_n < N_SIZE\n rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)\n\n x = tl.load(\n x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0\n ).to(tl.float32)\n x_hat = x * rstd\n out = x_hat * rms_w\n out_off = (\n pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k\n )\n tl.store(output_ptr + out_off, out, mask=x_ptr_mask)\n\n\ndef rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):\n out = torch.empty_like(x)\n if len(x.shape) == 3:\n batch, M, K = x.shape\n stride_x_batch, stride_x_m, stride_x_k = x.stride()\n stride_rms_w = rms_w.stride()[0]\n stride_out_batch, stride_out_m, stride_out_k = out.stride()\n else:\n batch, K = x.shape\n M = 1\n stride_x_batch, stride_x_k = x.stride()\n stride_x_m = 1\n stride_rms_w = rms_w.stride()[0]\n stride_out_batch, stride_out_k = out.stride()\n stride_out_m = 1\n assert rms_w.shape[-1] == K\n\n rmsnorm_triton[\n (\n batch,\n M,\n )\n ](\n x,\n rms_w,\n out,\n stride_x_batch,\n stride_x_m,\n stride_x_k,\n stride_rms_w,\n stride_out_batch,\n stride_out_m,\n stride_out_k,\n eps=eps,\n N_SIZE=K,\n )\n return out\n", - "description_1": "Use triton language to implement a RMS Norm kernel. The kernel function 'rmsnorm_triton' takes 13 parameters: three pointers (x_ptr, rms_w_ptr, output_ptr) for input, weights, and output data respectively; six strides (stride_x_batch, stride_x_m, stride_x_k, stride_rms_w, stride_out_batch, stride_out_m, stride_out_k) for accessing elements in the input and output tensors; and three compile-time constants (N_SIZE, eps, BLOCK_N_SIZE) for the size of the data, epsilon for numerical stability, and block size for processing. The wrapper function 'rmsnorm_triton_wrapper' prepares the input data and calls the kernel with appropriate launch grid configuration.", - "description_2": "Use triton language to create a kernel for RMS Norm that processes input data with given weights and outputs normalized results. The kernel is configured with specific strides and block sizes, and is called through a wrapper function that handles input preparation and kernel launch.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc,\n stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta):\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n #------------#\n #- Prologue -#\n #------------#\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n # output offset\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n # dense input offset\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8) # compiler hint\n offpb = 0\n # sparse input offset\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8) # compiler hint\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n # output offset\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n # dense input offset\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8) # compiler hint\n offpa = 0\n # sparse input offset\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8) # compiler hint\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n\n # initialize a, b pointers\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n\n ## ---------------- ##\n ## Inner Loop ##\n ## ---------------- ##\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n # pre-fetch\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n # write-back directly\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n # accumulate partial results using spin-locks\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n\ndef _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time):\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n AS0 = a.size(0)\n # Shape check\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]\n if a_inner != b_inner:\n raise ValueError(f\"Size of tensor A along the {a_dim} dim ({a_inner}) must match size \"\n f\"of tensor B along the {b_dim} dim ({b_inner})\")\n if a_inner % 16 != 0:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n\n batch_size = a.size(0)\n a_outer = a.size(3 if trans_a else 2)\n dtype = a.dtype\n is_16_multiple = a_inner % 16 == 0\n is_32_multiple = a_inner % 32 == 0\n is_64_multiple = a_inner % 64 == 0\n if not is_16_multiple:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n device = a.device\n # create kernel\n total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])\n c = torch.empty((batch_size, total_width, block, block), dtype=dtype, device=a.device)\n for lut, width, pack in zip(luts, widths, packs):\n F32TK = [8, 16]\n F16TK = [16]\n F16TK += [32] if is_32_multiple else []\n F16TK += [64] if is_64_multiple else []\n TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]\n num_lock = 1\n meta = {\n 'TM': block * pack,\n 'TN': block * pack,\n 'BLOCK': block,\n 'TK': TK[0],\n 'TZ': 1,\n 'SDD': True,\n 'DSD': False,\n 'DDS': False\n }\n # create output\n locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)\n # maximum grid size is 65535\n # so operation might be decomposed into multiple\n # kernel calls\n max_width = 49152\n total = 0 if bench else None\n for off_width in range(0, width, max_width):\n grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]\n _kernel[grid](a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(0),\n c.stride(2),\n c.stride(3),\n a_outer,\n a_outer,\n a_inner,\n off_width,\n lut,\n locks,\n num_lock,\n num_warps=4,\n **meta)\n # save for backward pass\n return c\n", - "description_1": "Use triton language to implement a sparse-dense-dense (SDD) matrix multiplication kernel. The kernel function '_kernel' takes 22 parameters including input matrices A, B, C, and various strides and metadata for matrix dimensions and block sizes. The '_sdd_matmul' function calls this kernel, handling input validation and setting up the grid for execution.", - "description_2": "Use triton language to create a kernel for SDD matrix multiplication with input validation and grid setup.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm,\n stride_zattnm, **meta):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n columnid = tl.load(LUT + offset + rbmn * 4 + 1)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # apply scale\n if meta['APPLY_SCALE']:\n x = x * scale\n # apply RPE\n if meta['APPLY_RPE']:\n prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn\n rpe = tl.load(prpe, mask=check, other=0)\n x = x + rpe\n # apply key-padding mask\n if meta['APPLY_KP_MASK']:\n pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn\n kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))\n if meta['KP_MASK_MUL']:\n kp_m = tl.where(kp_m == 0, -float('inf'), 0.)\n x = x + kp_m\n # apply attention mask\n if meta['APPLY_ATTN_MASK']:\n pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn\n attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))\n if meta['ATTN_MASK_MUL']:\n attn_m = tl.where(attn_m == 0, -float('inf'), 0.)\n x = x + attn_m\n # computation\n x = tl.softmax(x)\n tl.store(px, x, mask=check)\n\n@triton.jit\ndef _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n # compute fused softmax backward\n x = tl.load(X, mask=check, other=0)\n dx = tl.load(DX, mask=check, other=0)\n x = x.to(tl.float32)\n dx = dx.to(tl.float32)\n y = x * (dx - tl.sum(x * dx, 0)) * scale\n tl.store(DX, y, mask=check)\n\nclass _sparse_softmax(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,\n num_blocks, maxlut, bench, time):\n\n apply_scale = False if scale == 1.0 else True\n\n # handle None rpe\n if rpe is None:\n apply_rpe = False\n stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0\n rpe = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_rpe = True\n stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)\n\n # handle None key_padding_mask\n if key_padding_mask is None:\n apply_kp_mask = False\n stride_zkpm = 0\n key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_kp_mask = True\n stride_zkpm = key_padding_mask.stride(0)\n\n # handle None attention_mask\n if attn_mask is None:\n apply_attn_mask = False\n stride_zattnm = 0\n attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_attn_mask = True\n stride_zattnm = attn_mask.stride(0)\n\n # run kernel\n M = x.shape[0]\n meta = {\n 'BLOCK': block,\n 'APPLY_SCALE': apply_scale,\n 'APPLY_RPE': apply_rpe,\n 'APPLY_KP_MASK': apply_kp_mask,\n 'APPLY_ATTN_MASK': apply_attn_mask,\n 'KP_MASK_MUL': kp_mask_mode == 'mul',\n 'ATTN_MASK_MUL': attn_mask_mode == 'mul',\n }\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\\\n stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)\n\n # save to context\n ctx.mark_dirty(x)\n ctx.save_for_backward(x, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.apply_scale = apply_scale\n ctx.apply_rpe = apply_rpe\n ctx.apply_kp_mask = apply_kp_mask\n ctx.apply_attn_mask = apply_attn_mask\n ctx.kp_mask_mode = kp_mask_mode\n ctx.attn_mask_mode = attn_mask_mode\n return x\n\n @staticmethod\n def backward(ctx, dx):\n\n # retrieve from context\n x, lut = ctx.saved_tensors\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n", - "description_1": "Use triton language to implement block-sparse softmax and its backward pass. The _forward kernel takes 13 parameters: X (input tensor), scale (scaling factor), LUT (look-up table), RPE (relative position embedding), KP_M (key padding mask), ATTN_M (attention mask), sizemax (maximum size), stride_zx (stride for X), stride_zrpe (stride for RPE), stride_hrpe (stride for RPE head), stride_srpe (stride for RPE sequence), stride_zkpm (stride for key padding mask), and stride_zattnm (stride for attention mask). The _backward kernel takes 7 parameters: X (input tensor), scale (scaling factor), DX (gradient tensor), LUT (look-up table), sizemax (maximum size), stride_zx (stride for X), and stride_zdx (stride for DX).", - "description_2": "Use triton language to create a block-sparse softmax function with forward and backward kernels, handling optional scaling, relative position embedding, key padding mask, and attention mask.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nminus_inf = -10000.0\n\n@triton.jit\ndef _flash_packed_kernel(\n QKV,\n mask,\n ADD_MASK: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n sm_scale,\n Out,\n stride_qz,\n stride_qn,\n stride_qm,\n stride_mz,\n stride_oz,\n stride_on,\n Z,\n H,\n N_CTX,\n P_SEQ,\n hidden_size,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n batch = off_hz // H\n head = off_hz % H\n\n q_offset = batch * stride_qz + head * BLOCK_DMODEL\n k_offset = q_offset + hidden_size\n v_offset = k_offset + hidden_size\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qn + offs_d[None, :]\n k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]\n v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]\n\n # mask\n off_mask = batch * stride_mz + offs_n[None, :]\n mask_ptrs = mask + off_mask\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # scale sm_scale by log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)\n q = (q * qk_scale).to(tl.float16)\n # loop over k, v and update accumulator\n lo = 0\n hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n k = tl.load(k_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)\n v = tl.load(v_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)\n\n if ADD_MASK:\n mask_val = tl.load(mask_ptrs)\n mask_ptrs += BLOCK_N\n qk = qk + mask_val.to(tl.float32)\n\n if IS_CAUSAL:\n qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16)\n qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf)\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc_scale = l_i * 0 + alpha # workaround some compiler bug\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.float16), v.to(tl.float16))\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n # write back l and m\n acc = acc / l_i[:, None]\n o_offset = batch * stride_oz + head * BLOCK_DMODEL\n out_ptrs = Out + o_offset + (offs_m[:, None] * stride_on + offs_d[None, :])\n tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX)\n\n\ndef _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True):\n heads = qkv.shape[-1] // 3 // head_size\n hidden_size = qkv.shape[-1] // 3\n\n BLOCK_M = 128\n BLOCK_N = 64 if head_size <= 64 else 32\n\n o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half)\n if mask is None:\n mask = torch.empty(0)\n add_mask = False\n\n grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1)\n num_stages = 4 if head_size <= 64 else 3\n num_warps = 4\n P_SEQ = 0\n\n _flash_packed_kernel[grid](qkv,\n mask,\n add_mask,\n causal,\n sm_scale,\n o,\n qkv.stride(0),\n qkv.stride(1),\n qkv.stride(2),\n mask.stride(1) if add_mask else 0,\n o.stride(0),\n o.stride(1),\n qkv.shape[0],\n heads,\n qkv.shape[1],\n P_SEQ,\n hidden_size,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL=head_size,\n num_warps=num_warps,\n num_stages=num_stages)\n\n return o\n", - "description_1": "Use triton language to implement a flash attention kernel. The kernel function '_flash_packed_kernel' takes 18 parameters: QKV (query, key, value tensor), mask (attention mask), ADD_MASK (whether to add mask), IS_CAUSAL (whether the attention is causal), sm_scale (softmax scale), Out (output tensor), stride_qz, stride_qn, stride_qm (strides for QKV tensor), stride_mz (stride for mask), stride_oz, stride_on (strides for output tensor), Z (batch size), H (number of heads), N_CTX (context size), P_SEQ (sequence length), hidden_size (hidden size), BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes for matrix multiplication). The function '_triton_packed_flash' is a wrapper that sets up the grid and calls the kernel with appropriate parameters.", - "description_2": "Use triton language to create a flash attention mechanism with a kernel function that processes QKV tensors and applies optional masking and causal attention. The kernel is executed with a grid configuration based on the input tensor dimensions and block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom deepspeed.accelerator import get_accelerator\n\n@triton.jit\ndef gelu_functor(x):\n # Using approximation introduces greater parity errors.\n # return tl.sigmoid(1.702 * x) * x\n return x * 0.5 * (1.0 + tl.math.erf(x / 1.41421356237))\n\n@triton.jit\ndef gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = gelu_functor(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef gelu(activations: torch.Tensor) -> torch.Tensor:\n assert activations.is_contiguous()\n assert get_accelerator().on_accelerator(activations)\n\n output = torch.empty_like(activations)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024)\n return output\n", - "description_1": "Use triton language to implement a GELU activation function. The `gelu_functor` kernel takes one parameter `x` (a tensor element) and applies the GELU function using the error function approximation. The `gelu_kernel` takes four parameters: `x_ptr` (pointer to input tensor), `output_ptr` (pointer to output tensor), `n_elements` (number of elements in the tensor), and `BLOCK_SIZE` (block size for parallel execution). It computes the GELU activation for each element in the input tensor and stores the result in the output tensor. The `gelu` function is a wrapper that prepares the input tensor, sets up the grid for kernel execution, and calls the `gelu_kernel`.", - "description_2": "Use triton language to implement a GELU activation function with a functor and a kernel, and provide a wrapper function to execute the kernel on a tensor.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef layer_norm_kernel(\n Out,\n A,\n Weight,\n Bias,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.0)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0.0).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\n@triton.jit\ndef layer_norm_residual_kernel(\n Out,\n A,\n Residual,\n ln_input,\n Weight,\n Bias,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n Residual += row * stride\n ln_input += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = a + res\n tl.store(ln_input + cols, a, mask=cols < N)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.0)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\n@triton.jit\ndef layer_norm_residual_bias_kernel(\n Out,\n A,\n Residual,\n InputBias,\n ln_input,\n Weight,\n Bias,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n Residual += row * stride\n ln_input += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)\n res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)\n b = tl.load(InputBias + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = a + b + res\n tl.store(ln_input + cols, a, mask=cols < N)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.0)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n tl.store(Out + cols, out, mask=mask)\n\ndef layer_norm(a, weight, bias, eps):\n assert a.is_contiguous()\n assert weight.is_contiguous()\n assert bias.is_contiguous()\n\n out = torch.empty_like(a)\n a_arg = a.view(-1, a.shape[-1])\n M, N = a_arg.shape\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n layer_norm_kernel[(M, )](\n out,\n a_arg,\n weight,\n bias,\n a_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return out\n\ndef layer_norm_residual(a, input_bias, residual, weight, bias, eps):\n assert a.is_contiguous()\n assert weight.is_contiguous()\n assert bias.is_contiguous()\n assert residual.is_contiguous()\n\n out = torch.empty_like(a)\n ln_input = torch.empty_like(a)\n a_arg = a.view(-1, a.shape[-1])\n residual = residual.view(-1, residual.shape[-1])\n M, N = a_arg.shape\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n if input_bias is None:\n layer_norm_residual_kernel[(M, )](\n out,\n a_arg,\n residual,\n ln_input,\n weight,\n bias,\n a_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n else:\n layer_norm_residual_bias_kernel[(M, )](\n out,\n a_arg,\n residual,\n input_bias,\n ln_input,\n weight,\n bias,\n a_arg.stride(0),\n N,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return out\n", - "description_1": "Use triton language to implement a layer normalization kernel with 8 parameters: Out, A, Weight, Bias, stride, N, eps, BLOCK_SIZE, where 'Out' is the output tensor, 'A' is the input tensor, 'Weight' and 'Bias' are parameters for scaling and shifting, 'stride' is the stride of rows in memory, 'N' is the number of columns to normalize, 'eps' is a small constant to prevent division by zero, and 'BLOCK_SIZE' is a triton constant for block size. Implement additional kernels for layer normalization with residual connections, with and without input bias, using similar parameters and logic.", - "description_2": "Use triton language to create layer normalization and layer normalization with residual kernels, which compute the mean and variance of the input tensor, apply normalization, and optionally add residuals and biases, using parameters for input/output tensors, weights, biases, strides, and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n K_block_ptr = tl.make_block_ptr(base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1))\n V_block_ptr = tl.make_block_ptr(base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0))\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.float16)\n lo = 0\n hi = N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.float16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n acc = acc / l_i[:, None]\n O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n tl.store(O_block_ptr, acc.to(tl.float16))\n\n\nclass triton_flash_attn(torch.nn.Module):\n\n def __init__(self, ):\n super(triton_flash_attn, self).__init__()\n\n def forward(self, q, k, v, sm_scale, block_128=True):\n BLOCK = 128 if block_128 else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n k.shape[0],\n k.shape[1],\n k.shape[2],\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=1,\n )\n return o\n", - "description_1": "Use triton language to implement a forward kernel for computing the attention mechanism. The kernel has 25 parameters: Q, K, V (input tensors), sm_scale (a scaling factor), Out (output tensor), multiple stride parameters for addressing, Z, H, and N_CTX for grid dimensions, and BLOCK_M, BLOCK_DMODEL, BLOCK_N which are block dimensions marked as constexpr. The kernel processes blocks of Q, K, V to compute the attention scores and values, storing results in Out.", - "description_2": "Use triton language to create a torch.nn.Module named triton_flash_attn. The forward method takes 5 parameters: q, k, v (input tensors), sm_scale (a scaling factor), and block_128 (a boolean to determine block size). It computes an output tensor using the _fwd_kernel, adapting execution grid and warps based on input dimensions and the provided block size.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom deepspeed.accelerator import get_accelerator\n\n@triton.jit\ndef residual_add_bias_kernel(\n hidden_state_ptr,\n residual_ptr,\n attn_output_ptr,\n hidden_state_size,\n attn_bias_ptr,\n final_bias_ptr,\n bias_size,\n output_ptr,\n mp_size: tl.constexpr,\n mlp_after_attn: tl.constexpr,\n pre_attn_norm: tl.constexpr,\n add_attn_bias: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n\n block_start = pid * BLOCK_SIZE\n\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < hidden_state_size\n\n bias_offsets = offsets % bias_size\n bias_mask = bias_offsets < bias_size\n\n tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask)\n tl_residual = tl.load(residual_ptr + offsets, mask=mask)\n tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask)\n tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask)\n tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask)\n\n if mlp_after_attn:\n if pre_attn_norm:\n output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size\n else:\n output = tl_hidden_state + tl_residual + tl_final_bias\n else:\n output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size\n if add_attn_bias:\n output += tl_attn_bias / mp_size\n\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor,\n attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,\n add_attn_bias: bool, pre_attn_norm: bool):\n # check that all tensors are on the same device\n assert get_accelerator().on_accelerator(hidden_state) \\\n and get_accelerator().on_accelerator(residual) \\\n and get_accelerator().on_accelerator(attn_output) \\\n and get_accelerator().on_accelerator(attn_bias) \\\n and get_accelerator().on_accelerator(final_bias)\n\n # check that all tensors have the same dtype\n assert hidden_state.dtype == residual.dtype == attn_output.dtype \\\n == attn_bias.dtype == final_bias.dtype\n\n # check that all tensors have the right shape\n assert hidden_state.shape == residual.shape == attn_output.shape\n assert attn_bias.shape == final_bias.shape\n assert attn_bias.shape[0] == hidden_state.shape[2]\n\n output = torch.empty_like(hidden_state)\n\n hidden_state_size = output.numel()\n bias_size = attn_bias.numel()\n\n grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), )\n\n residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\\\n attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \\\n add_attn_bias, \\\n BLOCK_SIZE=1024)\n\n return output\n", - "description_1": "Use triton language to implement a kernel function 'residual_add_bias_kernel' that performs element-wise addition of hidden state, residual, attention output, and biases with optional scaling and normalization. The kernel takes 13 parameters: pointers to hidden state, residual, attention output, attention bias, final bias, and output, sizes of hidden state and bias, and several compile-time constants for configuration. The function 'residual_add_bias' wraps this kernel, ensuring input tensors are on the same device and have compatible shapes and types, and launches the kernel with a computed grid size.", - "description_2": "Use triton language to create a kernel for element-wise tensor addition with optional scaling and biasing, and a wrapper function to prepare and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for softmax without mask\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr):\n row_idx = tl.program_id(0)\n row_start_ptr = input_ptr + row_idx * stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\n# Triton kernel for softmax with mask\n@triton.jit\ndef masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n row_idx = tl.program_id(0)\n row_start_ptr = input_ptr + row_idx * stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride # mask_stride is 0 for 1d mask\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)\n mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)\n row_minus_max = row - tl.max(row, axis=0)\n row_minus_max = row_minus_max + mask\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\n# Wrapper function to call the appropriate Triton softmax kernel\ndef softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:\n assert input.is_contiguous()\n assert (dim == -1) or (dim == len(input.shape) - 1), \"Only dim=-1 is supported\"\n\n use_mask = False if mask is None else True\n input_arg = input.view(-1, input.shape[-1])\n n_rows, n_cols = input_arg.shape\n BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2)\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n output = torch.empty_like(input)\n if use_mask:\n assert mask.is_contiguous()\n mask = mask.view(-1, mask.shape[-1])\n mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0\n masked_softmax_kernel[(n_rows, )](\n output,\n input,\n input_arg.stride(0),\n mask,\n mask_stride,\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n else:\n softmax_kernel[(n_rows, )](\n output,\n input,\n input_arg.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return output\n", - "description_1": "Use triton language to define a softmax kernel (softmax_kernel) which computes the softmax of a matrix without masking. This kernel takes in 5 parameters: the output pointer, input pointer, stride, number of columns, and a BLOCK_SIZE as a constant expression. Another softmax kernel with masking (masked_softmax_kernel) is defined to handle masked softmax computation. It takes in 7 parameters: the output pointer, input pointer, stride, mask pointer, mask stride, number of columns, and a BLOCK_SIZE as a constant expression. The function softmax is a wrapper to choose the appropriate kernel based on whether a mask is provided and prepare parameters for kernel execution.", - "description_2": "Use triton language to define a softmax kernel and its masked version, then implement a wrapper function to execute the appropriate kernel based on input conditions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom .gelu import gelu_functor\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 256,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=3, num_warps=8),\n triton.Config({\n 'BLOCK_M': 256,\n 'BLOCK_N': 128,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=3, num_warps=8),\n triton.Config({\n 'BLOCK_M': 256,\n 'BLOCK_N': 64,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 64,\n 'BLOCK_N': 256,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 128,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 64,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 64,\n 'BLOCK_N': 128,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 128,\n 'BLOCK_N': 32,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=4, num_warps=4),\n triton.Config({\n 'BLOCK_M': 64,\n 'BLOCK_N': 32,\n 'BLOCK_K': 32,\n 'SPLIT_K': 1\n }, num_stages=5, num_warps=2),\n ],\n key=['CACHE_M', 'CACHE_N', 'CACHE_K'],\n prune_configs_by={\n 'early_config_prune': _fp16_matmul_prune_config,\n 'perf_model': None,\n 'top_k': AUTOTUNE_TOP_K\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _fp_matmul(\n A,\n B,\n C,\n M,\n N,\n K,\n bias,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n CACHE_M,\n CACHE_N,\n CACHE_K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr,\n BIAS_ADD: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(K, 0, -BLOCK_K * SPLIT_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.)\n b = tl.load(B, mask=rk[:, None] < k, other=0.)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n # bias addition\n if BIAS_ADD:\n bias_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n bias_ptr = bias + bias_offset\n b = tl.load(bias_ptr, mask=bias_offset < N)\n acc = acc + b[None, :]\n # activation\n if ACTIVATION == \"relu\":\n acc = tl.where(acc >= 0, acc, 0)\n elif ACTIVATION == \"leaky_relu\":\n acc = tl.where(acc >= 0, acc, 0.01 * acc)\n elif ACTIVATION == \"gelu\":\n #acc = tl.sigmoid(1.702 * acc) * acc\n acc = gelu_functor(acc)\n elif ACTIVATION == \"sigmoid\":\n acc = tl.sigmoid(acc) # sigmoid\n acc = acc.to(C.dtype.element_ty)\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8\n },\n num_stages=1, # this is mainly for unit test, to minimize the share memory usage\n num_warps=8),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=['CACHE_M', 'CACHE_N', 'CACHE_K'],\n prune_configs_by={\n 'early_config_prune': matmul_4d_prune_config,\n 'perf_model': None,\n 'top_k': AUTOTUNE_TOP_K\n },\n)\n@triton.jit\ndef matmul_4d_kernel(\n # Pointers to matrices\n a_ptr,\n b_ptr,\n c_ptr,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_M,\n CACHE_N,\n CACHE_K,\n stride_ab,\n stride_ah,\n stride_am,\n stride_ak,\n stride_bb,\n stride_bh,\n stride_bk,\n stride_bn,\n stride_cb,\n stride_ch,\n stride_cm,\n stride_cn,\n scale,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MASK: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n head = tl.program_id(axis=1)\n batch = tl.program_id(axis=2)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n if MASK:\n if (pid_m + 1) * BLOCK_SIZE_M - 1 < pid_n * BLOCK_SIZE_N:\n c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.dtype.element_ty) - float(\"inf\")\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +\n stride_cn * offs_cn[None, :])\n tl.store(c_ptrs, c)\n return\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +\n (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))\n b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +\n (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)\n b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)\n a = tl.load(a_ptrs, mask=a_mask, other=0.)\n b = tl.load(b_ptrs, mask=b_mask, other=0.)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n c = accumulator.to(c_ptr.dtype.element_ty)\n if scale > 0:\n c = c * scale.to(c_ptr.dtype.element_ty)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if MASK:\n c += tl.where(offs_cm[:, None] >= offs_cn[None, :], 0, float(\"-inf\"))\n c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +\n stride_cn * offs_cn[None, :])\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n", - "description_1": "Use triton language to implement two matrix multiplication kernels. The first kernel, _fp_matmul, takes 22 parameters including matrices A, B, C, dimensions M, N, K, and various strides and constants. It performs matrix multiplication with optional bias addition and activation functions. The second kernel, matmul_4d_kernel, takes 22 parameters including pointers to matrices a_ptr, b_ptr, c_ptr, dimensions M, N, K, and various strides and constants. It computes the matrix multiplication C = A x B with optional scaling and masking.", - "description_2": "Use triton language to implement matrix multiplication kernels with optional bias, activation, scaling, and masking.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# triton kernel\n@triton.jit\ndef kernel(X, stride_xm, #\n Z, stride_zn, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n off_m = tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, BLOCK_N)\n Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1\n Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn\n tl.store(Zs, tl.load(Xs))\n\nret = triton.compile(kernel, signature=\"*fp32,i32,*fp32,i32\", constants={\"BLOCK_M\": 64, \"BLOCK_N\": 64})\nprint(ret.asm[\"ttgir\"])\n", - "description_1": "Use triton language to define a kernel function that takes four arguments: X (tensor pointer of fp32), stride_xm (integer), Z (tensor pointer of fp32), and stride_zn (integer). The kernel uses two constexpr parameters, BLOCK_M and BLOCK_N, to define the block size. The kernel performs a block-wise memory load from X and stores the data into Z using the defined strides.", - "description_2": "Use triton language to define and compile a kernel for block-wise memory transfer between two fp32 tensors with configurable strides and block size.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport importlib.util\nfrom triton.common.backend import register_backend\n\nclass ExtensionBackend:\n stub_so_path = \"\"\n\ndef test_dummy_backend():\n register_backend(\"cpu\", ExtensionBackend)\n\n @triton.jit\n def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 10\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), xmask)\n tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)\n\n inp = torch.randn(10)\n out = torch.randn(10)\n kernel[(10, )](inp, out, 10, XBLOCK=16)\n spec = importlib.util.spec_from_file_location(\"__triton_launcher\", ExtensionBackend.stub_so_path)\n mod = importlib.util.module_from_spec(spec)\n spec.loader.exec_module(mod)\n launch_counter = getattr(mod, \"launch_counter\")\n\n for _ in range(100):\n kernel[(10, )](inp, out, 10, XBLOCK=16)\n\n assert launch_counter() > 0\n", - "description_1": "Use triton language to define a kernel that loads data from an input pointer, processes it, and stores it to an output pointer. The kernel takes four parameters: in_ptr0 (input pointer), out_ptr0 (output pointer), xnumel (number of elements), and XBLOCK (block size). The kernel is launched with a grid size of 10 and a block size of 16.", - "description_2": "Use triton language to create a kernel that performs element-wise operations on input data and stores the result in an output buffer, with specific grid and block dimensions.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit()\ndef kernel(x_ptr, y_ptr, out_ptr):\n # Triton kernel to perform element-wise addition of two vectors\n pid = tl.program_id(axis=0)\n x = tl.load(x_ptr + pid)\n y = tl.load(y_ptr + pid)\n out = x + y\n tl.store(out_ptr + pid, out)\n\ndef test_xpu_backend(cmdopt):\n if cmdopt == \"xpu\":\n has_ipex = False\n try:\n import intel_extension_for_pytorch # type: ignore # noqa: F401\n has_ipex = True if hasattr(torch, \"xpu\") else False\n except Exception:\n has_ipex = False\n\n if has_ipex:\n for _ in range(1000):\n x = torch.randn((65536, ), device=\"xpu\", dtype=torch.float32)\n y = torch.randn((65536, ), device=\"xpu\", dtype=torch.float32)\n z = torch.zeros((65536, ), device=\"xpu\", dtype=torch.float32)\n # Call the Triton kernel\n kernel[(65536, )](x, y, z, num_warps=32)\n assert torch.all(x + y == z)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition of two vectors. The kernel takes three pointers as arguments: x_ptr, y_ptr, and out_ptr, which point to the input vectors and the output vector, respectively. The kernel uses the program_id to identify the current element to process. The kernel is called with a grid size of 65536 and num_warps set to 32. The test_xpu_backend function checks for Intel GPU runtime support and calls the kernel 1000 times with random input vectors on the 'xpu' device.", - "description_2": "Use triton language to create a kernel for element-wise vector addition and execute it on an Intel GPU if available.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom numpy.random import RandomState\nimport numpy as np\n\ndef test_chained_matmul():\n def chained_matmul_reference(a, b, c):\n intermediate = torch.einsum('MK,NK->MN', a, b)\n return torch.einsum('MN,NK->MK', intermediate, c)\n\n @triton.jit\n def chained_matmul_kernel(A, # shape: (m, k)\n B, # shape: (n, k)\n C, # shape: (n, k)\n out, # shape: (m, k)\n m, n, k: tl.constexpr, #\n block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):\n\n tl.static_assert(block_k == k, f\"expected block_k == k but got {block_k} != {k}\")\n\n block_ix = tl.program_id(0)\n a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \\\n + tl.arange(0, block_k)[None, :]\n\n a = tl.load(A + a_tile, mask=a_tile < m * k, other=0.0)\n\n acc = tl.zeros([block_m, block_k], dtype=tl.float32)\n\n for loop_block_start in range(0, n, block_n):\n bc_tile = (loop_block_start + tl.arange(0, block_n))[:, None] * block_k \\\n + tl.arange(0, block_k)[None, :]\n b = tl.load(B + bc_tile, mask=bc_tile < n * k, other=0.0)\n\n intermediate = tl.dot(a, tl.trans(b))\n intermediate_mask = ((loop_block_start + tl.arange(0, block_n)) < n)[None, :] \\\n * (tl.arange(0, block_m) < m)[:, None]\n\n intermediate = tl.where(intermediate_mask, intermediate, 0.0)\n\n c = tl.load(C + bc_tile, mask=bc_tile < n * k)\n\n acc += tl.dot(intermediate.to(A.dtype.element_ty), c)\n\n tl.store(out + a_tile, acc.to(A.dtype.element_ty), mask=a_tile < m * k)\n\n m, n, k = 32, 64, 128\n block_m, block_n, block_k = 16, 32, k\n\n grid = (triton.cdiv(m, block_m), )\n a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda')\n b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda')\n c = torch.randint_like(b, low=0, high=2)\n triton_result = torch.zeros_like(a)\n\n torch_result = chained_matmul_reference(a, b, c)\n chained_matmul_kernel[grid](\n a, b, c, triton_result, m, n, k, #\n block_m=block_m, block_n=block_n, block_k=block_k)\n\n assert (torch_result == triton_result).all()\n\n\ndef test_vecmat():\n\n @triton.jit\n def batched_vecmat(\n # inputs\n A, # shape: [dim_m, dim_k]\n B, # shape: [dim_m, dim_n, dim_k]\n # dimensions\n dim_m, dim_n, dim_k,\n # outputs\n output,\n # block information\n block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):\n m_index = tl.program_id(0)\n n_index = tl.program_id(1)\n # Output tile\n output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \\\n + (n_index * block_n + tl.arange(0, block_n))[None, :]\n\n vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)\n k_blocks = dim_k // block_k\n for k_index in range(k_blocks):\n # Load A tile\n a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \\\n + (k_index * block_k + tl.arange(0, block_k))[None, :]\n a = tl.load(A + a_tile)\n\n # Load B tile, transposed to [n, m, k] in order to broadcast A on a\n # leading dimension.\n b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \\\n + (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \\\n + (k_index * block_k + tl.arange(0, block_k))[None, None, :]\n b = tl.load(B + b_tile)\n\n expanded_a, _ = tl.broadcast(a, b)\n vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))\n\n tl.store(output + output_tile, vecmat)\n\n M, N, K = 128, 128, 128\n block_m, block_n, block_k = 16, 32, 64\n\n rs = RandomState(17)\n A_vec = rs.randint(0, 4, (M, K)).astype('float32')\n B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')\n A = A_vec\n B = B_vec\n\n A_tri = torch.tensor(A, device='cuda')\n B_tri = torch.tensor(B, device='cuda')\n C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')\n\n grid = (M // block_m, N // block_n)\n\n batched_vecmat[grid](\n A_tri, B_tri, M, N, K, C_tri, #\n block_m=block_m, block_n=block_n, block_k=block_k, #\n num_warps=4, num_stages=1)\n\n A_expanded = A[:, np.newaxis, :]\n A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))\n AB = A_broadcasted * B\n C_ref = np.sum(AB, axis=2)\n\n np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)\n\n\ndef test_iv_dependent_matmul(type):\n\n @triton.jit\n def kernel(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #\n type: tl.constexpr):\n pid = tl.program_id(axis=0)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n pid_m = pid // num_pid_n\n pid_n = pid % num_pid_n\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n a_ptrs = a_ptr\n b_ptrs = b_ptr\n if type == \"post_load_two_iters\":\n a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_three_iters\":\n a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk\n a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak\n b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n if type == \"pre_load\":\n a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak\n b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk\n elif type == \"post_pre_mixed\":\n a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n if type == \"post_load\":\n a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak\n b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_pre_mixed\":\n b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_two_iters\":\n a_ptrs = a_ptrs_next\n b_ptrs = b_ptrs_next\n a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_three_iters\":\n a_ptrs = a_ptrs_next\n b_ptrs = b_ptrs_next\n a_ptrs_next = a_ptrs_next_next\n b_ptrs_next = b_ptrs_next_next\n a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak\n b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n M = 256\n K = 256\n N = 256\n BLOCK_SIZE_K = 32\n BLOCK_SIZE_N = 32\n BLOCK_SIZE_M = 32\n\n a = torch.rand((M, K), device='cuda')\n b = torch.rand((K, N), device='cuda')\n\n torch_output = torch.mm(a, b)\n triton_output = torch.empty_like(torch_output, device=torch_output.device)\n\n def grid(META):\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n\n num_stages = 4 if type == \"post_load_three_iters\" else 3\n kernel[grid](\n a, b, triton_output, M, N, K, #\n a.stride(0), a.stride(1), b.stride(0), b.stride(1), #\n triton_output.stride(0), triton_output.stride(1), #\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #\n num_stages=num_stages)\n torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)\n", - "description_1": "Use triton language to implement three kernels: 1) 'chained_matmul_kernel' for performing a chained matrix multiplication on inputs A, B, and C with output stored in 'out'. It requires parameters for matrix dimensions (m, n, k) and block sizes (block_m, block_n, block_k). 2) 'batched_vecmat' for computing a batched vector-matrix multiplication with inputs A and B, output stored in 'output', and requires dimensions (dim_m, dim_n, dim_k) and block sizes (block_m, block_n, block_k). 3) 'kernel' for an induction variable dependent matrix multiplication with inputs a_ptr, b_ptr, c_ptr, dimensions (M, N, K), strides (stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn), block sizes (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K), and a type parameter to determine the loading strategy.", - "description_2": "Use triton language to create kernels for matrix operations: 1) a chained matrix multiplication with specific block sizes and dimensions, 2) a batched vector-matrix multiplication with broadcasting, and 3) an induction variable dependent matrix multiplication with configurable loading strategies.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Element-Wise Addition Kernel\n@triton.jit\ndef _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Element-Wise Addition Test\ndef test_elementwise(N, dtype_str):\n stream = torch.cuda.Stream()\n torch.cuda.set_stream(stream)\n torch.manual_seed(0)\n dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str]\n z = torch.empty((N, ), dtype=dtype, device='cuda')\n x = torch.randn_like(z)\n y = torch.randn_like(z)\n grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )\n fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)\n ms = triton.testing.do_bench_cudagraph(fn)\n\n# Reduction Kernel\n@triton.jit\ndef _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n # run in a loop to only to make it compute bound.\n for i in range(100):\n x = tl.sum(x, axis=0) + y\n tl.store(output_ptr + offsets, x, mask=mask)\n\n# Reduction Test\ndef test_reductions(N, dtype_str):\n stream = torch.cuda.Stream()\n torch.cuda.set_stream(stream)\n torch.manual_seed(0)\n dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str]\n z = torch.empty((N, ), dtype=dtype, device='cuda')\n if dtype == torch.float16 or dtype == torch.float32:\n x = torch.randn_like(z)\n y = torch.randn_like(z)\n else:\n info = torch.iinfo(dtype)\n x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')\n y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')\n grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )\n fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)\n ms = triton.testing.do_bench_cudagraph(fn)\n", - "description_1": "Use triton language to implement two kernels: one for element-wise addition and another for reduction. The element-wise addition kernel (_add) takes five parameters: pointers to input arrays x and y, a pointer to the output array, the number of elements to process, and a block size. It performs addition on elements of x and y and stores the result in the output array. The reduction kernel (_sum) also takes five parameters: pointers to input arrays x and y, a pointer to the output array, the number of elements to process, and a block size. It performs a reduction operation by summing elements of x and y in a loop to make it compute-bound, and stores the result in the output array.", - "description_2": "Use triton language to create a kernel for element-wise addition of two arrays and another kernel for performing a reduction operation by summing elements of two arrays in a loop.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(Q, K, V, sm_scale, #\n L, M, #\n Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, N_CTX, D0, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n\n # TODO: may replace with TMA store without range offset\n # initialize offsets for store\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n stride_qh_2d = stride_qh // stride_qm // stride_qk\n\n q_tile_ptr = tl.make_block_ptr(\n base=Q,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_tile_ptr = tl.make_block_ptr(\n base=K,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(off_hz * stride_qh_2d, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n v_tile_ptr = tl.make_block_ptr(\n base=V,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(off_hz * stride_qh_2d, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n out_tile_ptr = tl.make_block_ptr(\n base=Out,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n # load q: it will stay in SRAM throughout\n q = tl.load(q_tile_ptr)\n\n # loop over k, v and update accumulators\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_tile_ptr, boundary_check=(0, 1))\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(tl.float16)\n v = tl.load(v_tile_ptr, boundary_check=(0, 1))\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0])\n v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0])\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n\n acc = acc.to(tl.float16)\n tl.store(out_tile_ptr, acc, boundary_check=(0, 1))\n\n\n@triton.jit\ndef _bwd_preprocess(Out, DO, L, #\n NewDO, Delta, #\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(Q, K, V, sm_scale, Out, DO, #\n DQ, DK, DV, #\n L, M, #\n D, stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n Z, H, N_CTX, D0, #\n num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # init tile_ptr\n stride_qz_2d = stride_qz // stride_qm // stride_qk\n stride_qh_2d = stride_qh // stride_qm // stride_qk\n\n q_tile_ptr = tl.make_block_ptr(\n base=Q,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_tile_ptr = tl.make_block_ptr(\n base=K,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n v_tile_ptr = tl.make_block_ptr(\n base=V,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n do_tile_ptr = tl.make_block_ptr(\n base=DO,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n dq_tile_ptr = tl.make_block_ptr(\n base=DQ,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n dk_tile_ptr = tl.make_block_ptr(\n base=DK,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n dv_tile_ptr = tl.make_block_ptr(\n base=DV,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n # offset pointers for batch/head\n DQ += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_tile_ptr, boundary_check=(0, 1))\n v = tl.load(v_tile_ptr, boundary_check=(0, 1))\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_tile_ptr, boundary_check=(0, 1))\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, tl.trans(k))\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_tile_ptr, boundary_check=(0, 1))\n dv += tl.dot(tl.trans(p.to(tl.float16)), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, tl.trans(v))\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(tl.trans(ds.to(tl.float16)), q)\n # compute dq\n dq = tl.load(dq_tile_ptr)\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_tile_ptr, dq)\n # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0])\n do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0])\n dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0])\n q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])\n do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])\n dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])\n # increment tile pointers\n k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0])\n v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0])\n # write-back\n tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1))\n tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1))\n dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0])\n dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0])\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n D0 = q.shape[0] * q.shape[1] * q.shape[2]\n _fwd_kernel[grid](\n q, k, v, sm_scale, #\n L, m, #\n o, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], q.shape[2], D0, #\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, #\n num_warps=num_warps, num_stages=2)\n\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n BLOCK = 128\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n D0 = q.shape[0] * q.shape[1] * q.shape[2]\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l, #\n do_scaled, delta, #\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)\n _bwd_kernel[(ctx.grid[1], )](\n q, k, v, ctx.sm_scale, #\n o, do_scaled, #\n dq, dk, dv, #\n l, m, #\n delta, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n q.shape[0], q.shape[1], q.shape[2], D0, #\n ctx.grid[0], #\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, #\n num_warps=8, num_stages=1)\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention mechanism with forward and backward kernels. The forward kernel (_fwd_kernel) takes 25 parameters: Q, K, V (query, key, value tensors), sm_scale (softmax scaling factor), L, M (intermediate tensors for storing results), Out (output tensor), various stride parameters for memory access, Z, H, N_CTX, D0 (dimensions and context size), and BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes for computation). The backward kernel (_bwd_kernel) takes 30 parameters: Q, K, V, sm_scale, Out, DO (derivative of output), DQ, DK, DV (derivatives of Q, K, V), L, M, D (intermediate tensors), various stride parameters, Z, H, N_CTX, D0, num_block (number of blocks), and BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes). The _bwd_preprocess function is used to preprocess the gradients before the backward pass.", - "description_2": "Use triton language to create a fused attention operator with forward and backward passes, handling query, key, value tensors, and their gradients efficiently using block-wise operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #\n FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr #\n ):\n a_block_ptr = tl.make_block_ptr(\n base=a_ptr,\n shape=(M, K),\n strides=(stride_am, stride_ak),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_K),\n order=(1, 0),\n )\n b_block_ptr = tl.make_block_ptr(\n base=b_ptr,\n shape=(K, N),\n strides=(stride_bk, stride_bn),\n offsets=(0, 0),\n block_shape=(BLOCK_K, BLOCK_N),\n order=(0, 1),\n )\n a = tl.load(a_block_ptr)\n b = tl.load(b_block_ptr)\n\n c = tl.dot(a, b)\n\n if FLOAT16_OUTPUT:\n c = c.to(tl.float16)\n\n if USE_TMA_EPILOGUE:\n c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))\n tl.store(c_block_ptr, c)\n else:\n offs_m = tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn\n tl.store(c_ptrs, c)\n\n\ndef test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):\n if (TRANS_A):\n a = torch.randn((K, M), device='cuda', dtype=torch.float16).T\n else:\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n if (TRANS_B):\n b = torch.randn((N, K), device='cuda', dtype=torch.float16).T\n else:\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n\n if OUTPUT_TYPE == \"float16\":\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n else:\n c = torch.empty((M, N), device=a.device, dtype=torch.float32)\n\n matmul_no_scf_kernel[(1, 1)](\n a_ptr=a, b_ptr=b, c_ptr=c, #\n M=M, N=N, K=K, #\n stride_am=a.stride(0), stride_ak=a.stride(1), #\n stride_bk=b.stride(0), stride_bn=b.stride(1), #\n stride_cm=c.stride(0), stride_cn=c.stride(1), #\n BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #\n num_warps=NUM_WARPS, #\n num_ctas=NUM_CTAS, #\n FLOAT16_OUTPUT=(OUTPUT_TYPE == \"float16\"), #\n USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #\n enable_warp_specialization=ENABLE_WS)\n a_f32 = a.to(torch.float32)\n b_f32 = b.to(torch.float32)\n golden = torch.matmul(a_f32, b_f32)\n torch.set_printoptions(profile=\"full\")\n assert torch.allclose(c, golden, rtol=1e-2, atol=1e-3, equal_nan=True)\n\n\n@triton.jit\ndef matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_wm, stride_wn, #\n stride_zm, stride_zn, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, #\n out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #\n ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #\n DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #\n A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #\n B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #\n W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, #\n Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr #\n ):\n pid = tl.program_id(axis=0)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n block_offset_m = pid_m * BLOCK_M\n block_offset_n = pid_n * BLOCK_N\n\n a_tile_ptr = tl.make_block_ptr(\n base=a_ptr,\n shape=(M, K),\n strides=(stride_am, stride_ak),\n offsets=(block_offset_m, 0),\n block_shape=(BLOCK_M, BLOCK_K),\n order=(A_ORDER_0, A_ORDER_1),\n )\n b_tile_ptr = tl.make_block_ptr(\n base=b_ptr,\n shape=(K, N),\n strides=(stride_bk, stride_bn),\n offsets=(0, block_offset_n),\n block_shape=(BLOCK_K, BLOCK_N),\n order=(B_ORDER_0, B_ORDER_1),\n )\n # for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix\n w_tile_ptr = tl.make_block_ptr(\n base=w_ptr,\n shape=(N, N),\n strides=(stride_wm, stride_wn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_N),\n order=(W_ORDER_0, W_ORDER_1),\n )\n z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n offs_m = block_offset_m + tl.arange(0, BLOCK_M)\n offs_n = block_offset_n + tl.arange(0, BLOCK_N)\n z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn\n bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn\n mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]\n\n for k in range(0, K, BLOCK_K):\n a = tl.load(a_tile_ptr, boundary_check=(0, 1))\n b = tl.load(b_tile_ptr, boundary_check=(0, 1))\n z += tl.dot(a, b)\n a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])\n b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])\n\n z = z.to(out_dtype)\n\n if ADD_MATRIX:\n z += tl.load(bias_ptrs, mask=mask)\n if ADD_ROWS:\n ZRs = bias_ptr + offs_m * stride_zm\n z += tl.load(ZRs)[:, None]\n if ADD_COLS:\n ZCs = bias_ptr + offs_n * stride_zn\n z += tl.load(ZCs)[None, :]\n if DO_SOFTMAX:\n max = tl.max(z, 1)\n z = z - max[:, None]\n num = tl.exp(z.to(tl.float32)).to(max.dtype)\n den = tl.sum(num, 1)\n z = num / den[:, None]\n if CHAIN_DOT:\n w = tl.load(w_tile_ptr)\n z = tl.dot(z.to(w.dtype), w)\n z = z.to(out_dtype)\n\n if USE_TMA_STORE:\n z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),\n offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),\n order=(Z_ORDER_0, Z_ORDER_1))\n tl.store(z_block_ptr, z, boundary_check=(0, 1))\n else:\n tl.store(z_ptrs, z, mask=mask)\n\n\ndef test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,\n out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):\n if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [\n '16-32-64-4-4-512-256-64-True-False',\n '16-32-64-4-4-512-256-64-True-True',\n '16-32-64-4-4-512-256-64-False-False',\n '16-32-64-4-4-512-256-64-False-True',\n ]:\n return\n\n if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [\n '16-32-64-4-1-256-256-256-False',\n '16-32-64-4-2-256-256-256-False',\n '16-32-64-4-2-256-256-256-True',\n '16-32-64-8-2-256-256-256-False',\n '16-32-64-8-2-256-256-256-True',\n ]:\n return\n enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()\n if NUM_CTAS > 1 and enable_tma in [\"on\", \"true\", \"1\"]:\n return\n\n M = BLOCK_M if M is None else M\n N = BLOCK_N if N is None else N\n K = BLOCK_K if K is None else K\n\n if (TRANS_A):\n a = torch.randn((K, M), device='cuda', dtype=torch.float16).T\n a_order = [0, 1]\n else:\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n a_order = [1, 0]\n\n if (TRANS_B):\n b = torch.randn((N, K), device='cuda', dtype=torch.float16).T\n b_order = [0, 1]\n else:\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n b_order = [1, 0]\n\n if out_dtype == 'float16' and epilogue != 'softmax':\n out_dtype = tl.float16\n torch_out_dtype = torch.float16\n else:\n out_dtype = tl.float32\n torch_out_dtype = torch.float32\n\n if epilogue in ['add-matrix', 'add-rows', 'add-cols']:\n if (TRANS_OUTPUT):\n bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T\n else:\n bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)\n else:\n bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)\n\n w = torch.randn((N, N), device='cuda', dtype=torch.float16).T\n w_order = [0, 1]\n\n if (TRANS_OUTPUT):\n z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T\n z_order = [0, 1]\n else:\n z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)\n z_order = [1, 0]\n\n a_f32 = a.to(torch.float32)\n b_f32 = b.to(torch.float32)\n dot = torch.matmul(a_f32, b_f32)\n\n def process_epilogue(d, bias, w, epilogue):\n if epilogue == 'add-matrix':\n ref = d + bias\n elif epilogue == 'add-rows':\n ref = d + bias[:, 0][:, None]\n elif epilogue == 'add-cols':\n ref = d + bias[0, :][None, :]\n elif epilogue == 'softmax':\n num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])\n denom = torch.sum(num, dim=-1, keepdims=True)\n ref = num / denom\n elif epilogue == 'chain-dot':\n ref = torch.matmul(d, w.to(torch.float32))\n else:\n ref = d\n return ref\n\n golden = process_epilogue(dot, bias, w, epilogue)\n\n def grid(META):\n return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )\n\n pgm = matmul_kernel[grid](\n a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #\n M=M, N=N, K=K, #\n stride_am=a.stride(0), stride_ak=a.stride(1), #\n stride_bk=b.stride(0), stride_bn=b.stride(1), #\n stride_wm=w.stride(0), stride_wn=w.stride(1), #\n stride_zm=z.stride(0), stride_zn=z.stride(1), #\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #\n out_dtype=out_dtype, #\n USE_TMA_STORE=USE_TMA_STORE, #\n ADD_MATRIX=epilogue == 'add-matrix', #\n ADD_ROWS=epilogue == 'add-rows', #\n ADD_COLS=epilogue == 'add-cols', #\n DO_SOFTMAX=epilogue == 'softmax', #\n CHAIN_DOT=epilogue == 'chain-dot', #\n A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #\n B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #\n W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #\n Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #\n num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #\n enable_warp_specialization=ENABLE_WS)\n\n torch.set_printoptions(profile=\"full\")\n golden = torch.nn.functional.normalize(golden)\n z = torch.nn.functional.normalize(z)\n assert torch.allclose(z, golden, rtol=1e-2, atol=1e-3, equal_nan=True)\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: (1) 'matmul_no_scf_kernel' takes 18 parameters including pointers to input matrices A and B and output matrix C, their dimensions M, N, K, respective strides, block dimensions, and flags for output type and epilogue usage; performs matrix multiplication with optional post-processing steps. (2) 'matmul_kernel' handles more complex operations with additional parameters for bias, auxiliary matrix W, matrix orders, group sizes, and multiple epilogues; efficiently calculates Z as a result with adjustable stages and warp specialization. Both kernels are invoked with respective grid launch configurations and helper functions that setup required input conditions.", - "description_2": "Use triton language to define matrix multiplication kernels with support for advanced epilogues and matrix layouts, optimizing computation with configurable grid and warp settings.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef gemm_fusion_kernel(A, B, C, E, #\n M, N, K, #\n stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):\n pid = tl.program_id(0)\n\n a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))\n c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))\n e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n\n acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)\n a = tl.load(a_tile_ptr)\n for i in range(0, N, BLOCK_N):\n b = tl.load(b_tile_ptr)\n o_ab = tl.dot(a, tl.trans(b))\n c = tl.load(c_tile_ptr)\n o_ab = o_ab.to(tl.float16)\n acc_e += tl.dot(o_ab, c)\n b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_N, 0])\n c_tile_ptr = tl.advance(c_tile_ptr, [BLOCK_N, 0])\n\n acc_e = acc_e.to(tl.float16)\n tl.store(e_tile_ptr, acc_e)\n\ndef test_gemm_fusion():\n M, N, K = 4096, 4096, 64\n BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64\n A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)\n B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)\n C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)\n E = torch.empty((M, K), dtype=torch.float16, device='cuda')\n ref_out = torch.matmul(torch.matmul(A, B.T), C)\n num_warps = 4\n grid = (triton.cdiv(M, BLOCK_M), 1)\n gemm_fusion_kernel[grid](\n A, B, C, E, M, N, K, #\n A.stride(0), A.stride(1), #\n B.stride(0), B.stride(1), #\n C.stride(0), C.stride(1), #\n E.stride(0), E.stride(1), #\n BLOCK_M, BLOCK_N, BLOCK_K, #\n num_warps=num_warps)\n\n torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)\n\n@triton.jit\ndef batched_gemm_fusion(Q, K, V, Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, NH, N_CTX, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n q_tile_ptr = tl.make_block_ptr(\n base=Q,\n shape=(Z, NH, N_CTX, BLOCK_DMODEL),\n strides=(stride_qz, stride_qh, stride_qm, stride_qk),\n offsets=(off_hz // NH, off_hz % NH, start_m, 0),\n block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),\n order=(3, 2, 1, 0),\n )\n k_tile_ptr = tl.make_block_ptr(\n base=K,\n shape=(Z, NH, N_CTX, BLOCK_DMODEL),\n strides=(stride_kz, stride_kh, stride_kn, stride_kk),\n offsets=(off_hz // NH, off_hz % NH, 0, 0),\n block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),\n order=(3, 2, 1, 0),\n )\n v_tile_ptr = tl.make_block_ptr(\n base=V,\n shape=(Z, NH, N_CTX, BLOCK_DMODEL),\n strides=(stride_vz, stride_vh, stride_vk, stride_vn),\n offsets=(off_hz // NH, off_hz % NH, 0, 0),\n block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),\n order=(3, 2, 1, 0),\n )\n o_tile_ptr = tl.make_block_ptr(\n base=Out,\n shape=(Z, NH, N_CTX, BLOCK_DMODEL),\n strides=(stride_oz, stride_oh, stride_om, stride_on),\n offsets=(off_hz // NH, off_hz % NH, start_m, 0),\n block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),\n order=(3, 2, 1, 0),\n )\n\n q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))\n q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))\n for i in range(0, N_CTX, BLOCK_N):\n k = tl.load(k_tile_ptr, boundary_check=(0, 1, 2, 3))\n k = tl.view(k, (BLOCK_N, BLOCK_DMODEL))\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n\n p = qk.to(tl.float16)\n v = tl.load(v_tile_ptr, boundary_check=(0, 1, 2, 3))\n v = tl.view(v, (BLOCK_N, BLOCK_DMODEL))\n acc += tl.dot(p, v)\n\n k_tile_ptr = tl.advance(k_tile_ptr, [0, 0, BLOCK_N, 0])\n v_tile_ptr = tl.advance(v_tile_ptr, [0, 0, BLOCK_N, 0])\n\n acc = tl.view(acc, (1, 1, BLOCK_M, BLOCK_DMODEL))\n acc = acc.to(tl.float16)\n tl.store(o_tile_ptr, acc)\n\ndef test_batched_gemm_fusion():\n Z = 4\n NH = 48\n H = 64\n N_CTX = 2048\n BLOCK_M, BLOCK_N, BLOCK_DMODEL = 128, 128, H\n torch.manual_seed(20)\n A = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)\n B = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)\n C = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)\n E = torch.empty_like(A)\n BT = B.transpose(-1, -2)\n ref_out = torch.matmul(torch.matmul(A, BT), C)\n num_warps = 4\n grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)\n batched_gemm_fusion[grid](\n A, B, C, E, #\n A.stride(0), A.stride(1), A.stride(2), A.stride(3), #\n B.stride(0), B.stride(1), B.stride(2), B.stride(3), #\n C.stride(0), C.stride(1), C.stride(2), C.stride(3), #\n E.stride(0), E.stride(1), E.stride(2), E.stride(3), #\n Z, NH, N_CTX, #\n BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)\n\n torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)\n", - "description_1": "Use triton language to implement two kernels: 'gemm_fusion_kernel' and 'batched_gemm_fusion'. The 'gemm_fusion_kernel' takes 17 parameters: A, B, C, E (input matrices), M, N, K (dimensions), stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek (strides), and BLOCK_M, BLOCK_N, BLOCK_K (block sizes). It performs a fused matrix multiplication and accumulation operation. The 'batched_gemm_fusion' kernel takes 22 parameters: Q, K, V, Out (input matrices), stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on (strides), Z, NH, N_CTX (dimensions), and BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes). It performs a batched matrix multiplication and accumulation operation.", - "description_2": "Use triton language to implement two kernels for matrix operations: one for fused matrix multiplication and accumulation, and another for batched matrix multiplication and accumulation, each with specific input matrices, dimensions, strides, and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing import assert_close\n\n# Kernel to add two vectors\n@triton.jit\ndef add_kernel(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),\n block_shape=(BLOCK_SIZE, ), order=(0, ))\n x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')\n\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Test function for add_kernel\ndef test_add(SIZE, BLOCK_SIZE, dtype_str):\n dtype_mapping = {\n 'float16': torch.float16,\n 'float32': torch.float32,\n }\n dtype = dtype_mapping[dtype_str]\n output = torch.empty(SIZE, device='cuda', dtype=dtype)\n x = torch.randn(SIZE, device='cuda', dtype=dtype)\n y = torch.randn(SIZE, device='cuda', dtype=dtype)\n\n def grid(meta):\n return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), )\n\n add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)\n\n output_torch = x + y\n torch.set_printoptions(profile='full')\n assert_close(output, output_torch, rtol=1e-2, atol=1e-3, check_dtype=False)\n\n# Kernel to load and reduce a matrix\n@triton.jit\ndef load_reduce_kernel(\n x_ptr,\n y_ptr,\n stride_xm,\n stride_xn,\n stride_y,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))\n x = tl.load(x_ptr)\n y = tl.max(x, axis=1)\n tl.store(y_ptr + tl.arange(0, BLOCK_M), y)\n\n# Test function for load_reduce_kernel\ndef test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):\n dtype_mapping = {\n 'float16': torch.float16,\n }\n dtype = dtype_mapping[dtype_str]\n x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)\n y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)\n\n load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)\n\n golden = x.max(dim=1)[0]\n torch.set_printoptions(profile='full')\n assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False)\n", - "description_1": "Use triton language to implement two kernels: one for element-wise addition of two vectors and another for loading a matrix and reducing it along the rows. The add_kernel takes five parameters: pointers to input vectors x and y, a pointer to the output vector, the number of elements, and a block size. The load_reduce_kernel takes seven parameters: pointers to input matrix x and output vector y, strides for x and y, and block sizes for the matrix dimensions.", - "description_2": "Use triton language to create a vector addition kernel and a matrix row reduction kernel, each with specific parameters for data pointers, strides, and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(Q, K, V, sm_scale, #\n L, M, #\n Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, N_CTX, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr #\n ):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n q = tl.load(q_ptrs)\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n l_prev *= tl.exp(m_prev - m_curr)\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n p = p.to(Q.dtype.element_ty)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n l_prev = l_curr\n m_prev = m_curr\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vk\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(Out, DO, L, #\n NewDO, Delta, #\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #\n ):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(Q, K, V, sm_scale, Out, DO, #\n DQ, DK, DV, #\n L, M, #\n D, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n Z, H, N_CTX, #\n num_block, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n ):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n q = tl.load(q_ptrs)\n qk = tl.dot(q, tl.trans(k))\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n do = tl.load(do_ptrs)\n dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, tl.trans(v))\n ds = p * dp * sm_scale\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)\n dq = tl.load(dq_ptrs)\n dq += tl.dot(ds.to(Q.dtype.element_ty), k)\n tl.store(dq_ptrs, dq)\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n assert num_warps == 4\n\n _fwd_kernel[grid](\n q, k, v, sm_scale, #\n L, m, #\n o, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], q.shape[2], #\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, #\n BLOCK_DMODEL=Lk #\n )\n\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n BLOCK = 128\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l, #\n do_scaled, delta, #\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)\n _bwd_kernel[(ctx.grid[1], )](\n q, k, v, ctx.sm_scale, #\n o, do_scaled, #\n dq, dk, dv, #\n l, m, #\n delta, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n q.shape[0], q.shape[1], q.shape[2], #\n ctx.grid[0], #\n BLOCK_M=BLOCK, BLOCK_N=BLOCK, #\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, #\n num_warps=8, num_stages=1 #\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to define three kernels and an attention function: '_fwd_kernel' with 27 parameters for forward pass computation using blocks of dimensions, '_bwd_preprocess' with 5 parameters to preprocess data for the backward pass, and '_bwd_kernel' with 32 parameters for backward pass involving gradient calculations. The '_attention' function implements the autograd function with forward and backward methods utilizing these kernels for attention mechanism computations.", - "description_2": "Use triton language to implement a fused attention mechanism with forward and backward kernels for efficient GPU computation, ensuring proper input and output tensor handling and stride management.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef static_persistent_matmul_kernel( #\n a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #\n NUM_SM: tl.constexpr #\n):\n start_tile = tl.program_id(axis=0)\n m_tiles = tl.cdiv(M, BLOCK_M)\n n_tiles = tl.cdiv(N, BLOCK_N)\n num_tiles = m_tiles * n_tiles\n offs_k = tl.arange(0, BLOCK_K)\n\n for tile_id in range(start_tile, num_tiles, NUM_SM):\n pid_m = tile_id // n_tiles\n pid_n = tile_id % n_tiles\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n for k in range(0, K, BLOCK_K):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_K * stride_ak\n b_ptrs += BLOCK_K * stride_bk\n\n offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n\n c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn\n tl.store(c_ptrs, accumulator)\n\n\n@triton.jit\ndef static_persistent_tma_matmul_kernel( #\n a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #\n NUM_SM: tl.constexpr #\n):\n start_tile = tl.program_id(axis=0)\n m_tiles = tl.cdiv(M, BLOCK_M)\n n_tiles = tl.cdiv(N, BLOCK_N)\n k_tiles = tl.cdiv(K, BLOCK_K)\n num_tiles = m_tiles * n_tiles\n\n pre_pid_m = start_tile // n_tiles\n pre_pid_n = start_tile % n_tiles\n\n block_offset_m = pre_pid_m * BLOCK_M\n block_offset_n = pre_pid_n * BLOCK_N\n a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),\n offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),\n offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))\n for tile_id in range(start_tile, num_tiles, NUM_SM):\n pid_m = tile_id // n_tiles\n pid_n = tile_id % n_tiles\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n if tile_id >= NUM_SM:\n a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])\n b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])\n\n for k in range(0, K, BLOCK_K):\n a = tl.load(a_tile_ptr)\n b = tl.load(b_tile_ptr)\n accumulator += tl.dot(a, b)\n a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])\n b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])\n\n offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n\n c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn\n tl.store(c_ptrs, accumulator)\n pre_pid_m = pid_m\n pre_pid_n = pid_n\n\n\ndef test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS,\n TRANS_A, TRANS_B, USE_TMA):\n if (TRANS_A):\n a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T\n else:\n a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)\n\n if (TRANS_B):\n b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T\n else:\n b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)\n c = torch.empty((M, N), device=a.device, dtype=torch.float32)\n\n num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count\n grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )\n\n if USE_TMA:\n static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),\n stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),\n stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,\n num_ctas=NUM_CTAS)\n else:\n static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),\n stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),\n stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,\n num_ctas=NUM_CTAS)\n\n th_c = torch.matmul(a, b)\n torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)\n", - "description_1": "Use triton language to implement two matrix multiplication kernels, static_persistent_matmul_kernel and static_persistent_tma_matmul_kernel, which take 18 parameters including pointers to input matrices, dimensions M, N, K, stride values, block sizes, and a constant NUM_SM for GPU hardware-specific value. The kernels perform matrix multiplication using tiling to optimize for GPU execution. The function test_user_defined_persistent_non_warp_specialized_gemm calls these kernels based on a USE_TMA flag, allocating input and output matrices on CUDA, calculating strides, and validating results against PyTorch's matmul.", - "description_2": "Use triton language to perform tiled matrix multiplication on GPU using triton.jit decorated kernels, handling different strides and block sizes with support for hardware-specific execution paths based on configuration flags.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing import assert_close\n\n# Triton kernel for matrix multiplication using TMA load/store\n@triton.jit\ndef matmul_tma_load_store(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n OUTPUT_F16: tl.constexpr\n):\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),\n block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))\n c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))\n a = tl.load(a_block_ptr)\n b = tl.load(b_block_ptr)\n\n c = tl.dot(a, b)\n if OUTPUT_F16:\n c = c.to(tl.float16)\n\n tl.store(c_block_ptr, c)\n\n\n# Function to test the Triton kernel\ndef test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16):\n if TRANS_A:\n a = torch.randn((K, M), device='cuda', dtype=torch.float16).T\n else:\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n if TRANS_B:\n b = torch.randn((N, K), device='cuda', dtype=torch.float16).T\n else:\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n\n c = torch.empty((M, N), device=a.device, dtype=torch.float32)\n if OUTPUT_F16:\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n\n matmul_tma_load_store[(1, 1)](\n a_ptr=a, b_ptr=b, c_ptr=c,\n M=M, N=N, K=K,\n stride_am=a.stride(0), stride_ak=a.stride(1),\n stride_bk=b.stride(0), stride_bn=b.stride(1),\n stride_cm=c.stride(0), stride_cn=c.stride(1),\n BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,\n num_warps=NUM_WARPS, num_ctas=NUM_CTAS,\n OUTPUT_F16=OUTPUT_F16)\n\n golden = torch.matmul(a, b)\n torch.set_printoptions(profile=\"full\")\n assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)\n", - "description_1": "Use triton language to create a kernel `matmul_tma_load_store` for matrix multiplication. The kernel requires 14 regular parameters: a_ptr, b_ptr, c_ptr (pointers to matrices), M, N, K (dimensions), stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn (stride information); and 4 constexpr parameters: BLOCK_M, BLOCK_N, BLOCK_K (block sizes), OUTPUT_F16 (output precision control). The kernel uses block pointers and performs a matrix multiplication using `tl.dot`. If OUTPUT_F16 is true, it converts the result to float16 before storing. A testing function `test_tma_load_store` is provided to validate the kernel with varying dimensions and configurations using PyTorch for comparison.", - "description_2": "Use triton language to implement a matrix multiplication kernel `matmul_tma_load_store` with block pointer loading/storing and optional float16 output. A test function validates its correctness against PyTorch's matmul.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing import assert_close\n\n@triton.jit\ndef kernel_device_assert(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_assert(x == 0, \"x != 0\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_assert_passes(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_assert(0 == 0, \"x != 0\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit(debug=False)\ndef kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_assert(x == 0, \"x != 0\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_assert(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n assert x == 0, \"x != 0\"\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_static_assert(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.static_assert(BLOCK == 128, \"BLOCK != 128\")\n tl.store(Y + tl.arange(0, BLOCK), x)\n\ndef test_assert(func: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if func == \"device_assert\":\n kernel_device_assert[(1, )](x, y, BLOCK=shape[0])\n if func == \"device_assert_passes\":\n kernel_assert_passes[(1, )](x, y, BLOCK=shape[0])\n elif func == \"no_debug\":\n kernel_device_assert_no_debug[(1, )](x, y, BLOCK=shape[0])\n elif func == \"assert\":\n kernel_assert[(1, )](x, y, BLOCK=shape[0])\n elif func == \"static_assert\":\n kernel_static_assert[(1, )](x, y, BLOCK=shape[0])\n elif func == \"double_assert\":\n kernel_device_assert[(1, )](x, y, BLOCK=shape[0])\n kernel_assert_passes[(1, )](x, y, BLOCK=shape[0])\n assert_close(y, x)\n\n@triton.jit\ndef jit_device_assert_none(x):\n tl.device_assert(x == 0, \"x != 0\")\n\n@triton.jit(debug=True)\ndef jit_device_assert_true(x):\n tl.device_assert(x == 0, \"x != 0\")\n\n@triton.jit(debug=False)\ndef jit_device_assert_false(x):\n tl.device_assert(x == 0, \"x != 0\")\n\n@triton.jit\ndef kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n if jit_debug == \"true\":\n jit_device_assert_true(x)\n elif jit_debug == \"false\":\n jit_device_assert_false(x)\n else:\n jit_device_assert_none(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit(debug=True)\ndef kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n if jit_debug == \"true\":\n jit_device_assert_true(x)\n elif jit_debug == \"false\":\n jit_device_assert_false(x)\n else:\n jit_device_assert_none(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit(debug=False)\ndef kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n if jit_debug == \"true\":\n jit_device_assert_true(x)\n elif jit_debug == \"false\":\n jit_device_assert_false(x)\n else:\n jit_device_assert_none(x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\ndef test_assert_nested(caller: str, callee: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if caller == \"none\":\n kernel_device_assert_nested[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)\n elif caller == \"true\":\n kernel_device_assert_nested_true[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)\n elif caller == \"false\":\n kernel_device_assert_nested_false[(1, )](x, y, BLOCK=shape[0], jit_debug=callee)\n assert_close(y, x)\n", - "description_1": "Use triton language to define multiple kernels that perform device assertions and store results. Each kernel takes three parameters: X (input tensor), Y (output tensor), and BLOCK (block size). The kernels perform various assertions on the input data and store the results in the output tensor. Additionally, there are nested kernels that call other kernels based on a debug flag.", - "description_2": "Use triton language to create kernels for device assertions with input and output tensors, and handle nested kernel calls based on debug flags.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport uuid\nfrom torch.testing import assert_close\n\n@triton.jit\ndef kernel_device_print(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.device_print(\"x: \", x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_print(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n # Triton should add a space after this prefix.\n print(\"x:\", x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_device_print_large(\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)\n # Triton should change this prefix to \"x: \".\n tl.device_print(\"x \", x)\n\n@triton.jit\ndef kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.full((BLOCK, ), 1, tl.int32)\n print(\"\", x, y)\n\n@triton.jit\ndef kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.full((BLOCK, ), 1, tl.int32)\n tl.device_print(\"\", x, y)\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n@triton.jit\ndef kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):\n # This function takes an extra value as a tl.constexpr so this kernel is not\n # cached. This way the static print is run every time.\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.static_print(\"\", x)\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n@triton.jit\ndef kernel_no_arg_print():\n print(\"\", tl.program_id(0))\n\n@triton.jit\ndef kernel_print_no_arg():\n print(\"no arg\")\n\ndef test_print(func: str, data_type: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if func == \"device_print\":\n kernel_device_print[(1, )](x, y, BLOCK=shape[0])\n elif func == \"print\":\n kernel_print[(1, )](x, y, BLOCK=shape[0])\n elif func == \"device_print_large\":\n kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)\n elif func == \"print_multiple_args\":\n kernel_print_multiple_args[(1, )](x, y, BLOCK=shape[0])\n elif func == \"device_print_multiple_args\":\n kernel_device_print_multiple_args[(1, )](x, y, BLOCK=shape[0])\n elif func == \"static_print\":\n kernel_static_print[(1, )](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())\n elif func == \"no_arg_print\":\n kernel_no_arg_print[(1, )](num_warps=4)\n elif func == \"print_no_arg\":\n kernel_print_no_arg[(1, )](num_warps=4)\n else:\n assert f\"Unknown kernel: {func}\"\n\n if func != \"print_no_arg\" and func != \"no_arg_print\" and func != \"device_print_large\" and \\\n func != \"print_multiple_args\" and func != \"device_print_multiple_args\":\n assert_close(y, x)\n\nif __name__ == \"__main__\":\n test_print(sys.argv[1], sys.argv[2])\n", - "description_1": "Use triton language to define multiple kernels that perform operations such as device printing and storing data. Each kernel takes a varying number of arguments depending on its functionality. The test function orchestrates the execution of these kernels based on string input to match kernel names, manages data initialization using PyTorch, and ensures results are as expected with assert_close.", - "description_2": "Use triton language to create kernels for printing and manipulating arrays, with a Python test function to execute them based on input.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef test_annotations(device):\n\n @triton.jit\n def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):\n pass\n\n x = torch.empty(1, device=device)\n _kernel[(1, )](x, x.shape[0], 32)\n try:\n _kernel[(1, )](x.shape[0], x.shape[0], 32)\n except AttributeError:\n pass\n", - "description_1": "Use triton language to define a kernel function '_kernel' that takes three parameters: X (a torch.Tensor), N (an integer), and BLOCK_SIZE (a triton constexpr). The kernel is called with a 1D grid of size 1, passing a tensor 'x', its size, and a block size of 32. The kernel is also tested with incorrect parameters to handle an AttributeError.", - "description_2": "Use triton language to define a kernel with a tensor, an integer, and a constexpr as parameters, and call it with a 1D grid.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):\n pid = tl.program_id(0)\n # We only copy half of the data to see if the padding works\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),\n block_shape=(BLOCK_SIZE, ), order=(0, ))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),\n block_shape=(BLOCK_SIZE, ), order=(0, ))\n a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)\n tl.store(b_block_ptr, a, boundary_check=(0, ))\n\ndef test_block_copy(dtype_str, n, padding_option):\n dtype = getattr(torch, dtype_str)\n if dtype_str in (\"bool\", \"int16\"):\n a = torch.randint(0, 2, (n, ), device=\"cuda\", dtype=dtype)\n else:\n a = torch.randn((n, ), device=\"cuda\", dtype=dtype)\n b = torch.zeros((n, ), device=\"cuda\", dtype=dtype)\n\n grid = lambda meta: (triton.cdiv(n, meta[\"BLOCK_SIZE\"]), )\n block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)\n\n@triton.jit\ndef matmul_no_scf_with_advance_kernel( #\n a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr #\n):\n offs_m = tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),\n block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))\n # Below two lines are just for testing negative offsets for the `advance` API, which could be removed\n a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))\n a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))\n a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option=\"zero\")\n b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option=\"zero\")\n\n c = tl.dot(a, b)\n c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn\n tl.store(c_ptrs, c)\n\ndef test_block_ptr_matmul_no_scf(shape, num_warps):\n m, n, k = shape\n a = torch.randn((m, k), device=\"cuda\", dtype=torch.float16)\n b = torch.randn((k, n), device=\"cuda\", dtype=torch.float16)\n c = torch.empty((m, n), device=\"cuda\", dtype=torch.float32)\n\n grid = lambda META: (1, )\n matmul_no_scf_with_advance_kernel[grid](\n a_ptr=a, b_ptr=b, c_ptr=c, #\n M=m, N=n, K=k, #\n stride_am=a.stride(0), stride_ak=a.stride(1), #\n stride_bk=b.stride(0), stride_bn=b.stride(1), #\n stride_cm=c.stride(0), stride_cn=c.stride(1), #\n BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, #\n num_warps=num_warps)\n", - "description_1": "Use triton language to implement two kernels: block_copy_kernel and matmul_no_scf_with_advance_kernel. The block_copy_kernel copies half of the data from a_ptr to b_ptr with padding options, using parameters: a_ptr (source pointer), b_ptr (destination pointer), N (total elements), BLOCK_SIZE (block size), and padding_option (padding type). The matmul_no_scf_with_advance_kernel performs matrix multiplication with parameters: a_ptr (matrix A pointer), b_ptr (matrix B pointer), c_ptr (matrix C pointer), M (rows of A and C), N (columns of B and C), K (columns of A and rows of B), stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn (strides for matrices), BLOCK_M, BLOCK_N, BLOCK_K (block sizes for matrices).", - "description_2": "Use triton language to create a kernel for copying data with padding and another for matrix multiplication without using the SCF dialect, utilizing block pointers and strides.", - "difficulty": 3 - }, - { - "code": "import pytest\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton.runtime.jit import reinterpret\n\n\n@pytest.mark.parametrize(\"dtype_x\", ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',\n 'float16', 'float32', 'float64', 'bfloat16'])\ndef test_empty_kernel(dtype_x, device):\n SIZE = 128\n\n @triton.jit\n def kernel(X, SIZE: tl.constexpr):\n pass\n\n x = torch.randint(0, 127, (SIZE,), dtype=getattr(torch, dtype_x), device=device)\n kernel[(1,)](x, SIZE=SIZE, num_warps=4)\n\n\n@pytest.mark.parametrize(\"dtype_x\", ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',\n 'float16', 'float32', 'float64', 'bfloat16'])\n@pytest.mark.parametrize(\"expr\", ['x', 'x+1', 'x-1'])\ndef test_unary_op(dtype_x, expr, device):\n SIZE = 128\n\n @triton.jit\n def kernel(Z, X, SIZE: tl.constexpr):\n off = tl.arange(0, SIZE)\n x = tl.load(X + off)\n z = GENERATE_TEST_HERE\n tl.store(Z + off, z)\n\n kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})\n x = torch.randint(0, 127, (SIZE,), dtype=getattr(torch, dtype_x), device=device)\n z = torch.empty_like(x)\n kernel_patched[(1,)](z, x, SIZE=SIZE, num_warps=4)\n np.testing.assert_allclose(eval(expr), z.cpu().numpy(), rtol=0.01)\n\n\n@pytest.mark.parametrize(\"op\", ['+', '-', '*', '/', '%'])\n@pytest.mark.parametrize(\"dtype_x, dtype_y\", [('int8', 'int8'), ('int16', 'int16'), ('int32', 'int32'), ('int64', 'int64'),\n ('uint8', 'uint8'), ('uint16', 'uint16'), ('uint32', 'uint32'), ('uint64', 'uint64'),\n ('float16', 'float16'), ('float32', 'float32'), ('float64', 'float64')])\ndef test_bin_op(dtype_x, dtype_y, op, device):\n expr = f' x {op} y'\n SIZE = 128\n\n @triton.jit\n def kernel(Z, X, Y, SIZE: tl.constexpr):\n off = tl.arange(0, SIZE)\n x = tl.load(X + off)\n y = tl.load(Y + off)\n z = GENERATE_TEST_HERE\n tl.store(Z + off, z)\n\n kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})\n x = torch.randint(0, 127, (SIZE,), dtype=getattr(torch, dtype_x), device=device)\n y = torch.randint(0, 127, (SIZE,), dtype=getattr(torch, dtype_y), device=device)\n z = torch.empty_like(x)\n kernel[(1,)](z, x, y, SIZE=SIZE, num_warps=4)\n np.testing.assert_allclose(eval(expr), z.cpu().numpy(), rtol=0.01)\n\n\n@pytest.mark.parametrize(\"dtype\", ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',\n 'float16', 'float32', 'float64', 'bfloat16'])\ndef test_arange(dtype, device):\n BLOCK = 128\n\n @triton.jit\n def kernel(X, N: tl.constexpr):\n off = tl.arange(0, BLOCK)\n tl.store(X + off, off)\n\n x = torch.empty(BLOCK, dtype=getattr(torch, dtype), device=device)\n kernel[(1,)](x, N=BLOCK)\n np.testing.assert_allclose(x.cpu().numpy(), np.arange(0, BLOCK))\n\n\ndef patch_kernel(kernel, to_replace):\n kernel = triton.JITFunction(kernel.fn)\n for key, value in to_replace.items():\n kernel.src = kernel.src.replace(key, value)\n return kernel\n", - "description_1": "Implement triton kernels to perform unary, binary operations and evaluation of arange for tensors with various data types.", - "description_2": "1. Implement triton kernels to perform unary and binary operations on tensors with various data types and evaluate the result with numpy. 2. Implement triton kernels to perform a range of values and identity mapping on tensors using various data types.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel that loads data from X and stores it in Y\n@triton.jit\ndef kernel_single(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n tl.store(Y + tl.arange(0, BLOCK), x)\n\n# Kernel that calls an inline device function\n@triton.jit\ndef device_inline(x):\n return x + x\n\n@triton.jit\ndef kernel_call(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = device_inline(x)\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n# Kernel that calls a noinline device function\n@triton.jit(noinline=True)\ndef device_noinline(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = x + x\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n@triton.jit\ndef kernel_call_noinline(X, Y, BLOCK: tl.constexpr):\n device_noinline(X, Y, BLOCK)\n\n# Kernel that applies softmax to the loaded data\n@triton.jit\ndef kernel_multi_files(X, Y, BLOCK: tl.constexpr):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.softmax(x)\n tl.store(Y + tl.arange(0, BLOCK), y)\n\n# Autotuned kernel\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK\": 128}, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr):\n for i in range(0, SIZE, BLOCK):\n x = tl.load(X + i + tl.arange(0, BLOCK))\n tl.store(Y + i + tl.arange(0, BLOCK), x)\n\n# Test function to call the kernels\ndef test_line_info(func: str):\n shape = (128, )\n x = torch.arange(0, shape[0], dtype=torch.float32, device='cuda')\n y = torch.zeros(shape, dtype=x.dtype, device=\"cuda\")\n if func == \"single\":\n kernel_single[(1,)](x, y, BLOCK=shape[0])\n elif func == \"call\":\n kernel_call[(1,)](x, y, BLOCK=shape[0])\n elif func == \"call_noinline\":\n kernel_call_noinline[(1,)](x, y, BLOCK=shape[0])\n elif func == \"multi_files\":\n kernel_multi_files[(1,)](x, y, BLOCK=shape[0])\n elif func == \"autotune\":\n kernel_autotune[(1,)](x, y, SIZE=shape[0])\n", - "description_1": "Use triton language to define multiple kernels: 'kernel_single' loads data from input X and stores it in output Y using a block size; 'device_inline' is an inline function that doubles the input; 'kernel_call' uses 'device_inline' to process data; 'device_noinline' is a noinline function that doubles the input; 'kernel_call_noinline' calls 'device_noinline'; 'kernel_multi_files' applies softmax to the input data; 'kernel_autotune' is an autotuned kernel that processes data in blocks. Each kernel is called in 'test_line_info' function based on the input string.", - "description_2": "Use triton language to define and call kernels for data loading, processing with inline and noinline functions, applying softmax, and autotuning with block processing.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport numpy as np\nimport scipy.stats\n\nBLOCK = 1024\n\n# Kernel for generating random uint32\n@triton.jit\ndef kernel_randint(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.randint(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for generating uniform random numbers\n@triton.jit\ndef kernel_rand(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.rand(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel for generating normal random numbers\n@triton.jit\ndef kernel_randn(X, N, seed):\n offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n rand = tl.randn(seed, offset)\n tl.store(X + offset, rand, mask=offset < N)\n\n# Kernel to test rand limits\n@triton.jit\ndef kernel_rand_limits(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = tl.random.uint32_to_uniform_float(x)\n tl.store(output + idx, y)\n\n# Test function for random uint32 generation\ndef test_randint(size, seed, device):\n size = list(map(int, size.split(',')))\n x = torch.empty(size, dtype=torch.int32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK), )\n kernel_randint[grid](x, N, seed)\n out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()\n gen = CustomPhilox4x(seed, config=PHILOX_32)\n out_ref = [gen.random_raw()[0] for _ in out_tri]\n assert out_tri == out_ref\n\n# Test function for uniform PRNG\ndef test_rand(size, seed, device):\n x = torch.empty(size, dtype=torch.float32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK), )\n kernel_rand[grid](x, N, seed)\n assert all((x >= 0) & (x <= 1))\n assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01\n\n# Test function for normal PRNG\ndef test_randn(size, seed, device):\n x = torch.empty(size, dtype=torch.float32, device=device)\n N = x.numel()\n grid = (triton.cdiv(N, BLOCK), )\n kernel_randn[grid](x, N, seed)\n assert abs(x.mean()) < 1e-2\n assert abs(x.std() - 1) < 1e-2\n\n# Test function for rand limits\ndef test_rand_limits(device):\n min_max_int32 = torch.tensor([\n torch.iinfo(torch.int32).min,\n torch.iinfo(torch.int32).max,\n ], dtype=torch.int32, device=device)\n output = torch.empty(2, dtype=torch.float32, device=device)\n kernel_rand_limits[(1, )](min_max_int32, output, 2)\n assert output[0] == output[1]\n assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0\n", - "description_1": "Use triton language to implement kernels for generating random numbers. The 'kernel_randint' function generates random uint32 numbers, taking three parameters: X (output tensor), N (number of elements), and seed (random seed). The 'kernel_rand' function generates uniform random numbers, with the same parameters. The 'kernel_randn' function generates normal random numbers, also with the same parameters. The 'kernel_rand_limits' function tests the limits of random number generation, taking three parameters: input (input tensor), output (output tensor), and n (number of elements as a constant expression).", - "description_2": "Use triton language to create kernels for random number generation, including uint32, uniform, and normal distributions, and test the limits of these random numbers.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel for normalization with rematerialization\n@triton.jit\ndef triton_normalization(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 512\n rnumel = 4096\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x3 = xindex\n x0 = xindex % 64\n tmp1 = tl.load(in_ptr0 + (x0), xmask)\n tmp3 = tl.load(in_ptr1 + (x0), xmask)\n tmp11 = tl.load(in_ptr2 + (x0), xmask)\n tmp13 = tl.load(in_ptr3 + (x0), xmask)\n _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r2 = rindex\n tmp0 = tl.load(in_out_ptr0 + (r2 + (4096 * x3)), rmask & xmask, eviction_policy='evict_last', other=0)\n tmp2 = tmp0 - tmp1\n tmp4 = 1e-05\n tmp5 = tmp3 + tmp4\n tmp6 = tl.sqrt(tmp5)\n tmp7 = 1 / tmp6\n tmp8 = 1.0\n tmp9 = tmp7 * tmp8\n tmp10 = tmp2 * tmp9\n tmp12 = tmp10 * tmp11\n tmp14 = tmp12 + tmp13\n _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17)\n tl.store(in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp14, rmask & xmask)\n tmp17 = tl.sum(_tmp17, 1)[:, None]\n tmp18 = 4096.0\n tmp19 = tmp17 / tmp18\n tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask)\n\n# Call the normalization kernel\ntorch.manual_seed(123)\nbuf14 = torch.rand(8, 64, 64, 64, device=\"cuda\")\nbuf16 = torch.rand(8, 1, 64, device=\"cuda\")\narg114_1 = torch.rand(64, device=\"cuda\")\narg115_1 = torch.rand(64, device=\"cuda\")\narg8_1 = torch.rand(64, device=\"cuda\")\narg9_1 = torch.rand(64, device=\"cuda\")\ntriton_normalization[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)\ntorch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)\n\n# Kernel for average pooling backward\n@triton.jit\ndef triton_avg_pool_bw(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n x1 = (xindex // 8) % 8\n x0 = xindex % 8\n x2 = (xindex // 64)\n x5 = xindex\n tmp0 = (-1) + x1\n tmp1 = (-1) + x0\n tmp2 = 2 + x1\n tmp3 = 2 + x0\n tmp4 = 0\n tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4))\n tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4))\n tmp7 = 8\n tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7))\n tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7))\n tmp10 = tmp5 + tmp4\n tmp11 = tmp6 + tmp4\n tmp12 = 1\n tmp13 = tmp8 - tmp12\n tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13))\n tmp15 = tmp9 - tmp12\n tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15))\n tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to(tl.float32)\n tmp18 = tmp17 / 9\n tmp19 = tmp10 < tmp8\n tmp20 = tmp11 < tmp9\n tmp21 = tmp19 & tmp20\n tmp22 = 0.0\n tmp23 = tl.where(tmp21, tmp18, tmp22)\n tmp24 = tmp6 + tmp12\n tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15))\n tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to(tl.float32)\n tmp27 = tmp26 / 9\n tmp28 = tmp24 < tmp9\n tmp29 = tmp19 & tmp28\n tmp30 = tmp23 + tmp27\n tmp31 = tl.where(tmp29, tmp30, tmp23)\n tmp32 = 2\n tmp33 = tmp6 + tmp32\n tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15))\n tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to(tl.float32)\n tmp36 = tmp35 / 9\n tmp37 = tmp33 < tmp9\n tmp38 = tmp19 & tmp37\n tmp39 = tmp31 + tmp36\n tmp40 = tl.where(tmp38, tmp39, tmp31)\n tmp41 = tmp5 + tmp12\n tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13))\n tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to(tl.float32)\n tmp44 = tmp43 / 9\n tmp45 = tmp41 < tmp8\n tmp46 = tmp45 & tmp20\n tmp47 = tmp40 + tmp44\n tmp48 = tl.where(tmp46, tmp47, tmp40)\n tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to(tl.float32)\n tmp50 = tmp49 / 9\n tmp51 = tmp45 & tmp28\n tmp52 = tmp48 + tmp50\n tmp53 = tl.where(tmp51, tmp52, tmp48)\n tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to(tl.float32)\n tmp55 = tmp54 / 9\n tmp56 = tmp45 & tmp37\n tmp57 = tmp53 + tmp55\n tmp58 = tl.where(tmp56, tmp57, tmp53)\n tmp59 = tmp5 + tmp32\n tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13))\n tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to(tl.float32)\n tmp62 = tmp61 / 9\n tmp63 = tmp59 < tmp8\n tmp64 = tmp63 & tmp20\n tmp65 = tmp58 + tmp62\n tmp66 = tl.where(tmp64, tmp65, tmp58)\n tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to(tl.float32)\n tmp68 = tmp67 / 9\n tmp69 = tmp63 & tmp28\n tmp70 = tmp66 + tmp68\n tmp71 = tl.where(tmp69, tmp70, tmp66)\n tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to(tl.float32)\n tmp73 = tmp72 / 9\n tmp74 = tmp63 & tmp37\n tmp75 = tmp71 + tmp73\n tmp76 = tl.where(tmp74, tmp75, tmp71)\n tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None)\n\n# Call the average pooling backward kernel\ninp = torch.ones(8, 2048, 8, 8, device=\"cuda\", dtype=torch.half)\nout = torch.ones_like(inp) * 3\nnumel = inp.numel()\ntriton_avg_pool_bw[(numel // 1024, )](inp, out, 1024)\nout_ref = torch.ones_like(inp)\nout_ref[:, :, 1:7, 0::7] = 2 / 3\nout_ref[:, :, 0::7, 1:7] = 2 / 3\nout_ref[:, :, 0::7, 0::7] = 4 / 9\ntorch.testing.assert_close(out, out_ref)\n", - "description_1": "Use triton language to implement two kernels: one for normalization with rematerialization and another for average pooling backward. The normalization kernel takes 10 parameters: two output pointers, four input pointers, two integers for element counts, and two block size constants. It performs element-wise operations and stores results. The average pooling backward kernel takes three parameters: an input pointer, an output pointer, and a block size constant. It computes average pooling gradients and stores the results.", - "description_2": "Use triton language to create kernels for normalization and average pooling backward, each with specific input/output pointers and block size parameters, performing element-wise computations and storing results.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef f8_to_f16(x, dtype):\n @triton.jit\n def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offs < N\n x = tl.load(X + offs, mask=mask)\n tl.store(Y + offs, x, mask=mask)\n\n ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)\n grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), )\n dtype = getattr(tl, dtype)\n kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024)\n return ret\n", - "description_1": "Use triton language to define a kernel that converts a tensor from float8 to float16. The kernel takes four parameters: Y (output tensor), X (input tensor), N (number of elements), and BLOCK_SIZE (block size for parallel processing). The kernel uses triton's program_id and arange to calculate offsets and masks for loading and storing data. The function f8_to_f16 calls this kernel with appropriate grid and block size configurations.", - "description_2": "Use triton language to implement a kernel for converting float8 tensors to float16, utilizing parallel processing with specified block sizes.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Description: Kernel for copying data from source to destination with block sizes.\n@triton.jit\ndef _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):\n # Calculate offsets within the current block\n offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Load data from source based on calculated offsets with bounds checking\n x = tl.load(src + offsets, mask=offsets < N)\n # Store data to destination based on calculated offsets with bounds checking\n tl.store(dst + offsets, x, mask=offsets < N)\n\ndef test_kwargs():\n N = 1024\n src = torch.empty(N, device='cuda')\n dst = torch.empty(N, device='cuda')\n \n # Define autotuning configurations for block sizes\n configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]\n \n @triton.autotune(configs=configs, key=['N'])\n def kernel_autotuned(dst, src, N):\n grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )\n _kernel[grid](dst, src, N)\n _kernel[grid](dst=dst, src=src, N=N)\n \n # Test the autotuned kernel\n kernel_autotuned(dst, src, N)\n\n# Description: Kernel for incrementing each element of source by 1 with block sizes and restore capability.\n@triton.jit\ndef _kernel_restore(src, N, BLOCK_SIZE: tl.constexpr):\n # Calculate offsets within the current block\n offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Load and increment data from source based on calculated offsets with bounds checking\n x = tl.load(src + offsets, mask=offsets < N) + 1\n # Store incremented data back to source\n tl.store(src + offsets, x, mask=offsets < N)\n\ndef test_restore():\n N = 1024\n src = torch.zeros(N, device='cuda')\n\n # Define autotuning configurations for block sizes with restore capability\n configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]\n\n @triton.autotune(configs=configs, key=['N'], restore_value=['src'])\n def kernel_restore_autotuned(src, N):\n grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )\n _kernel_restore[grid](src, N)\n \n # Test the autotuned kernel with restore functionality\n kernel_restore_autotuned(src, N)\n triton.testing.assert_close(src, torch.ones_like(src))\n", - "description_1": "Use triton language to implement two kernels: 1) A kernel that copies data from a source tensor to a destination tensor with configurable block sizes using Triton's `@jit` and `autotune` functionalities. It requires four parameters: `dst`, `src`, `N`, and `BLOCK_SIZE`, where `dst` and `src` are tensors, `N` is the size of the data, and `BLOCK_SIZE` is the configurable block size. 2) A kernel that increments each element of a source tensor by 1, also using block sizes and including restore functionality, requiring three parameters: `src`, `N`, and `BLOCK_SIZE`.", - "description_2": "Use triton language to create a block-wise data copy kernel and an increment kernel with autotuning for different block sizes, employing features like `@jit`, `autotune`, and restore capabilities.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel function that increments an integer and stores it\n@triton.jit\ndef function_1(i):\n i = i + 1\n i = function_2(i)\n return i\n\n# Triton kernel function that increments an integer\n@triton.jit\ndef function_2(i):\n i = i + 1\n return i\n\n# Triton kernel that uses function_1 and stores the result\n@triton.jit\ndef kernel(X, i, BLOCK: tl.constexpr):\n i = i + 1\n i = function_1(i)\n tl.store(X, i)\n\n# Triton kernel with no specialization on 'i'\n@triton.jit(do_not_specialize=[\"i\"])\ndef kernel_nospec(X, i, BLOCK: tl.constexpr):\n i = i + 1\n i = function_1(i)\n tl.store(X, i)\n\n# Test function to check cache reuse\ndef test_reuse():\n counter = 0\n\n def inc_counter(*args, **kwargs):\n nonlocal counter\n counter += 1\n\n JITFunction.cache_hook = inc_counter\n reset_tmp_dir()\n x = torch.empty(1, dtype=torch.int32, device='cuda')\n for i in range(10):\n kernel[(1, )](x, 1, BLOCK=1024)\n assert counter == 1\n\n# Test function to check specialization\n@pytest.mark.parametrize('mode', ['enable', 'disable'])\ndef test_specialize(mode):\n counter = 0\n\n def inc_counter(*args, **kwargs):\n nonlocal counter\n counter += 1\n\n JITFunction.cache_hook = inc_counter\n reset_tmp_dir()\n x = torch.empty(1, dtype=torch.int32, device='cuda')\n function = {'enable': kernel, 'disable': kernel_nospec}[mode]\n target = {'enable': 4, 'disable': 1}[mode]\n for i in [1, 2, 4, 8, 16, 32]:\n function[(1, )](x, i, BLOCK=512)\n assert counter == target\n", - "description_1": "Use triton language to define a series of kernels: 'function_1' and 'function_2' increment an integer, 'kernel' and 'kernel_nospec' use these functions to increment and store a value in a tensor. 'kernel_nospec' does not specialize on the integer parameter. Test functions 'test_reuse' and 'test_specialize' ensure cache reuse and specialization behavior.", - "description_2": "Use triton language to create kernels that increment integers and store results, with tests for cache and specialization.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport tracemalloc\nimport gc\n\ndef test_memory_leak() -> None:\n\n @triton.jit\n def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n xnumel = 10\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), xmask)\n tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)\n\n tracemalloc.start()\n try:\n inp = torch.randn(10, device='cuda')\n out = torch.randn(10, device='cuda')\n kernel[(10, )](inp, out, 10, XBLOCK=16)\n gc.collect()\n begin, _ = tracemalloc.get_traced_memory()\n for _ in range(100):\n kernel[(10, )](inp, out, 10, XBLOCK=16)\n gc.collect()\n end, _ = tracemalloc.get_traced_memory()\n assert end - begin < 30000\n finally:\n tracemalloc.stop()\n", - "description_1": "Use triton language to define a kernel function that takes four arguments: in_ptr0 (input pointer), out_ptr0 (output pointer), xnumel (number of elements), and XBLOCK (block size as a constexpr). The kernel initializes xnumel to 10, computes an offset based on the program ID and block size, and processes data in a block-wise manner. It uses triton's load and store operations with masks to handle conditional operations based on the element index. The kernel is called within a Python function to test for memory leaks by repeatedly executing the kernel and comparing memory usage before and after the execution.", - "description_2": "Use triton language to create a kernel function for memory leak testing by performing block-wise operations on input and output pointers.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport multiprocessing\nfrom collections import namedtuple\n\ninstance_descriptor = namedtuple(\"instance_descriptor\",\n [\"divisible_by_16\", \"equal_to_1\", \"ids_of_folded_args\", \"divisible_by_8\"])\n\n\ndef compile_fn(config, cc):\n # Kernel function for element-wise subtraction and multiplication\n @triton.jit\n def kernel_sub(a, b, o, N: tl.constexpr):\n idx = tl.arange(0, N)\n tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)\n\n triton.compile(\n fn=kernel_sub,\n signature={0: \"*fp32\", 1: \"*fp32\", 2: \"*fp32\"},\n device=0,\n constants={3: 32},\n configs=[config],\n warm_cache_only=True,\n cc=cc,\n )\n\n\ndef test_compile_in_subproc() -> None:\n major, minor = torch.cuda.get_device_capability(0)\n cc = major * 10 + minor\n config = instance_descriptor(tuple(range(4)), (), (), ())\n\n multiprocessing.set_start_method('fork')\n proc = multiprocessing.Process(target=compile_fn, args=(config, cc))\n proc.start()\n proc.join()\n assert proc.exitcode == 0\n\n\ndef compile_fn_dot(config, cc):\n # Kernel function for matrix dot product\n @triton.jit\n def kernel_dot(Z):\n offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]\n z = tl.load(Z + offs)\n z = tl.dot(z, z)\n tl.store(Z + offs, z)\n\n triton.compile(\n fn=kernel_dot,\n signature={0: \"*fp32\"},\n device=0,\n configs=[config],\n warm_cache_only=True,\n cc=cc,\n )\n\n\ndef test_compile_in_forked_subproc() -> None:\n reset_tmp_dir()\n major, minor = torch.cuda.get_device_capability(0)\n cc = major * 10 + minor\n config = instance_descriptor(tuple(range(1)), (), (), ())\n\n assert multiprocessing.get_start_method() == 'fork'\n proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc))\n proc.start()\n proc.join()\n assert proc.exitcode == 0\n", - "description_1": "Use triton language to define two kernels: one for element-wise subtraction and multiplication of two arrays, and another for computing the dot product of a matrix. The first kernel, 'kernel_sub', takes four parameters: two input arrays 'a' and 'b', an output array 'o', and a constant 'N' representing the size of the arrays. It computes the element-wise subtraction of 'b' multiplied by 777 from 'a' and stores the result in 'o'. The second kernel, 'kernel_dot', takes one parameter: a matrix 'Z'. It computes the dot product of 'Z' with itself and stores the result back in 'Z'. Both kernels are compiled with specific configurations and device capabilities.", - "description_2": "Use triton language to create a kernel for element-wise operations on arrays and another for matrix dot product, compiling them with specific configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport kernel_utils\n\n# Kernel to perform a matrix multiplication with customization for block sizes\n@triton.jit\ndef kernel(C, A, B, M, N, K,\n stride_cm, stride_cn,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr):\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n\n offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n offs_k = tl.arange(0, BLOCK_K)\n a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_K * stride_ak\n b_ptrs += BLOCK_K * stride_bk\n\n c = kernel_utils.mul(accumulator, accumulator)\n offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n tl.store(c_ptrs, c)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel that computes the product of two matrices A and B, storing the result in C. It uses block-based matrix multiplication for efficient memory access and parallel computation. Each thread block computes one block of the output matrix C, defined by the constants BLOCK_M, BLOCK_N, and BLOCK_K. The kernel also includes functionality to handle arbitrary sizes of matrices and accumulates the product in float32 precision before writing back to C.", - "description_2": "Use triton language to create a matrix multiplication kernel that performs block-wise matrix computation with customizable block sizes for efficient GPU execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel to add two tensors\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = triton.program_id(0)\n block_size = 1024\n offset = pid * block_size + triton.arange(0, block_size)\n mask = offset < N\n x = triton.load(X + offset, mask=mask)\n y = triton.load(Y + offset, mask=mask)\n z = x + y\n triton.store(Z + offset, z, mask=mask)\n\n# Function to call the Triton kernel\ndef add_tensors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n assert x.is_cuda and y.is_cuda\n assert x.shape == y.shape\n z = torch.empty_like(x)\n N = x.numel()\n grid = lambda meta: (triton.cdiv(N, meta['block_size']),)\n add_kernel[grid](x, y, z, N)\n return z\n", - "description_1": "Use triton language to implement a kernel that adds two tensors element-wise. The kernel is decorated with @triton.jit and takes four parameters: X, Y, Z, and N. X and Y are input tensors, Z is the output tensor, and N is the number of elements. The kernel computes the sum of X and Y and stores the result in Z. The function add_tensors calls this kernel, ensuring the input tensors are on CUDA and have the same shape, and returns the result tensor.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors, and a function to invoke this kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(X, Y, Z, N, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\n# Function to call the Triton kernel\ndef add_tensors(x: torch.Tensor, y: torch.Tensor):\n assert x.is_cuda and y.is_cuda\n assert x.numel() == y.numel()\n z = torch.empty_like(x)\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),)\n add_kernel[grid](x, y, z, x.numel(), BLOCK_SIZE=BLOCK_SIZE)\n return z\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel 'add_kernel' takes 5 parameters: X, Y, Z, N, and BLOCK_SIZE. X, Y, and Z are pointers to the input and output tensors, N is the number of elements, and BLOCK_SIZE is a compile-time constant defining the number of elements processed by each program instance. The function 'add_tensors' calls this kernel, ensuring the input tensors are on CUDA and have the same number of elements. It creates an output tensor Z and launches the kernel with a grid size calculated based on the number of elements and BLOCK_SIZE.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors and a function to execute this kernel on CUDA tensors.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\n# Example function calling the Triton kernel\ndef call_kernel(x, x_size):\n meta = {'BLOCK_SIZE': 128}\n kernel[(x_size,)](x, x_size, **meta)\n\n# Another Triton kernel with autotuning\n@triton.autotune(configs=[\n triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),\n triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),\n], key=['x_size'])\n@triton.jit\ndef autotuned_kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Autotuned kernel implementation here\n\n# Example function calling the autotuned Triton kernel\ndef call_autotuned_kernel(x, x_size):\n autotuned_kernel[(x_size,)](x, x_size)\n", - "description_1": "Use triton language to implement two kernels: `kernel` and `autotuned_kernel`. The `kernel` takes pointers and a size, using a block size from meta-parameters for computation. The `autotuned_kernel` enhances this with automatic tuning, selecting optimal configurations based on input size.", - "description_2": "Implement Triton kernels with support for block size meta-parameters, and apply autotuning for optimal performance with different input sizes.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, **meta):\n idx = triton.program_id(0)\n x = X[idx]\n Y[idx] = x * 2.0\n\n# Example function calling the Triton kernel\ndef call_example_kernel(X):\n Y = torch.empty_like(X)\n # Launch the Triton kernel with a single block and the number of elements as the grid size\n example_kernel[(X.numel(),)](X, Y)\n return Y\n\n# Example usage\nif __name__ == \"__main__\":\n X = torch.tensor([1.0, 2.0, 3.0, 4.0], device='cuda')\n Y = call_example_kernel(X)\n print(Y)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' that multiplies each element in a tensor X by 2, storing the result in a tensor Y. The kernel takes two parameters, X and Y, and uses the current program ID to index into these tensors. 'call_example_kernel' function launches this kernel on a single grid covering the entire tensor.", - "description_2": "Use triton language to create a kernel for element-wise operations on tensors, specifically multiplying each tensor element by 2 and storing the result.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n ):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement a vector addition kernel. The kernel 'add_kernel' takes five parameters: pointers to the input vectors x and y, a pointer to the output vector, the number of elements in the vectors, and a block size as a compile-time constant. The kernel computes the element-wise sum of x and y, storing the result in the output vector. The 'add' function prepares the output tensor, sets up the grid for kernel execution, and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a kernel for element-wise vector addition, and implement a function to execute this kernel on CUDA tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N)\"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n x = x + 1\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation\n )\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication kernel that takes three matrix pointers (a_ptr, b_ptr, c_ptr), three matrix dimensions (M, N, K), six stride variables, and five meta-parameters. The kernel computes the matrix multiplication C = A x B with an optional activation function, storing the result in C.", - "description_2": "Use triton language to create a wrapper function 'matmul' that checks input constraints, allocates output matrix, and launches the 'matmul_kernel' for matrix multiplication with optional activation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n random = tl.rand(seed, offsets)\n x_keep = random > p\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\n# Inputs for dropout function\nx = torch.randn(size=(10, )).cuda()\np = 0.5\nx_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()\noutput = dropout(x, x_keep=x_keep, p=p)\nprint(tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"keep mask\"] + x_keep.tolist(),\n [\"output\"] + output.tolist(),\n]))\n\n# Inputs for seeded_dropout function\nx = torch.randn(size=(10, )).cuda()\noutput = seeded_dropout(x, p=0.5, seed=123)\noutput2 = seeded_dropout(x, p=0.5, seed=123)\noutput3 = seeded_dropout(x, p=0.5, seed=512)\nprint(\n tabulate.tabulate([\n [\"input\"] + x.tolist(),\n [\"output (seed = 123)\"] + output.tolist(),\n [\"output (seed = 123)\"] + output2.tolist(),\n [\"output (seed = 512)\"] + output3.tolist(),\n ]))\n", - "description_1": "Use triton language to create a dropout kernel that receives a pointer to input data, a dropout mask, and other parameters to zero out elements with a probability p, storing results in output memory. Another seeded version generates the mask using a random function with a seed, ensuring the same dropout mask if the seed is unchanged.", - "description_2": "Use triton language to implement a dropout kernel that zeros input elements with probability p, using both direct and pseudo-random generated masks.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Write mean / rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n tl.store(Y + cols, y, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, Mean, Rstd, Lock,\n stride, N, eps, GROUP_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n X += row * stride\n DY += row * stride\n DX += row * stride\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n DB = DB + lock_id * N + cols\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n tl.store(DX + cols, dx, mask=mask)\n # Accumulate partial sums for dw/db\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = (dy).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n tl.atomic_xchg(Lock, 0)\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n tl.store(FINAL_DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n y = torch.empty_like(x)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device='cuda')\n rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M, )](\n x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)\n db = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M, )](\n dx, dy, _dw, _db, x, w, b, m, v, locks,\n x_arg.stride(0), N, ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n _layer_norm_bwd_dwdb[grid](\n _dw, _db, dw, db, GROUP_SIZE_M, N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128, num_ctas=1)\n return dx, None, dw, db, None\n\nlayer_norm = LayerNorm.apply\n", - "description_1": "Use triton language to implement a high-performance layer normalization kernel and its backward pass. The forward kernel takes 10 parameters: pointers to input/output data, weights, biases, mean, reciprocal of std deviation, stride, the number of columns in the input, epsilon for stability, and block size. The backward dx kernel takes 14 parameters including pointers to input/output gradients and locks for parallel reduction. The backward dw/db kernel takes 7 parameters to compute the weight/bias gradients by accumulating partial sums.", - "description_2": "Use triton language to create a layer normalization operator with parallel reduction in backward pass for efficient computation of gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, #\n K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #\n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #\n N_CTX: tl.constexpr):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n else:\n lo, hi = 0, N_CTX\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_block_ptr)\n acc += tl.dot(p.to(tl.float16), v)\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.jit\ndef _attn_fwd(Q, K, V, sm_scale, M, Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, #\n N_CTX: tl.constexpr, #\n BLOCK_M: tl.constexpr, #\n BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n STAGE: tl.constexpr #\n ):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale\n qk_scale *= 1.44269504\n q = tl.load(Q_block_ptr)\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M, BLOCK_DMODEL, BLOCK_N, #\n 4 - STAGE, offs_m, offs_n, N_CTX #\n )\n if STAGE & 2:\n tl.debug_barrier()\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M, BLOCK_DMODEL, BLOCK_N, #\n 2, offs_m, offs_n, N_CTX #\n )\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\n@triton.jit\ndef _attn_bwd_preprocess(O, DO, #\n Delta, #\n Z, H, N_CTX, #\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #\n ):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_hz = tl.program_id(1)\n off_n = tl.arange(0, D_HEAD)\n o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])\n do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n delta = tl.sum(o * do, axis=1)\n tl.store(Delta + off_hz * N_CTX + off_m, delta)\n\n@triton.jit\ndef _attn_bwd_dkdv(dk, dv, #\n Q, k, v, sm_scale, #\n DO, #\n M, D, #\n stride_tok, stride_d, #\n H, N_CTX, BLOCK_M1: tl.constexpr, #\n BLOCK_N1: tl.constexpr, #\n BLOCK_DMODEL: tl.constexpr, #\n start_n, start_m, num_steps, #\n MASK: tl.constexpr):\n offs_m = start_m + tl.arange(0, BLOCK_M1)\n offs_n = start_n + tl.arange(0, BLOCK_N1)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d\n do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d\n tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)\n curr_m = start_m\n step_m = BLOCK_M1\n for blk_idx in range(num_steps):\n qT = tl.load(qT_ptrs)\n offs_m = curr_m + tl.arange(0, BLOCK_M1)\n m = tl.load(M + offs_m)\n qkT = tl.dot(k, qT)\n pT = tl.math.exp2(qkT - m[None, :])\n if MASK:\n mask = (offs_m[None, :] >= offs_n[:, None])\n pT = tl.where(mask, pT, 0.0)\n do = tl.load(do_ptrs)\n ppT = pT\n ppT = ppT.to(tl.float16)\n dv += tl.dot(ppT, do)\n Di = tl.load(D + offs_m)\n dpT = tl.dot(v, tl.trans(do)).to(tl.float32)\n dsT = pT * (dpT - Di[None, :])\n dsT = dsT.to(tl.float16)\n dk += tl.dot(dsT, tl.trans(qT))\n curr_m += step_m\n qT_ptrs += step_m * stride_tok\n do_ptrs += step_m * stride_tok\n return dk, dv\n\n@triton.jit\ndef _attn_bwd_dq(dq, q, K, V, #\n do, m, D,\n stride_tok, stride_d, #\n H, N_CTX, #\n BLOCK_M2: tl.constexpr, #\n BLOCK_N2: tl.constexpr, #\n BLOCK_DMODEL: tl.constexpr,\n start_m, start_n, num_steps, #\n MASK: tl.constexpr):\n offs_m = start_m + tl.arange(0, BLOCK_M2)\n offs_n = start_n + tl.arange(0, BLOCK_N2)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d\n vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d\n Di = tl.load(D + offs_m)\n tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)\n curr_n = start_n\n step_n = BLOCK_N2\n for blk_idx in range(num_steps):\n kT = tl.load(kT_ptrs)\n vT = tl.load(vT_ptrs)\n qk = tl.dot(q, kT)\n p = tl.math.exp2(qk - m)\n if MASK:\n offs_n = curr_n + tl.arange(0, BLOCK_N2)\n mask = (offs_m[:, None] >= offs_n[None, :])\n p = tl.where(mask, p, 0.0)\n dp = tl.dot(do, vT).to(tl.float32)\n ds = p * (dp - Di[:, None])\n ds = ds.to(tl.float16)\n dq += tl.dot(ds, tl.trans(kT))\n curr_n += step_n\n kT_ptrs += step_n * stride_tok\n vT_ptrs += step_n * stride_tok\n return dq\n\n@triton.jit\ndef _attn_bwd(Q, K, V, sm_scale, #\n DO, #\n DQ, DK, DV, #\n M, D,\n stride_z, stride_h, stride_tok, stride_d, #\n H, N_CTX, #\n BLOCK_M1: tl.constexpr, #\n BLOCK_N1: tl.constexpr, #\n BLOCK_M2: tl.constexpr, #\n BLOCK_N2: tl.constexpr, #\n BLK_SLICE_FACTOR: tl.constexpr, #\n BLOCK_DMODEL: tl.constexpr):\n LN2: tl.constexpr = 0.6931471824645996\n\n bhid = tl.program_id(2)\n off_chz = (bhid * N_CTX).to(tl.int64)\n adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)\n pid = tl.program_id(0)\n\n Q += adj\n K += adj\n V += adj\n DO += adj\n DQ += adj\n DK += adj\n DV += adj\n M += off_chz\n D += off_chz\n\n offs_k = tl.arange(0, BLOCK_DMODEL)\n\n start_n = pid * BLOCK_N1\n start_m = start_n\n\n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR\n offs_n = start_n + tl.arange(0, BLOCK_N1)\n\n dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)\n\n k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)\n v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)\n\n num_steps = BLOCK_N1 // MASK_BLOCK_M1\n\n dk, dv = _attn_bwd_dkdv(dk, dv, #\n Q, k, v, sm_scale, #\n DO, #\n M, D, #\n stride_tok, stride_d, #\n H, N_CTX, #\n MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, #\n start_n, start_m, num_steps, #\n MASK=True #\n )\n\n start_m += num_steps * MASK_BLOCK_M1\n num_steps = (N_CTX - start_m) // BLOCK_M1\n\n dk, dv = _attn_bwd_dkdv( #\n dk, dv, #\n Q, k, v, sm_scale, #\n DO, #\n M, D, #\n stride_tok, stride_d, #\n H, N_CTX, #\n BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, #\n start_n, start_m, num_steps, #\n MASK=False #\n )\n\n dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d\n tl.store(dv_ptrs, dv)\n\n dk *= sm_scale\n dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d\n tl.store(dk_ptrs, dk)\n\n start_m = pid * BLOCK_M2\n end_n = start_m + BLOCK_M2\n\n MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR\n offs_m = start_m + tl.arange(0, BLOCK_M2)\n\n q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)\n dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)\n do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)\n\n m = tl.load(M + offs_m)\n m = m[:, None]\n\n num_steps = BLOCK_M2 // MASK_BLOCK_N2\n dq = _attn_bwd_dq(dq, q, K, V, #\n do, m, D, #\n stride_tok, stride_d, #\n H, N_CTX, #\n BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, #\n start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #\n MASK=True #\n )\n end_n -= num_steps * MASK_BLOCK_N2\n num_steps = end_n // BLOCK_N2\n dq = _attn_bwd_dq(dq, q, K, V, #\n do, m, D, #\n stride_tok, stride_d, #\n H, N_CTX, #\n BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, #\n start_m, end_n - num_steps * BLOCK_N2, num_steps, #\n MASK=False #\n )\n dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d\n dq *= LN2\n tl.store(dq_ptrs, dq)\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n BLOCK_M = 128\n BLOCK_N = 64 if Lk <= 64 else 32\n num_stages = 4 if Lk <= 64 else 3\n num_warps = 4\n stage = 3 if causal else 1\n if torch.cuda.get_device_capability()[0] == 9:\n num_warps = 8\n num_stages = 7 if Lk >= 64 else 3\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, sm_scale, M, o, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], #\n N_CTX=q.shape[2], #\n BLOCK_M=BLOCK_M, #\n BLOCK_N=BLOCK_N, #\n BLOCK_DMODEL=Lk, #\n STAGE=stage, #\n num_warps=num_warps, #\n num_stages=num_stages #\n )\n\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, M = ctx.saved_tensors\n assert do.is_contiguous()\n assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n BATCH, N_HEAD, N_CTX = q.shape[:3]\n PRE_BLOCK = 128\n NUM_WARPS, NUM_STAGES = 4, 1\n if torch.cuda.get_device_capability()[0] == 9:\n NUM_STAGES = 5\n BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32\n BLK_SLICE_FACTOR = 2\n RCP_LN2 = 1.4426950408889634\n arg_k = k\n arg_k = arg_k * (ctx.sm_scale * RCP_LN2)\n PRE_BLOCK = 128\n assert N_CTX % PRE_BLOCK == 0\n pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)\n delta = torch.empty_like(M)\n _attn_bwd_preprocess[pre_grid](\n o, do, #\n delta, #\n BATCH, N_HEAD, N_CTX, #\n BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL #\n )\n grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)\n _attn_bwd[grid](\n q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #\n M, delta, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n N_HEAD, N_CTX, #\n BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #\n BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #\n BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, #\n num_warps=NUM_WARPS, #\n num_stages=NUM_STAGES #\n )\n\n return dq, dk, dv, None, None\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention mechanism with forward and backward passes. The forward pass (_attn_fwd) computes the attention output given query (Q), key (K), and value (V) matrices, along with scaling and other parameters. The backward pass (_attn_bwd) computes gradients for Q, K, and V given the gradient of the output. The kernels handle block-wise operations and support both causal and non-causal attention. The main function, _attention, is a PyTorch autograd function that wraps these kernels for use in neural network training.", - "description_2": "Use triton language to implement a fused attention mechanism with forward and backward passes, supporting block-wise operations and both causal and non-causal attention.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x = tl.math.asin(x)\n tl.store(y_ptr + offsets, x, mask=mask)\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n\noutput_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement a kernel function 'asin_kernel' that computes the arc sine of each element in a given input tensor. The kernel takes four parameters: 'x_ptr' (pointer to input tensor), 'y_ptr' (pointer to output tensor), 'n_elements' (number of elements in the tensor), and 'BLOCK_SIZE' (block size for parallel execution). The kernel uses triton's math library to compute the arc sine and stores the result in the output tensor. The kernel is invoked with a grid configuration based on the number of elements and block size.", - "description_2": "Use triton language to create a kernel that calculates the arc sine of tensor elements using triton's math library, and execute it with appropriate grid configuration.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel_with_block_pointers(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),\n offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),\n order=(1, 0))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),\n offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),\n order=(1, 0))\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n a = tl.load(a_block_ptr, boundary_check=(0, 1))\n b = tl.load(b_block_ptr, boundary_check=(0, 1))\n accumulator += tl.dot(a, b)\n a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))\n b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))\n c = accumulator.to(tl.float16)\n\n c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),\n offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),\n block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))\n tl.store(c_block_ptr, c, boundary_check=(0, 1))\n\ndef matmul(a, b):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel_with_block_pointers[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1))\n return c\n\ntorch.manual_seed(0)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\ntriton_output = matmul(a, b)\ntorch_output = torch.matmul(a, b)\nprint(f\"triton_output={triton_output}\")\nprint(f\"torch_output={torch_output}\")\nif torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):\n print(\"✅ Triton and Torch match\")\nelse:\n print(\"❌ Triton and Torch differ\")\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with block pointers. The kernel 'matmul_kernel_with_block_pointers' takes 14 parameters: three pointers to matrices (a_ptr, b_ptr, c_ptr), three integers for matrix dimensions (M, N, K), six integers for strides (stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn), and four compile-time constants for block sizes and group size (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M). The kernel computes the product of matrices A and B, storing the result in matrix C. The 'matmul' function is a wrapper that checks input constraints, allocates output, and launches the kernel.", - "description_2": "Use triton language to create a matrix multiplication kernel using block pointers for optimized memory access. Implement a wrapper function to handle input validation and kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernel(a_ptr, b_ptr, z_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_zm, stride_zn, #\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, #\n A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #\n B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr #\n ):\n pid = tl.program_id(axis=0)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n block_offset_m = pid_m * BLOCK_SIZE_M\n block_offset_n = pid_n * BLOCK_SIZE_N\n\n a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),\n offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),\n order=(A_ORDER_0, A_ORDER_1))\n b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),\n offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),\n order=(B_ORDER_0, B_ORDER_1))\n z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M)\n offs_n = block_offset_n + tl.arange(0, BLOCK_SIZE_N)\n z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn\n mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]\n\n for k in range(0, K, BLOCK_SIZE_K):\n a = tl.load(a_tile_ptr)\n b = tl.load(b_tile_ptr)\n z += tl.dot(a, b)\n a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])\n b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])\n\n z = z.to(tl.float16)\n\n tl.store(z_ptrs, z, mask=mask)\n\n\ndef matmul(a, b, a_order, b_order):\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n\n z = torch.empty((M, N), device=a.device, dtype=torch.float16)\n\n def grid(META):\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n\n matmul_kernel[grid](\n a_ptr=a, b_ptr=b, z_ptr=z, #\n M=M, N=N, K=K, #\n stride_am=a.stride(0), stride_ak=a.stride(1), #\n stride_bk=b.stride(0), stride_bn=b.stride(1), #\n stride_zm=z.stride(0), stride_zn=z.stride(1), #\n A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #\n B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] #\n )\n return z\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (matmul_kernel) that takes 19 parameters including pointers to input matrices, matrix dimensions, stride information, block size, and order constants for matrix multiplication, and an outer function (matmul) to call this kernel.", - "description_2": "Use triton language to create a matrix multiplication kernel and a wrapper function to execute the kernel on given matrices with specified dimensions and strides.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef matmul_kernel(a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n pid = tl.program_id(axis=0)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n block_offset_m = pid_m * BLOCK_SIZE_M\n block_offset_n = pid_n * BLOCK_SIZE_N\n\n a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),\n offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0))\n b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),\n offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(0, 1))\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, K, BLOCK_SIZE_K):\n a = tl.load(a_tile_ptr)\n b = tl.load(b_tile_ptr)\n accumulator += tl.dot(a, b)\n a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])\n b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])\n\n c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),\n offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),\n order=(1, 0))\n\n tl.store(c_block_ptr, accumulator)\n\n\ndef matmul(a, b):\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n assert (K % 32 == 0), \"We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K\"\n\n c = torch.empty((M, N), device=a.device, dtype=torch.float32)\n\n def grid(META):\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n\n matmul_kernel[grid](\n a_ptr=a, b_ptr=b, c_ptr=c,\n M=M, N=N, K=K,\n stride_am=a.stride(0), stride_ak=a.stride(1),\n stride_bk=b.stride(0), stride_bn=b.stride(1),\n stride_cm=c.stride(0), stride_cn=c.stride(1))\n return c\n\n\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16).T\nc = matmul(a, b)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel and its corresponding call function. The `matmul_kernel` takes 14 arguments: pointers to matrices a, b, c, and their respective dimensions (M, N, K). It also takes strides for the matrices a, b, c and compile-time constants for block sizes (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) and group size (GROUP_SIZE_M). The kernel computes matrix multiplication of a and b, storing the result in c. The `matmul` function, which wraps the kernel call, validates input matrices, initializes an output matrix, and invokes the kernel with computed grid dimensions.", - "description_2": "Use triton language to create a matrix multiplication operation with customizable block and grid sizes, leveraging compile-time constants for optimal parallel execution on GPU.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef grouped_matmul_kernel(\n group_a_ptrs,\n group_b_ptrs,\n group_c_ptrs,\n group_gemm_sizes,\n g_lds,\n group_size,\n NUM_SM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n tile_idx = tl.program_id(0)\n last_problem_end = 0\n for g in range(group_size):\n gm = tl.load(group_gemm_sizes + g * 3)\n gn = tl.load(group_gemm_sizes + g * 3 + 1)\n gk = tl.load(group_gemm_sizes + g * 3 + 2)\n num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n num_tiles = num_m_tiles * num_n_tiles\n while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):\n k = gk\n lda = tl.load(g_lds + g * 3)\n ldb = tl.load(g_lds + g * 3 + 1)\n ldc = tl.load(g_lds + g * 3 + 2)\n a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))\n b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))\n c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))\n tile_idx_in_gemm = tile_idx - last_problem_end\n tile_m_idx = tile_idx_in_gemm // num_n_tiles\n tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]\n b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n tl.multiple_of(a_ptrs, [16, 16])\n tl.multiple_of(b_ptrs, [16, 16])\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * ldb\n c = accumulator.to(tl.float16)\n\n offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]\n\n tl.store(c_ptrs, c)\n\n tile_idx += NUM_SM\n\n last_problem_end = last_problem_end + num_tiles\n\n\ndef group_gemm_fn(group_A, group_B):\n device = torch.device('cuda')\n assert len(group_A) == len(group_B)\n group_size = len(group_A)\n\n A_addrs = []\n B_addrs = []\n C_addrs = []\n g_sizes = []\n g_lds = []\n group_C = []\n for i in range(group_size):\n A = group_A[i]\n B = group_B[i]\n assert A.shape[1] == B.shape[0]\n M, K = A.shape\n K, N = B.shape\n C = torch.empty((M, N), device=device, dtype=A.dtype)\n group_C.append(C)\n A_addrs.append(A.data_ptr())\n B_addrs.append(B.data_ptr())\n C_addrs.append(C.data_ptr())\n g_sizes += [M, N, K]\n g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n\n d_a_ptrs = torch.tensor(A_addrs, device=device)\n d_b_ptrs = torch.tensor(B_addrs, device=device)\n d_c_ptrs = torch.tensor(C_addrs, device=device)\n d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)\n d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)\n grid = lambda META: (META['NUM_SM'], )\n grouped_matmul_kernel[grid](\n d_a_ptrs,\n d_b_ptrs,\n d_c_ptrs,\n d_g_sizes,\n d_g_lds,\n group_size,\n )\n\n return group_C\n\n\ngroup_m = [1024, 512, 256, 128]\ngroup_n = [1024, 512, 256, 128]\ngroup_k = [1024, 512, 256, 128]\ngroup_A = []\ngroup_B = []\nassert len(group_m) == len(group_n)\nassert len(group_n) == len(group_k)\ngroup_size = len(group_m)\nfor i in range(group_size):\n M = group_m[i]\n N = group_n[i]\n K = group_k[i]\n A = torch.rand((M, K), device=\"cuda\", dtype=torch.float16)\n B = torch.rand((K, N), device=\"cuda\", dtype=torch.float16)\n group_A.append(A)\n group_B.append(B)\n\ntri_out = group_gemm_fn(group_A, group_B)\nref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]\nfor i in range(group_size):\n assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0)\n", - "description_1": "Use triton language to implement a grouped matrix multiplication kernel that processes multiple GEMM operations in parallel. The kernel takes pointers to matrices, their sizes, and leading dimensions, and computes the result using a fixed number of streaming multiprocessors. The kernel is called from a function that prepares the input matrices and launches the kernel on the GPU.", - "description_2": "Use triton language to create a kernel for grouped GEMM operations and a function to prepare and launch this kernel on the GPU.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=[\"m\", \"n\", \"k\"],\n)\n@triton.jit\ndef triton_matmul_kernel(\n lhs_ptr,\n rhs_ptr,\n output_ptr,\n m,\n n,\n k,\n lhs_stride_m,\n lhs_stride_k,\n rhs_stride_k,\n rhs_stride_n,\n output_stride_m,\n output_stride_n,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n num_pid_m = tl.cdiv(m, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(n, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n lhs_ptrs = lhs_ptr + (\n offs_am[:, None] * lhs_stride_m + offs_k[None, :] * lhs_stride_k\n )\n rhs_ptrs = rhs_ptr + (\n offs_k[:, None] * rhs_stride_k + offs_bn[None, :] * rhs_stride_n\n )\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n lhs = tl.load(lhs_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0)\n rhs = tl.load(rhs_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(lhs, rhs, accumulator)\n lhs_ptrs += BLOCK_SIZE_K * lhs_stride_k\n rhs_ptrs += BLOCK_SIZE_K * rhs_stride_k\n output = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n output_ptrs = (\n output_ptr\n + output_stride_m * offs_cm[:, None]\n + output_stride_n * offs_cn[None, :]\n )\n output_mask = (offs_cm[:, None] < m) & (offs_cn[None, :] < n)\n tl.store(output_ptrs, output, mask=output_mask)\n\n\ndef triton_matmul(lhs, rhs):\n output = torch.empty(\n (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16\n )\n\n def grid(meta):\n return (\n triton.cdiv(lhs.shape[0], meta[\"BLOCK_SIZE_M\"])\n * triton.cdiv(rhs.shape[1], meta[\"BLOCK_SIZE_N\"]),\n )\n\n triton_matmul_kernel[grid](\n lhs,\n rhs,\n output,\n lhs.shape[0],\n rhs.shape[1],\n lhs.shape[1],\n lhs.stride(0),\n lhs.stride(1),\n rhs.stride(0),\n rhs.stride(1),\n output.stride(0),\n output.stride(1),\n )\n\n return output\n", - "description_1": "Use triton language to define a kernel called 'triton_matmul_kernel' which performs matrix multiplication on input matrices. The kernel takes 15 input parameters: three pointers to input and output matrices, three integers m, n, k representing matrix dimensions, six strides for memory access, and four compile-time constants for block size and group size. A separate Python function 'triton_matmul' sets up the output matrix, grid size, and calls the kernel to execute the multiplication, returning the output matrix.", - "description_2": "Use triton language to implement matrix multiplication for matrices of size defined by m, n, k with configurable block sizes and group sizes. Utilize the provided matrix strides for accessing elements and handle output storage with a specified mask for proper boundary handling.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # Obtain the program ID for the current block\n pid = tl.program_id(axis=0)\n # Calculate the start index for the block\n block_start = pid * BLOCK_SIZE\n # Create offsets for each element in the block\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Create a mask to handle out-of-bounds accesses\n mask = offsets < n_elements\n # Load elements from x and y using the mask\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n # Perform element-wise addition\n output = x + y\n # Store the result back to the output pointer\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to call the Triton kernel\ndef add(x: torch.Tensor, y: torch.Tensor):\n # Create an output tensor with the same shape as x\n output = torch.empty_like(x)\n # Ensure all tensors are on the CUDA device\n assert x.is_cuda and y.is_cuda and output.is_cuda\n # Get the total number of elements\n n_elements = output.numel()\n # Define the grid size for the kernel launch\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n # Launch the Triton kernel\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\n# Example usage\ntorch.manual_seed(42)\nsize = 98432\nx = torch.rand(size, device='cuda:0')\ny = torch.rand(size, device='cuda:0')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel 'add_kernel' takes five parameters: x_ptr (pointer to the first input tensor), y_ptr (pointer to the second input tensor), output_ptr (pointer to the output tensor), n_elements (total number of elements to process), and BLOCK_SIZE (block size for the kernel). The function 'add' wraps this kernel, taking two torch.Tensor objects as input, ensuring they are on the CUDA device, and launching the kernel with appropriate grid size.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors, and a wrapper function to execute this kernel on CUDA tensors.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows,\n n_cols, BLOCK_SIZE: tl.constexpr):\n row_start = tl.program_id(0)\n row_step = tl.num_programs(0)\n for row_idx in tl.range(row_start, n_rows, row_step):\n row_start_ptr = input_ptr + row_idx * input_row_stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n mask = col_offsets < n_cols\n row = tl.load(input_ptrs, mask=mask, other=-float('inf'))\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=mask)\n\ndef softmax(x):\n n_rows, n_cols = x.shape\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n num_warps = 8\n num_stages = 4 if SIZE_SMEM > 200_000 else 2\n y = torch.empty_like(x)\n kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))\n if kernel is None:\n kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,\n num_stages=num_stages, num_warps=num_warps, grid=(1,))\n kernel._init_handles()\n n_regs = kernel.n_regs\n size_smem = kernel.metadata.shared\n occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)\n occupancy = min(occupancy, SIZE_SMEM // size_smem)\n num_programs = NUM_SM * occupancy\n kernels[BLOCK_SIZE] = (kernel, num_programs)\n num_programs = min(num_programs, n_rows)\n kernel[(num_programs, 1, 1)](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_rows,\n n_cols\n )\n return y\n\ntorch.manual_seed(42)\nx = torch.randn(1823, 781, device='cuda')\ny_triton = softmax(x)\n", - "description_1": "Use triton language to implement a fused softmax kernel for matrices that can fit in the GPU's SRAM. The kernel 'softmax_kernel' computes the softmax for each row of the input matrix in parallel, by subtracting the maximum value in the row, computing exponentials, summing them up, and then normalizing each element. The softmax function handles the preparation of parameters, kernel execution, and post-processing.", - "description_2": "Use triton language to write a softmax operation as a triton kernel optimized for small matrices fitting in GPU's SRAM, which reduces the max value from each row before applying exponentials and normalization.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4)\n ],\n key=['M', 'N', 'K']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n\n accumulator = tl.dot(a, b, accumulator)\n\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if ACTIVATION == 'leaky_relu':\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation\n )\n return c\n\ntorch.manual_seed(42)\na = torch.randn((512, 512), device='cuda', dtype=torch.float16)\nb = torch.randn((512, 512), device='cuda', dtype=torch.float16)\ntriton_output = matmul(a, b)\ntorch_output = torch.matmul(a, b)\n\nprint(f\"triton_output_with_fp16_inputs={triton_output}\")\nprint(f\"torch_output_with_fp16_inputs={torch_output}\")\n\nrtol = 1e-2 if if_hip_mi200() else 0\n\nif torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):\n print(\"✅ Triton and Torch match\")\nelse:\n print(\"❌ Triton and Torch differ\")\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with optional leaky ReLU activation. The kernel takes pointers to matrices A, B, and C, dimensions M, N, K, and strides for each matrix. It uses block sizes and group size for efficient computation. The kernel computes the product of A and B, optionally applies leaky ReLU, and stores the result in C.", - "description_2": "Use triton language to perform matrix multiplication with optional leaky ReLU activation, utilizing block sizes and group size for optimization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n", - "description_1": "Use triton language to implement two dropout kernels. The first kernel, _dropout, takes six parameters: pointers to input, mask, and output tensors, the number of elements, dropout probability, and block size. It applies dropout using a precomputed mask. The second kernel, _seeded_dropout, takes six parameters: pointers to input and output tensors, the number of elements, dropout probability, a random seed, and block size. It applies dropout by generating a random mask on-the-fly using the seed.", - "description_2": "Use triton language to implement dropout kernels with precomputed and on-the-fly random masks.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_cols,\n p,\n seeds,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n row_start = tl.program_id(0) * BLOCK_SIZE\n col_offsets = tl.arange(0, BLOCK_SIZE)\n row_seed = seeds + pid\n mask = col_offsets < n_cols\n x = tl.load(x_ptr + row_start + col_offsets, mask=mask)\n random = tl.rand(row_seed, row_start + col_offsets)\n x_keep = random > p\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + row_start + col_offsets, output, mask=mask)\n\ndef seeded_dropout(x, p, seeds):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (x.shape[0], 1)\n BLOCKSIZE = triton.next_power_of_2(x.shape[1])\n _seeded_dropout[grid](x, output, x.shape[1], p, seeds, BLOCK_SIZE=BLOCKSIZE)\n return output\n\nx = torch.randn(size=(3, 5)).cuda()\nseeds_1 = torch.rand(size=(x.shape[0], )).cuda()\nseeds_2 = torch.rand(size=(x.shape[0], )).cuda()\noutput = seeded_dropout(x, p=0.5, seeds=seeds_1)\noutput2 = seeded_dropout(x, p=0.5, seeds=seeds_1)\noutput3 = seeded_dropout(x, p=0.5, seeds=seeds_2)\n", - "description_1": "Use triton language to implement a seeded dropout operation on a 2D tensor. The kernel '_seeded_dropout' is decorated with @triton.jit and takes 6 parameters: 'x_ptr' (pointer to the input tensor), 'output_ptr' (pointer to the output tensor), 'n_cols' (number of columns in the input tensor), 'p' (dropout probability), 'seeds' (seed for random number generation), and 'BLOCK_SIZE' (size of the block for computation). The kernel computes a dropout mask using pseudorandom numbers and applies it to the input tensor. The function 'seeded_dropout' wraps this kernel call, preparing the input data and configuring the grid size for the Triton kernel execution.", - "description_2": "Use triton language to implement a kernel for applying seeded dropout on input tensors, utilizing pseudorandom number generation and block-level parallelism for efficient execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X,\n Y,\n W,\n B,\n Mean,\n Rstd,\n stride,\n N,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.0)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n tl.store(Y + cols, y, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n DX,\n DY,\n DW,\n DB,\n X,\n W,\n Mean,\n Rstd,\n Lock,\n stride,\n N,\n GROUP_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr\n):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n X += row * stride\n DY += row * stride\n DX += row * stride\n\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n DB = DB + lock_id * N + cols\n\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n\n xhat = (x - mean) * rstd\n wdy = w * dy\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n\n tl.store(DX + cols, dx, mask=mask)\n\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = dy.to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n\n tl.atomic_xchg(Lock, 0)\n\n@triton.jit\ndef _layer_norm_bwd_dwbd(\n DW,\n DB,\n FINAL_DW,\n FINAL_DB,\n M,\n N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n tl.store(FINAL_DB + cols, sum_db, mask=cols < N)\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n y = torch.empty_like(x)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device=x.device)\n rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm does not support feature din >= 64KB\")\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n\n _layer_norm_fwd_fused[(M, )](\n x_arg, y, weight, bias, mean, rstd,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1\n )\n\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n\n N = w.shape[0]\n GROUP_SIZE_M = 64\n\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)\n _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)\n _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)\n dw = torch.empty((N, ), dtype=w.dtype, device=w.device)\n db = torch.empty((N, ), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M, )](\n dx, dy, _dw, _db, x, w, m, v, locks,\n x_arg.stride(0), N,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n GROUP_SIZE_M=GROUP_SIZE_M,\n num_warps=ctx.num_warps\n )\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n _layer_norm_bwd_dwbd[grid](\n _dw, _db, dw, db, min(GROUP_SIZE_M, M), N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128, num_ctas=1\n )\n return dx, None, dw, db, None\n\nlayer_norm = LayerNorm.apply\n", - "description_1": "Use triton language to implement a layer normalization operation with three kernels: one for the forward pass (_layer_norm_fwd_fused) and two for the backward pass (_layer_norm_bwd_dx_fused and _layer_norm_bwd_dwbd). The forward kernel computes the mean and variance of the input and normalizes it, while the backward kernels compute the gradients with respect to the input, weights, and biases. The LayerNorm class wraps these kernels for use in PyTorch's autograd system.", - "description_2": "Use triton language to create a layer normalization operation with forward and backward passes, utilizing three kernels for computation and PyTorch for integration.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, #\n K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #\n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #\n N_CTX: tl.constexpr, fp8_v: tl.constexpr):\n # range of values handled by this stage\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n # causal = False\n else:\n lo, hi = 0, N_CTX\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n # loop over k, v and update accumulator\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(K_block_ptr)\n qk = tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n # -- update output accumulator --\n acc = acc * alpha[:, None]\n # update acc\n v = tl.load(V_block_ptr)\n if fp8_v:\n p = p.to(tl.float8e5)\n else:\n p = p.to(tl.float16)\n acc = tl.dot(p, v, acc)\n # update m_i and l_i\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n\n@triton.jit\ndef _attn_fwd(Q, K, V, sm_scale, M, Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, N_CTX, #\n HEAD_DIM: tl.constexpr, #\n BLOCK_M: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n STAGE: tl.constexpr #\n ):\n tl.static_assert(BLOCK_N <= HEAD_DIM)\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n\n # block pointers\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, HEAD_DIM),\n order=(1, 0),\n )\n v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, HEAD_DIM),\n order=v_order,\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(HEAD_DIM, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(HEAD_DIM, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, HEAD_DIM),\n order=(1, 0),\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)\n # load scales\n qk_scale = sm_scale\n qk_scale *= 1.44269504 # 1/log(2)\n # load q: it will stay in SRAM throughout\n q = tl.load(Q_block_ptr)\n # stage 1: off-band\n # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE\n # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M, HEAD_DIM, BLOCK_N, #\n 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #\n )\n # stage 2: on-band\n if STAGE & 2:\n # barrier makes it easier for compielr to schedule the\n # two loops independently\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M, HEAD_DIM, BLOCK_N, #\n 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #\n )\n # epilogue\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n # shape constraints\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n # when v is in float8_e5m2 it is transposed.\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n assert HEAD_DIM_K in {16, 32, 64, 128, 256}\n o = torch.empty_like(q)\n stage = 3 if causal else 1\n extra_kern_args = {}\n # Tuning for AMD target\n if is_hip():\n waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2\n extra_kern_args = {\"waves_per_eu\": waves_per_eu, \"allow_flush_denorm\": True}\n\n grid = lambda args: (triton.cdiv(q.shape[2], args[\"BLOCK_M\"]), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, sm_scale, M, o, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], #\n N_CTX=q.shape[2], #\n HEAD_DIM=HEAD_DIM_K, #\n STAGE=stage, #\n **extra_kern_args)\n\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.HEAD_DIM = HEAD_DIM_K\n ctx.causal = causal\n return o\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention operation comprising multiple kernels: _attn_fwd_inner, _attn_fwd, and a PyTorch autograd function _attention. _attn_fwd_inner calculates part of the attention forward pass, updating accumulator values based on input matrices. _attn_fwd manages memory pointers and orchestrates calls to _attn_fwd_inner, computing the full forward pass. The _attention class integrates these kernels into PyTorch, handling input/output tensors and context data for gradient calculations. Key parameters include block sizes, strides, stage identifiers, and causality flags.", - "description_2": "Use triton language to implement a fused attention operation with kernels for computing forward pass and integration with PyTorch autograd system, handling tensors and context data.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.language.extra import libdevice\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n # Calculate program id and offsets\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n \n # Create mask for valid elements\n mask = offsets < n_elements\n \n # Load input data\n x = tl.load(x_ptr + offsets, mask=mask)\n \n # Apply asin function using libdevice\n x = libdevice.asin(x)\n \n # Store result in output\n tl.store(y_ptr + offsets, x, mask=mask)\n\ntorch.manual_seed(42)\nsize = 98432\nx = torch.rand(size, device=\"cuda\")\n\noutput_triton = torch.zeros(size, device=\"cuda\")\noutput_torch = torch.asin(x)\n\nassert x.is_cuda and output_triton.is_cuda\n\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n\n# Call the Triton kernel\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\n\nprint(output_torch)\nprint(output_triton)\n\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n\nextern_libs = {\"libdevice\": \"third_party/libdevice.10.bc\"}\n\noutput_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs=extern_libs)\n\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement a kernel function 'asin_kernel' that computes the element-wise arcsine of input tensor elements in CUDA memory. The function takes pointers to input and output tensors, the number of elements to process, and a block size. It operates on blocks of data, loading input values, applying the asin operation using libdevice, and storing the results.", - "description_2": "Use triton language to implement a kernel that computes the element-wise arcsine of a tensor in GPU memory.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef grouped_matmul_kernel(\n group_a_ptrs,\n group_b_ptrs,\n group_c_ptrs,\n group_gemm_sizes,\n g_lds,\n group_size,\n NUM_SM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n tile_idx = tl.program_id(0)\n last_problem_end = 0\n for g in range(group_size):\n gm = tl.load(group_gemm_sizes + g * 3)\n gn = tl.load(group_gemm_sizes + g * 3 + 1)\n gk = tl.load(group_gemm_sizes + g * 3 + 2)\n num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)\n num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)\n num_tiles = num_m_tiles * num_n_tiles\n while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):\n k = gk\n lda = tl.load(g_lds + g * 3)\n ldb = tl.load(g_lds + g * 3 + 1)\n ldc = tl.load(g_lds + g * 3 + 2)\n a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))\n b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))\n c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))\n\n tile_idx_in_gemm = tile_idx - last_problem_end\n tile_m_idx = tile_idx_in_gemm // num_n_tiles\n tile_n_idx = tile_idx_in_gemm % num_n_tiles\n\n offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]\n b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):\n tl.multiple_of(a_ptrs, [16, 16])\n tl.multiple_of(b_ptrs, [16, 16])\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += BLOCK_SIZE_K * ldb\n c = accumulator.to(tl.float16)\n\n offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]\n\n tl.store(c_ptrs, c)\n\n tile_idx += NUM_SM\n\n last_problem_end = last_problem_end + num_tiles\n\ndef group_gemm_fn(group_A, group_B):\n device = torch.device(\"cuda\")\n assert len(group_A) == len(group_B)\n group_size = len(group_A)\n\n A_addrs = []\n B_addrs = []\n C_addrs = []\n\n g_sizes = []\n g_lds = []\n group_C = []\n\n for i in range(group_size):\n A = group_A[i]\n B = group_B[i]\n assert A.shape[1] == B.shape[0]\n M, K = A.shape\n K, N = B.shape\n C = torch.empty((M, N), device=device, dtype=A.dtype)\n group_C.append(C)\n A_addrs.append(A.data_ptr())\n B_addrs.append(B.data_ptr())\n C_addrs.append(C.data_ptr())\n\n g_sizes += [M, N, K]\n g_lds += [A.stride(0), B.stride(0), C.stride(0)]\n\n d_a_ptrs = torch.tensor(A_addrs, device=device)\n d_b_ptrs = torch.tensor(B_addrs, device=device)\n d_c_ptrs = torch.tensor(C_addrs, device=device)\n\n d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)\n d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)\n\n grid = lambda META: (META[\"NUM_SM\"], )\n grouped_matmul_kernel[grid](\n d_a_ptrs,\n d_b_ptrs,\n d_c_ptrs,\n d_g_sizes,\n d_g_lds,\n group_size\n )\n\n return group_C\n", - "description_1": "Use triton language to implement a grouped matrix multiplication kernel that processes multiple matrix multiplications in parallel. The kernel takes pointers to groups of matrices A, B, and C, along with their sizes and leading dimensions. It computes the product of each pair of matrices A and B, storing the result in C. The kernel is optimized for different block sizes and numbers of streaming multiprocessors (SMs).", - "description_2": "Use triton language to create a function that prepares and launches the grouped matrix multiplication kernel. This function takes lists of matrices A and B, checks their compatibility, and prepares device pointers and size information. It then calls the kernel to perform the matrix multiplications and returns the resulting matrices.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\nclass FusedSecondOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n ):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 9), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # apply the kernel\n joint_second_order_fwd[num_blocks,](\n coords, output_tensor, block_size, coord_numel, output_numel\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # call backward kernel\n joint_second_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n )\n return coord_grad_output\n\n@triton.jit\ndef joint_second_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n):\n \"\"\"\n This Triton implementation includes l=0, 1, 2 within the\n same kernel, as it would be a common operation.\n \"\"\"\n # these are hardcoded because they are predetermined;\n coord_stride = 3\n # work out the row offsets\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n # as the name suggests, this is effectively every node/atom\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST_00 = 3.87298334620742\n CONST_01 = 2.23606797749979\n CONST_02 = -1.11803398874989\n CONST_03 = 1.93649167310371\n CONST_04 = tl.sqrt(3.0)\n Y10 = CONST_04 * x\n Y11 = CONST_04 * y\n Y12 = CONST_04 * z\n Y20 = CONST_00 * x * z\n Y21 = CONST_00 * x * y\n Y23 = CONST_00 * y * z # looks jarring but just helping the compiler ;)\n Y22 = CONST_02 * x * x + CONST_01 * y * y + CONST_02 * z * z\n Y24 = -CONST_03 * x * x + CONST_03 * z * z\n output_stride = 9 # sum of [2l + 1] over l=0, 1, 2\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = output_striding + (block_size * output_stride * block_id)\n # first column are all zeros, per zeroth order\n tl.store(output_ptr + output_row_offset, 1.0, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y10,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y11,\n mask=output_row_offset + 2 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 3,\n Y12,\n mask=output_row_offset + 3 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 4,\n Y20,\n mask=output_row_offset + 4 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 5,\n Y21,\n mask=output_row_offset + 5 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 6,\n Y22,\n mask=output_row_offset + 6 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 7,\n Y23,\n mask=output_row_offset + 6 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 8,\n Y24,\n mask=output_row_offset + 7 < output_numel,\n )\n\n@triton.jit\ndef joint_second_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n):\n # work out the row offsets\n block_id = tl.program_id(0)\n # these are hardcoded because they are predetermined;\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n # as the name suggests, this is effectively every node/atom\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_stride = 9 # [2l + 1]\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = output_striding + (block_size * output_stride * block_id)\n CONST_00 = 3.87298334620742\n CONST_01 = 2.23606797749979\n CONST_02 = 4.47213595499958\n CONST_03 = tl.sqrt(3.0)\n # load in gradients w.r.t. spherical harmonic projections.\n # gradient of l = 0 goes to zero\n g_Y10 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_Y11 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n g_Y12 = tl.load(\n sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel\n )\n g_Y20 = tl.load(\n sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel\n )\n g_Y21 = tl.load(\n sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel\n )\n g_Y22 = tl.load(\n sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel\n )\n g_Y23 = tl.load(\n sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel\n )\n g_Y24 = tl.load(\n sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel\n )\n g_x = (\n CONST_00 * g_Y20 * z\n + CONST_00 * g_Y21 * y\n - CONST_01 * g_Y22 * x\n - CONST_00 * g_Y24 * x\n + CONST_03 * g_Y10\n )\n g_y = (\n CONST_00 * g_Y21 * x\n + CONST_02 * g_Y22 * y\n + CONST_00 * g_Y23 * z\n + CONST_03 * g_Y11\n )\n g_z = (\n CONST_00 * g_Y20 * x\n - CONST_01 * g_Y22 * z\n + CONST_00 * g_Y23 * y\n + CONST_00 * g_Y24 * z\n + CONST_03 * g_Y12\n )\n # write out gradients\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n", - "description_1": "Use triton language to implement two kernels: joint_second_order_fwd and joint_second_order_bwd. The joint_second_order_fwd kernel computes the second order spherical harmonics for a given set of coordinates. It takes 5 parameters: coord_ptr (input coordinates), output_ptr (output tensor), block_size (size of each block), coord_numel (number of elements in the input coordinates), and output_numel (number of elements in the output tensor). The joint_second_order_bwd kernel computes the gradient of the input coordinates with respect to the spherical harmonics. It takes 6 parameters: coord_ptr (input coordinates), coord_grad_ptr (gradient of the input coordinates), sph_grad_ptr (gradient of the spherical harmonics), block_size (size of each block), coord_numel (number of elements in the input coordinates), and output_numel (number of elements in the output tensor).", - "description_2": "Use triton language to implement kernels for computing second order spherical harmonics and their gradients. The forward kernel calculates the harmonics for input coordinates, while the backward kernel computes the gradients of these coordinates.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\n@triton.jit\ndef zeroth_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # work out the row offsets\n block_id = tl.program_id(0)\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, 1.0, mask=output_row_offset < output_numel)\n\n@triton.jit\ndef zeroth_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # work out the row offsets\n block_id = tl.program_id(0) # noqa: F841\n # do nothing in this function because no gradient contributions!\n\nclass ZerothOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.ones(\n (*coords.shape[:-1], 1), dtype=coords.dtype, device=coords.device\n )\n ctx.save_for_backward(coords)\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n zeroth_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n return output_tensor\n\n @staticmethod\n def backward(\n ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64, col_offset: int = 0\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # call backward kernel\n zeroth_order_bwd[num_blocks,](\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n", - "description_1": "Use triton language to implement two kernels: 'zeroth_order_fwd' and 'zeroth_order_bwd'. The 'zeroth_order_fwd' kernel takes 7 parameters: coord_ptr (tensor), output_ptr (tensor), block_size (constexpr), coord_numel (constexpr), output_numel (constexpr), col_offset (constexpr), and output_stride (constexpr). It calculates row offsets and stores a value of 1.0 in the output tensor. The 'zeroth_order_bwd' kernel takes 8 parameters: coord_ptr (tensor), coord_grad_ptr (tensor), sph_grad_ptr (tensor), block_size (constexpr), coord_numel (constexpr), output_numel (constexpr), col_offset (constexpr), and output_stride (constexpr). It calculates row offsets but does not perform any operations as there are no gradient contributions. The 'ZerothOrderSphericalHarmonic' class uses these kernels in its 'forward' and 'backward' methods to compute the zeroth order spherical harmonic and its gradient.", - "description_2": "Use triton language to create a forward kernel that initializes an output tensor with ones based on calculated row offsets, and a backward kernel that does not perform any operations. These kernels are used in a PyTorch autograd function to compute zeroth order spherical harmonics.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\nclass FirstOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 3), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # apply the kernel\n first_order_fwd[num_blocks,](\n coords, output_tensor, block_size, coord_numel, output_numel, col_offset\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64, col_offset: int = 0\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # call backward kernel\n first_order_bwd[num_blocks,](\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n )\n return coord_grad_output\n\n@triton.jit\ndef first_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # these are hardcoded because they are predetermined;\n coord_stride = 3\n # work out the row offsets\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n # as the name suggests, this is effectively every node/atom\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST_00 = tl.sqrt(3.0)\n Y10 = CONST_00 * x\n Y11 = CONST_00 * y\n Y12 = CONST_00 * z\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y10, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y11,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y12,\n mask=output_row_offset + 2 < output_numel,\n )\n\n@triton.jit\ndef first_order_bwd(\n coord_ptr: tl.tensor, # noqa: F403\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # work out the row offsets\n block_id = tl.program_id(0)\n # these are hardcoded because they are predetermined;\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n # as the name suggests, this is effectively every node/atom\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n # load in gradients w.r.t. spherical harmonic projections\n g_Y10 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_Y11 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_Y12 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n # read in current gradients\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST_00 = tl.sqrt(3.0)\n g_x += CONST_00 * g_Y10\n g_y += CONST_00 * g_Y11\n g_z += CONST_00 * g_Y12\n # write out gradients\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n", - "description_1": "Use triton language to implement a forward and backward kernel for computing first order spherical harmonics. The forward kernel 'first_order_fwd' takes 7 parameters: coord_ptr (input tensor), output_ptr (output tensor), block_size (size of each block), coord_numel (number of elements in input), output_numel (number of elements in output), col_offset (column offset), and output_stride (stride for output). It computes the spherical harmonics for each block of input coordinates. The backward kernel 'first_order_bwd' takes 8 parameters: coord_ptr (input tensor), coord_grad_ptr (gradient tensor for input), sph_grad_ptr (gradient tensor for spherical harmonics), block_size, coord_numel, output_numel, col_offset, and output_stride. It computes the gradient of the input coordinates based on the gradient of the spherical harmonics.", - "description_2": "Use triton language to create kernels for computing and backpropagating first order spherical harmonics, handling input and output tensors with specified block sizes and offsets.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\nclass TenthOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 21), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # apply the kernel\n tenth_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor,\n block_size: int = 64,\n col_offset: int = 0,\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # call backward kernel\n tenth_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n\n@triton.jit\ndef tenth_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n # -------------------- kernel implementations\n Y00 = (\n 27.2034486491732 * x**5 * x * z\n + 27.2034486491732 * z**5 * z * x\n + 685.526905959165 * x**5 * z**5\n - 326.441383790078 * x**3 * z**3 * z\n - 326.441383790078 * x**3 * z * z**3\n )\n # This section is shortened for brevity. The actual kernel contains many more constants and calculations.\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n\n@triton.jit\ndef tenth_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n # Load gradients\n # -------------------- kernel implementations\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n g_0\n * (\n 225.548647486108 * x**4 * x * z\n - 979.324151370235 * x**2 * z**3 * z\n + 3427.63452979582 * x**3 * z**5\n + 3862.96644634988 * x**4 * z**4\n - 27.2034486491732 * z**5\n )\n )\n # Write out gradients\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n", - "description_1": "Use triton language to implement forward and backward kernels for a spherical harmonic transformation. The forward kernel takes 7 parameters: coord_ptr (pointer to input coordinates), output_ptr (pointer to output), block_size (size of the block), coord_numel (total number of coordinates), output_numel (total number of output elements), col_offset (offset for columns), and output_stride (stride for output storage). It calculates transformations for spherical harmonics up to the tenth order using tensor computations. The backward kernel similarly takes 8 parameters: coord_ptr, coord_grad_ptr (pointer for coordinate gradients), sph_grad_ptr (pointer for spherical harmonic gradients), block_size, coord_numel, output_numel, col_offset, and output_stride, to compute gradient updates for input coordinates.", - "description_2": "Use triton language to create and apply spherical harmonics transformation kernels for both forward and backward passes with complex polynomial calculations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\n@triton.jit\ndef second_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # these are hardcoded because they are predetermined;\n coord_stride = 3\n # work out the row offsets\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n # as the name suggests, this is effectively every node/atom\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST_00 = 3.87298334620742\n CONST_01 = 2.23606797749979\n CONST_02 = -1.11803398874989\n CONST_03 = 1.93649167310371\n Y20 = CONST_00 * x * z\n Y21 = CONST_00 * x * y\n Y23 = CONST_00 * y * z # looks jarring but just helping the compiler ;)\n Y22 = CONST_02 * x * x + CONST_01 * y * y + CONST_02 * z * z\n Y24 = -CONST_03 * x * x + CONST_03 * z * z\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y20, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y21,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y22,\n mask=output_row_offset + 2 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 3,\n Y23,\n mask=output_row_offset + 3 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 4,\n Y24,\n mask=output_row_offset + 4 < output_numel,\n )\n\n@triton.jit\ndef second_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # work out the row offsets\n block_id = tl.program_id(0)\n # these are hardcoded because they are predetermined;\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n # as the name suggests, this is effectively every node/atom\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n CONST_00 = 3.87298334620742\n CONST_01 = 2.23606797749979\n CONST_02 = 4.47213595499958\n # load in gradients w.r.t. spherical harmonic projections\n g_Y20 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_Y21 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_Y22 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n g_Y23 = tl.load(\n sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel\n )\n g_Y24 = tl.load(\n sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel\n )\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n CONST_00 * g_Y20 * z\n + CONST_00 * g_Y21 * y\n - CONST_01 * g_Y22 * x\n - CONST_00 * g_Y24 * x\n )\n g_y += CONST_00 * g_Y21 * x + CONST_02 * g_Y22 * y + CONST_00 * g_Y23 * z\n g_z += (\n CONST_00 * g_Y20 * x\n - CONST_01 * g_Y22 * z\n + CONST_00 * g_Y23 * y\n + CONST_00 * g_Y24 * z\n )\n # write out gradients\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n\nclass SecondOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n num_projections = 5 # 2l + 1\n # allocate a tensor if one isn't given\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], num_projections),\n dtype=coords.dtype,\n device=coords.device,\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # apply the kernel\n second_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor,\n block_size: int = 64,\n col_offset: int = 0,\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # call backward kernel\n second_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n", - "description_1": "Use triton language to implement two kernels: 'second_order_fwd' and 'second_order_bwd'. The 'second_order_fwd' kernel computes the second order spherical harmonics for a given set of coordinates. It takes 7 parameters: coord_ptr (input coordinates), output_ptr (output tensor), block_size (size of each block), coord_numel (number of elements in coordinates), output_numel (number of elements in output), col_offset (column offset), and output_stride (stride of the output tensor). The 'second_order_bwd' kernel computes the gradient of the input coordinates with respect to the spherical harmonics. It takes 8 parameters: coord_ptr (input coordinates), coord_grad_ptr (gradient of coordinates), sph_grad_ptr (gradient of spherical harmonics), block_size, coord_numel, output_numel, col_offset, and output_stride.", - "description_2": "Use triton language to create kernels for computing second order spherical harmonics and their gradients. The forward kernel calculates the harmonics based on input coordinates, while the backward kernel computes the gradients of these coordinates.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\nclass ThirdOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 7), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n third_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor,\n coord_grad_output: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n if not isinstance(coord_grad_output, torch.Tensor):\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n third_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n\n@triton.jit\ndef third_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST000 = 2.64575131106459\n CONST002 = 5.12347538297980\n CONST004 = 6.48074069840786\n CONST005 = 10.2469507659596\n CONST006 = -2.09165006633519\n CONST007 = -1\n CONST008 = -6.27495019900557\n CONST009 = -3.96862696659689\n CONST010 = -1.62018517460197\n VAR07 = x * x * x\n VAR08 = x * x\n VAR16 = y * y * y\n VAR17 = y * y\n VAR25 = z * z * z\n VAR26 = z * z\n Y00 = CONST006 * VAR07 - CONST008 * VAR26 * x\n Y01 = CONST005 * x * y * z\n Y02 = CONST010 * VAR07 + x * (CONST004 * VAR17 + CONST010 * VAR26)\n Y03 = CONST000 * VAR16 + CONST009 * VAR08 * y + CONST009 * VAR26 * y\n Y04 = CONST010 * VAR25 + z * (CONST004 * VAR17 + CONST010 * VAR08)\n Y05 = CONST002 * y * (CONST007 * VAR08 + VAR26)\n Y06 = -CONST006 * VAR25 + CONST008 * VAR08 * z\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y01,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y02,\n mask=output_row_offset + 2 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 3,\n Y03,\n mask=output_row_offset + 3 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 4,\n Y04,\n mask=output_row_offset + 4 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 5,\n Y05,\n mask=output_row_offset + 5 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 6,\n Y06,\n mask=output_row_offset + 6 < output_numel,\n )\n\n@triton.jit\ndef third_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_1 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_2 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n g_3 = tl.load(\n sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel\n )\n g_4 = tl.load(\n sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel\n )\n g_5 = tl.load(\n sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel\n )\n g_6 = tl.load(\n sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel\n )\n CONST002 = 6.48074069840786\n CONST005 = 12.9614813968157\n CONST007 = -3.96862696659689\n CONST008 = -12.5499003980111\n CONST009 = -10.2469507659596\n CONST010 = -7.93725393319377\n CONST011 = -6.27495019900557\n CONST012 = -5.12347538297980\n CONST013 = -4.86055552380590\n CONST014 = -3.24037034920393\n CONST015 = -1.62018517460197\n VAR08 = x * x\n VAR17 = y * y\n VAR26 = z * z\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n CONST008 * g_6 * x * z\n - CONST009 * g_1 * y * z\n + CONST009 * g_5 * x * y\n + CONST010 * g_3 * x * y\n + CONST014 * g_4 * x * z\n + g_0 * (CONST011 * VAR08 - CONST011 * VAR26)\n + g_2 * (CONST002 * VAR17 + CONST013 * VAR08 + CONST015 * VAR26)\n )\n g_y += (\n CONST005 * g_2 * x * y\n + CONST005 * g_4 * y * z\n - CONST009 * g_1 * x * z\n + g_3 * (CONST007 * VAR08 + CONST007 * VAR26 - CONST010 * VAR17)\n + g_5 * (CONST012 * VAR08 - CONST012 * VAR26)\n )\n g_z += (\n -CONST008 * g_0 * x * z\n - CONST009 * g_1 * x * y\n - CONST009 * g_5 * y * z\n + CONST010 * g_3 * y * z\n + CONST014 * g_2 * x * z\n + g_4 * (CONST002 * VAR17 + CONST013 * VAR26 + CONST015 * VAR08)\n + g_6 * (CONST011 * VAR08 - CONST011 * VAR26)\n )\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n", - "description_1": "Use triton language to implement two kernels: 'third_order_fwd' and 'third_order_bwd'. The 'third_order_fwd' kernel computes the third-order spherical harmonics for a given set of coordinates. It takes 7 parameters: coord_ptr (input coordinates), output_ptr (output tensor), block_size (size of each block), coord_numel (number of elements in coordinates), output_numel (number of elements in output), col_offset (column offset), and output_stride (stride of the output tensor). The 'third_order_bwd' kernel computes the gradient of the spherical harmonics with respect to the input coordinates. It takes the same 7 parameters as 'third_order_fwd', with the addition of coord_grad_ptr (gradient of coordinates) and sph_grad_ptr (gradient of spherical harmonics).", - "description_2": "Use triton language to create kernels for computing third-order spherical harmonics and their gradients, with parameters for input/output tensors, block size, and strides.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n@triton.jit\ndef fourth_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST000 = 1.12500000000000\n CONST001 = 2.25000000000000\n CONST002 = 3.00000000000000\n CONST005 = 2.21852991866236\n CONST007 = 9.48683298050514\n CONST010 = 20.1246117974981\n CONST011 = -18.8248505970167\n CONST012 = -13.3111795119741\n CONST013 = -10.0623058987491\n CONST014 = -9.00000000000000\n CONST015 = -8.87411967464942\n CONST016 = -7.11512473537885\n CONST017 = -6.27495019900557\n CONST018 = -3.35410196624968\n CONST019 = -1.67705098312484\n VAR06 = x * x * x * x\n VAR07 = x * x * x\n VAR08 = x * x\n VAR15 = y * y * y * y\n VAR16 = y * y * y\n VAR17 = y * y\n VAR24 = z * z * z * z\n VAR25 = z * z * z\n VAR26 = z * z\n Y00 = CONST015 * VAR07 * z - CONST015 * VAR25 * x\n Y01 = y * (-CONST011 * VAR26 * x + CONST017 * VAR07)\n Y02 = CONST018 * VAR07 * z + x * (CONST010 * VAR17 * z + CONST018 * VAR25)\n Y03 = CONST016 * VAR07 * y + x * (CONST007 * VAR16 + CONST016 * VAR26 * y)\n Y04 = (\n CONST000 * VAR06\n + CONST000 * VAR24\n + CONST002 * VAR15\n + CONST014 * VAR17 * VAR26\n + VAR08 * (CONST001 * VAR26 + CONST014 * VAR17)\n )\n Y05 = CONST016 * VAR25 * y + z * (CONST007 * VAR16 + CONST016 * VAR08 * y)\n Y06 = (\n -CONST019 * VAR06\n + CONST019 * VAR24\n + VAR17 * (CONST013 * VAR08 - CONST013 * VAR26)\n )\n Y07 = y * (CONST011 * VAR08 * z - CONST017 * VAR25)\n Y08 = CONST005 * VAR06 + CONST005 * VAR24 + CONST012 * VAR08 * VAR26\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y01,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y02,\n mask=output_row_offset + 2 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 3,\n Y03,\n mask=output_row_offset + 3 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 4,\n Y04,\n mask=output_row_offset + 4 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 5,\n Y05,\n mask=output_row_offset + 5 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 6,\n Y06,\n mask=output_row_offset + 6 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 7,\n Y07,\n mask=output_row_offset + 7 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 8,\n Y08,\n mask=output_row_offset + 8 < output_numel,\n )\n\n@triton.jit\ndef fourth_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_1 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_2 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n g_3 = tl.load(\n sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel\n )\n g_4 = tl.load(\n sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel\n )\n g_5 = tl.load(\n sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel\n )\n g_6 = tl.load(\n sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel\n )\n g_7 = tl.load(\n sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel\n )\n g_8 = tl.load(\n sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel\n )\n CONST000 = 2.00000000000000\n CONST001 = 4.50000000000000\n CONST002 = 2.25000000000000\n CONST006 = 9.48683298050514\n CONST008 = 12.0000000000000\n CONST012 = 28.4604989415154\n CONST014 = 40.2492235949962\n CONST015 = -37.6497011940334\n CONST016 = -6.70820393249937\n CONST017 = -26.6223590239483\n CONST018 = -21.3453742061366\n CONST019 = -20.1246117974981\n CONST020 = -18.8248505970167\n CONST021 = -18.0000000000000\n CONST022 = -14.2302494707577\n CONST023 = -10.0623058987491\n CONST024 = -9.00000000000000\n CONST025 = -8.87411967464942\n CONST026 = -7.11512473537885\n CONST027 = -6.27495019900557\n CONST028 = -3.35410196624968\n VAR07 = x * x * x\n VAR08 = x * x\n VAR16 = y * y * y\n VAR17 = y * y\n VAR25 = z * z * z\n VAR26 = z * z\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n CONST015 * g_7 * x * y * z\n + CONST022 * g_5 * x * y * z\n + g_0 * (CONST017 * VAR08 * z - CONST025 * VAR25)\n + g_1 * y * (CONST020 * VAR08 - CONST020 * VAR26)\n + g_2 * (-CONST019 * VAR17 * z + CONST023 * VAR08 * z + CONST028 * VAR25)\n + g_3 * (CONST006 * VAR16 + CONST018 * VAR08 * y + CONST026 * VAR26 * y)\n + g_4\n * (CONST000 * x * (CONST002 * VAR26 + CONST024 * VAR17) + CONST001 * VAR07)\n + g_6 * (-CONST016 * VAR07 + CONST019 * VAR17 * x)\n + g_8 * (CONST017 * VAR26 * x - CONST025 * VAR07)\n )\n g_y += (\n CONST000 * g_6 * y * (CONST023 * VAR08 - CONST023 * VAR26)\n + CONST014 * g_2 * x * y * z\n + g_1 * (-CONST020 * VAR26 * x + CONST027 * VAR07)\n + g_3 * (CONST026 * VAR07 + x * (CONST012 * VAR17 + CONST026 * VAR26))\n + g_4 * (CONST008 * VAR16 + CONST021 * VAR08 * y + CONST021 * VAR26 * y)\n + g_5 * (CONST026 * VAR25 + z * (CONST012 * VAR17 + CONST026 * VAR08))\n + g_7 * (CONST020 * VAR08 * z - CONST027 * VAR25)\n )\n g_z += (\n -CONST015 * g_1 * x * y * z\n + CONST022 * g_3 * x * y * z\n + g_0 * (-CONST017 * VAR26 * x + CONST025 * VAR07)\n + g_2 * (CONST028 * VAR07 + x * (-CONST019 * VAR17 + CONST023 * VAR26))\n + g_4 * (CONST001 * VAR08 * z + CONST001 * VAR25 + CONST021 * VAR17 * z)\n + g_5 * (CONST006 * VAR16 + CONST018 * VAR26 * y + CONST026 * VAR08 * y)\n + g_6 * (CONST016 * VAR25 - CONST019 * VAR17 * z)\n + g_7 * y * (CONST020 * VAR08 - CONST020 * VAR26)\n + g_8 * (CONST017 * VAR08 * z - CONST025 * VAR25)\n )\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n", - "description_1": "Use triton language to implement two kernels: (1) 'fourth_order_fwd' computes fourth-order spherical harmonics projections from input coordinates. It takes seven parameters: coord_ptr (input coordinates tensor), output_ptr (output tensor for spherical harmonics), block_size (size of each block for parallel computation), coord_numel (number of elements in the coordinates tensor), output_numel (number of elements in the output tensor), col_offset (column offset for output storing), and output_stride (stride of the output tensor). (2) 'fourth_order_bwd' computes the gradient of the coordinates with respect to the spherical harmonics projections. It takes the same parameters as 'fourth_order_fwd', plus coord_grad_ptr (gradient of coordinates tensor) and sph_grad_ptr (gradient of the spherical harmonics projections tensor).", - "description_2": "Use triton language to implement forward and backward kernels for computing fourth-order spherical harmonics and their gradients, utilizing parallel computation with block sizes, strides, and offsets.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\n@triton.jit\ndef sixth_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST002 = 3.26558761940328\n CONST003 = 3.26558761940328\n CONST004 = 6.53117523880657\n CONST006 = 8.38944649544891\n CONST007 = 9.79676285820985\n CONST008 = 10.3266947761614\n CONST009 = 3.60555127546399\n CONST010 = -1.78863600265677\n CONST011 = 14.5309475774982\n CONST012 = 8.94318001328386\n CONST013 = 16.5227116418583\n CONST014 = 16.5227116418583\n CONST015 = 17.8863600265677\n CONST017 = 20.6533895523229\n CONST018 = 20.2812259244849\n CONST019 = -107.318160159406\n CONST020 = 17.8863600265677\n CONST022 = 29.3902885746295\n CONST024 = 40.5624518489699\n CONST025 = 41.9472324772445\n CONST026 = -1.63279380970164\n CONST027 = -83.8944649544891\n CONST028 = -78.3741028656788\n CONST030 = -71.5454401062709\n CONST032 = -52.2494019104525\n CONST033 = -52.2494019104525\n CONST035 = -48.4364919249939\n CONST036 = -41.3067791046458\n CONST037 = -36.3273689437454\n CONST038 = -29.3902885746295\n CONST039 = -27.0416345659799\n CONST040 = -26.1247009552263\n CONST041 = -26.1247009552263\n CONST042 = -19.5935257164197\n CONST043 = -2.42182459624970\n CONST044 = -9.79676285820985\n CONST045 = -7.15454401062709\n CONST046 = -3.38020432074749\n CONST047 = -1.12673477358250\n VAR07 = x * x * x\n VAR08 = x * x\n VAR04 = VAR07 * VAR07\n VAR05 = VAR07 * VAR08\n VAR06 = VAR08 * VAR08\n VAR16 = y * y * y\n VAR17 = y * y\n VAR13 = VAR16 * VAR16\n VAR14 = VAR16 * VAR17\n VAR15 = VAR17 * VAR17\n VAR25 = z * z * z\n VAR26 = z * z\n VAR22 = VAR25 * VAR25\n VAR23 = VAR25 * VAR26\n VAR24 = VAR26 * VAR26\n Y00 = CONST011 * VAR05 * z + CONST011 * VAR23 * x + CONST035 * VAR07 * VAR25\n Y01 = y * (CONST006 * VAR05 + CONST025 * VAR24 * x + CONST027 * VAR07 * VAR26)\n Y02 = (\n -CONST045 * VAR05 * z\n + CONST045 * VAR23 * x\n + VAR17 * (CONST030 * VAR07 * z - CONST030 * VAR25 * x)\n )\n Y03 = VAR16 * (-CONST028 * VAR26 * x + CONST040 * VAR07) + y * (\n CONST007 * VAR05 + CONST038 * VAR24 * x + CONST042 * VAR07 * VAR26\n )\n Y04 = (\n CONST003 * VAR05 * z\n + VAR07 * (CONST004 * VAR25 + CONST033 * VAR17 * z)\n + x * (CONST002 * VAR23 - CONST032 * VAR15 * z + CONST032 * VAR17 * VAR25)\n )\n Y05 = (\n CONST008 * VAR05 * y\n + VAR07 * (CONST017 * VAR26 * y + CONST036 * VAR16)\n + x * (CONST008 * VAR24 * y + CONST013 * VAR14 + CONST036 * VAR16 * VAR26)\n )\n Y06 = (\n CONST009 * VAR13\n + CONST018 * VAR17 * VAR24\n + CONST039 * VAR15 * VAR26\n + CONST047 * VAR04\n + CONST047 * VAR22\n + VAR06 * (CONST018 * VAR17 + CONST046 * VAR26)\n + VAR08 * (CONST024 * VAR17 * VAR26 + CONST039 * VAR15 + CONST046 * VAR24)\n )\n Y07 = (\n CONST008 * VAR23 * y\n + VAR25 * (CONST017 * VAR08 * y + CONST036 * VAR16)\n + z * (CONST008 * VAR06 * y + CONST014 * VAR14 + CONST036 * VAR08 * VAR16)\n )\n Y08 = (\n CONST026 * VAR04\n - CONST026 * VAR22\n + CONST040 * VAR17 * VAR24\n - CONST041 * VAR15 * VAR26\n + VAR06 * (CONST026 * VAR26 - CONST041 * VAR17)\n + VAR08 * (-CONST026 * VAR24 + CONST041 * VAR15)\n )\n Y09 = VAR16 * (CONST028 * VAR08 * z - CONST041 * VAR25) + y * (\n CONST022 * VAR06 * z - CONST042 * VAR08 * VAR25 + CONST044 * VAR23\n )\n Y10 = (\n CONST010 * VAR04\n + CONST010 * VAR22\n + CONST020 * VAR17 * VAR24\n + VAR06 * (CONST012 * VAR26 + CONST015 * VAR17)\n + VAR08 * (CONST012 * VAR24 + CONST019 * VAR17 * VAR26)\n )\n Y11 = y * (CONST006 * VAR23 + CONST025 * VAR06 * z + CONST027 * VAR08 * VAR25)\n Y12 = (\n -CONST037 * VAR06 * VAR26\n + CONST037 * VAR08 * VAR24\n + CONST043 * VAR04\n - CONST043 * VAR22\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y01,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y02,\n mask=output_row_offset + 2 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 3,\n Y03,\n mask=output_row_offset + 3 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 4,\n Y04,\n mask=output_row_offset + 4 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 5,\n Y05,\n mask=output_row_offset + 5 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 6,\n Y06,\n mask=output_row_offset + 6 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 7,\n Y07,\n mask=output_row_offset + 7 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 8,\n Y08,\n mask=output_row_offset + 8 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 9,\n Y09,\n mask=output_row_offset + 9 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 10,\n Y10,\n mask=output_row_offset + 10 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 11,\n Y11,\n mask=output_row_offset + 11 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 12,\n Y12,\n mask=output_row_offset + 12 < output_numel,\n )\n\n@triton.jit\ndef sixth_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_1 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_2 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n g_3 = tl.load(\n sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel\n )\n g_4 = tl.load(\n sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel\n )\n g_5 = tl.load(\n sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel\n )\n g_6 = tl.load(\n sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel\n )\n g_7 = tl.load(\n sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel\n )\n g_8 = tl.load(\n sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel\n )\n g_9 = tl.load(\n sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel\n )\n g_10 = tl.load(\n sph_grad_ptr + output_row_offset + 10,\n mask=output_row_offset + 10 < output_numel,\n )\n g_11 = tl.load(\n sph_grad_ptr + output_row_offset + 11,\n mask=output_row_offset + 11 < output_numel,\n )\n g_12 = tl.load(\n sph_grad_ptr + output_row_offset + 12,\n mask=output_row_offset + 12 < output_numel,\n )\n CONST000 = 2.00000000000000\n CONST002 = 4.00000000000000\n CONST003 = 3.00000000000000\n CONST004 = 6.53117523880657\n CONST006 = 8.94318001328386\n CONST007 = 8.38944649544891\n CONST008 = 10.3266947761614\n CONST009 = 9.79676285820985\n CONST013 = 16.3279380970164\n CONST014 = 17.8863600265677\n CONST015 = 16.5227116418583\n CONST016 = 20.6533895523229\n CONST017 = 20.2812259244849\n CONST018 = 21.6333076527839\n CONST020 = 17.8863600265677\n CONST022 = 29.3902885746295\n CONST024 = 35.7727200531355\n CONST026 = 40.5624518489699\n CONST028 = 41.9472324772445\n CONST029 = 48.9838142910493\n CONST030 = 51.6334738808072\n CONST035 = 71.5454401062709\n CONST037 = 81.1249036979398\n CONST039 = 82.6135582092915\n CONST040 = -3.26558761940328\n CONST042 = 117.561154298518\n CONST046 = 208.997607641810\n CONST048 = -251.683394863467\n CONST049 = -214.636320318813\n CONST050 = -214.636320318813\n CONST051 = 16.5227116418583\n CONST052 = -167.788929908978\n CONST053 = -156.748205731358\n CONST054 = -145.309475774982\n CONST055 = -123.920337313937\n CONST056 = -117.561154298518\n CONST057 = 3.26558761940328\n CONST058 = -108.166538263920\n CONST059 = -107.318160159406\n CONST060 = -104.498803820905\n CONST061 = -104.498803820905\n CONST062 = -83.8944649544891\n CONST063 = -82.6135582092915\n CONST064 = -78.3741028656788\n CONST065 = -72.6547378874909\n CONST066 = -71.5454401062709\n CONST067 = -58.7805771492591\n CONST068 = -54.0832691319598\n CONST069 = -52.2494019104525\n CONST070 = -52.2494019104525\n CONST071 = -48.9838142910492\n CONST072 = -41.3067791046458\n CONST073 = -39.1870514328394\n CONST074 = -35.7727200531355\n CONST075 = -29.3902885746295\n CONST076 = -27.0416345659799\n CONST077 = -26.1247009552263\n CONST078 = -26.1247009552263\n CONST079 = -19.5935257164197\n CONST080 = -14.5309475774982\n CONST081 = -13.5208172829900\n CONST082 = -10.7318160159406\n CONST083 = -9.79676285820985\n CONST084 = -7.15454401062709\n CONST085 = -6.76040864149498\n CONST086 = -3.38020432074749\n CONST087 = -1.63279380970164\n VAR07 = x * x * x\n VAR08 = x * x\n VAR05 = VAR07 * VAR08\n VAR06 = VAR08 * VAR08\n VAR16 = y * y * y\n VAR17 = y * y\n VAR14 = VAR16 * VAR17\n VAR15 = VAR17 * VAR17\n VAR25 = z * z * z\n VAR26 = z * z\n VAR23 = VAR25 * VAR26\n VAR24 = VAR26 * VAR26\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n g_0 * (CONST054 * VAR08 * VAR25 - CONST065 * VAR06 * z - CONST080 * VAR23)\n + g_1 * y * (CONST028 * VAR06 + CONST028 * VAR24 + CONST048 * VAR08 * VAR26)\n + g_10\n * (\n CONST000 * x * (CONST006 * VAR24 + CONST059 * VAR17 * VAR26)\n + CONST002 * VAR07 * (CONST006 * VAR26 + CONST014 * VAR17)\n + CONST082 * VAR05\n )\n + g_11 * y * (-CONST052 * VAR07 * z + CONST052 * VAR25 * x)\n + g_12 * (-CONST054 * VAR07 * VAR26 + CONST065 * VAR24 * x + CONST080 * VAR05)\n + g_2\n * (\n -CONST074 * VAR06 * z\n + CONST084 * VAR23\n + VAR17 * (CONST049 * VAR08 * z - CONST066 * VAR25)\n )\n + g_3\n * (\n VAR16 * (CONST064 * VAR08 - CONST064 * VAR26)\n + y * (CONST029 * VAR06 + CONST067 * VAR08 * VAR26 + CONST075 * VAR24)\n )\n + g_4\n * (\n CONST003 * VAR08 * (CONST004 * VAR25 + CONST069 * VAR17 * z)\n + CONST013 * VAR06 * z\n - CONST040 * VAR23\n - CONST070 * VAR15 * z\n + CONST070 * VAR17 * VAR25\n )\n + g_5\n * (\n CONST003 * VAR08 * (CONST016 * VAR26 * y + CONST072 * VAR16)\n + CONST008 * VAR24 * y\n + CONST015 * VAR14\n + CONST030 * VAR06 * y\n + CONST072 * VAR16 * VAR26\n )\n + g_6\n * (\n CONST000\n * x\n * (CONST026 * VAR17 * VAR26 + CONST076 * VAR15 + CONST086 * VAR24)\n + CONST002 * VAR07 * (CONST017 * VAR17 + CONST086 * VAR26)\n + CONST085 * VAR05\n )\n + g_7\n * (\n -CONST072 * VAR25 * x * y\n + z * (CONST063 * VAR16 * x - CONST072 * VAR07 * y)\n )\n + g_8\n * (\n CONST000 * x * (CONST077 * VAR15 - CONST087 * VAR24)\n + CONST002 * VAR07 * (-CONST077 * VAR17 + CONST087 * VAR26)\n + CONST083 * VAR05\n )\n + g_9\n * (CONST053 * VAR16 * x * z + y * (CONST042 * VAR07 * z - CONST073 * VAR25 * x))\n )\n g_y += (\n CONST000 * g_2 * y * (CONST066 * VAR07 * z - CONST066 * VAR25 * x)\n + g_1 * (CONST007 * VAR05 + CONST028 * VAR24 * x + CONST062 * VAR07 * VAR26)\n + g_10\n * (CONST024 * VAR06 * y + CONST050 * VAR08 * VAR26 * y - CONST074 * VAR24 * y)\n + g_11 * (CONST007 * VAR23 + CONST028 * VAR06 * z + CONST062 * VAR08 * VAR25)\n + g_3\n * (\n CONST003 * VAR17 * (-CONST064 * VAR26 * x + CONST078 * VAR07)\n + CONST009 * VAR05\n + CONST075 * VAR24 * x\n + CONST079 * VAR07 * VAR26\n )\n + g_4\n * (CONST061 * VAR07 * y * z + x * (CONST046 * VAR16 * z + CONST060 * VAR25 * y))\n + g_5\n * (\n CONST008 * VAR05\n + VAR07 * (CONST016 * VAR26 + CONST055 * VAR17)\n + x * (CONST008 * VAR24 + CONST055 * VAR17 * VAR26 - CONST063 * VAR15)\n )\n + g_6\n * (\n CONST018 * VAR14\n + CONST026 * VAR06 * y\n + CONST026 * VAR24 * y\n + CONST058 * VAR16 * VAR26\n + VAR08 * (CONST037 * VAR26 * y + CONST058 * VAR16)\n )\n + g_7\n * (\n CONST008 * VAR23\n + VAR25 * (CONST016 * VAR08 + CONST055 * VAR17)\n + z * (CONST008 * VAR06 + CONST039 * VAR15 + CONST055 * VAR08 * VAR17)\n )\n + g_8\n * (\n CONST060 * VAR08 * VAR16\n - CONST060 * VAR16 * VAR26\n + CONST069 * VAR24 * y\n - CONST070 * VAR06 * y\n )\n + g_9\n * (\n CONST003 * VAR17 * (CONST064 * VAR08 * z - CONST077 * VAR25)\n + CONST022 * VAR06 * z\n - CONST079 * VAR08 * VAR25\n + CONST083 * VAR23\n )\n )\n g_z += (\n g_0 * (CONST054 * VAR07 * VAR26 - CONST065 * VAR24 * x - CONST080 * VAR05)\n + g_1 * y * (CONST052 * VAR07 * z - CONST052 * VAR25 * x)\n + g_10\n * (\n CONST020 * VAR06 * z\n + CONST035 * VAR17 * VAR25\n + CONST082 * VAR23\n + VAR08 * (CONST050 * VAR17 * z - CONST074 * VAR25)\n )\n + g_11 * y * (CONST028 * VAR06 + CONST028 * VAR24 + CONST048 * VAR08 * VAR26)\n + g_12 * (CONST054 * VAR08 * VAR25 - CONST065 * VAR06 * z - CONST080 * VAR23)\n + g_2\n * (\n CONST074 * VAR24 * x\n - CONST084 * VAR05\n + VAR17 * (-CONST049 * VAR26 * x + CONST066 * VAR07)\n )\n + g_3\n * (\n -CONST053 * VAR16 * x * z\n + y * (CONST056 * VAR25 * x + CONST073 * VAR07 * z)\n )\n + g_4\n * (\n CONST057 * VAR05\n + VAR07 * (CONST069 * VAR17 - CONST079 * VAR26)\n + x * (CONST013 * VAR24 + CONST053 * VAR17 * VAR26 - CONST070 * VAR15)\n )\n + g_5\n * (\n -CONST072 * VAR07 * y * z\n + x * (CONST063 * VAR16 * z - CONST072 * VAR25 * y)\n )\n + g_6\n * (\n CONST037 * VAR17 * VAR25\n + CONST068 * VAR15 * z\n + CONST085 * VAR06 * z\n + CONST085 * VAR23\n + VAR08 * (CONST037 * VAR17 * z + CONST081 * VAR25)\n )\n + g_7\n * (\n CONST003 * VAR26 * (CONST016 * VAR08 * y + CONST072 * VAR16)\n + CONST008 * VAR06 * y\n + CONST030 * VAR24 * y\n + CONST051 * VAR14\n + CONST072 * VAR08 * VAR16\n )\n + g_8\n * (\n CONST004 * VAR08 * VAR25\n + CONST040 * VAR06 * z\n + CONST061 * VAR17 * VAR25\n - CONST070 * VAR15 * z\n - CONST083 * VAR23\n )\n + g_9\n * (\n VAR16 * (CONST064 * VAR08 - CONST064 * VAR26)\n + y * (CONST022 * VAR06 - CONST067 * VAR08 * VAR26 + CONST071 * VAR24)\n )\n )\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n\nclass SixthOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 13), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n sixth_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx, sph_grad_tensor: torch.Tensor, block_size: int = 64, col_offset: int = 0\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n sixth_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n", - "description_1": "Use triton language to implement forward and backward kernels for sixth order spherical harmonics. The forward kernel takes 7 parameters: coord_ptr (tensor), output_ptr (tensor), block_size (constexpr), coord_numel (constexpr), output_numel (constexpr), col_offset (constexpr), and output_stride (constexpr). The backward kernel also takes 8 parameters: coord_ptr (tensor), coord_grad_ptr (tensor), sph_grad_ptr (tensor), block_size (constexpr), coord_numel (constexpr), output_numel (constexpr), col_offset (constexpr), and output_stride (constexpr). These kernels compute the forward and backward pass for the spherical harmonic transformation of the input coordinates.", - "description_2": "Use triton language to compute forward and backward passes of sixth order spherical harmonics using given tensors and parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\nclass SeventhOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 15), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # apply the kernel\n seventh_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor,\n block_size: int = 64,\n col_offset: int = 0,\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # call backward kernel\n seventh_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n\n@triton.jit\ndef seventh_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n # Variables and constants are defined here\n CONST002 = 3.87298334620742\n # The rest of the constants would be here\n # -------------------- kernel implementations\n # Computation and storage for outputs Y00 to Y14 would go here\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n # Store results back to output_ptr with proper masking\n tl.store(output_ptr + output_row_offset, x, mask=output_row_offset < output_numel)\n\n@triton.jit\ndef seventh_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n # Load gradients\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n # The rest of the gradient loads g_1 to g_14 would go here\n # -------------------- kernel implementations\n # Computation for gradients g_x, g_y, g_z would go here\n # Write out gradients\n tl.store(\n coord_grad_ptr + coord_row_offset, x, mask=coord_row_offset < coord_numel\n )\n # The rest of the gradient stores would go here\n", - "description_1": "Use triton language to implement two kernels, 'seventh_order_fwd' and 'seventh_order_bwd'. The forward kernel computes the seventh order spherical harmonics given input coordinates, and stores the result in the output tensor. It requires 7 parameters: 'coord_ptr', 'output_ptr', 'block_size', 'coord_numel', 'output_numel', 'col_offset', and 'output_stride'. The backward kernel calculates the gradient of the coordinates with respect to the input spherical harmonic gradients. It also requires 8 parameters: 'coord_ptr', 'coord_grad_ptr', 'sph_grad_ptr', 'block_size', 'coord_numel', 'output_numel', 'col_offset', and 'output_stride'.", - "description_2": "Use triton language to compute forward and backward passes of seventh order spherical harmonics using the specified kernel parameters.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\nfrom equitriton.utils import calculate_lastdim_num_blocks\n\n@triton.jit\ndef eighth_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # These are hardcoded because they are predetermined;\n coord_stride = 3\n # Work out the row offsets\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n CONST000 = 1.12741169450483\n # Additional constants omitted for brevity...\n # Compute high order spherical harmonics\n VAR06 = x * x * x * x\n VAR07 = x * x * x\n VAR08 = x * x\n VAR02 = VAR06 * VAR06\n VAR03 = VAR06 * VAR07\n VAR04 = VAR07 * VAR07\n VAR05 = VAR07 * VAR08\n VAR15 = y * y * y * y\n VAR16 = y * y * y\n VAR17 = y * y\n VAR11 = VAR15 * VAR16\n VAR12 = VAR15 * VAR16\n VAR13 = VAR16 * VAR16\n VAR14 = VAR16 * VAR17\n VAR24 = z * z * z * z\n VAR25 = z * z * z\n VAR26 = z * z\n VAR20 = VAR24 * VAR24\n VAR21 = VAR24 * VAR25\n VAR22 = VAR25 * VAR25\n VAR23 = VAR25 * VAR26\n Y00 = (\n -CONST066 * VAR05 * VAR25\n + CONST066 * VAR07 * VAR23\n + CONST089 * VAR03 * z\n - CONST089 * VAR21 * x\n )\n # Additional Y calculations omitted for brevity...\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n # Additional stores omitted for brevity...\n\n@triton.jit\ndef eighth_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n # Work out the row offsets\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_1 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n # Additional gradient loads omitted for brevity...\n CONST000 = 2.00000000000000\n CONST001 = 3.00000000000000\n # Additional constants omitted for brevity...\n VAR06 = x * x * x * x\n VAR07 = x * x * x\n VAR08 = x * x\n VAR03 = VAR06 * VAR07\n VAR04 = VAR07 * VAR07\n VAR05 = VAR07 * VAR08\n VAR15 = y * y * y * y\n VAR16 = y * y * y\n VAR17 = y * y\n VAR12 = VAR15 * VAR16\n VAR13 = VAR16 * VAR16\n VAR14 = VAR16 * VAR17\n VAR24 = z * z * z * z\n VAR25 = z * z * z\n VAR26 = z * z\n VAR21 = VAR24 * VAR25\n VAR22 = VAR25 * VAR25\n VAR23 = VAR25 * VAR26\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n g_0\n * (\n CONST049 * VAR08 * VAR23\n - CONST131 * VAR06 * VAR25\n + CONST151 * VAR04 * z\n - CONST211 * VAR21\n )\n )\n # Additional updates omitted for brevity...\n # Write out gradients\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n\nclass EighthOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 17), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # Apply the kernel\n eighth_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor,\n block_size: int = 64,\n col_offset: int = 0,\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n # Call backward kernel\n eighth_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n", - "description_1": "Use triton language to implement a spherical harmonic transformation of eighth order. The forward kernel 'eighth_order_fwd' computes this transformation with parameters: (1) coord_ptr: pointer to input tensor; (2) output_ptr: pointer to output tensor; (3) block_size: size of the processing block; (4) coord_numel: total number of elements in coord; (5) output_numel: total number of elements in output; (6) col_offset: offset in the column; (7) output_stride: stride of the output tensor. The backward kernel 'eighth_order_bwd' calculates gradients with respect to the inputs and has the same parameters as the forward kernel, with additional pointers for gradients. These kernels are used within 'EighthOrderSphericalHarmonic', a PyTorch autograd function, which encapsulates the forward and backward passes.", - "description_2": "Use triton language to develop forward and backward kernels for computing eighth order spherical harmonics and integrating these kernels into a PyTorch autograd function.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\nfrom triton import language as tl\n\ndef calculate_lastdim_num_blocks(coords, block_size):\n last_dim = coords.shape[-1]\n return (last_dim + block_size - 1) // block_size\n\n@triton.jit\ndef ninth_order_fwd(\n coord_ptr: tl.tensor,\n output_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel)\n z = tl.load(coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel)\n \n CONST001 = 2.65478475211798\n CONST020 = 23.8930627690618\n CONST078 = -223.001919177910\n CONST091 = 334.502878766866\n CONST105 = -95.5722510762473\n \n VAR06 = x * x * x * x\n VAR07 = x * x * x\n VAR08 = x * x\n VAR01 = VAR07 * VAR07 * VAR07\n VAR02 = VAR06 * VAR06\n VAR03 = VAR06 * VAR07\n VAR04 = VAR07 * VAR07\n VAR05 = VAR07 * VAR08\n VAR15 = y * y * y * y\n VAR16 = y * y * y\n VAR17 = y * y\n VAR10 = VAR16 * VAR16 * VAR16\n VAR11 = VAR15 * VAR15\n VAR12 = VAR15 * VAR16\n VAR13 = VAR16 * VAR16\n VAR14 = VAR16 * VAR17\n VAR24 = z * z * z * z\n VAR25 = z * z * z\n VAR26 = z * z\n VAR19 = VAR25 * VAR25 * VAR25\n VAR20 = VAR24 * VAR24\n VAR21 = VAR24 * VAR25\n VAR22 = VAR25 * VAR25\n VAR23 = VAR25 * VAR26\n \n Y00 = (\n CONST001 * VAR01\n + CONST020 * VAR20 * x\n + CONST078 * VAR07 * VAR22\n + CONST091 * VAR05 * VAR24\n + CONST105 * VAR03 * VAR26\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (output_striding + (block_size * output_stride * block_id) + col_offset)\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n\n@triton.jit\ndef ninth_order_bwd(\n coord_ptr: tl.tensor,\n coord_grad_ptr: tl.tensor,\n sph_grad_ptr: tl.tensor,\n block_size: tl.constexpr,\n coord_numel: tl.constexpr,\n output_numel: tl.constexpr,\n col_offset: tl.constexpr,\n output_stride: tl.constexpr,\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel)\n z = tl.load(coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel)\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (output_striding + (block_size * output_stride * block_id) + col_offset)\n g_0 = tl.load(sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel)\n\n g_x = tl.load(coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n g_y = tl.load(coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel)\n g_z = tl.load(coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel)\n\n g_x += (\n g_0\n * (\n CONST021 * VAR20\n + CONST022 * VAR02\n + CONST179 * VAR04 * VAR26\n + CONST180 * VAR08 * VAR22\n + CONST204 * VAR06 * VAR24\n )\n )\n tl.store(coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel)\n tl.store(coord_grad_ptr + coord_row_offset + 1, g_y, mask=coord_row_offset + 1 < coord_numel)\n tl.store(coord_grad_ptr + coord_row_offset + 2, g_z, mask=coord_row_offset + 2 < coord_numel)\n\nclass NinthOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor,\n output_tensor: torch.Tensor | None = None,\n mask: torch.Tensor | None = None,\n block_size: int = 64,\n col_offset: int = 0,\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 19), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n ninth_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor,\n block_size: int = 64,\n col_offset: int = 0,\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n ninth_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n", - "description_1": "Use triton language to implement forward and backward kernels for ninth order spherical harmonic transformation with six input parameters: coord_ptr, output_ptr, block_size, coord_numel, output_numel, and col_offset, used in a PyTorch autograd function. The function handles tensors for coordinate transformation and gradient calculation.", - "description_2": "Use triton language to implement forward and backward kernels with input params for spherical harmonic transformation, integrated in PyTorch autograd.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _silu_and_mul_kernel(\n gateup_ptr,\n out_ptr,\n N: tl.constexpr,\n stride_gum: tl.constexpr,\n stride_gun: tl.constexpr,\n stride_om: tl.constexpr,\n stride_on: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"silu and mul kernel.\"\"\"\n m_id = tl.program_id(0)\n\n up_ptr = gateup_ptr + N * stride_gun\n\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun\n up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun\n out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on\n\n for _ in range(0, N, BLOCK_SIZE_N):\n gate = tl.load(gate_ptrs).to(tl.float32)\n up = tl.load(up_ptrs).to(tl.float32)\n\n gate = gate / (1 + fast_expf(-gate))\n out = gate * up\n\n tl.store(out_ptrs, out)\n\n gate_ptrs += BLOCK_SIZE_N * stride_gun\n up_ptrs += BLOCK_SIZE_N * stride_gun\n out_ptrs += BLOCK_SIZE_N * stride_on\n\n\n@triton.jit\ndef _silu_and_mul_no_align_kernel(\n gateup_ptr,\n out_ptr,\n N: tl.constexpr,\n stride_gum: tl.constexpr,\n stride_gun: tl.constexpr,\n stride_om: tl.constexpr,\n stride_on: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"silu and mul kernel.\"\"\"\n m_id = tl.program_id(0)\n\n up_ptr = gateup_ptr + N * stride_gun\n\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun\n up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun\n out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on\n\n for n in range(0, N, BLOCK_SIZE_N):\n mask = n + offs_n < N\n gate = tl.load(gate_ptrs, mask=mask).to(tl.float32)\n up = tl.load(up_ptrs, mask=mask).to(tl.float32)\n\n gate = gate / (1 + fast_expf(-gate))\n out = gate * up\n\n tl.store(out_ptrs, out, mask=mask)\n\n gate_ptrs += BLOCK_SIZE_N * stride_gun\n up_ptrs += BLOCK_SIZE_N * stride_gun\n out_ptrs += BLOCK_SIZE_N * stride_on\n\n\ndef silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None):\n \"\"\"silu and mul.\"\"\"\n assert gate_up.dim() == 2\n\n M = gate_up.size(0)\n N = gate_up.size(-1) // 2\n if out is None:\n out_shape = (M, N)\n out = gate_up.new_empty(out_shape)\n\n BLOCK_SIZE_N = triton.next_power_of_2(N)\n BLOCK_SIZE_N = min(BLOCK_SIZE_N, 1024)\n num_warps = 4\n num_stages = 2\n grid = (M, )\n if N % BLOCK_SIZE_N == 0:\n _silu_and_mul_kernel[grid](gate_up,\n out,\n N,\n stride_gum=gate_up.stride(0),\n stride_gun=gate_up.stride(1),\n stride_om=out.stride(0),\n stride_on=out.stride(1),\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n num_warps=num_warps,\n num_stages=num_stages)\n else:\n _silu_and_mul_no_align_kernel[grid](gate_up,\n out,\n N,\n stride_gum=gate_up.stride(0),\n stride_gun=gate_up.stride(1),\n stride_om=out.stride(0),\n stride_on=out.stride(1),\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n num_warps=num_warps,\n num_stages=num_stages)\n\n return out\n", - "description_1": "Use triton language to implement two kernels, _silu_and_mul_kernel and _silu_and_mul_no_align_kernel, each with 7 parameters: gateup_ptr, out_ptr, N, stride_gum, stride_gun, stride_om, stride_on, and BLOCK_SIZE_N. These kernels perform element-wise operations on input tensors, applying the SiLU activation function followed by multiplication. The silu_and_mul function, with 2 parameters: gate_up and out, calls these kernels based on the alignment of the input tensor dimensions.", - "description_2": "Use triton language to create kernels for element-wise SiLU activation and multiplication on tensors, with conditional kernel selection based on tensor alignment.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nLOG2 = math.log(2)\n\n@triton.jit\ndef tl_pow(a, b):\n \"\"\"triton pow.\"\"\"\n return tl.exp(b * tl.log(a))\n\n@triton.jit\ndef tl_2pow(b):\n \"\"\"triton pow2.\"\"\"\n return tl.exp(b * LOG2)\n\n@triton.jit\ndef tl_log2(a):\n \"\"\"triton log2.\"\"\"\n return tl.log(a) / LOG2\n\n@triton.jit\ndef _get_interleave_power_of_2(i, n):\n \"\"\"get interleave power of 2.\"\"\"\n start = -tl_2pow(3 - tl_log2(n))\n start = tl_2pow(start)\n ratio = start\n return start * tl_pow(ratio, i)\n\n@triton.jit\ndef get_slope(i, n):\n \"\"\"get slope.\"\"\"\n closest_power_of_2 = tl_2pow(tl_log2(n).to(tl.int32))\n if i < closest_power_of_2:\n return _get_interleave_power_of_2(i, closest_power_of_2)\n else:\n return _get_interleave_power_of_2((i - closest_power_of_2) * 2,\n 2 * closest_power_of_2)\n\n@triton.jit\ndef _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr,\n BLOCK: tl.constexpr):\n if num_sub_blocks > 1:\n offs_sub = tl.arange(0, num_sub_blocks)\n offs_n = tl.arange(0, BLOCK // num_sub_blocks)\n ret = tl.load(offset_ptr + block_id * num_sub_blocks + offs_sub)[\n None, :] * BLOCK // num_sub_blocks + offs_n[:, None]\n return tl.ravel(ret)\n else:\n offs_n = tl.arange(0, BLOCK)\n return tl.load(offset_ptr + block_id) * BLOCK + offs_n\n\n@triton.jit\ndef _fwd_split_kernel(\n Q, K, V, sm_scale, alibi_scale, B_kvlen, Block_offsets, Acc_out,\n stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd, stride_ok, stride_obs, stride_oh,\n stride_od, stride_boffb, head_offset, num_heads, kv_group_num, block_per_cta,\n num_sub_blocks: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr\n):\n \"\"\"first step kernel of split k attention.\"\"\"\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n split_k_id = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = 1\n cur_batch_kv_len = tl.load(B_kvlen + cur_batch)\n history_len = cur_batch_kv_len - cur_batch_seq_len\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = (cur_batch * stride_qbs + cur_head * stride_qh +\n offs_d * stride_qd)\n off_k = (cur_kv_head * stride_kh + offs_d[None, :] * stride_kd)\n off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)\n\n q = tl.load(Q + off_q).to(tl.float32)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_offset_ptrs = Block_offsets + cur_batch * stride_boffb\n head_slope = get_slope(\n cur_head.to(tl.float32) + head_offset, num_heads.to(tl.float32))\n\n # initialize pointer to m and l\n m_i = -float('inf')\n l_i = float(0)\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n kv_len_per_prog = block_per_cta * BLOCK_N\n loop_start = kv_len_per_prog * split_k_id\n loop_end = tl.minimum(loop_start + kv_len_per_prog, cur_batch_kv_len)\n\n # load block offset\n start_block_id = loop_start // BLOCK_N\n b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,\n num_sub_blocks, BLOCK_N)\n\n for start_n in range(loop_start, loop_end, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n\n mask = (start_n + offs_n[:, None]) < cur_batch_kv_len\n\n # -- compute qk ----\n k = tl.load(\n k_ptrs + b_offset[:, None] * stride_kbs,\n mask=mask,\n other=0.0,\n )\n\n v = tl.load(\n v_ptrs + b_offset[:, None] * stride_vbs,\n mask=mask,\n other=0.0,\n )\n\n # prefetch b_offset\n if start_n + BLOCK_N < loop_end:\n start_block_id += 1\n b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,\n num_sub_blocks, BLOCK_N)\n\n qk = tl.sum(q[None, :] * k, 1)\n qk *= sm_scale\n\n mask = start_n + offs_n\n bias = mask.to(tl.float32) * (head_slope * alibi_scale)\n qk += bias\n\n # NOTE: inf - inf = nan, and nan will leads to error\n qk = tl.where(\n history_len >= (start_n + offs_n),\n qk,\n -float('inf'),\n )\n\n # -- compute p, m_i and l_i\n m_i_new = tl.maximum(m_i, tl.max(qk, 0))\n p = tl.exp(qk - m_i_new)\n alpha = tl.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + tl.sum(p, 0)\n\n # -- update output accumulator --\n # scale acc\n acc = acc * alpha\n\n # update acc\n p_new = p.to(v.dtype)\n acc += tl.sum(p_new[:, None] * v, 0)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n # initialize pointers to output\n off_acc = (cur_batch * stride_obs + split_k_id * stride_ok +\n cur_head * stride_oh + offs_d * stride_od)\n tl.store(Acc_out + off_acc, acc)\n\n off_meta = (cur_batch * stride_obs + split_k_id * stride_ok +\n cur_head * stride_oh + BLOCK_DMODEL)\n tl.store(Acc_out + off_meta + tl.arange(0, 1), m_i)\n tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)\n\n@triton.jit\ndef _reduce_split_kernel(\n Acc, Out, stride_ak, stride_abs, stride_ah, stride_ad,\n stride_obs, stride_oh, stride_od, SPLIT_K: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n \"\"\"second step kernel of split k attention.\"\"\"\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n # initialize offsets\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_k = tl.arange(0, SPLIT_K)\n\n offs_acc = (cur_batch * stride_abs + cur_head * stride_ah +\n offs_k[:, None] * stride_ak + offs_d[None, :] * stride_ad)\n offs_mi = (cur_batch * stride_abs + cur_head * stride_ah +\n stride_ak * offs_k + BLOCK_DMODEL)\n\n acc_k = tl.load(Acc + offs_acc)\n m_k = tl.load(Acc + offs_mi)\n l_k = tl.load(Acc + offs_mi + 1)\n\n m_max = tl.max(m_k, 0)\n alpha = tl.exp(m_k - m_max)\n acc_k = acc_k * alpha[:, None]\n l_k = l_k * alpha\n\n acc = tl.sum(acc_k, 0)\n l_sum = tl.sum(l_k, 0)\n acc = acc / l_sum\n\n out_offs = (cur_batch * stride_obs + cur_head * stride_oh +\n offs_d * stride_od)\n tl.store(Out + out_offs, acc)\n\ndef alibi_paged_attention_fwd(\n q: Tensor, k: Tensor, v: Tensor, o: Tensor,\n block_offsets: Tensor, b_start_loc: Tensor,\n b_seq_len: Tensor, b_kv_seq_len: Tensor,\n max_input_len: int, head_offset: int = 0,\n num_heads: int = -1, alibi_scale: float = 1.0,\n k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None,\n quant_policy: Literal[0, 4, 8] = 0,\n):\n \"\"\"Paged attention forward with alibi bias.\"\"\"\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n if quant_policy == 4:\n assert Lq == Lk * 2 and Lk == Lv\n assert Lk in {8, 16, 32, 64}\n else:\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5) # 计算scale系数\n batch, head = b_seq_len.shape[0], q.shape[-2]\n kv_group_num = q.shape[-2] // k[0].shape[-2]\n if num_heads <= 0:\n num_heads = head\n\n BLOCK = 64 if k.size(1) < 16 else k.size(1)\n num_sub_blocks = BLOCK // k.size(1)\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 4 if Lq <= 64 else 8\n kernel_meta = get_kernel_meta(q)\n is_decoding = q.shape[-3] == b_seq_len.size(0)\n if not is_decoding:\n if quant_policy > 0:\n _fwd_kernel_quant[grid](q,\n k,\n v,\n k_scales_zeros,\n v_scales_zeros,\n sm_scale,\n alibi_scale,\n b_start_loc,\n b_seq_len,\n b_kv_seq_len,\n block_offsets,\n o,\n q.stride(-3),\n q.stride(-2),\n q.stride(-1),\n k.stride(-3),\n k.stride(-2),\n k.stride(-1),\n v.stride(-3),\n v.stride(-2),\n v.stride(-1),\n k_scales_zeros.stride(-3),\n k_scales_zeros.stride(-2),\n k_scales_zeros.stride(-1),\n v_scales_zeros.stride(-3),\n v_scales_zeros.stride(-2),\n v_scales_zeros.stride(-1),\n quant_policy,\n o.stride(-3),\n o.stride(-2),\n o.stride(-1),\n block_offsets.stride(0),\n head_offset=head_offset,\n num_heads=num_heads,\n kv_group_num=kv_group_num,\n num_sub_blocks=num_sub_blocks,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lq,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n **kernel_meta)\n else:\n _fwd_kernel[grid](q,\n k,\n v,\n sm_scale,\n alibi_scale,\n b_start_loc,\n b_seq_len,\n b_kv_seq_len,\n block_offsets,\n o,\n q.stride(-3),\n q.stride(-2),\n q.stride(-1),\n k.stride(-3),\n k.stride(-2),\n k.stride(-1),\n v.stride(-3),\n v.stride(-2),\n v.stride(-1),\n o.stride(-3),\n o.stride(-2),\n o.stride(-1),\n block_offsets.stride(0),\n head_offset=head_offset,\n num_heads=num_heads,\n kv_group_num=kv_group_num,\n num_sub_blocks=num_sub_blocks,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lq,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n **kernel_meta)\n else:\n SPLIT_K = 4\n grid = (batch, head, SPLIT_K)\n block_per_cta = triton.cdiv(block_offsets.size(-1), SPLIT_K)\n acc = q.new_empty(batch, head, SPLIT_K, Lq + 2, dtype=torch.float32)\n if quant_policy > 0:\n _fwd_split_kernel_quant[grid](\n q,\n k,\n v,\n k_scales_zeros,\n v_scales_zeros,\n sm_scale,\n alibi_scale,\n b_kv_seq_len,\n block_offsets,\n acc,\n stride_qbs=q.stride(-3),\n stride_qh=q.stride(-2),\n stride_qd=q.stride(-1),\n stride_kbs=k.stride(-3),\n stride_kh=k.stride(-2),\n stride_kd=k.stride(-1),\n stride_vbs=v.stride(-3),\n stride_vh=v.stride(-2),\n stride_vd=v.stride(-1),\n stride_kszbs=k_scales_zeros.stride(-3),\n stride_kszh=k_scales_zeros.stride(-2),\n stride_kszd=k_scales_zeros.stride(-1),\n stride_vszbs=v_scales_zeros.stride(-3),\n stride_vszh=v_scales_zeros.stride(-2),\n stride_vszd=v_scales_zeros.stride(-1),\n quant_policy=quant_policy,\n stride_ok=acc.stride(-2),\n stride_obs=acc.stride(-4),\n stride_oh=acc.stride(-3),\n stride_od=acc.stride(-1),\n stride_boffb=block_offsets.stride(0),\n head_offset=head_offset,\n num_heads=num_heads,\n kv_group_num=kv_group_num,\n block_per_cta=block_per_cta,\n num_sub_blocks=num_sub_blocks,\n BLOCK_DMODEL=Lq,\n BLOCK_N=BLOCK,\n num_warps=4,\n num_stages=1,\n **kernel_meta)\n\n else:\n _fwd_split_kernel[grid](q,\n k,\n v,\n sm_scale,\n alibi_scale,\n b_kv_seq_len,\n block_offsets,\n acc,\n stride_qbs=q.stride(-3),\n stride_qh=q.stride(-2),\n stride_qd=q.stride(-1),\n stride_kbs=k.stride(-3),\n stride_kh=k.stride(-2),\n stride_kd=k.stride(-1),\n stride_vbs=v.stride(-3),\n stride_vh=v.stride(-2),\n stride_vd=v.stride(-1),\n stride_ok=acc.stride(-2),\n stride_obs=acc.stride(-4),\n stride_oh=acc.stride(-3),\n stride_od=acc.stride(-1),\n stride_boffb=block_offsets.stride(0),\n head_offset=head_offset,\n num_heads=num_heads,\n kv_group_num=kv_group_num,\n block_per_cta=block_per_cta,\n num_sub_blocks=num_sub_blocks,\n BLOCK_DMODEL=Lq,\n BLOCK_N=BLOCK,\n num_warps=4,\n num_stages=1,\n **kernel_meta)\n\n grid = (batch, head)\n _reduce_split_kernel[grid](acc,\n o,\n stride_ak=acc.stride(-2),\n stride_abs=acc.stride(-4),\n stride_ah=acc.stride(-3),\n stride_ad=acc.stride(-1),\n stride_obs=o.stride(-3),\n stride_oh=o.stride(-2),\n stride_od=o.stride(-1),\n SPLIT_K=SPLIT_K,\n BLOCK_DMODEL=Lq,\n num_warps=num_warps,\n num_stages=1,\n **kernel_meta)\n", - "description_1": "Use triton language to implement a split-k attention mechanism with forward and reduction kernels. This involves loading block offsets, computing QK products with scaling, handling attention weights and accumulation, and finally storing outputs with considerations for quantization.", - "description_2": "Use triton language to compute paged attention with alibi bias. This includes determining scales and shapes, setting up kernel grids, and invoking forward and reduction kernels with quantization support if necessary.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit(do_not_specialize=('seq_len', ))\ndef apply_rotary_pos_emb_qk_kernel(\n Q,\n K,\n COS,\n SIN,\n Q_EMB,\n K_EMB,\n seq_len,\n stride_qs: tl.constexpr,\n stride_qh: tl.constexpr,\n stride_qd: tl.constexpr,\n stride_ks: tl.constexpr,\n stride_kh: tl.constexpr,\n stride_kd: tl.constexpr,\n stride_qes: tl.constexpr,\n stride_qeh: tl.constexpr,\n stride_qed: tl.constexpr,\n stride_kes: tl.constexpr,\n stride_keh: tl.constexpr,\n stride_ked: tl.constexpr,\n half_size: tl.constexpr,\n BLOCK: tl.constexpr,\n BLOCK_QH: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n \"\"\"apply rotary on key AND query kernel.\"\"\"\n seq_block_id = tl.program_id(0)\n head_id = tl.program_id(1)\n\n pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK)\n pos_mask = pos_offset < seq_len\n pos_offset = tl.max_contiguous(tl.multiple_of(pos_offset % seq_len, BLOCK),\n BLOCK)\n\n feat_size = half_size * 2\n feat_offset_l = tl.arange(0, BLOCK_N)\n feat_mask = feat_offset_l < half_size\n feat_offset_l = feat_offset_l % half_size\n feat_offset_h = half_size + feat_offset_l\n seq_mask = pos_mask[:, None] and feat_mask[None, :]\n cs_offset_l = pos_offset[:, None] * feat_size + feat_offset_l[None, :]\n cs_offset_h = pos_offset[:, None] * feat_size + feat_offset_h[None, :]\n q_elem_type = Q.dtype.element_ty\n cos_l = tl.load(COS + cs_offset_l).to(q_elem_type)\n cos_h = tl.load(COS + cs_offset_h).to(q_elem_type)\n sin_l = tl.load(SIN + cs_offset_l).to(q_elem_type)\n sin_h = tl.load(SIN + cs_offset_h).to(q_elem_type)\n\n if head_id < BLOCK_QH:\n q_ptr = Q + pos_offset * stride_qs\n qe_ptr = Q_EMB + pos_offset * stride_qes\n ql_ptrs = q_ptr[:, None] + feat_offset_l[None, :] * stride_qd\n qh_ptrs = q_ptr[:, None] + feat_offset_h[None, :] * stride_qd\n qel_ptrs = qe_ptr[:, None] + feat_offset_l[None, :] * stride_qed\n qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed\n ql_ptrs += head_id * stride_qh\n qh_ptrs += head_id * stride_qh\n qel_ptrs += head_id * stride_qeh\n qeh_ptrs += head_id * stride_qeh\n\n q_l = tl.load(ql_ptrs)\n q_h = tl.load(qh_ptrs)\n qe_l = q_l * cos_l - q_h * sin_l\n qe_h = q_h * cos_h + q_l * sin_h\n\n tl.store(qel_ptrs, qe_l, mask=seq_mask)\n tl.store(qeh_ptrs, qe_h, mask=seq_mask)\n else:\n head_id = head_id - BLOCK_QH\n k_ptr = K + pos_offset * stride_ks\n ke_ptr = K_EMB + pos_offset * stride_kes\n kl_ptrs = k_ptr[:, None] + feat_offset_l[None, :] * stride_kd\n kh_ptrs = k_ptr[:, None] + feat_offset_h[None, :] * stride_kd\n kel_ptrs = ke_ptr[:, None] + feat_offset_l[None, :] * stride_ked\n keh_ptrs = ke_ptr[:, None] + feat_offset_h[None, :] * stride_ked\n kl_ptrs += head_id * stride_kh\n kh_ptrs += head_id * stride_kh\n kel_ptrs += head_id * stride_keh\n keh_ptrs += head_id * stride_keh\n k_l = tl.load(kl_ptrs)\n k_h = tl.load(kh_ptrs)\n ke_l = k_l * cos_l - k_h * sin_l\n ke_h = k_h * cos_h + k_l * sin_h\n\n tl.store(kel_ptrs, ke_l, mask=seq_mask)\n tl.store(keh_ptrs, ke_h, mask=seq_mask)\n\n\ndef apply_rotary_pos_emb(q: Tensor,\n k: Tensor,\n cos: Tensor,\n sin: Tensor,\n q_embed: Tensor = None,\n k_embed: Tensor = None):\n \"\"\"Apply rotary positional embedding on query and key.\n\n Args:\n q (Tensor): Query state.\n k (Tensor): Key state.\n cos (Tensor): cosine matrix (seq_len, dim).\n sin (Tensor): sine matrix (seq_len, dim).\n q_embed (Tensor): output q, can be same as q\n k_embed (Tensor): output k, can be same as k\n\n Returns:\n Tuple[Tensor, Tensor]: Embedded query and key.\n \"\"\"\n if cos.device != q.device:\n cos = cos.to(device=q.device)\n if sin.device != q.device:\n sin = sin.to(device=q.device)\n\n if q_embed is None:\n q_embed = torch.empty_like(q)\n if k_embed is None:\n k_embed = torch.empty_like(k)\n\n seq_len = cos.numel() // cos.size(-1)\n BLOCK = 16\n half_size = q.size(-1) // 2\n BLOCK_N = triton.next_power_of_2(half_size)\n num_heads_q = q.size(-2)\n num_heads_k = k.size(-2)\n num_warps = 4\n num_stages = 4\n\n grid = [triton.cdiv(seq_len, BLOCK), num_heads_q + num_heads_k]\n apply_rotary_pos_emb_qk_kernel[grid](q,\n k,\n cos,\n sin,\n q_embed,\n k_embed,\n seq_len=seq_len,\n stride_qs=q.stride(-3),\n stride_qh=q.stride(-2),\n stride_qd=q.stride(-1),\n stride_ks=k.stride(-3),\n stride_kh=k.stride(-2),\n stride_kd=k.stride(-1),\n stride_qes=q_embed.stride(-3),\n stride_qeh=q_embed.stride(-2),\n stride_qed=q_embed.stride(-1),\n stride_kes=k_embed.stride(-3),\n stride_keh=k_embed.stride(-2),\n stride_ked=k_embed.stride(-1),\n half_size=half_size,\n BLOCK=BLOCK,\n BLOCK_QH=num_heads_q,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=num_stages)\n\n return q_embed, k_embed\n", - "description_1": "Use triton language to define a kernel that applies rotary positional embeddings on query (Q) and key (K) tensors. This kernel takes 21 parameters including Q, K, COS, SIN, Q_EMB, K_EMB, and several stride and size constants. It calculates positional offsets, loads cosine and sine values, and applies rotary transformations on query and key vectors. The function apply_rotary_pos_emb wraps this kernel to facilitate the computation, taking 6 parameters (q, k, cos, sin, q_embed, k_embed) and setting up kernel execution parameters such as grid size and memory strides.", - "description_2": "Use triton language to implement a kernel for rotary positional embeddings on input tensors, and provide a wrapper function to manage inputs and kernel execution.", - "difficulty": 3 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n@triton.jit\ndef _get_unpacked_order(offs_n, elem_per_int: tl.constexpr):\n \"\"\"get unpacked order.\"\"\"\n origin_order = offs_n % elem_per_int\n unpacked_order = (origin_order & 1) * 4 + origin_order // 2\n return unpacked_order\n\n@triton.jit\ndef _broadcast_pack(weight, width: tl.constexpr):\n \"\"\"broadcast pack.\"\"\"\n broadcast_tmp = tl.arange(0, width)\n BLOCK_SIZE_K: tl.constexpr = weight.shape[0]\n BLOCK_SIZE_QN: tl.constexpr = weight.shape[1]\n BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width\n weight = tl.broadcast(weight[:, :, None], broadcast_tmp[None, None, :])\n weight = tl.reshape(weight, (BLOCK_SIZE_K, BLOCK_SIZE_N))\n return weight\n\n@triton.jit\ndef _unpack_weight(weight, order):\n \"\"\"unpack weight.\"\"\"\n weight = _broadcast_pack(weight, 8)\n weight = weight >> (order * 4)\n # cast to float16\n immLut = (0xf0 & 0xcc) | 0xaa\n BOTTOM_MASK = 0xf\n I4s_TO_F16s_MAGIC_NUM = 0x6400\n FP16_TOP_MAGIC_NUM = 0x6400\n weight = tl.inline_asm_elementwise(\n \"\"\"lop3.b32 $1, $1, $2, $3, $4;\n sub.f16x2 $1, $1, $5;\n mov.b32 {$0, _}, $1;\"\"\",\n '=h, r, n, n, n, r', [\n weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut,\n FP16_TOP_MAGIC_NUM\n ],\n dtype=tl.float16,\n is_pure=False,\n pack=1)\n return weight\n\n@triton.jit\ndef awq_linear_kernel(\n a_ptr,\n qw_ptr,\n s_ptr,\n qz_ptr,\n c_ptr,\n M,\n N: tl.constexpr,\n K: tl.constexpr,\n stride_am,\n stride_ak: tl.constexpr, #\n stride_wk: tl.constexpr,\n stride_wn: tl.constexpr, #\n stride_sk: tl.constexpr,\n stride_sn: tl.constexpr, #\n stride_zk: tl.constexpr,\n stride_zn: tl.constexpr, #\n stride_cm,\n stride_ck: tl.constexpr,\n stride_cn: tl.constexpr,\n # Meta-parameters\n M_NEXT_P2: tl.constexpr,\n Q_GROUP_SIZE: tl.constexpr,\n SPLIT_K_ITERS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr, #\n GROUP_SIZE_M: tl.constexpr, #\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n ELEM_PER_INT = 8\n if Q_GROUP_SIZE > BLOCK_SIZE_K:\n GROUP_SIZE_K: tl.constexpr = BLOCK_SIZE_K\n else:\n GROUP_SIZE_K: tl.constexpr = Q_GROUP_SIZE\n K_PER_GROUP: tl.constexpr = Q_GROUP_SIZE // GROUP_SIZE_K\n\n # -----------------------------------------------------------\n # Map program ids `pid` to the block of C it should compute.\n # This is done in a grouped ordering to promote L2 data reuse.\n # See above `L2 Cache Optimizations` section for details.\n pid = tl.program_id(axis=0)\n split_kid = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n # ----------------------------------------------------------\n # Create pointers for the first blocks of A and B.\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8\n offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN)\n offs_k = tl.arange(0, GROUP_SIZE_K)\n unpacked_order = _get_unpacked_order(offs_bn, ELEM_PER_INT)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am +\n offs_k[None, :] * stride_ak)\n qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk +\n offs_wn[None, :] * stride_wn)\n s_ptrs = s_ptr + offs_bn * stride_sn\n qz_ptrs = qz_ptr + offs_wn * stride_zn\n\n # split k\n NUM_K_BLOCKS = K // GROUP_SIZE_K\n K_PER_SPLIT = tl.cdiv(NUM_K_BLOCKS, SPLIT_K_ITERS)\n k_start = split_kid * K_PER_SPLIT\n k_last = min(k_start + K_PER_SPLIT, NUM_K_BLOCKS)\n a_ptrs += k_start * GROUP_SIZE_K * stride_ak\n qw_ptrs += k_start * GROUP_SIZE_K * stride_wk\n qg_id = k_start // K_PER_GROUP\n\n # -----------------------------------------------------------\n # Iterate to compute a block of the C matrix.\n # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n # of fp32 values for higher accuracy.\n # `accumulator` will be converted back to fp16 after the loop.\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n s = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty)\n zs = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty)\n\n # prefetch\n next_qw = tl.load(qw_ptrs)\n qw_ptrs += GROUP_SIZE_K * stride_wk\n\n for k in range(k_start, k_last):\n a = tl.load(a_ptrs)\n qw = next_qw\n if k + 1 < k_last:\n next_qw = tl.load(qw_ptrs)\n w = _unpack_weight(qw, unpacked_order)\n\n if k == k_start or k % K_PER_GROUP == 0:\n s = tl.load(s_ptrs + qg_id * stride_sk)[None, :]\n qz = tl.load(qz_ptrs + qg_id * stride_zk)[None, :]\n qg_id += 1\n z = _unpack_weight(qz, unpacked_order)\n zs = -z * s\n b = w * s + zs\n\n # We accumulate along the K dimension.\n accumulator += tl.dot(a, b)\n\n # Advance the ptrs to the next K block.\n a_ptrs += GROUP_SIZE_K * stride_ak\n qw_ptrs += GROUP_SIZE_K * stride_wk\n\n c = accumulator.to(tl.float16)\n\n # -----------------------------------------------------------\n # Write back the block of the output matrix C with masks.\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:,\n None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if stride_ck > 0:\n c_ptrs += split_kid * stride_ck\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\ndef awq_linear(x, qweight, scales, qzeros):\n \"\"\"awq linear.\"\"\"\n M = x.size(0)\n K = qweight.size(0)\n N = scales.size(1)\n SPLIT_K_ITERS = 4\n group_size = K // scales.size(0)\n\n def grid(META):\n \"\"\"grid.\"\"\"\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) *\n triton.cdiv(N, META['BLOCK_SIZE_N']), SPLIT_K_ITERS)\n\n out = scales.new_empty(M, SPLIT_K_ITERS, N)\n M_NEXT_P2 = triton.next_power_of_2(M)\n\n awq_linear_kernel[grid](\n # Pointers to matrices\n x,\n qweight,\n scales,\n qzeros,\n out,\n # Matrix dimensions\n M,\n N,\n K,\n stride_am=x.stride(0),\n stride_ak=x.stride(1), #\n stride_wk=qweight.stride(0),\n stride_wn=qweight.stride(1), #\n stride_sk=scales.stride(0),\n stride_sn=scales.stride(1), #\n stride_zk=qzeros.stride(0),\n stride_zn=qzeros.stride(1), #\n stride_cm=out.stride(0),\n stride_ck=out.stride(1),\n stride_cn=out.stride(2),\n # Meta-parameters\n M_NEXT_P2=M_NEXT_P2,\n Q_GROUP_SIZE=group_size,\n SPLIT_K_ITERS=SPLIT_K_ITERS)\n\n return out.sum(1)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with quantization support. The kernel function 'awq_linear_kernel' takes 30 parameters: 5 pointers to matrices (a_ptr, qw_ptr, s_ptr, qz_ptr, c_ptr), 3 matrix dimensions (M, N, K), 11 stride parameters for accessing matrix elements, and 11 meta-parameters for controlling the kernel execution. The kernel computes the product of matrices A and B, where A has shape (M, K), B has shape (K, N), and the result C has shape (M, N). The function 'awq_linear' is a wrapper that prepares the input data and calls the kernel with appropriate grid and meta-parameters.", - "description_2": "Use triton language to create a quantized matrix multiplication kernel with support for custom block sizes and group sizes, optimizing for L2 cache reuse and allowing for split-K iterations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef get_autotune_config():\n \"\"\"get autotune config.\"\"\"\n return [\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 256,\n 'BLOCK_SIZE_K': 128\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 16,\n 'BLOCK_SIZE_N': 256,\n 'BLOCK_SIZE_K': 128\n },\n num_stages=4,\n num_warps=4),\n ]\n\n\n@triton.autotune(\n configs=get_autotune_config(),\n key=['N', 'K'],\n)\n@triton.jit\ndef _fused_lora_kernel(\n a_ptr,\n lora_a_ptr,\n lora_b_ptr,\n c_ptr,\n scaling_ptr,\n rank_start_ptr,\n ranks_ptr,\n seq_start_ptr,\n seq_lens_ptr,\n adapter_ids_ptr,\n N: tl.constexpr,\n K: tl.constexpr,\n stride_am: tl.constexpr,\n stride_ak: tl.constexpr,\n stride_lar: tl.constexpr,\n stride_lak: tl.constexpr,\n stride_lbr: tl.constexpr,\n stride_lbn: tl.constexpr,\n stride_cm: tl.constexpr,\n stride_cn: tl.constexpr,\n BLOCK_SIZE_R: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"fused lora kernel.\"\"\"\n pid = tl.program_id(axis=0)\n bid = tl.program_id(axis=1)\n\n M = tl.load(seq_lens_ptr + bid)\n if M <= 0:\n return\n\n seq_start = tl.load(seq_start_ptr + bid)\n adapter_id = tl.load(adapter_ids_ptr + bid)\n rank_start = tl.load(rank_start_ptr + adapter_id)\n rank = tl.load(ranks_ptr + adapter_id)\n\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n GROUP_SIZE_M: tl.constexpr = 1\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n if pid_m * BLOCK_SIZE_M >= M:\n return\n\n offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))\n offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))\n\n mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M\n if rank == 0:\n offs_cm = offs_m\n offs_cn = offs_n\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = mask_cm[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, 0, mask=c_mask)\n return\n\n offs_am = (seq_start +\n (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)\n offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am +\n offs_k[None, :] * stride_ak)\n la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak +\n offs_r[None, :] * stride_lar)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,\n other=0.0)\n la = tl.load(la_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, la)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n la_ptrs += BLOCK_SIZE_K * stride_lak\n ar = accumulator.to(lora_b_ptr.dtype.element_ty)\n\n offs_lbn = offs_n % N\n lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr +\n offs_lbn * stride_lbn)\n lb = tl.load(lb_ptrs, mask=tl.arange(0, BLOCK_SIZE_R)[:, None] < rank)\n\n c = tl.dot(ar, lb)\n\n scaling = tl.load(scaling_ptr + adapter_id)\n c *= scaling\n\n c = c.to(c_ptr.dtype.element_ty)\n offs_cm = offs_m\n offs_cn = offs_n\n c_ptrs = c_ptr + stride_cm * offs_cm[:,\n None] + stride_cn * offs_cn[None, :]\n c_mask = mask_cm[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef fused_lora(input: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor,\n scaling: torch.LongTensor, rank_start: torch.LongTensor,\n ranks: torch.LongTensor, seq_start: torch.LongTensor,\n seq_lens: torch.LongTensor, adapter_ids: torch.LongTensor,\n max_rank: int, max_seqlen: int):\n \"\"\"fused lora.\"\"\"\n\n def grid(META):\n ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M']) *\n triton.cdiv(N, META['BLOCK_SIZE_N'])), batch_size)\n return ret\n\n assert input.dim() == 2\n batch_size = seq_lens.numel()\n M, K = input.shape\n N = lora_b.size(1)\n\n output = input.new_empty((M, N))\n\n BLOCK_SIZE_R = max(16, max_rank)\n _fused_lora_kernel[grid](\n input,\n lora_a,\n lora_b,\n output,\n scaling,\n rank_start,\n ranks,\n seq_start,\n seq_lens,\n adapter_ids,\n N,\n K,\n stride_am=input.stride(0),\n stride_ak=input.stride(1),\n stride_lar=lora_a.stride(0),\n stride_lak=lora_a.stride(1),\n stride_lbr=lora_b.stride(0),\n stride_lbn=lora_b.stride(1),\n stride_cm=output.stride(0),\n stride_cn=output.stride(1),\n BLOCK_SIZE_R=BLOCK_SIZE_R,\n )\n\n return output\n", - "description_1": "Use triton language to implement a fused kernel for the LoRA (Low-Rank Adaptation) mechanism, optimizing matrix multiplication with LoRA matrices. This involves efficiently loading, processing, and storing data using Triton programs, with 26 parameters defining pointers, constants, strides, and block sizes.", - "description_2": "Use triton language to create a fused kernel for optimized matrix multiplication using the LoRA technique, managing data through pointers and block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n A,\n B,\n C,\n SortedIdx,\n ExpStart,\n ExpEnd,\n Weights,\n N: tl.constexpr,\n K: tl.constexpr,\n stride_am: tl.constexpr,\n stride_ak: tl.constexpr,\n stride_be: tl.constexpr,\n stride_bn: tl.constexpr,\n stride_bk: tl.constexpr,\n stride_cm: tl.constexpr,\n stride_cn: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n M_NP2: tl.constexpr,\n ENABLE_WEIGHTS: tl.constexpr,\n top_k: tl.constexpr,\n expert_offset: tl.constexpr,\n reindex_a: tl.constexpr,\n reindex_c: tl.constexpr,\n):\n \"\"\"fused moe kernel.\"\"\"\n exp_id = tl.program_id(1)\n pid = tl.program_id(0)\n\n exp_start = tl.load(ExpStart + exp_id + expert_offset)\n exp_end = tl.load(ExpEnd + exp_id + expert_offset)\n M = exp_end - exp_start\n if M <= 0:\n return\n\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n if pid_m * BLOCK_SIZE_M >= M:\n return\n\n offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n mask_sid = offs_sid < exp_end\n sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n if reindex_a:\n offs_am = sid // top_k\n else:\n offs_am = offs_sid\n a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),\n BLOCK_SIZE_N)\n\n # deepseek has 160 experts, exp index would overflow int32\n exp_off = stride_be * exp_id.to(tl.int64)\n b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=mask_sid[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if ENABLE_WEIGHTS:\n weight = tl.load(Weights + sid, mask=mask_sid)\n accumulator = accumulator * weight[:, None].to(accumulator.dtype)\n\n c = accumulator.to(A.dtype.element_ty)\n\n if reindex_c:\n offs_cm = sid\n else:\n offs_cm = offs_sid\n c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]\n tl.store(c_ptrs, c, mask=mask_sid[:, None])\n\n\ndef fused_moe_kernel_launcher(\n A: torch.Tensor,\n B: torch.Tensor,\n C: torch.Tensor,\n sorted_idx: torch.Tensor,\n exp_start: torch.Tensor,\n exp_end: torch.Tensor,\n weights: torch.Tensor,\n enable_weights: bool = False,\n top_k: int = 1,\n num_tokens: int = None,\n expert_offset: int = 0,\n reindex_a: bool = True,\n reindex_c: bool = True,\n):\n \"\"\"fused moe kernel launcher.\"\"\"\n\n if num_tokens is None:\n num_tokens = A.size(0)\n M_NP2 = triton.next_power_of_2(num_tokens)\n M_NP2 = max(32, M_NP2)\n E, N, K = B.shape\n\n def _grid_fn(META):\n grid = (triton.cdiv(num_tokens, META['BLOCK_SIZE_M']) *\n triton.cdiv(N, META['BLOCK_SIZE_N']), E)\n return grid\n\n A = A.flatten(0, -2)\n C = C.flatten(0, -2)\n\n grid = _grid_fn\n kernel_meta = get_kernel_meta(A)\n fused_moe_kernel[grid](\n A,\n B,\n C,\n sorted_idx,\n exp_start,\n exp_end,\n weights,\n N=N,\n K=K,\n stride_am=A.stride(0),\n stride_ak=A.stride(1),\n stride_be=B.stride(0),\n stride_bn=B.stride(1),\n stride_bk=B.stride(2),\n stride_cm=C.stride(0),\n stride_cn=C.stride(1),\n ENABLE_WEIGHTS=enable_weights,\n top_k=top_k,\n expert_offset=expert_offset,\n reindex_a=reindex_a,\n reindex_c=reindex_c,\n M_NP2=M_NP2,\n **kernel_meta,\n )\n\n\n@triton.jit\ndef _start_end_kernel(TopkIdx, SortedIdx, ExpStart, ExpEnd,\n len_sorted_idx: int, num_experts: tl.constexpr,\n BLOCK: tl.constexpr):\n \"\"\"start end kernel.\"\"\"\n exp_id = tl.program_id(0)\n exp_start = -1\n cnt = 0\n\n s_off = tl.arange(0, BLOCK)\n\n # find start\n for sidx_start in range(0, len_sorted_idx, BLOCK):\n sidx_off = sidx_start + s_off\n sidx_mask = sidx_off < len_sorted_idx\n sidx = tl.load(SortedIdx + sidx_off, mask=sidx_mask, other=0)\n tidx = tl.load(TopkIdx + sidx, mask=sidx_mask, other=num_experts)\n tidx_mask = tidx == exp_id\n cnt += tl.sum(tidx_mask.to(tl.int32))\n if cnt > 0 and exp_start < 0:\n exp_start = sidx_start + tl.argmax(tidx_mask, axis=0)\n\n if exp_start < 0:\n exp_start *= 0\n exp_end = exp_start + cnt\n tl.store(ExpStart + exp_id, exp_start)\n tl.store(ExpEnd + exp_id, exp_end)\n\n\ndef get_start_end(topk_idx: torch.Tensor, sorted_idx: torch.Tensor,\n num_experts: int):\n \"\"\"get start and end.\n\n same process as:\n >>> exp_tok_cnt = F.one_hot(flatten_topk_ids, num_classes=E).sum(0)\n >>> exp_end = exp_tok_cnt.cumsum(0)\n >>> exp_start = exp_end - exp_tok_cnt\n \"\"\"\n start_end = sorted_idx.new_empty(2, num_experts)\n exp_start = start_end[0, :]\n exp_end = start_end[1, :]\n\n BLOCK = 128\n kernel_meta = get_kernel_meta(topk_idx)\n _start_end_kernel[(num_experts, )](\n topk_idx,\n sorted_idx,\n exp_start,\n exp_end,\n len_sorted_idx=sorted_idx.numel(),\n num_experts=num_experts,\n BLOCK=BLOCK,\n num_warps=4,\n num_stages=1,\n **kernel_meta,\n )\n\n return exp_start, exp_end\n\n\ndef fused_moe(hidden_states: torch.Tensor,\n w1: torch.Tensor,\n w2: torch.Tensor,\n topk_weights: torch.Tensor,\n topk_ids: torch.Tensor,\n topk: int,\n expert_offset: int = 0,\n num_experts: int = None,\n renormalize: bool = False) -> torch.Tensor:\n \"\"\"fused moe.\"\"\"\n M = hidden_states.size(0)\n E, N, _ = w1.shape\n full_exp = False\n if num_experts is None:\n num_experts = E\n elif num_experts == E:\n full_exp = True\n\n def __get_sorted_idx(topk_ids: torch.Tensor):\n flatten_topk_ids = topk_ids.flatten()\n sorted_idx = flatten_topk_ids.argsort()\n\n exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,\n num_experts)\n return sorted_idx, exp_start, exp_end\n\n if renormalize:\n topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n if not topk_weights.is_contiguous():\n topk_weights = topk_weights.contiguous()\n\n sorted_idx, exp_start, exp_end = __get_sorted_idx(topk_ids)\n\n if full_exp:\n intermediate_cache1 = hidden_states.new_empty((M, topk, N))\n else:\n intermediate_cache1 = hidden_states.new_zeros((M, topk, N))\n # gate and up\n fused_moe_kernel_launcher(\n hidden_states,\n w1,\n intermediate_cache1,\n sorted_idx=sorted_idx,\n exp_start=exp_start,\n exp_end=exp_end,\n weights=topk_weights,\n enable_weights=False,\n top_k=topk,\n num_tokens=M,\n expert_offset=expert_offset,\n reindex_a=True,\n reindex_c=False,\n )\n\n # activate\n unflat_size = intermediate_cache1.shape[:-1]\n intermediate_cache1 = intermediate_cache1.flatten(0, -2)\n gate_cache = silu_and_mul(intermediate_cache1)\n gate_cache = gate_cache.unflatten(0, unflat_size)\n\n if full_exp:\n intermediate_cache2 = hidden_states.new_empty((M, topk, w2.shape[1]))\n else:\n intermediate_cache2 = hidden_states.new_zeros((M, topk, w2.shape[1]))\n # down\n fused_moe_kernel_launcher(\n gate_cache,\n w2,\n intermediate_cache2,\n sorted_idx=sorted_idx,\n exp_start=exp_start,\n exp_end=exp_end,\n weights=topk_weights,\n enable_weights=True,\n top_k=1,\n num_tokens=M,\n expert_offset=expert_offset,\n reindex_a=False,\n reindex_c=True,\n )\n\n ret = intermediate_cache2.sum(dim=1)\n return ret\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel, which performs matrix multiplication and weighting for selected experts, along with supporting functions to determine start and end of expert sections.", - "description_2": "Use triton language to create a fused kernel for Mixture of Experts (MoE) operations, including sorting and expert section computations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef _fused_rotary_emb_kernel(\n Q, K, PostionIds, InvFreq, scaling_factor, OutQ, OutK, stride_bq,\n stride_sq, stride_hq: tl.constexpr, stride_dq: tl.constexpr, stride_bk,\n stride_sk, stride_hk: tl.constexpr, stride_dk: tl.constexpr, stride_bp,\n stride_sp, max_seq_len, BLOCK: tl.constexpr, BLOCK_HQ: tl.constexpr,\n BLOCK_HK: tl.constexpr, BLOCK_F: tl.constexpr):\n \"\"\"fused rotary emb kernel.\"\"\"\n batch_id = tl.program_id(0)\n seq_block_id = tl.program_id(1)\n\n s_off = seq_block_id * BLOCK + tl.arange(0, BLOCK)[:, None]\n f_off = tl.arange(0, BLOCK_F)[None, :]\n s_mask = s_off < max_seq_len\n\n bp_off = stride_bp * batch_id\n p_off = bp_off + stride_sp * s_off\n\n sq_off = batch_id * stride_bq + s_off * stride_sq\n q0_off = sq_off + f_off * stride_dq\n q1_off = q0_off + BLOCK_F * stride_dq\n\n sk_off = batch_id * stride_bk + s_off * stride_sk\n k0_off = sk_off + f_off * stride_dk\n k1_off = k0_off + BLOCK_F * stride_dk\n\n inv_freq = tl.load(InvFreq + f_off).to(tl.float32)\n position_ids = tl.load(PostionIds + p_off, mask=s_mask).to(tl.float32)\n position_ids = position_ids / scaling_factor\n\n # pos_freq = tl.dot(position_ids, inv_freq)\n pos_freq = position_ids * inv_freq\n cos = tl.cos(pos_freq).to(Q.dtype.element_ty)\n sin = tl.sin(pos_freq).to(Q.dtype.element_ty)\n\n for h in range(BLOCK_HQ):\n q0 = tl.load(Q + q0_off + h * stride_hq, mask=s_mask)\n q1 = tl.load(Q + q1_off + h * stride_hq, mask=s_mask)\n q0_out = q0 * cos - q1 * sin\n tl.store(OutQ + q0_off + h * stride_hq, q0_out, mask=s_mask)\n q1_out = q1 * cos + q0 * sin\n tl.store(OutQ + q1_off + h * stride_hq, q1_out, mask=s_mask)\n\n for h in range(BLOCK_HK):\n k0 = tl.load(K + k0_off + h * stride_hk, mask=s_mask)\n k1 = tl.load(K + k1_off + h * stride_hk, mask=s_mask)\n k0_out = k0 * cos - k1 * sin\n tl.store(OutK + k0_off + h * stride_hk, k0_out, mask=s_mask)\n k1_out = k1 * cos + k0 * sin\n tl.store(OutK + k1_off + h * stride_hk, k1_out, mask=s_mask)\n\n\ndef fused_rotary_emb(q: Tensor,\n k: Tensor,\n position_ids: torch.LongTensor,\n inv_freq: Tensor,\n scaling_factor: float,\n out_q: Tensor = None,\n out_k: Tensor = None):\n \"\"\"Fuse `rotary_embedding` and `apply_rotary_pos_emb`.\"\"\"\n\n if out_q is None:\n out_q = torch.empty_like(q)\n else:\n assert q.stride() == out_q.stride()\n if out_k is None:\n out_k = torch.empty_like(k)\n else:\n assert k.stride() == out_k.stride()\n\n assert q.dim() == 4\n assert k.dim() == 4\n assert q.size(0) == position_ids.size(0)\n\n BLOCK = 32\n BLOCK_HQ = q.size(-2)\n BLOCK_HK = k.size(-2)\n BLOCK_F = q.size(-1) // 2\n batch_size = q.size(0)\n max_seq_len = q.size(1)\n kernel_meta = get_kernel_meta(q)\n num_warps = 4\n\n grid = (batch_size, triton.cdiv(max_seq_len, BLOCK))\n _fused_rotary_emb_kernel[grid](q,\n k,\n position_ids,\n inv_freq,\n scaling_factor,\n out_q,\n out_k,\n stride_bq=q.stride(0),\n stride_sq=q.stride(1),\n stride_hq=q.stride(2),\n stride_dq=q.stride(3),\n stride_bk=k.stride(0),\n stride_sk=k.stride(1),\n stride_hk=k.stride(2),\n stride_dk=k.stride(3),\n stride_bp=position_ids.stride(0),\n stride_sp=position_ids.stride(1),\n max_seq_len=max_seq_len,\n BLOCK=BLOCK,\n BLOCK_HQ=BLOCK_HQ,\n BLOCK_HK=BLOCK_HK,\n BLOCK_F=BLOCK_F,\n num_warps=num_warps,\n num_stages=1,\n **kernel_meta)\n\n return out_q, out_k\n", - "description_1": "Use triton language to define a kernel function '_fused_rotary_emb_kernel' with 22 parameters including tensors Q, K, PostionIds, InvFreq, float scaling_factor, tensors OutQ, OutK, and integer strides. This kernel applies rotary embeddings to the input tensors and outputs transformed tensors OutQ and OutK. The kernel uses batch and sequence block IDs for parallel processing, loading tensor slices based on calculated offsets and applying frequency-based sinusoidal transformations. The function 'fused_rotary_emb' is a wrapper with 7 parameters including input tensors q, k, position_ids, inv_freq, float scaling_factor, and optional output tensors out_q, out_k. It prepares necessary configuration for launching the kernel with defined grid and block sizes.", - "description_2": "Use triton language to create a fused rotary embedding kernel with 22 parameters that applies frequency-based sinusoidal transformations on input tensors for optimized GPU execution, and a wrapper function with 7 parameters that configures and launches the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef tanh(x):\n \"\"\"tanh.\"\"\"\n return 2 * tl.sigmoid(2 * x) - 1\n\nfast_expf = tl.math.exp\nfast_dividef = tl.math.fdiv\n\n@triton.autotune(configs=[\n triton.Config({}, num_stages=2, num_warps=16),\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=4),\n],\n key=['BLOCK_H', 'BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV'])\n@triton.jit\ndef _fwd_grouped_split_kernel(\n Q,\n K,\n V,\n sm_scale,\n KV_seqlens,\n Block_offsets,\n Acc_out,\n stride_qbs: tl.constexpr,\n stride_qh: tl.constexpr,\n stride_qd: tl.constexpr,\n stride_kp: tl.constexpr,\n stride_kbs: tl.constexpr,\n stride_kh: tl.constexpr,\n stride_kd: tl.constexpr,\n stride_vp: tl.constexpr,\n stride_vbs: tl.constexpr,\n stride_vh: tl.constexpr,\n stride_vd: tl.constexpr,\n stride_ok: tl.constexpr,\n stride_obs: tl.constexpr,\n stride_oh: tl.constexpr,\n stride_od: tl.constexpr,\n stride_boffb,\n kv_group_num: tl.constexpr,\n window_size: tl.constexpr,\n head_size: tl.constexpr,\n head_size_v: tl.constexpr,\n num_heads_q: tl.constexpr,\n logit_softcapping: tl.constexpr,\n SPLIT_K: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_H: tl.constexpr,\n BLOCK_DMODEL1: tl.constexpr,\n):\n \"\"\"first step kernel of split k attention.\"\"\"\n # Kernel implementation here\n\n@triton.autotune(configs=[\n triton.Config({}, num_stages=2, num_warps=16),\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=4),\n],\n key=['BLOCK_H', 'BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV'])\n@triton.jit\ndef _fwd_grouped_split_quant_kernel(\n Q,\n K,\n V,\n KScalesZeros,\n VScalesZeros,\n sm_scale,\n KV_seqlens,\n Block_offsets,\n Acc_out,\n stride_qbs: tl.constexpr,\n stride_qh: tl.constexpr,\n stride_qd: tl.constexpr,\n stride_kp: tl.constexpr,\n stride_kbs: tl.constexpr,\n stride_kh: tl.constexpr,\n stride_kd: tl.constexpr,\n stride_vp: tl.constexpr,\n stride_vbs: tl.constexpr,\n stride_vh: tl.constexpr,\n stride_vd: tl.constexpr,\n stride_kszp: tl.constexpr,\n stride_kszbs: tl.constexpr,\n stride_kszh: tl.constexpr,\n stride_kszd: tl.constexpr,\n stride_vszp: tl.constexpr,\n stride_vszbs: tl.constexpr,\n stride_vszh: tl.constexpr,\n stride_vszd: tl.constexpr,\n quant_policy: tl.constexpr,\n stride_ok: tl.constexpr,\n stride_obs: tl.constexpr,\n stride_oh: tl.constexpr,\n stride_od: tl.constexpr,\n stride_boffb,\n kv_group_num: tl.constexpr,\n window_size: tl.constexpr,\n head_size: tl.constexpr,\n head_size_v: tl.constexpr,\n num_heads_q: tl.constexpr,\n logit_softcapping: tl.constexpr,\n SPLIT_K: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_H: tl.constexpr,\n BLOCK_DMODEL1: tl.constexpr,\n):\n \"\"\"first step kernel of split k attention with quantization.\"\"\"\n # Kernel implementation here\n\n@triton.jit\ndef _reduce_split_kernel(\n Acc,\n Out,\n stride_ak,\n stride_abs,\n stride_ah,\n stride_ad,\n stride_obs,\n stride_oh,\n stride_od,\n head_size_v: tl.constexpr,\n SPLIT_K: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n):\n \"\"\"second step kernel of split k attention.\"\"\"\n # Kernel implementation here\n\ndef paged_attention_fwd(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n o: torch.Tensor,\n block_offsets: torch.Tensor,\n q_start_loc: torch.Tensor,\n q_seqlens: torch.Tensor,\n kv_seqlens: torch.Tensor,\n max_seqlen: int,\n k_scales_zeros: torch.Tensor = None,\n v_scales_zeros: torch.Tensor = None,\n quant_policy: int = 0,\n window_size: int = None,\n sm_scale: float = None,\n logit_softcapping: float = None,\n):\n \"\"\"Paged Attention forward.\"\"\"\n # Function implementation here\n", - "description_1": "Use triton language to implement kernels for split k attention with optional quantization and a paged attention forward pass. The primary kernels are _fwd_grouped_split_kernel and _fwd_grouped_split_quant_kernel for computing initial attention and _reduce_split_kernel for reducing results. The paged_attention_fwd function orchestrates these kernels to perform the attention operation, handling various parameters like query, key, value tensors and their dimensions.", - "description_2": "Use triton language to create kernels for split k attention with quantization, performing a paged attention forward pass. The process includes initial computation and reduction of attention scores, managed by a central function that configures and invokes these kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\nfrom .triton_utils import get_kernel_meta\n\n@triton.jit\ndef _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):\n \"\"\"compute rms norm.\"\"\"\n xf = x.to(tl.float32)\n var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)\n out = xf * tl.math.rsqrt(var + eps)\n out = (w * out).to(x.dtype)\n return out\n\n@triton.jit\ndef rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,\n eps: tl.constexpr, N_COLS: tl.constexpr,\n BLOCK_N: tl.constexpr):\n \"\"\"rms norm kernel.\"\"\"\n prog_id = tl.program_id(0)\n offsets = tl.arange(0, BLOCK_N)\n w = tl.load(weight + offsets, mask=offsets < N_COLS)\n x_ptr = input + prog_id * input_row_stride\n x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)\n out = _compute_rms_norm(x, w, eps, N_COLS)\n out_ptr = output + prog_id * input_row_stride\n tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)\n\n@triton.jit\ndef add_rms_norm_kernel(input, weight, residual, output, out_residual,\n input_row_stride: tl.constexpr,\n residual_row_stride: tl.constexpr, eps: tl.constexpr,\n N_COLS: tl.constexpr, BLOCK_N: tl.constexpr):\n \"\"\"rms norm kernel with additional residual.\"\"\"\n prog_id = tl.program_id(0)\n offsets = tl.arange(0, BLOCK_N)\n w = tl.load(weight + offsets, mask=offsets < N_COLS)\n x_ptr = input + prog_id * input_row_stride\n x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)\n res_ptr = residual + prog_id * residual_row_stride\n res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)\n new_x = x + res\n out_res_ptr = out_residual + prog_id * residual_row_stride\n tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)\n out = _compute_rms_norm(new_x, w, eps, N_COLS)\n out_ptr = output + prog_id * input_row_stride\n tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)\n\ndef rms_norm(hidden_states: Tensor,\n weight: Tensor,\n eps: float = 1e-6,\n residual: Tensor = None,\n out: Tensor = None,\n out_residual: Tensor = None):\n \"\"\"rms norm function calling Triton kernels.\"\"\"\n if not hidden_states.is_contiguous():\n hidden_states = hidden_states.contiguous()\n\n feat_size = weight.shape[0]\n seq_len = hidden_states.numel() // hidden_states.size(-1)\n input_stride = hidden_states.stride(-2)\n\n BLOCK_N = triton.next_power_of_2(feat_size)\n\n if out is None:\n out = torch.empty_like(hidden_states)\n\n kernel_meta = get_kernel_meta(hidden_states)\n grid = (seq_len, )\n\n if residual is None:\n rms_norm_kernel[grid](hidden_states,\n weight,\n out,\n input_row_stride=input_stride,\n eps=eps,\n N_COLS=feat_size,\n BLOCK_N=BLOCK_N,\n num_warps=4,\n num_stages=2,\n **kernel_meta)\n return out\n else:\n if out_residual is None:\n out_residual = torch.empty_like(hidden_states)\n\n res_stride = residual.stride(-2)\n add_rms_norm_kernel[grid](hidden_states,\n weight,\n residual,\n out,\n out_residual,\n input_row_stride=input_stride,\n residual_row_stride=res_stride,\n eps=eps,\n N_COLS=feat_size,\n BLOCK_N=BLOCK_N,\n num_warps=4,\n num_stages=2,\n **kernel_meta)\n return out, out_residual\n", - "description_1": "Use triton language to implement rms norm kernels. The kernels handle two cases: one without residual addition and one with residual addition. The _compute_rms_norm kernel takes four arguments: x (Tensor), w (Tensor), eps (float, constant expression), and N_COLS (int, constant expression) to compute the RMS norm. The rms_norm_kernel function has seven parameters: input (Tensor), weight (Tensor), output (Tensor), input_row_stride (int, constant expression), eps (float, constant expression), N_COLS (int, constant expression), and BLOCK_N (int, constant expression). It computes the rms norm of the input without a residual. The add_rms_norm_kernel function has ten parameters: input (Tensor), weight (Tensor), residual (Tensor), output (Tensor), out_residual (Tensor), input_row_stride (int, constant expression), residual_row_stride (int, constant expression), eps (float, constant expression), N_COLS (int, constant expression), and BLOCK_N (int, constant expression). It computes the rms norm of the input with an added residual.", - "description_2": "Use triton language to implement an rms norm operation with optional residual addition. The kernels are parameterized for tensor inputs and include both cases of including and excluding residual contributions in computations.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom .triton_utils import get_kernel_meta\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_N': 64,\n 'BLOCK_K': 128,\n },\n num_stages=4,\n num_warps=4),\n triton.Config({\n 'BLOCK_N': 128,\n 'BLOCK_K': 128,\n },\n num_stages=4,\n num_warps=4)\n ],\n key=['N', 'K'],\n)\n@triton.jit\ndef _linear(\n A,\n B,\n C,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n rms_scale_ptr,\n linear_scale_ptr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n offs_k = tl.arange(0, BLOCK_K)\n a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_K * stride_ak\n b_ptrs += BLOCK_K * stride_bk\n c = accumulator.to(tl.float32)\n\n rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]\n linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]\n c = c * rms_scale * linear_scale\n\n offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\n 'BLOCK_N': 64,\n 'BLOCK_K': 128,\n },\n num_stages=4,\n num_warps=4),\n triton.Config({\n 'BLOCK_N': 128,\n 'BLOCK_K': 128,\n },\n num_stages=4,\n num_warps=4)\n ],\n key=['N', 'K'],\n)\n@triton.jit\ndef _linear_add(\n A,\n B,\n C,\n residual_ptr,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n rms_scale_ptr,\n linear_scale_ptr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n offs_k = tl.arange(0, BLOCK_K)\n a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_K * stride_ak\n b_ptrs += BLOCK_K * stride_bk\n c = accumulator.to(tl.float32)\n\n rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]\n linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]\n c = c * rms_scale * linear_scale\n c = c.to(residual_ptr.dtype.element_ty)\n\n offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n residual_ptrs = (residual_ptr + stride_cm * offs_cm[:, None] +\n stride_cn * offs_cn[None, :])\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n residual = tl.load(residual_ptrs, mask=c_mask, other=0.)\n c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n tl.store(c_ptrs, c + residual, mask=c_mask)\n\n\ndef matmul_kernel_dynamic_quant(a,\n b,\n rms_scale,\n linear_scale,\n residual=None,\n bias=None,\n output_dtype=torch.float16):\n assert a.shape[-1] == b.shape[-1]\n assert b.ndim == 2 and b.is_contiguous()\n M = a.numel() // a.shape[-1]\n N, K = b.shape\n c_shape = a.shape[:-1] + (N, )\n if residual is not None:\n assert residual.shape == c_shape\n assert residual.is_contiguous()\n c = a.new_empty(c_shape, dtype=output_dtype)\n\n BLOCK_M = 128\n if M < BLOCK_M:\n BLOCK_M = triton.next_power_of_2(M)\n BLOCK_M = max(BLOCK_M, 16)\n\n def grid(META):\n return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), )\n\n kernel_meta = get_kernel_meta(a)\n if residual is not None:\n _linear_add[grid](a,\n b,\n c,\n residual,\n M,\n N,\n K,\n a.stride(-2),\n a.stride(-1),\n b.stride(1),\n b.stride(0),\n c.stride(-2),\n c.stride(-1),\n BLOCK_M=BLOCK_M,\n GROUP_SIZE_M=8,\n rms_scale_ptr=rms_scale,\n linear_scale_ptr=linear_scale,\n **kernel_meta)\n else:\n _linear[grid](a,\n b,\n c,\n M,\n N,\n K,\n a.stride(-2),\n a.stride(-1),\n b.stride(1),\n b.stride(0),\n c.stride(-2),\n c.stride(-1),\n BLOCK_M=BLOCK_M,\n GROUP_SIZE_M=8,\n rms_scale_ptr=rms_scale,\n linear_scale_ptr=linear_scale,\n **kernel_meta)\n if bias is not None:\n c += bias\n\n return c\n\n\n@triton.jit\ndef _per_token_quant_int8(\n y_ptr,\n y_q_ptr,\n y_s_ptr,\n y_stride,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK: tl.constexpr,\n):\n row = tl.program_id(0)\n y_ptr += row * y_stride\n y_q_ptr += row * y_stride\n y_s_ptr += row\n\n cols = tl.arange(0, BLOCK) # N <= BLOCK\n mask = cols < N\n\n y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32)\n _absmax = tl.maximum(tl.max(tl.abs(y)), eps)\n y_s = _absmax / 127\n y_q = tl.math.round(y / y_s).to(tl.int8)\n\n tl.store(y_q_ptr + cols, y_q, mask=mask)\n tl.store(y_s_ptr, y_s)\n\n\ndef per_token_quant_int8(x, eps):\n x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)\n M = x.numel() // x.shape[-1]\n N = x.shape[-1]\n x_s = torch.empty(x.shape[:-1] + (1, ),\n device=x.device,\n dtype=torch.float32)\n BLOCK = triton.next_power_of_2(N)\n num_warps = min(max(BLOCK // 256, 1), 8)\n kernel_meta = get_kernel_meta(x)\n _per_token_quant_int8[(M, )](x,\n x_q,\n x_s,\n x.stride(-2),\n N,\n eps,\n BLOCK=BLOCK,\n num_warps=num_warps,\n **kernel_meta)\n\n return x_q, x_s\n\n\n@triton.jit\ndef _rms_norm_fwd_fused_dynamic_symmetric(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n Scale, # pointer to the scales of the output activation\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n _var = x * x\n var = tl.sum(_var, axis=0) / N\n rstd = tl.math.rsqrt(var + eps)\n\n w = tl.load(W + cols, mask=mask)\n x_hat = x * rstd\n y = x_hat * w\n\n scale = tl.max(tl.abs(y)).to(tl.float32) / 127\n tl.store(Scale + row, scale)\n\n y = tl.math.round(y / scale)\n y = tl.minimum(y, 127)\n y = tl.maximum(y, -128)\n tl.store(Y + cols, y, mask=mask)\n\n\ndef rms_norm_dynamic_quant(x, w, eps):\n x_arg = x.flatten(0, -2)\n y = torch.empty_like(x, dtype=torch.int8)\n M, K = x_arg.shape\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K))\n if K > BLOCK_SIZE:\n raise RuntimeError(\n \"This rms norm doesn't support feature dim >= 64KB.\")\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)\n kernel_meta = get_kernel_meta(x_arg)\n _rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg,\n y,\n w,\n scale,\n x_arg.stride(0),\n K,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n **kernel_meta)\n return y, scale\n", - "description_1": "Use triton language to implement a linear operation with optional residual addition, per-token quantization, and RMS normalization with dynamic quantization. The linear operation kernels (_linear and _linear_add) take matrices A and B, perform a dot product, and store the result in matrix C. The per-token quantization kernel (_per_token_quant_int8) quantizes a tensor into signed 8-bit integers. The RMS normalization kernel (_rms_norm_fwd_fused_dynamic_symmetric) normalizes input tensor X using RMS and applies dynamic symmetric quantization.", - "description_2": "Use triton language to implement matrix multiplication with optional residual addition and perform per-token quantization and RMS normalization with dynamic quantization.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n LAST_K_BLOCK: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n BLOCK_N: tl.constexpr,\n D_HEAD: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +\n k_block_col_idx * layout_col_stride_m).to(tl.int32)\n start_n = k_block_id * BLOCK_N\n if LAST_K_BLOCK:\n if EVEN_D:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=offs_n[None, :] + start_n < k_seqlen,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=(offs_n[None, :] + start_n < k_seqlen) &\n (offs_d[:, None] < D_HEAD),\n )\n else:\n if EVEN_D:\n k = tl.load(k_ptrs + start_n * stride_kt)\n else:\n k = tl.load(k_ptrs + start_n * stride_kt,\n mask=offs_d[:, None] < D_HEAD)\n\n qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n if LAST_K_BLOCK | M_LT_N:\n qk += tl.where(\n offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),\n 0,\n float(\"-inf\"),\n )\n\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n p = tl.math.exp2(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n m_i = m_ij\n l_i = l_i * alpha + l_ij\n\n p = p.to(Q.dtype.element_ty)\n if LAST_K_BLOCK:\n if EVEN_D:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=offs_n[:, None] + start_n < k_seqlen,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=(offs_n[:, None] + start_n < k_seqlen) &\n (offs_d[None, :] < D_HEAD),\n )\n else:\n if EVEN_D:\n v = tl.load(v_ptrs + start_n * stride_vt)\n else:\n v = tl.load(v_ptrs + start_n * stride_vt,\n mask=offs_d[None, :] < D_HEAD)\n\n acc += tl.dot(p, v)\n\n return acc, l_i, m_i\n\n\n@triton.heuristics({\n \"M_LT_N\":\n lambda kwargs: kwargs[\"BLOCK_M\"] < kwargs[\"BLOCK_N\"],\n})\n@triton.jit\ndef _fwd_kernel_batch_inference(\n Q,\n K,\n V,\n Out,\n sm_scale,\n q_batch_starts,\n q_batch_ends,\n k_batch_starts,\n k_batch_ends,\n q_batch_ids,\n q_start_sids,\n stride_qb,\n stride_qt,\n stride_qh,\n stride_qd,\n stride_kb,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_vb,\n stride_vt,\n stride_vh,\n stride_vd,\n stride_ob,\n stride_ot,\n stride_oh,\n stride_od,\n layout_crow_ptr,\n layout_col_ptr,\n layout_crow_stride_h,\n layout_crow_stride_m,\n layout_col_stride_h,\n layout_col_stride_m,\n q_k_ratio,\n HAS_BATCH_DIM: tl.constexpr,\n D_HEAD: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n off_zm = tl.program_id(0)\n off_h = tl.program_id(1)\n\n off_h_for_kv = off_h // q_k_ratio\n\n if HAS_BATCH_DIM:\n off_z = tl.program_id(2)\n Q += off_z * stride_qb\n K += off_z * stride_kb\n V += off_z * stride_vb\n Out += off_z * stride_ob\n start_m = off_zm\n q_start_sid = start_m * BLOCK_M\n else:\n off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)\n q_start_sid = tl.load(q_start_sids + off_zm)\n start_m = q_start_sid // BLOCK_M\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_D)\n\n q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)\n q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start\n k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)\n k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start\n past_len = k_seqlen - q_seqlen\n\n Q += q_cu_start * stride_qt + off_h * stride_qh\n K += k_cu_start * stride_kt + off_h_for_kv * stride_kh\n V += k_cu_start * stride_vt + off_h_for_kv * stride_vh\n Out += q_cu_start * stride_ot + off_h * stride_oh\n\n q_pbid = (past_len + q_start_sid) // BLOCK_M\n\n if EVEN_D:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n other=0,\n )\n\n sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +\n q_pbid * layout_crow_stride_m)\n\n k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)\n k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)\n\n m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)\n\n k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd\n v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd\n\n sm_scale *= (\n 1.44269504\n )\n\n for k_block_col_idx in range(k_block_start, k_block_end - 1):\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n False,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_end - 1,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n True,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n\n if EVEN_D:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n )\n", - "description_1": "Use triton language to create kernels for batched inference of sparse flash attention. The '_fwd_kernel_inner' kernel computes attention weights and updates accumulated attention. It takes 26 inputs, including shared memory pointers, strides, and constants that define the dimensions of the blocks being processed. The '_fwd_kernel_batch_inference' kernel orchestrates the entire process by handling the batching logic and preparing data for the inner kernel. It accepts 39 parameters with specifics on input tensors (Q, K, V, Out), their strides, and batching indices for query and key-value matrices.", - "description_2": "Use triton language to create a flash attention mechanism that efficiently handles variable sequence lengths with sparse data. Implement separate kernels for inner computation of attention weights and batch processing of these computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom vllm.platforms import current_platform\n\nif triton.__version__ >= \"2.1.0\":\n\n # Kernel function for forward attention\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n k_scale,\n v_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SLIDING_WINDOW: tl.constexpr,\n ):\n # Kernel logic here...\n\n return\n\n # Kernel function for forward attention with alibi\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n k_scale,\n v_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Kernel logic here...\n\n return\n\n # Function to perform context attention\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n kv_cache_dtype: str,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n k_scale: float = 1.0,\n v_scale: float = 1.0,\n alibi_slopes=None,\n sliding_window=None):\n # Function logic here...\n\n if alibi_slopes is not None:\n _fwd_kernel_alibi[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n k_scale,\n v_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n alibi_slopes,\n v_cache.shape[3],\n k_cache.shape[4],\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(4),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_DMODEL_PADDED=Lk_padded,\n BLOCK_N=BLOCK,\n num_warps=NUM_WARPS,\n num_stages=1,\n )\n return\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n k_scale,\n v_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n v_cache.shape[3],\n k_cache.shape[4],\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(4),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_DMODEL_PADDED=Lk_padded,\n BLOCK_N=BLOCK,\n SLIDING_WINDOW=sliding_window,\n num_warps=NUM_WARPS,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement kernels for forward attention with different configurations: standard, with sliding window, and with alibi bias. The functions take numerous parameters, including query, key, value matrices, caches, and configuration constants to perform efficient attention calculations in a block-wise manner, considering masking and scaling.", - "description_2": "Use triton language to write a context attention forward function that utilizes different kernels based on the presence of alibi slopes and efficiently handles various scaling and masking strategies.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef cdiv_fn(x, y):\n return (x + y - 1) // y\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n if first and second:\n tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)\n elif first:\n tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)\n elif second:\n tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)\n else:\n tensor = tl.load(block_ptr)\n return tensor\n\n@triton.jit\ndef _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n actual_seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n OFFS_M: tl.constexpr,\n OFFS_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n MASK_STEPS: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n):\n # loop over k, v, and update accumulator\n for start_n in range(block_min, block_max, BLOCK_N):\n # For padded blocks, we will overrun the tensor size if\n # we load all BLOCK_N. For others, the blocks are all within range.\n k = load_fn(\n K_block_ptr,\n PADDED_HEAD,\n MASK_STEPS and (n_extra_tokens != 0),\n \"zero\",\n )\n if PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MASK_STEPS:\n if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n boundary_m = tl.full([BLOCK_M],\n actual_seqlen_k,\n dtype=tl.int32)\n size_n = start_n + OFFS_N[None, :]\n mask = size_n < boundary_m[:, None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if IS_CAUSAL:\n causal_boundary = start_n + offs_n_causal\n causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n qk = tl.where(causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n if bias_ptr is not None:\n bias = load_fn(bias_ptr, False, MASK_STEPS\n and (n_extra_tokens != 0), \"zero\")\n qk += bias * 1.44269504089\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = (batch_philox_offset +\n start_m * BLOCK_M * actual_seqlen_k + start_n -\n BLOCK_N)\n keep = dropout_mask(\n philox_seed,\n philox_offset,\n dropout_p,\n BLOCK_M,\n BLOCK_N,\n actual_seqlen_k,\n )\n if RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n tl.where(keep, p,\n -p).to(encoded_softmax_block_ptr.type.element_ty),\n )\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n p.to(encoded_softmax_block_ptr.type.element_ty),\n )\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": True,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 64,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 32,\n \"BLOCK_N\": 32,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 16,\n \"BLOCK_N\": 16,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n ],\n key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],\n)\n@triton.jit\ndef attn_fwd(\n Q,\n K,\n V,\n bias,\n sm_scale,\n L,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n stride_bz,\n stride_bh,\n stride_bm,\n stride_bn,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n HQ: tl.constexpr,\n HK: tl.constexpr,\n ACTUAL_BLOCK_DMODEL: tl.constexpr,\n MAX_SEQLENS_Q: tl.constexpr,\n MAX_SEQLENS_K: tl.constexpr,\n VARLEN: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_h_q = tl.program_id(1)\n off_z = tl.program_id(2)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n if VARLEN:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M > seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n else:\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = MAX_SEQLENS_Q\n seqlen_k = MAX_SEQLENS_K\n n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n if IS_CAUSAL:\n n_blocks_seqlen = cdiv_fn(\n (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)\n n_blocks = min(n_blocks, n_blocks_seqlen)\n if n_blocks <= 0:\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n return\n\n GROUP_SIZE: tl.constexpr = HQ // HK\n off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q\n\n n_extra_tokens = 0\n if seqlen_k < BLOCK_N:\n n_extra_tokens = BLOCK_N - seqlen_k\n elif seqlen_k % BLOCK_N:\n n_extra_tokens = seqlen_k % BLOCK_N\n padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL\n\n q_offset = (off_z * stride_qz + off_h_q * stride_qh +\n cu_seqlens_q_start * stride_qm)\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_offset = (off_z * stride_kz + off_h_k * stride_kh +\n cu_seqlens_k_start * stride_kn)\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n v_offset = (off_z * stride_vz + off_h_k * stride_vh +\n cu_seqlens_k_start * stride_vk)\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n if BIAS_TYPE != 0:\n bias_ptr = tl.make_block_ptr(\n base=bias + off_h_q * stride_bh,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_bm, stride_bn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n bias_ptr = None\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base \\\n + (off_z * HQ + off_h_q) \\\n * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n encoded_softmax_block_ptr = 0\n m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504089\n q = load_fn(Q_block_ptr, True, padded_head, \"zero\")\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n padded_block_k = n_extra_tokens != 0\n is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n if IS_CAUSAL:\n masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n else:\n masked_blocks = padded_block_k\n masked_blocks = min(masked_blocks, n_blocks)\n n_full_blocks = n_blocks - masked_blocks\n block_min = 0\n block_max = n_blocks * BLOCK_N\n if n_full_blocks > 0:\n block_max = (n_blocks - masked_blocks) * BLOCK_N\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n 0,\n 0,\n 0,\n bias_ptr,\n False,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n False,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n block_min = block_max\n block_max = n_blocks * BLOCK_N\n\n tl.debug_barrier()\n if masked_blocks > 0:\n offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0\n K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, n_full_blocks))\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n True,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n end_m_idx = (start_m + 1) * BLOCK_M\n start_m_idx = start_m * BLOCK_M\n causal_start_idx = seqlen_q - seqlen_k\n acc = acc.to(Out.type.element_ty)\n if IS_CAUSAL:\n if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n out_mask_boundary = tl.full((BLOCK_DMODEL, ),\n causal_start_idx,\n dtype=tl.int32)\n mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n out_ptrs_mask = (mask_m_offsets[:, None] >=\n out_mask_boundary[None, :])\n z = 0.0\n acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n tl.store(O_block_ptr, acc, boundary_check=(0, 1))\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n q,\n k,\n v,\n o,\n cu_seqlens_q,\n cu_seqlens_k,\n max_seqlens_q,\n max_seqlens_k,\n causal=False,\n sm_scale=1.0,\n bias=None,\n ):\n if o is None:\n o = torch.empty_like(q, dtype=v.dtype)\n\n check_args(\n q,\n k,\n v,\n o,\n varlen=True,\n cu_seqlens_q=cu_seqlens_q,\n cu_seqlens_k=cu_seqlens_k,\n )\n if True:\n total_q, nheads_q, head_size = q.shape\n total_k, nheads_k, _ = k.shape\n batch = len(cu_seqlens_q) - 1\n q_strides = (0, q.stride(1), q.stride(0), q.stride(2))\n k_strides = (0, k.stride(1), k.stride(0), k.stride(2))\n v_strides = (0, v.stride(1), v.stride(0), v.stride(2))\n o_strides = (0, o.stride(1), o.stride(0), o.stride(2))\n else:\n batch, seqlen_q, nheads_q, head_size = q.shape\n _, seqlen_k, nheads_k, _ = k.shape\n q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))\n k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))\n v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))\n o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))\n\n unpadded_head_dims = {32, 64, 128, 256}\n if head_size not in unpadded_head_dims:\n padded_d_model = None\n for i in unpadded_head_dims:\n if i > head_size:\n padded_d_model = i\n break\n assert padded_d_model is not None\n else:\n padded_d_model = head_size\n\n grid = lambda META: (\n triton.cdiv(max_seqlens_q, META[\"BLOCK_M\"]),\n nheads_q,\n batch,\n )\n\n encoded_softmax = None\n\n philox_seed = 0x1BF52\n philox_offset = 0x1D4B42\n\n if bias is not None:\n bias_strides = (\n bias.stride(0),\n bias.stride(1),\n bias.stride(2),\n bias.stride(3),\n )\n else:\n bias_strides = (0, 0, 0, 0)\n\n attn_fwd[grid](\n q,\n k,\n v,\n bias,\n sm_scale,\n None,\n o,\n *q_strides,\n *k_strides,\n *v_strides,\n *o_strides,\n *bias_strides,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p=0.0,\n philox_seed=philox_seed,\n philox_offset_base=philox_offset,\n encoded_softmax=encoded_softmax,\n HQ=nheads_q,\n HK=nheads_k,\n ACTUAL_BLOCK_DMODEL=head_size,\n MAX_SEQLENS_Q=max_seqlens_q,\n MAX_SEQLENS_K=max_seqlens_k,\n IS_CAUSAL=causal,\n VARLEN=True,\n BLOCK_DMODEL=padded_d_model,\n BIAS_TYPE=0 if bias is None else 1,\n ENABLE_DROPOUT=False,\n RETURN_ENCODED_SOFTMAX=False,\n )\n\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = head_size\n ctx.causal = causal\n ctx.dropout_p = 0.0\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n ctx.encoded_softmax = encoded_softmax\n ctx.return_encoded_softmax = False\n return o, encoded_softmax\n\ntriton_attention = _attention.apply\n", - "description_1": "Use triton language to implement Flash Attention v2 with functions: cdiv_fn for division, load_fn to load blocks, _attn_fwd_inner for inner loop of attention with 26 parameters handling masking, dropout and biases, attn_fwd as the main kernel function with 44 parameters dealing with sequences and blocks, and _attention as a wrapper with 12 parameters to call the kernel.", - "description_2": "Use triton language to create attention mechanisms optimized for memory access and dropout handling in sequence data.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom vllm.model_executor.layers.ops.sample import _uniform_to_exponential\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test that we can convert uniform to exponential without div by 0.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel that converts uniform random numbers to exponential random numbers. The kernel function '_uniform_to_exponential_kernel' takes three parameters: 'input' (a tensor of uniform random numbers), 'output' (a tensor to store the resulting exponential random numbers), and 'n' (a constant expression representing the number of elements to process). The kernel uses Triton's parallel processing capabilities to load data from the input tensor, apply the '_uniform_to_exponential' transformation, and store the results in the output tensor.", - "description_2": "Use triton language to create a kernel that transforms uniform random numbers into exponential random numbers using parallel processing.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\ny = torch.rand(size, device='cuda')\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}')\n", - "description_1": "Use triton language to implement a vector addition kernel function that takes five arguments: pointers to the first and second input vectors, pointer to the output vector, the number of elements in the vector, and a block size. The kernel computes element-wise sum of two input vectors. The function `add` serves as a helper to allocate output tensor and launch the kernel with an appropriate grid size.", - "description_2": "Use triton language to implement a vector addition kernel and a helper function to launch the kernel on CUDA tensors.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _bgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n scaling,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's\n performance\n \"\"\"\n pid_sk = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n\n offset_n = tl.arange(0, BLOCK_N)\n offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K\n a_ptr = input_ptr + cur_batch * xm_stride\n b_ptr = lora_ptr + l0_stride * lora_index\n accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)\n for k in range(0, K, BLOCK_K * SPLIT_K):\n current_k = k + offset_k\n current_k_c = tl.max_contiguous(current_k, BLOCK_K)\n tiled_a = tl.load(\n a_ptr + current_k_c,\n mask=current_k < K,\n other=0.0,\n ) # [BLOCK_K]\n b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)\n\n tiled_b = tl.load(\n b_ptr + offset_n[:, None] * lora_k_stride +\n current_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n ) # [BLOCK_N,BLOCK_K]\n\n accumulator += tl.sum(tiled_a * tiled_b, 1)\n accumulator *= scaling\n offset_cn = tl.arange(0, BLOCK_N)\n c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride\n c_mask = offset_cn < N\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _bgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n scaling: float = 1.0,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n batches = lora_indices_tensor.size(0)\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_N = triton.next_power_of_2(N)\n # First try to load optimal config from the file\n config = get_lora_op_configs(\"bgmv_shrink\", batches, K)\n\n grid = lambda META: (\n META[\"SPLIT_K\"],\n batches,\n )\n _bgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_N=BLOCK_N,\n **config,\n )\n return\n", - "description_1": "Use triton language to implement a kernel that performs a grouped GEMV operation with additional support for SPLIT-K to optimize performance for large hidden sizes. The kernel takes 14 arguments, including pointers to input, LoRA weights, and output tensors, dimensions N and K, LoRA indices, scaling factor, various stride values, and compile-time constants BLOCK_N, BLOCK_K, and SPLIT_K. The kernel loads input data and LoRA weights, computes the GEMV using a for-loop over the K dimension, scales the result, and stores it in the output. A separate torch function '_bgmv_shrink' prepares the data and calls this kernel with the configured grid.", - "description_2": "Use triton language to implement and execute a kernel performing a specialized GEMV operation with LoRA weights using SPLIT-K for large hidden sizes, managing data with pointers and specific strides.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n The sgmv's expand triton kernel is based on GroupGEMM.\n \"\"\"\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n add_inputs: bool = False,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4, 10].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n add_inputs (bool, optional): Defaults to False. adds the final lora \n results to the output.\n \"\"\"\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n # TODO tuning this config\n\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_sgmv_expand_kernel' that performs a specialized matrix-vector multiplication with support for LoRA (Low-Rank Adaptation) weights. The kernel takes 22 parameters: input_ptr, lora_ptr, out_ptr, N, K, b_seq_start_loc, seq_lens, lora_indices, xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, and several constexpr parameters for block sizes and flags. The function '_sgmv_expand' is a wrapper that prepares the inputs and launches the kernel with a grid configuration based on the batch size and sequence length.", - "description_2": "Use triton language to create a kernel for matrix-vector multiplication with LoRA weights, and a wrapper function to set up and execute the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n Similar to the 'sgmv_expand' operator, but with an added parameter \n 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator \n might be that in the future, we could implement a fusion operator to \n achieve the current functionality instead of having to call it multiple \n times.\n \"\"\"\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <\n (slice_offset + N))\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n) -> None:\n \"\"\"_summary_\n\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4, 10].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n slice_offst (int): output_tensor's offst\n slice_size (int): current output_tensor's size\n add_inputs (bool, optional): Defaults to False. adds the final lora \n results to the output..\n \"\"\"\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert slice_size == lora_b_weights.size(-2)\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n # TODO tuning this config\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_sgmv_expand_slice_kernel' with 23 parameters for matrix operations with LoRA weights, and a wrapper function '_sgmv_expand_slice' with 11 parameters to prepare and launch the kernel.", - "description_2": "Use triton language to create a kernel for matrix operations with LoRA weights and a wrapper to manage inputs and launch the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n scaling,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.\n The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,\n introducing SPLIT-K can improve performance\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sk = tl.program_id(axis=1)\n cur_batch = tl.program_id(axis=2)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)\n\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride)\n b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +\n offset_k[:, None] * lora_n_stride)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < k_remaining,\n other=0.0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < k_remaining,\n other=0.0)\n accumulator += tl.dot(tiled_a, tiled_b)\n\n a_ptr += BLOCK_K * SPLIT_K * xk_stride\n b_ptr += BLOCK_K * SPLIT_K * lora_n_stride\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n accumulator *= scaling\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n scaling: float,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_M = 32\n BLOCK_N = 16\n BLOCK_K = 32\n SPLIT_K = 8\n EVEN_K = K % (BLOCK_K * SPLIT_K) == 0\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n SPLIT_K,\n batches,\n )\n\n _sgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n SPLIT_K,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_sgmv_shrink_kernel' with 22 parameters for matrix operations with LoRA weights, and a wrapper function '_sgmv_shrink' with 9 parameters to prepare and invoke the kernel.", - "description_2": "Use triton language to create a kernel for matrix operations with LoRA weights and a wrapper to manage inputs and execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional, Dict, Any, Tuple, Callable\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr,\n sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk,\n stride_bn, stride_cm, stride_cn, stride_bse, stride_bsn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,\n compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr,\n use_int8_w8a16: tl.constexpr):\n \"\"\"\n Implements the fused computation for a Mixture of Experts (MOE) using\n token and expert matrices.\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n if use_int8_w8a16:\n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[\n None, :] * stride_bsn\n b_scale = tl.load(b_scale_ptrs)\n\n if use_fp8_w8a8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n if use_int8_w8a16:\n accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)\n elif use_fp8_w8a8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n if use_int8_w8a16:\n accumulator = (accumulator * b_scale).to(compute_type)\n elif use_fp8_w8a8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef moe_align_block_size(\n topk_ids: torch.Tensor, block_size: int,\n num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Aligns the token distribution across experts to be compatible with block\n size for matrix multiplication.\n \"\"\"\n max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)\n sorted_ids = torch.empty((max_num_tokens_padded, ),\n dtype=torch.int32,\n device=topk_ids.device)\n sorted_ids.fill_(topk_ids.numel())\n max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)\n expert_ids = torch.empty((max_num_m_blocks, ),\n dtype=torch.int32,\n device=topk_ids.device)\n num_tokens_post_pad = torch.empty((1),\n dtype=torch.int32,\n device=topk_ids.device)\n ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,\n expert_ids, num_tokens_post_pad)\n return sorted_ids, expert_ids, num_tokens_post_pad\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any], compute_type: tl.dtype,\n use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if use_fp8_w8a8:\n A, A_scale = ops.scaled_fp8_quant(A, A_scale)\n assert B_scale is not None\n elif use_int8_w8a16:\n assert B_scale is not None\n else:\n assert A_scale is None\n assert B_scale is None\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,\n B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8_w8a8=use_fp8_w8a8,\n use_int8_w8a16=use_int8_w8a16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel performs matrix multiplication for tokens and expert matrices, considering various parameters like block sizes, strides, and compute types. The kernel is invoked with a function that aligns token distribution and sets up the necessary configurations.", - "description_2": "Use triton language to create a fused MoE kernel for efficient matrix multiplication and invoke it with aligned token distribution and configuration settings.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n\n @triton.jit\n def softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n\n @triton.jit\n def softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr,\n x_ptr,\n dt_ptr,\n dt_bias_ptr,\n A_ptr,\n B_ptr,\n C_ptr,\n D_ptr,\n z_ptr,\n out_ptr,\n batch,\n nheads,\n dim,\n dstate,\n nheads_ngroups_ratio,\n stride_state_batch,\n stride_state_head,\n stride_state_dim,\n stride_state_dstate,\n stride_x_batch,\n stride_x_head,\n stride_x_dim,\n stride_dt_batch,\n stride_dt_head,\n stride_dt_dim,\n stride_dt_bias_head,\n stride_dt_bias_dim,\n stride_A_head,\n stride_A_dim,\n stride_A_dstate,\n stride_B_batch,\n stride_B_group,\n stride_B_dstate,\n stride_C_batch,\n stride_C_group,\n stride_C_dstate,\n stride_D_head,\n stride_D_dim,\n stride_z_batch,\n stride_z_head,\n stride_z_dim,\n stride_out_batch,\n stride_out_head,\n stride_out_dim,\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h //\n nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h //\n nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +\n offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +\n offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,\n other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt) # scalar, not a matrix\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs,\n state,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state,\n x,\n dt,\n A,\n B,\n C,\n D=None,\n z=None,\n dt_bias=None,\n dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else\n (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else\n ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(\n -1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state,\n x,\n dt,\n dt_bias,\n A,\n B,\n C,\n D,\n z,\n out,\n batch,\n nheads,\n dim,\n dstate,\n nheads // ngroups,\n state.stride(0),\n state.stride(1),\n state.stride(2),\n state.stride(3),\n x.stride(0),\n x.stride(1),\n x.stride(2),\n dt.stride(0),\n dt.stride(1),\n dt.stride(2),\n *(dt_bias.stride(0),\n dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0),\n A.stride(1),\n A.stride(2),\n B.stride(0),\n B.stride(1),\n B.stride(2),\n C.stride(0),\n C.stride(1),\n C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0],\n z_strides[1],\n z_strides[2],\n out.stride(0),\n out.stride(1),\n out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a softplus function and a selective scan update kernel. The softplus function takes one parameter 'dt' and applies a softplus transformation. The selective scan update kernel takes 47 parameters including pointers to matrices, matrix dimensions, strides, and meta-parameters. It performs a selective scan update on the input matrices based on the provided parameters.", - "description_2": "Use triton language to create a softplus function with one parameter for element-wise transformation and a selective scan update kernel with 47 parameters for matrix operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n \"\"\"Similar to torch.rand, but allows for seeds to be set per row.\n\n seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.\n If it is 3d, the additional seeds needed will be derived automatically\n in a deterministic fashion:\n [\n row 0: [columns_with_seed_0], [columns_with_seed0^1], ...\n ]\n \"\"\"\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n # The philox PRNG Triton uses generates 4 random numbers at once.\n # Therefore, the most efficient use of it is to divide the\n # block size by 4, and then save the generated random numbers to\n # each of the 4 slices of the tensor.\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n # Manual tuning. This seems to give best performance on A100 for\n # simple kernels like this.\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output\n tensor. The random numbers in a row generated using the seed for that row.\n\n Args:\n out_ptr: The output tensor.\n seed_ptr: The per-row seeds to use for random number generation.\n out_row_stride: The stride between rows of the output tensor.\n out_3d_stride: The stride between 3D slices of the output tensor.\n seed_row_stride: The stride between rows of the seed tensor.\n n_rows: The number of rows in the output tensor.\n n_3d: The size of second dimension of the output tensor,\n if output tensor is 3D.\n n_cols: The number of columns in the output tensor.\n n_slices: The number of philox outputs to use.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n # Get the row index.\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n # Get the seed for the current element.\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n # Generate random numbers in [0, 1).\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a seeded uniform random number generator function (`seeded_uniform`) and its corresponding kernel (`_seeded_uniform_triton`). The `seeded_uniform` function allows creating a random tensor with a per-row seed, supporting up to 3D tensors. It takes parameters: size (variable dimensions), seeds (1D tensor for row-specific seeds), optional output tensor, dtype, device, and pin_memory flag. The Triton kernel `_seeded_uniform_triton` takes the output tensor pointer, seed pointer, strides for rows and 3D slices, number of rows, 3D size, number of columns, slices, and block size as constexpr to generate random numbers in the range [0,1) efficiently.", - "description_2": "Use triton language to create a random tensor with per-row seeds using `seeded_uniform`, handling up to 3D with optimized block size and warp count.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS: tl.constexpr = 1e-6\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n col_offsets = tl.arange(0, block_size)\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n\n if modify_greedy_probs:\n if not uses_random_sampling:\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to create a kernel function `_uniform_to_exponential` that converts uniform noise to exponential noise, with one input parameter for the tensor of uniform noise. Another kernel function `_sample_triton` is implemented for token sampling, requiring inputs like sample indices, output pointers, probabilities, seeds, uniform noise, strides, number of samples, columns, and best samples. It includes several control parameters for modifying probabilities, saving logprobs, and saving modified probabilities.", - "description_2": "Use triton language to implement a kernel for converting uniform noise to exponential noise and another kernel for sampling tokens from probability distributions, involving operations on tensors with various input parameters and control flags.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nAWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\n\n@triton.jit\ndef awq_dequantize_kernel(\n qweight_ptr, # quantized matrix\n scales_ptr, # scales, per group\n zeros_ptr, # zeros, per group\n group_size, # Should always be one of the supported group sizes\n result_ptr, # Output matrix\n num_cols, # input num cols in qweight\n num_rows, # input num rows in qweight\n BLOCK_SIZE_X: tl.constexpr,\n BLOCK_SIZE_Y: tl.constexpr):\n pid_x = tl.program_id(axis=0)\n pid_y = tl.program_id(axis=1)\n\n offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)\n offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8\n offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]\n\n masks_y = offsets_y < num_rows\n masks_x = offsets_x < num_cols\n\n masks = masks_y[:, None] & masks_x[None, :]\n\n result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)\n result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(\n 0, BLOCK_SIZE_X * 8)\n result_offsets = (8 * num_cols * result_offsets_y[:, None] +\n result_offsets_x[None, :])\n\n result_masks_y = result_offsets_y < num_rows\n result_masks_x = result_offsets_x < num_cols * 8\n result_masks = result_masks_y[:, None] & result_masks_x[None, :]\n\n iweights = tl.load(qweight_ptr + offsets, masks)\n\n reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +\n tl.arange(0, 4)[:, None]).reshape(8)\n\n shifts = reverse_awq_order_tensor * 4\n shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))\n shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))\n\n iweights = (iweights >> shifts) & 0xF\n\n zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +\n tl.arange(0, BLOCK_SIZE_Y) // group_size)\n zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8\n zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]\n\n zero_masks_y = zero_offsets_y < num_rows // group_size\n zero_masks_x = zero_offsets_x < num_cols\n zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]\n\n zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)\n\n zeros = (zeros >> shifts) & 0xF\n\n scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +\n tl.arange(0, BLOCK_SIZE_Y) // group_size)\n scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +\n tl.arange(0, BLOCK_SIZE_X * 8))\n scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +\n scale_offsets_x[None, :])\n scale_masks_y = scale_offsets_y < num_rows // group_size\n scale_masks_x = scale_offsets_x < num_cols * 8\n scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]\n\n scales = tl.load(scales_ptr + scale_offsets, scale_masks)\n\n iweights = (iweights - zeros) * scales\n iweights = iweights.to(result_ptr.type.element_ty)\n\n tl.store(result_ptr + result_offsets, iweights, result_masks)\n\n@triton.jit\ndef awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,\n group_size, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n SPLIT_K: tl.constexpr):\n pid = tl.program_id(axis=0)\n pid_z = tl.program_id(1)\n\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n\n pid_m = pid // num_pid_n\n pid_n = pid % num_pid_n\n\n accumulator_dtype = c_ptr.type.element_ty\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),\n dtype=accumulator_dtype)\n\n reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +\n tl.arange(0, 4)[:, None]).reshape(8)\n\n shifts = reverse_awq_order_tensor * 4\n shifts = tl.broadcast_to(shifts[None, :],\n (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))\n shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))\n\n offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n masks_am = offsets_am < M\n\n offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) +\n tl.arange(0, BLOCK_SIZE_N) // 8)\n masks_bn = offsets_bn < N // 8\n\n offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) +\n tl.arange(0, BLOCK_SIZE_N) // 8)\n masks_zn = offsets_zn < N // 8\n\n offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n masks_sn = offsets_sn < N\n\n offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offsets_a = K * offsets_am[:, None] + offsets_k[None, :]\n offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]\n\n a_ptrs = a_ptr + offsets_a\n b_ptrs = b_ptr + offsets_b\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n masks_k = offsets_k < K\n masks_a = masks_am[:, None] & masks_k[None, :]\n a = tl.load(a_ptrs, mask=masks_a)\n\n masks_b = masks_k[:, None] & masks_bn[None, :]\n b = tl.load(b_ptrs, mask=masks_b)\n\n offsets_szk = (\n (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +\n tl.arange(0, BLOCK_SIZE_K) // group_size)\n offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]\n masks_zk = offsets_szk < K // group_size\n masks_z = masks_zk[:, None] & masks_zn[None, :]\n zeros_ptrs = zeros_ptr + offsets_z\n zeros = tl.load(zeros_ptrs, mask=masks_z)\n\n offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]\n masks_sk = offsets_szk < K // group_size\n masks_s = masks_sk[:, None] & masks_sn[None, :]\n scales_ptrs = scales_ptr + offsets_s\n scales = tl.load(scales_ptrs, mask=masks_s)\n\n b = (b >> shifts) & 0xF\n zeros = (zeros >> shifts) & 0xF\n b = (b - zeros) * scales\n b = b.to(c_ptr.type.element_ty)\n\n accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)\n\n offsets_k += BLOCK_SIZE_K * SPLIT_K\n a_ptrs += BLOCK_SIZE_K * SPLIT_K\n b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)\n\n c = accumulator.to(c_ptr.type.element_ty)\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\ndef awq_dequantize_triton(qweight: torch.Tensor,\n scales: torch.Tensor,\n zeros: torch.Tensor,\n block_size_x: int = 32,\n block_size_y: int = 32) -> torch.Tensor:\n K = qweight.shape[0]\n M = scales.shape[1]\n group_size = qweight.shape[0] // scales.shape[0]\n\n assert K > 0 and M > 0\n assert scales.shape[0] == K // group_size and scales.shape[1] == M\n assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8\n assert group_size <= K\n assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K\n\n result = torch.empty(qweight.shape[0],\n qweight.shape[1] * 8,\n device=qweight.device,\n dtype=scales.dtype)\n\n Y = qweight.shape[0] # num rows\n X = qweight.shape[1] # num cols\n\n grid = lambda META: (\n triton.cdiv(X, META['BLOCK_SIZE_X']),\n triton.cdiv(Y, META['BLOCK_SIZE_Y']),\n )\n awq_dequantize_kernel[grid](qweight,\n scales,\n zeros,\n group_size,\n result,\n X,\n Y,\n BLOCK_SIZE_X=block_size_x,\n BLOCK_SIZE_Y=block_size_y)\n\n return result\n\ndef awq_gemm_triton(input: torch.Tensor,\n qweight: torch.Tensor,\n scales: torch.Tensor,\n qzeros: torch.Tensor,\n split_k_iters: int,\n block_size_m: int = 32,\n block_size_n: int = 32,\n block_size_k: int = 32) -> torch.Tensor:\n M, K = input.shape\n N = qweight.shape[1] * 8\n group_size = qweight.shape[0] // qzeros.shape[0]\n\n assert N > 0 and K > 0 and M > 0\n assert qweight.shape[0] == K and qweight.shape[1] == N // 8\n assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8\n assert scales.shape[0] == K // group_size and scales.shape[1] == N\n assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0\n assert split_k_iters <= 32\n assert group_size <= K\n assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(\n N, META['BLOCK_SIZE_N']),\n split_k_iters,\n )\n\n result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)\n\n awq_gemm_kernel[grid](input,\n qweight,\n result,\n qzeros,\n scales,\n M,\n N,\n K,\n group_size,\n BLOCK_SIZE_M=block_size_m,\n BLOCK_SIZE_N=block_size_n,\n BLOCK_SIZE_K=block_size_k,\n SPLIT_K=split_k_iters)\n\n return result\n", - "description_1": "Use triton language to implement two kernels: awq_dequantize_kernel and awq_gemm_kernel. The awq_dequantize_kernel takes 8 parameters: qweight_ptr (quantized matrix), scales_ptr (scales per group), zeros_ptr (zeros per group), group_size (supported group sizes), result_ptr (output matrix), num_cols (number of columns in qweight), num_rows (number of rows in qweight), and two block sizes (BLOCK_SIZE_X and BLOCK_SIZE_Y). It dequantizes the input matrix using the provided scales and zeros. The awq_gemm_kernel takes 13 parameters: a_ptr (input matrix), b_ptr (quantized weight matrix), c_ptr (output matrix), zeros_ptr (zeros per group), scales_ptr (scales per group), M, N, K (dimensions of the matrices), group_size, and three block sizes (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K), and SPLIT_K. It performs a matrix multiplication with dequantization of the weight matrix.", - "description_2": "Use triton language to implement a dequantization kernel and a GEMM kernel with dequantization. The dequantization kernel processes a quantized matrix using scales and zeros to produce a dequantized output. The GEMM kernel performs matrix multiplication on an input matrix and a dequantized weight matrix, producing a result matrix.", - "difficulty": 4 - }, - { - "code": "import torch\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\nfrom PIL import Image\nfrom torchvision.transforms.functional import to_tensor\nfrom time import time\nimport numpy as np\n\nGPU = torch.device(\"cuda\")\n\n# Triton kernel for propagating nearest neighbor field\n@triton.jit\ndef propagate_kernel(A_ptr, B_ptr, kNNF_ptr, output_ptr, l: int, K: int, P: int, A_h: int, A_w: int, A_c: int, **meta):\n \"\"\"\n Args:\n A_ptr : pointer to image A which has shape (A_w, A_h, A_c)\n B_ptr : pointer to image B\n kNNF_ptr : pointer to k nearest neighbor field which is (A_w , A_h, K, 3)\n output_ptr : pointer to output kNNF values for this iteration which is (A_w, A_h, K)\n l (int): looking distance (where to look for candidates)\n K (int): number of nearest neighbors (default: 5)\n P (int): patch size (default: 3)\n A_h (int): height of image A\n A_w (int): width of image A\n A_c (int): number of channels in image A\n \"\"\"\n \n B = meta[\"BLOCK_SIZE\"] # 16\n\n # map pid to block of kNNF that it should compute\n pid = tl.program_id(axis=0)\n num_pid_h = tl.cdiv(A_h, B)\n num_pid_w = tl.cdiv(A_w, B)\n y = pid // num_pid_h\n x = pid // num_pid_w\n ys = y * B + tl.arange(0, B)\n xs = x * B + tl.arange(0, B)\n A_grid = ys[:, None] + xs[None, :] # TODO need to add stride in ptr sum?\n\n # arrays for indexing certain things\n window = tl.arange(0, 3) - P // 2 # tl.arange(-P // 2, P // 2 + 1)\n channels = tl.arange(0, 3) # A_c \n coord_dim = tl.arange(1, 3) \n dist_dim = tl.zeros((1,), dtype=tl.int32)\n neighbors = tl.arange(0, 5) # K\n\n # extract pixel value patches from A\n A_patch_center_idxs = A_ptr + A_grid # B, B\n A_patch_idxs = (\n A_patch_center_idxs[ :, :, None, None, None, None]\n + window [None, None, None, :, None, None]\n + window [None, None, None, None, :, None]\n + channels [None, None, None, None, None, :]\n ) # B B 1 P P A_c\n A_patches = tl.load(A_patch_idxs)\n\n # look for new candidate patches l pixels left, right, up, and down (9 locations)\n candidates = tl.zeros((9, B, B, 2), dtype=tl.int32)\n candidate_distances = tl.zeros((9, B, B), dtype=tl.float32)\n i = 0\n for h in range(-l, l, l):\n for w in range(-l, l, l):\n\n # ensure indexes stay in bounds\n candidate_ys = tl.maximum(0, tl.minimum(A_h - 1, ys + h))\n candidate_xs = tl.maximum(0, tl.minimum(A_w - 1, xs + w))\n\n # load candidate patch centers from kNNF\n candidate_coords = (\n kNNF_ptr\n + candidate_ys[ :, None, None, None]\n + candidate_xs[None, :, None, None]\n + neighbors [None, None, :, None]\n + coord_dim [None, None, None, :]\n ) # B B K 2\n\n # candidate_coords shape here is actually (B, B, K, 3) !?\n # candidate_idxs = tl.sum(candidate_coords, axis=3) # Error: Encountered unimplemented code path in sum. This is likely a bug on our side.\n candidate_idxs = candidate_coords[:, :, :, 0] + candidate_coords[:, :, :, 1] # Error: cannot reshape block of different shape\n B_patch_center_idxs = tl.load(candidate_idxs)\n\n # load corresponding patches from image B\n B_patch_idxs = (\n B_ptr\n + B_patch_center_idxs[ :, :, :, None, None, None]\n + window [None, None, None, :, None, None]\n + window [None, None, None, None, :, None]\n + channels [None, None, None, None, None, :]\n ) # B B K P P A_c\n B_patches = tl.load(B_patch_idxs)\n\n # find distance between image A patches and candidate patches\n distances = tl.sum((B_patches - A_patches) ** 2, axis=[3, 4, 5]) # B, B, K\n\n # remember best candidates\n best_candidates = triton.torch.argmin(distances, axis=2) # B, B\n candidates[i] = B_patch_center_idxs[best_candidates]\n candidate_distances[i] = distances[best_candidates]\n\n i += 1\n\n # find overal best candidates\n idxs_new = triton.torch.argmin(candidate_distances, axis=0) # B, B\n kNNF_coord_new = candidates[idxs_new[None, :, :, None]] # B, B, 2\n kNNF_dist_new = candidate_distances[idxs_new[None, :, :]] # B, B\n\n # store \n tl.store(output_ptr + A_grid[:, :, None] + dist_dim[None, None, :], kNNF_dist_new)\n tl.store(output_ptr + A_grid[:, :, None] + coord_dim[None, None, :], kNNF_coord_new)\n\n# Function to call the Triton kernel\ndef propagate(A: torch.Tensor, B: torch.Tensor, kNNF: torch.Tensor, l: int, K: int, P: int):\n _, A_c, A_h_pad, A_w_pad = A.shape\n A_h, A_w = A_h_pad - P, A_w_pad - P\n\n output = torch.empty((A_h, A_w, 3), device=kNNF.device, dtype=kNNF.dtype)\n\n grid = lambda meta: (triton.cdiv(A_h, meta[\"BLOCK_SIZE\"]) * triton.cdiv(A_w, meta[\"BLOCK_SIZE\"]),)\n\n pgm = propagate_kernel[grid](\n A.squeeze().permute(1, 2, 0),\n B.squeeze().permute(1, 2, 0),\n kNNF,\n output,\n int(l),\n int(K),\n int(P),\n int(A_h),\n int(A_w),\n int(A_c),\n BLOCK_SIZE=16,\n )\n\n return output\n\n# Main function to perform patch match\ndef patch_match(img_A: torch.Tensor, img_B: torch.Tensor, K: int = 7, P: int = 5):\n (A_h, A_w), (B_h, B_w) = img_A.shape[2:], img_B.shape[2:]\n r = P // 2\n patch_range = torch.arange(-r, r + 1).to(GPU)\n patch_window = torch.stack(torch.meshgrid(patch_range, patch_range, indexing=\"ij\"))[None]\n\n img_A = F.pad(img_A, (r, r, r, r), mode=\"reflect\")\n img_B = F.pad(img_B, (r, r, r, r), mode=\"reflect\")\n\n # initialize\n idxsB = torch.randint(B_h * B_w, size=(A_h, A_w, K)).to(GPU)\n ysB = torch.div(idxsB, B_w, rounding_mode=\"floor\") + r\n xsB = idxsB % B_w + r\n patch_ysB = (ysB[..., None, None] + patch_window[None, :, 0]).long()\n patch_xsB = (xsB[..., None, None] + patch_window[None, :, 1]).long()\n patchesB = img_B.squeeze()[:, patch_ysB, patch_xsB].permute(1, 2, 3, 4, 5, 0).reshape(A_h, A_w, K, -1)\n\n idxsA = torch.stack(torch.meshgrid(torch.arange(A_h), torch.arange(A_w), indexing=\"ij\")).to(GPU)\n patchesA = idxsA[..., None, None] + patch_window[0, :, None, None]\n patchesA = img_A.squeeze()[:, patchesA[0], patchesA[1]].permute(1, 2, 3, 4, 0)\n patchesA = torch.tile(patchesA[:, :, None], [1, 1, K, 1, 1, 1]).reshape(A_h, A_w, K, -1)\n\n dAB = torch.sum(torch.square(patchesB - patchesA), dim=-1)\n kNNF = torch.stack((dAB, ysB, xsB), dim=-1).cpu()\n for y in torch.arange(A_h):\n for x in torch.arange(A_w):\n kNNF[y, x] = heapify(kNNF[y, x])\n kNNF = kNNF.to(GPU)\n\n # propagate\n max_side = torch.maximum(torch.tensor(A_h), torch.tensor(A_w))\n ls = torch.floor(max_side / torch.pow(2, torch.arange(torch.log2(max_side))))\n for l in ls:\n\n kNNF[:, :, 0] = propagate(img_A, img_B, kNNF, l.item(), K, P)\n\n for y in range(A_h):\n for x in range(A_w):\n siftup(kNNF[y, x], 0)\n\n return kNNF\n\nif __name__ == \"__main__\":\n with torch.inference_mode():\n img_A = Image.open(\"bike_a.png\").resize((240, 176))\n img_B = Image.open(\"bike_b.png\").resize((256, 192))\n\n img_A = to_tensor(img_A).unsqueeze(0).to(GPU)\n img_B = to_tensor(img_B).unsqueeze(0).to(GPU)\n\n P = 3\n K = 5\n\n t = time()\n kNNF = patch_match(img_A, img_B, K=K, P=P)\n print(time() - t)\n\n result = torch.zeros((img_A.shape[2], img_A.shape[3], img_A.shape[1]))\n all_dists = torch.zeros((img_A.shape[2], img_A.shape[3]))\n for y in range(img_A.shape[2]):\n for x in range(img_A.shape[3]):\n all_dists[y, x] = kNNF[y, x, 0, 0]\n result[y, x] = img_B[:, :, kNNF[y, x, 0, 1].long() - P // 2, kNNF[y, x, 0, 2].long() - P // 2].squeeze()\n result = (result.cpu().numpy() * 255).astype(np.uint8)\n Image.fromarray(result).save(f\"results/bike_{torch.mean(all_dists):.5f}.jpg\")\n", - "description_1": "Use triton language to implement a kernel that propagates the nearest neighbor field (kNNF) for image patches. The kernel takes pointers to images A and B, the kNNF, and an output buffer. It computes the best matching patches in image B for each patch in image A, considering a search window defined by the parameter l. The kernel uses a block size of 16 and processes patches of size P with K nearest neighbors. The function propagate calls this kernel and manages the data preparation and execution on the GPU.", - "description_2": "Use triton language to create a kernel for computing the nearest neighbor field for image patches, and a function to execute this kernel on the GPU.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n BLOCK_SIZE = meta[\"BLOCK_SIZE\"]\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float(\"inf\"))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o\n # f the input matrix\n softmax_kernel[(n_rows,)](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y\n\ntorch.manual_seed(0)\nx = torch.randn(1823, 781, device=\"cuda\")\ny_triton = softmax(x)\ny_torch = torch.softmax(x, axis=1)\nassert torch.allclose(y_triton, y_torch)\n", - "description_1": "Use triton language to implement a softmax kernel and its caller function. The softmax_kernel function has six parameters: output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, and meta. It computes the row-wise softmax using Triton's language capabilities for parallel processing and optimized memory access. The softmax function takes one parameter x, representing a 2D tensor, and uses the softmax_kernel to compute the softmax for each row with optimized block size and warps configuration.", - "description_2": "Use triton language to create a softmax operation leveraging optimized kernel execution for parallel processing on GPU, suitable for large matrices.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for dropout with mask\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n **meta,\n):\n BLOCK_SIZE = meta[\"BLOCK_SIZE\"]\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n# Triton kernel for seeded dropout\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n **meta,\n):\n # compute memory offsets of elements handled by this instance\n BLOCK_SIZE = meta[\"BLOCK_SIZE\"]\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n# Sample usage\nx = torch.randn(size=(10,)).cuda()\np = 0.5\nx_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()\n\n# Dropout with mask\noutput = dropout(x, x_keep=x_keep, p=p)\n\n# Seeded dropout\noutput = seeded_dropout(x, p=0.5, seed=123)\n", - "description_1": "Use triton language to implement two dropout kernels. The first kernel (_dropout) takes pointers to input tensor, mask tensor, output tensor, number of elements, and a probability p as arguments, and performs dropout using a precomputed mask. The second kernel (_seeded_dropout) takes pointers to input tensor, output tensor, number of elements, a probability p, and a random seed as arguments, and performs dropout using random generation instead of a mask.", - "description_2": "Use triton language to implement dropout kernels, one using a precomputed mask and another using random seed for element selection.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 256, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 256, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 32, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 32, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=5, num_warps=2\n ),\n triton.Config(\n {\"BLOCK_SIZE_M\": 32, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 8}, num_stages=5, num_warps=2\n ),\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n **meta,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n BLOCK_SIZE_M = meta[\"BLOCK_SIZE_M\"]\n BLOCK_SIZE_N = meta[\"BLOCK_SIZE_N\"]\n BLOCK_SIZE_K = meta[\"BLOCK_SIZE_K\"]\n GROUP_SIZE_M = 8\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n if meta[\"ACTIVATION\"]:\n accumulator = meta[\"ACTIVATION\"](accumulator)\n c = accumulator.to(tl.float32)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01 * x)\n\n\ndef matmul(a, b, activation=None):\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n assert a.is_contiguous(), \"matrix A must be contiguous\"\n assert b.is_contiguous(), \"matrix B must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n assert K % 32 == 0, \"We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K\"\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),)\n matmul_kernel[grid](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n ACTIVATION=activation,\n )\n return c\n\n\ntorch.manual_seed(0)\na = torch.randn((512, 512), device=\"cuda\", dtype=torch.float32)\nb = torch.randn((512, 512), device=\"cuda\", dtype=torch.float32)\ntriton_output = matmul(a, b, activation=None)\ntorch_output = torch.matmul(a, b)\nprint(f\"triton_output={triton_output}\")\nprint(f\"torch_output={torch_output}\")\nif triton.testing.allclose(triton_output, torch_output):\n print(\"✅ Triton and Torch match\")\nelse:\n print(\"❌ Triton and Torch differ\")\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (matmul_kernel) and a leaky_relu function. The matmul_kernel function takes 13 parameters, with meta-parameters for block sizes and a custom activation function, and computes matrix C as the product of matrices A and B. The leaky_relu function applies a leaky ReLU activation function to input data. The matmul function wraps the kernel and takes three parameters, A, B, and an optional activation function, ensuring A and B are contiguous and allocating space for the output matrix C. It launches the kernel using a grid configuration based on matrix dimensions.", - "description_2": "Use triton language to create an autotuned matrix multiplication operator with optional leaky ReLU activation, ensuring input matrices are contiguous and block size compatible.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n output_ptr, # *Pointer* to output vector\n n_elements, # Size of the vector\n **meta, # Optional meta-parameters for the kernel\n):\n BLOCK_SIZE = meta[\"BLOCK_SIZE\"] # How many inputs each program should process\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\ntorch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device=\"cuda\")\ny = torch.rand(size, device=\"cuda\")\noutput_torch = x + y\noutput_triton = add(x, y)\nprint(output_torch)\nprint(output_triton)\nprint(f\"The maximum difference between torch and triton is \" f\"{torch.max(torch.abs(output_torch - output_triton))}\")\n", - "description_1": "Use triton language to define a kernel function 'add_kernel' that computes the element-wise addition of two input vectors. The kernel takes four primary parameters: pointers to the input vectors 'x_ptr', 'y_ptr', a pointer for the output 'output_ptr', and the size of the vectors 'n_elements'. It uses a BLOCK_SIZE meta parameter to divide the computation into blocks processed by each program instance. The 'add' function prepares the output tensor, sets up the execution grid based on input size and block size, and launches the kernel.", - "description_2": "Use triton language to create a kernel for element-wise vector addition, and implement a function to manage output allocation, execution grid setup, and kernel invocation.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef bone_fwd_kernel(\n a, b, c, bone,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)\n\n o_am = (i_m * BM + tl.arange(0, BM))\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n\n p_bone = bone + i_n * s_bonep + o_k[:, None] * s_bonem + o_k[None, :] * s_bonen\n b_bone = tl.load(p_bone)\n\n b_acc = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_b += tl.dot(b_b, b_bone, allow_tf32=False).to(b_b.dtype) + b_bone.to(b_b.dtype)\n b_acc += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n mask = (o_cn[None, :] < N)\n p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]\n tl.store(p_c, b_acc.to(c.dtype.element_ty), mask=mask)\n\ndef bone_fwd(\n a: torch.Tensor, b: torch.Tensor, bone: torch.Tensor\n) -> torch.Tensor:\n B, L, K = a.shape\n M = B * L\n K, N = b.shape\n c = a.new_empty(B, L, N)\n BK = BN = 64\n\n def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, BN))\n bone_fwd_kernel[grid](\n a, b, c, bone,\n M, N, K,\n a.stride(1), a.stride(2),\n b.stride(0), b.stride(1),\n c.stride(1), c.stride(2),\n bone.stride(0), bone.stride(1), bone.stride(2),\n BK=BK, BN=BN, G=4, ACTIVATION=None,\n )\n return c\n\n@triton.jit\ndef bone_gradx_kernel(\n a, b, c, bone,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)\n\n o_am = (i_m * BM + tl.arange(0, BM))\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n\n p_bone = bone + o_k[:, None] * s_bonem + o_k[None, :] * s_bonen\n\n b_acc = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_bone = tl.load(p_bone)\n b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_bone = tl.dot(b_bone, b_b, allow_tf32=False).to(b_b.dtype) + b_bone\n\n b_b = b_b + b_bone\n b_acc += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n p_bone += s_bonep\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]\n\n tl.store(p_c, b_acc.to(c.dtype.element_ty))\n\ndef bone_gradx(\n do: torch.Tensor, b: torch.Tensor, bone: torch.Tensor\n) -> torch.Tensor:\n B, L, K = do.shape\n M = B * L\n K, N = b.shape\n _, block, _ = bone.shape\n c = do.new_empty(B, L, N)\n BK = BN = block\n\n def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, BN))\n bone_gradx_kernel[grid](\n do, b, c, bone,\n M, N, K,\n do.stride(1), do.stride(2),\n b.stride(0), b.stride(1),\n c.stride(1), c.stride(2),\n bone.stride(0), bone.stride(1), bone.stride(2),\n BK=BK, BN=BN, G=4, ACTIVATION=None,\n )\n return c\n", - "description_1": "Use triton language to define two kernels: bone_fwd_kernel and bone_gradx_kernel. bone_fwd_kernel has 22 parameters, which include pointers to matrices a, b, c, bone, matrix dimensions M, N, K, strides s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen, and meta-parameters BM, BK, BN, G, ACTIVATION. This kernel performs a matrix multiplication C = A x B with additional tensor operations. bone_fwd function, which is a wrapper around bone_fwd_kernel, has three input tensors and computes the product using triton's grid launch. bone_gradx_kernel has a similar parameter structure and functionality as bone_fwd_kernel, and it's used to compute the gradients with respect to input tensors in the matrix multiplication operation. The bone_gradx function encapsulates the call to bone_gradx_kernel.", - "description_2": "Use triton language to create and execute matrix multiplication kernels with additional tensor operations for forward and gradient computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef bone_gradx(\n a, b, c, bone,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)\n\n o_am = (i_m * BM + tl.arange(0, BM)) % M\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n p_bone = bone + o_k[:, None] * s_bonem + o_k[None, :] * s_bonen\n\n b_acc = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_bone = tl.load(p_bone)\n b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_bone = tl.dot(b_bone, b_b, allow_tf32=False).to(b_b.dtype) + b_bone\n\n b_b = b_b + b_bone\n b_acc += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n p_bone += s_bonep\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n\n b_c = b_acc\n\n p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]\n\n tl.store(p_c, b_c.to(c.dtype.element_ty))\n\n@triton.jit\ndef bone_gradw(\n a, b, c, w, dw,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cp, s_cm, s_cn, s_wm, s_wn, s_dwk, s_dwn,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n\n o_am = (i_m * BM + tl.arange(0, BM)) % M\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n\n b_dw = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_dw += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n\n p_dw = dw + s_dwk * o_cm[:, None] + s_dwn * o_cn[None, :]\n b_c = b_dw\n\n tl.store(p_dw, b_dw.to(c.dtype.element_ty))\n\n@triton.jit\ndef bone_gradwb(\n a, b, c, w,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cp, s_cm, s_cn, s_wm, s_wn,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr\n):\n i_n = tl.program_id(0)\n\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n o_m = tl.arange(0, BM)\n o_block = tl.arange(0, 64)\n\n p_a = a + (o_m[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n\n p_w = w + s_wm * o_m[:, None] + s_wn * o_bn[None, :]\n\n dc = tl.zeros((64, 64), dtype=tl.float32)\n for m in range(0, tl.cdiv(M, BM)):\n b_dw = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_a = tl.load(p_a, mask=(o_k[None, :] < K - k * BK) & (o_m[:, None] < M - m * BM), other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n\n b_dw += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n\n b_w = tl.load(p_w)\n p_a += BM * s_am\n p_w += BM * s_wm\n p_a -= K * s_ak\n p_b -= K * s_bk\n\n dc += b_dw\n\n p_c = c + o_block[:, None] * s_cm + o_block[None, :] * s_cn + i_n * s_cp\n\n tl.store(p_c, dc.to(c.dtype.element_ty))\n\ndef bone_bwd(a: torch.Tensor, b: torch.Tensor, bone: torch.Tensor) -> torch.Tensor:\n M, K = a.shape\n K, N = b.shape\n _, block, _ = bone.shape\n c = a.new_empty(M, N)\n BK = BN = block\n\n def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, BN))\n bone_gradx[grid](\n a, b, c, bone,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n bone.stride(0), bone.stride(1), bone.stride(2),\n BK=BK, BN=BN, G=4,\n ACTIVATION=None,\n )\n return c\n\ndef bone_bwd_wb(a: torch.Tensor, b: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n M, K = a.shape\n K, N = b.shape\n c = a.new_empty(8, 64, 64)\n BM = 64\n BK = 64\n BN = 64\n\n grid = (triton.cdiv(N, BN),)\n bone_gradwb[grid](\n a, b, c, w,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1), c.stride(2),\n w.stride(0), w.stride(1),\n BM=BM, BK=BK, BN=BN, G=4,\n num_stages=1,\n ACTIVATION=None,\n )\n return c\n\ndef bone_bwd_w(a: torch.Tensor, b: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n M, K = a.shape\n K, N = b.shape\n dw = a.new_empty(M, N)\n BM = 64\n BK = BN = 64\n\n grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))\n bone_gradw[grid](\n a, b, c, w, dw,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1), c.stride(2),\n w.stride(0), w.stride(1),\n dw.stride(0), dw.stride(1),\n BM=BM, BK=BK, BN=BN, G=4,\n ACTIVATION=None,\n )\n return dw\n", - "description_1": "Use triton language to create three kernels: 'bone_gradx', 'bone_gradw', and 'bone_gradwb'. Each kernel performs block matrix operations for matrix multiplication and updates. 'bone_gradx' computes matrix C = A x B with 20 parameters: four matrix pointers, three dimension values, ten stride values, and three constexprs for block and group sizes. 'bone_gradw' computes gradient updates for matrices with 23 parameters: five matrix pointers, three dimension values, ten stride values, and three constexprs. 'bone_gradwb' computes weighted backpropagation updates with 19 parameters: four matrix pointers, three dimension values, seven stride values, and three constexprs.", - "description_2": "Use triton language to implement three block matrix operation kernels. Define parameters for matrix pointers, dimensions, strides, and compile-time constants. Perform computation for matrix multiplication and gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\n@triton.autotune(\n configs=[\n triton.Config({'BM': 64}, num_stages=3, num_warps=2),\n triton.Config({'BM': 64}, num_stages=3, num_warps=4),\n triton.Config({'BM': 64}, num_stages=3, num_warps=8),\n triton.Config({'BM': 128}, num_stages=3, num_warps=2),\n triton.Config({'BM': 128}, num_stages=3, num_warps=4),\n triton.Config({'BM': 128}, num_stages=3, num_warps=8),\n triton.Config({'BM': 64}, num_stages=2, num_warps=2),\n triton.Config({'BM': 64}, num_stages=2, num_warps=4),\n triton.Config({'BM': 64}, num_stages=2, num_warps=8),\n triton.Config({'BM': 128}, num_stages=2, num_warps=2),\n triton.Config({'BM': 128}, num_stages=2, num_warps=4),\n triton.Config({'BM': 128}, num_stages=2, num_warps=8),\n triton.Config({'BM': 64}, num_stages=4, num_warps=2),\n triton.Config({'BM': 64}, num_stages=4, num_warps=4),\n triton.Config({'BM': 64}, num_stages=4, num_warps=8),\n triton.Config({'BM': 128}, num_stages=4, num_warps=2),\n triton.Config({'BM': 128}, num_stages=4, num_warps=4),\n triton.Config({'BM': 128}, num_stages=4, num_warps=8),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a, b, c, bone,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cm, s_cn,\n s_bonep, s_bonem, s_bonen,\n BK: tl.constexpr, BN: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n\n o_am = (i_m * BM + tl.arange(0, BM))\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n p_bone = bone + i_n * s_bonep + o_k[:, None] * s_bonem + o_k[None, :] * s_bonen\n b_bone = tl.load(p_bone)\n\n b_acc = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_b += tl.dot(b_b, b_bone, allow_tf32=False, acc=b_bone.to(tl.float32)).to(tl.bfloat16)\n\n b_acc += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n mask = (o_cn[None, :] < N)\n\n p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]\n\n tl.store(p_c, b_acc.to(c.dtype.element_ty), mask=mask)\n\ndef bone_fwd(\n bone: torch.Tensor,\n a: torch.Tensor,\n b: torch.Tensor,\n) -> torch.Tensor:\n B, L, K = a.shape\n M = B * L\n K, N = b.shape\n c = a.new_empty(B, L, N)\n BK = 64\n BN = 64\n\n def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, BN))\n matmul_kernel[grid](\n a, b, c, bone,\n M, N, K,\n a.stride(1), a.stride(2),\n b.stride(0), b.stride(1),\n c.stride(1), c.stride(2),\n bone.stride(0), bone.stride(1), bone.stride(2),\n BK=BK, BN=BN,\n ACTIVATION=None,\n )\n return c\n\n# Example usage\ndtype = torch.bfloat16\nB = 4\nL = 1024\na = torch.randn((B, L, 2048), device='cuda', dtype=dtype)\nb = torch.randn((2048, 4096), device='cuda', dtype=dtype)\nc = torch.randn((64, 64, 64), device='cuda', dtype=dtype)\n\nxx = bone_fwd(c, a, b)\nprint(xx.reshape(-1))\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (matmul_kernel) that performs C = A * B with additional processing involving a bone matrix. The kernel supports a range of block sizes and can operate on blocks of BM x BK and BK x BN sizes. The input tensors include pointers to matrices A, B, C, and an auxiliary matrix bone. Strides for the matrices are provided to support non-contiguous memory layouts. The function bone_fwd is used to invoke the kernel by calculating grid sizes based on input dimensions and meta-parameters.", - "description_2": "Use triton language to create a matmul kernel that handles custom tensor dimensions and grid configurations. Implement a function to facilitate kernel execution with specific input tensor shapes and memory strides.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef bone_gradwb(\n a, b, c, w, # Pointers to matrices\n BL, M, N, K, # Matrix dimensions\n s_ab, s_am, s_ak, s_bb, s_bk, s_bn, s_cp, s_cm, s_cn, s_wm, s_wn, # Strides\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr,\n):\n i_n = tl.program_id(0)\n offs_B = i_n // BL\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n o_m = tl.arange(0, BM)\n o_block = tl.arange(0, 64)\n o_wn = o_bn % N\n\n p_a = a + (o_m[:, None] * s_am + o_k[None, :] * s_ak + offs_B * s_ab)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn + offs_B * s_bb)\n p_w = w + s_wm * o_block[:, None] + s_wn * o_wn[None, :]\n\n dc = tl.zeros((64, 64), dtype=tl.float32)\n for m in range(0, tl.cdiv(M, BM)):\n b_dw = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_a = tl.load(p_a, mask=(o_k[None, :] < K - k * BK), other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_dw += tl.dot(b_a, b_b, allow_tf32=False)\n # Advance the ptrs to the next K block.\n p_a += BK * s_ak\n p_b += BK * s_bk\n\n b_w = tl.load(p_w)\n p_a += BM * s_am\n p_w += BM * s_wm\n p_a -= K * s_ak\n p_b -= K * s_bk\n dc += tl.dot(b_w.T, b_dw.to(b_w.dtype), allow_tf32=False).to(b_w.dtype) + b_dw\n\n p_c = c + o_block[:, None] * s_cm + o_block[None, :] * s_cn + i_n * s_cp\n tl.store(p_c, dc.to(c.dtype.element_ty))\n\ndef bone_bwd_wb(\n x: torch.Tensor,\n do: torch.Tensor,\n w: torch.Tensor,\n bone_g: int,\n bone_b: int,\n) -> torch.Tensor:\n B, M, K = x.shape\n _, K, O = do.shape\n N = B * O\n \n c = torch.zeros((B, bone_g, bone_b, bone_b), dtype=x.dtype, device=x.device)\n BM = BN = bone_b\n BL = triton.cdiv(O, BN)\n\n grid = (triton.cdiv(N, BN),)\n bone_gradwb[grid](\n x, do, c, w, \n BL, M, O, K,\n x.stride(0), x.stride(1), x.stride(2),\n do.stride(0), do.stride(1), do.stride(2),\n c.stride(1), c.stride(2), c.stride(3),\n w.stride(0), w.stride(1),\n BM=BM, BN=BN, G=4,\n ACTIVATION=None,\n )\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication and gradient accumulation kernel named 'bone_gradwb'. The kernel takes as input pointers to matrices a, b, c, and w, matrix dimensions BL, M, N, K, and the strides for each of these matrices. It also takes several compile-time constants including BM, BK, BN, G, and ACTIVATION. The kernel performs matrix operations including loading submatrices, dot products, and accumulation of gradients into the output matrix c.", - "description_2": "Implement a Triton kernel to perform matrix multiplication and gradient accumulation with support for specific matrix dimensions and strides.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef bone_gradx(\n a, b, c, bone,\n M, N, K,\n s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)\n\n o_am = (i_m * BM + tl.arange(0, BM))\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n p_bone = bone + o_k[:, None] * s_bonem + o_k[None, :] * s_bonen\n\n b_acc = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_bone = tl.load(p_bone)\n b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_b += tl.dot(b_bone, b_b, allow_tf32=False).to(b_b.dtype) + b_bone\n\n b_acc += tl.dot(b_a, b_b, allow_tf32=False)\n p_a += BK * s_ak\n p_b += BK * s_bk\n p_bone += s_bonep\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n\n p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]\n tl.store(p_c, b_acc.to(c.dtype.element_ty))\n\ndef bone_bwd(\n do: torch.Tensor,\n b: torch.Tensor,\n bone: torch.Tensor,\n) -> torch.Tensor:\n B, L, K = do.shape\n M = B * L\n K, N = b.shape\n _, block, _ = bone.shape\n c = do.new_empty(B, L, N)\n BK = BN = block\n\n def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, BN))\n bone_gradx[grid](\n do, b, c, bone,\n M, N, K,\n do.stride(1), do.stride(2),\n b.stride(0), b.stride(1),\n c.stride(1), c.stride(2),\n bone.stride(0), bone.stride(1), bone.stride(2),\n BK=BK, BN=BN, G=4,\n ACTIVATION=None,\n )\n return c\n", - "description_1": "Use triton language to implement a kernel function 'bone_gradx' that computes the matrix multiplication C = A x B with additional operations involving a 'bone' matrix. The kernel takes 20 parameters: 4 pointers to matrices (a, b, c, bone), 9 integers for matrix dimensions and strides (M, N, K, s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen), and 5 meta-parameters (BM, BK, BN, G, ACTIVATION). The function 'bone_bwd' calls this kernel to perform the backward pass of a custom operation, taking 3 parameters: do (gradient tensor), b (matrix), and bone (matrix), and returns a tensor c.", - "description_2": "Use triton language to create a kernel for matrix multiplication with additional operations, and implement a backward pass function that utilizes this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef bone_gradx(\n x, do, w, c, bone,\n M, N, K, s_xd, s_sl, s_am, s_ak, s_bk, s_bn, s_cm, s_cn, s_bonep, s_bonem, s_bonen,\n BM: tl.constexpr, BK: tl.constexpr, BN: tl.constexpr, G: tl.constexpr, ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n NM, NN = tl.num_programs(0), tl.num_programs(1)\n i_m, i_n = tl.program_id(0), tl.program_id(1)\n i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)\n\n o_am = (i_m * BM + tl.arange(0, BM))\n o_bn = (i_n * BN + tl.arange(0, BN)) % N\n o_k = tl.arange(0, BK)\n\n p_do = do + (o_am[:, None] * s_am + o_k[None, :] * s_ak)\n p_w = w + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)\n p_bone = bone + o_k[:, None] * s_bonem + o_k[None, :] * s_bonen\n\n b_acc = tl.zeros((BM, BN), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BK)):\n b_bone = tl.load(p_bone)\n b_a = tl.load(p_do, mask=o_k[None, :] < K - k * BK, other=0.0)\n b_b = tl.load(p_w, mask=o_k[:, None] < K - k * BK, other=0.0)\n b_b += tl.dot(b_bone, b_b, allow_tf32=False).to(b_b.dtype) + b_bone\n\n b_acc += tl.dot(b_a, b_b, allow_tf32=False)\n p_do += BK * s_ak\n p_w += BK * s_bk\n p_bone += s_bonep\n\n o_cm = i_m * BM + tl.arange(0, BM)\n o_cn = i_n * BN + tl.arange(0, BN)\n\n p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]\n tl.store(p_c, b_acc.to(c.dtype.element_ty))\n\ndef bone_bwd(\n do: torch.Tensor,\n w: torch.Tensor,\n bone: torch.Tensor,\n) -> torch.Tensor:\n B, L, K = do.shape\n M = B * L\n N, K = w.shape\n _, block, _ = bone.shape\n c = do.new_empty(B, L, N)\n BK = BN = block\n\n def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, BN))\n bone_gradx[grid](\n do, w, c, bone,\n M, N, K,\n do.stride(1), do.stride(2),\n w.stride(1), w.stride(0),\n c.stride(1), c.stride(2),\n bone.stride(0), bone.stride(2), bone.stride(1),\n BK=BK, BN=BN, G=4,\n ACTIVATION=None,\n )\n return c\n", - "description_1": "Use triton language to implement a kernel 'bone_gradx' which computes the matrix multiplication C = A x B for matrices with specific strides and sizes. The kernel is decorated with @triton.jit and uses parameters to specify matrix dimensions, strides, and additional metadata for the computation. A wrapper function 'bone_bwd' is also implemented to prepare and invoke the 'bone_gradx' kernel, passing tensors and their respective metadata.", - "description_2": "Use triton language to create a matrix multiplication kernel and a wrapper function to execute the kernel with the appropriate parameters.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Bias,\n Out,\n Lse,\n TMP,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_ob,\n stride_oh,\n stride_om,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for forward pass of FlashAttention\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out,\n DO,\n Delta,\n stride_ob,\n stride_oh,\n stride_om,\n stride_dob,\n stride_doh,\n stride_dom,\n nheads,\n seqlen_q,\n seqlen_q_rounded,\n headdim,\n BLOCK_M: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n):\n # Triton kernel for backward pass preprocessing\n\n@triton.jit\ndef _bwd_store_dk_dv(\n dk_ptrs,\n dv_ptrs,\n dk,\n dv,\n offs_n,\n offs_d,\n seqlen_k,\n headdim,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n):\n # Triton kernel to store gradients for DK and DV\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n,\n Q,\n K,\n V,\n Bias,\n DO,\n DQ,\n DK,\n DV,\n LSE,\n D,\n softmax_scale,\n stride_qm,\n stride_kn,\n stride_vn,\n stride_bm,\n stride_dom,\n stride_dqm,\n stride_dkn,\n stride_dvn,\n seqlen_q,\n seqlen_k,\n headdim,\n ATOMIC_ADD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for one column block of backward pass\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False},\n num_warps=8,\n num_stages=1,\n pre_hook=init_to_zero(\"DQ\"),\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True},\n num_warps=8,\n num_stages=1,\n pre_hook=init_to_zero(\"DQ\"),\n ),\n ],\n key=[\n \"CACHE_KEY_SEQLEN_Q\",\n \"CACHE_KEY_SEQLEN_K\",\n \"BIAS_TYPE\",\n \"IS_CAUSAL\",\n \"BLOCK_HEADDIM\",\n ],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q,\n K,\n V,\n Bias,\n DO,\n DQ,\n DK,\n DV,\n LSE,\n D,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_dob,\n stride_doh,\n stride_dom,\n stride_dqb,\n stride_dqh,\n stride_dqm,\n stride_dkb,\n stride_dkh,\n stride_dkn,\n stride_dvb,\n stride_dvh,\n stride_dvn,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n SEQUENCE_PARALLEL: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for backward pass of FlashAttention\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n (batch, seqlen_q, nheads, d) = q.shape\n (_, seqlen_k, _, _) = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\n \"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\"\n )\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (\n (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n )\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty(\n (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32\n )\n tmp = torch.empty(\n (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32\n )\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q,\n k,\n v,\n bias,\n o,\n lse,\n tmp,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM,\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1\n )\n return (o, lse, softmax_scale)\n\ndef _flash_attn_backward(\n do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None\n):\n if do.stride(-1) != 1:\n do = do.contiguous()\n (batch, seqlen_q, nheads, d) = q.shape\n (_, seqlen_k, _, _) = k.shape\n assert d <= 128\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n assert lse.shape == (batch, nheads, seqlen_q_rounded)\n assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1\n assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n dq_accum = torch.empty_like(q, dtype=torch.float32)\n delta = torch.empty_like(lse)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _bwd_preprocess_do_o_dot[grid](\n o,\n do,\n delta,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n do.stride(0),\n do.stride(2),\n do.stride(1),\n nheads,\n seqlen_q,\n seqlen_q_rounded,\n d,\n BLOCK_M=128,\n BLOCK_HEADDIM=BLOCK_HEADDIM,\n )\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n assert bias.stride(-1) == 1\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\n \"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\"\n )\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (\n (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n )\n grid = lambda META: (\n triton.cdiv(seqlen_k, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1,\n batch * nheads,\n )\n _bwd_kernel[grid](\n q,\n k,\n v,\n bias,\n do,\n dq_accum,\n dk,\n dv,\n lse,\n delta,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n do.stride(0),\n do.stride(2),\n do.stride(1),\n dq_accum.stride(0),\n dq_accum.stride(2),\n dq_accum.stride(1),\n dk.stride(0),\n dk.stride(2),\n dk.stride(1),\n dv.stride(0),\n dv.stride(2),\n dv.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM\n )\n dq.copy_(dq_accum)\n", - "description_1": "Use triton language to implement the forward and backward pass of FlashAttention. The forward kernel (_fwd_kernel) takes parameters such as queries (Q), keys (K), values (V), optional bias (Bias), output tensor (Out), log-sum-exp tensor (Lse), temporary storage (TMP), a scaling factor for softmax (softmax_scale), and stride parameters for each of the input tensors. It computes the scaled dot-product attention. The backward kernel (_bwd_kernel) computes gradients for Q, K, V given the gradient of the output (DO), and also uses intermediate outputs from the forward pass (LSE and Out). It takes similar parameters along with additional tensors for gradients (DQ, DK, DV). Both kernels include tuning options for block sizes and handling different input shapes efficiently.", - "description_2": "Use triton language to compute forward and backward passes for FlashAttention with optional bias and support for efficient parallel execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for cross scan and merge operations\n@triton.jit\ndef triton_cross_scan_flex(\n x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)\n y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)\n x_layout: tl.constexpr,\n y_layout: tl.constexpr,\n operation: tl.constexpr,\n onebyone: tl.constexpr,\n scans: tl.constexpr,\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n pos_h = (i_h * BH + tl.arange(0, BH)[:, None])\n pos_w = (i_w * BW + tl.arange(0, BW)[None, :])\n neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])\n neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])\n if scans == 0:\n HWRoute0 = pos_h * DW + pos_w\n HWRoute1 = pos_w * DH + pos_h\n HWRoute2 = neg_h * DW + neg_w\n HWRoute3 = neg_w * DH + neg_h\n elif scans == 1:\n HWRoute0 = pos_h * DW + pos_w\n HWRoute1 = HWRoute0\n HWRoute2 = HWRoute0\n HWRoute3 = HWRoute0\n elif scans == 2:\n HWRoute0 = pos_h * DW + pos_w\n HWRoute1 = HWRoute0\n HWRoute2 = neg_h * DW + neg_w\n HWRoute3 = HWRoute2 \n\n _tmp1 = DC * DH * DW\n\n y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)\n if y_layout == 0:\n p_y1 = y_ptr_base + HWRoute0\n p_y2 = y_ptr_base + _tmp1 + HWRoute1\n p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2\n p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3\n else:\n p_y1 = y_ptr_base + HWRoute0 * 4 * DC\n p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC\n p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC\n p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC \n \n if onebyone == 0:\n x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)\n if x_layout == 0:\n p_x = x_ptr_base + HWRoute0\n else:\n p_x = x_ptr_base + HWRoute0 * DC\n\n if operation == 0:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n _x = tl.load(p_x + _idx_x, mask=_mask_hw)\n tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)\n tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)\n tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)\n tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)\n elif operation == 1:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)\n _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)\n _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)\n _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)\n tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)\n\n else:\n x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)\n if x_layout == 0:\n p_x1 = x_ptr_base + HWRoute0\n p_x2 = p_x1 + _tmp1\n p_x3 = p_x2 + _tmp1\n p_x4 = p_x3 + _tmp1 \n else:\n p_x1 = x_ptr_base + HWRoute0 * 4 * DC\n p_x2 = p_x1 + DC\n p_x3 = p_x2 + DC\n p_x4 = p_x3 + DC \n \n if operation == 0:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n else:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)\n tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)\n tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)\n tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)\n\n\nclass CrossScanTritonF(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):\n if one_by_one:\n if in_channel_first:\n B, _, C, H, W = x.shape\n else:\n B, H, W, _, C = x.shape\n else:\n if in_channel_first:\n B, C, H, W = x.shape\n else:\n B, H, W, C = x.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = 1, 32, 32\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n \n ctx.in_channel_first = in_channel_first\n ctx.out_channel_first = out_channel_first\n ctx.one_by_one = one_by_one\n ctx.scans = scans\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n\n y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))\n triton_cross_scan_flex[(NH * NW, NC, B)](\n x.contiguous(), y, \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, \n BC, BH, BW, C, H, W, NH, NW\n )\n return y\n \n @staticmethod\n def backward(ctx, y: torch.Tensor):\n in_channel_first = ctx.in_channel_first\n out_channel_first = ctx.out_channel_first\n one_by_one = ctx.one_by_one\n scans = ctx.scans\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n if one_by_one:\n x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))\n else:\n x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))\n \n triton_cross_scan_flex[(NH * NW, NC, B)](\n x, y.contiguous(), \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,\n BC, BH, BW, C, H, W, NH, NW\n )\n return x, None, None, None, None\n\n\nclass CrossMergeTritonF(torch.autograd.Function):\n @staticmethod\n def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):\n if out_channel_first:\n B, _, C, H, W = y.shape\n else:\n B, H, W, _, C = y.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = 1, 32, 32\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.in_channel_first = in_channel_first\n ctx.out_channel_first = out_channel_first\n ctx.one_by_one = one_by_one\n ctx.scans = scans\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n if one_by_one:\n x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))\n else:\n x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))\n triton_cross_scan_flex[(NH * NW, NC, B)](\n x, y.contiguous(), \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,\n BC, BH, BW, C, H, W, NH, NW\n )\n return x\n \n @staticmethod\n def backward(ctx, x: torch.Tensor):\n in_channel_first = ctx.in_channel_first\n out_channel_first = ctx.out_channel_first\n one_by_one = ctx.one_by_one\n scans = ctx.scans\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))\n triton_cross_scan_flex[(NH * NW, NC, B)](\n x.contiguous(), y, \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,\n BC, BH, BW, C, H, W, NH, NW\n )\n return y, None, None, None, None, None\n\n\ndef cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):\n CSF = CrossScanTritonF if x.is_cuda and (not force_torch) else CrossScanF\n with torch.cuda.device(x.device):\n return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)\n\n\ndef cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):\n CMF = CrossMergeTritonF if y.is_cuda and (not force_torch) else CrossMergeF\n with torch.cuda.device(y.device):\n return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)\n", - "description_1": "Use triton language to implement a flexible cross scan and merge operation on tensors. The kernel function 'triton_cross_scan_flex' takes 14 parameters: two tensors (x and y), four layout and operation specifiers (x_layout, y_layout, operation, onebyone), a scan type specifier (scans), and seven constants (BC, BH, BW, DC, DH, DW, NH, NW) that define the block and grid sizes. The function performs different operations based on the 'operation' parameter: 0 for scan and 1 for merge. The 'scans' parameter determines the type of scan: 0 for cross scan, 1 for unidirectional, and 2 for bidirectional. The 'onebyone' parameter specifies whether the operation is applied one by one. The 'CrossScanTritonF' and 'CrossMergeTritonF' classes wrap this kernel for use in PyTorch's autograd system, providing forward and backward methods for the scan and merge operations, respectively.", - "description_2": "Use triton language to create a kernel for cross scan and merge operations on tensors, with parameters for layout, operation type, scan type, and block/grid sizes. Implement PyTorch autograd functions to wrap this kernel for forward and backward passes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X,\n Y,\n OUT,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_out_row,\n ncols,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X,\n Y,\n DOUT,\n OUT,\n DX,\n DY,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dout_row,\n stride_out_row,\n stride_dx_row,\n stride_dy_row,\n ncols,\n BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement two kernels: _swiglu_fwd_kernel and _swiglu_bwd_kernel. The _swiglu_fwd_kernel takes 7 parameters: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, and ncols. It computes the forward pass of the SwiGLU activation function using Triton. The _swiglu_bwd_kernel takes 14 parameters: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, and RECOMPUTE_OUTPUT. It computes the backward pass of the SwiGLU activation function, optionally recomputing the output if needed.", - "description_2": "Use triton language to create forward and backward kernels for the SwiGLU activation function, handling input and output strides, and optionally recomputing outputs during the backward pass.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n", - "description_1": "Use triton language to implement a layer normalization forward pass kernel with parameters: X (input tensor), Y (output tensor), W (weights), B (biases), Z (optional tensor for gating), Mean (mean of input), Rstd (reciprocal of standard deviation), stride_x_row (stride for input rows), stride_y_row (stride for output rows), stride_z_row (stride for Z tensor), M (number of rows), N (number of columns), eps (epsilon for numerical stability), BLOCK_N (block size for columns), HAS_BIAS (flag for bias), HAS_Z (flag for Z tensor), NORM_BEFORE_GATE (flag for normalization before gating), IS_RMS_NORM (flag for RMS normalization). The kernel computes the mean and variance, normalizes the input, applies a linear transformation, and optionally applies a gating mechanism.", - "description_2": "Use triton language to implement a function that calls the layer normalization forward pass kernel with parameters: x (input tensor), weight (weights), bias (biases), eps (epsilon for numerical stability), z (optional tensor for gating), out (output tensor), group_size (size of groups for normalization), norm_before_gate (flag for normalization before gating), is_rms_norm (flag for RMS normalization). The function prepares the input data, allocates output tensors, and launches the kernel with appropriate grid and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt)\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 45 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 9 parameters to prepare and invoke the kernel.", - "description_2": "Use triton language to create a kernel for matrix state updates with optional bias and scaling, and a wrapper to manage inputs and call the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Forward kernel function\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n # Implementation details omitted for brevity\n pass\n\n# Backward kernel function\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n # Implementation details omitted for brevity\n pass\n\n# Function to call forward kernel\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n # Function implementation omitted for brevity\n pass\n\n# Function to call backward kernel\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n # Function implementation omitted for brevity\n pass\n", - "description_1": "Use triton language to implement a batched matrix multiplication (BMM) forward kernel with optional causal masking and sequence index handling. The kernel computes the dot product between slices of two input matrices, handles chunking, and stores the result in the output tensor, considering various strides and chunk sizes.", - "description_2": "Use triton language to implement a batched matrix multiplication (BMM) backward kernel that computes the gradients of the input matrices based on the gradient of the output. It optionally incorporates residuals, updates gradients with respect to both inputs, and stores the results.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange, repeat\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head\n prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n if HAS_SEQ_IDX:\n seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n if IS_TRITON_22 or pid_c > -1:\n offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)\n prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)\n if not HAS_SEQ_IDX:\n scale_m = tl.exp(dA_cs_m)\n else:\n scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)\n if BLOCK_SIZE_DSTATE <= 128:\n C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)\n prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n prev_states = prev_states.to(C_ptr.dtype.element_ty)\n acc = tl.dot(C, prev_states) * scale_m[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)\n prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n prev_states = prev_states.to(C_ptr.dtype.element_ty)\n acc += tl.dot(C, prev_states)\n C_ptrs += BLOCK_SIZE_K\n prev_states_ptrs += BLOCK_SIZE_K\n acc *= scale_m[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)\n for k in range(0, K_MAX, BLOCK_SIZE_K):\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)\n cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))\n dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)\n cb *= dt_k\n if IS_CAUSAL:\n mask = offs_m[:, None] >= k + offs_k[None, :]\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(x_ptr.dtype.element_ty)\n x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)\n acc += tl.dot(cb, x)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n if HAS_D:\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),\n mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n acc += x_residual * D\n\n if HAS_Z:\n out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])\n tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))\n\n z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head\n z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])\n z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)\n acc *= z * tl.sigmoid(z)\n\n out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)\n tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))\n\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n int(chunk_size), int(headdim), int(dstate),\n int(batch), int(seqlen), int(nheads // ngroups),\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(int(dstate)), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=version.parse(triton.__version__) >= version.parse('2.2.0'),\n )\n return out, out_x\n", - "description_1": "Use triton language to implement a forward scan operation on chunks of data, processing input matrices with optional parameters like D and z, supporting various configurations and optimizations.", - "description_2": "Use triton language to perform a forward scan on data chunks, utilizing input matrices and optional parameters for optimized computation.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\ndef init_to_zero(names):\n return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}),\n triton.Config({'BLOCK_SIZE_H': 2}),\n triton.Config({'BLOCK_SIZE_H': 4}),\n triton.Config({'BLOCK_SIZE_H': 8}),\n triton.Config({'BLOCK_SIZE_H': 16}),\n triton.Config({'BLOCK_SIZE_H': 32}),\n triton.Config({'BLOCK_SIZE_H': 64}),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n # Implementation details are omitted for brevity\n pass\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_bwd_kernel(\n ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,\n ddt_ptr, dA_ptr, ddt_bias_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,\n stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,\n stride_dA_head,\n stride_ddt_bias_head,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n # Implementation details are omitted for brevity\n pass\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n int(batch), int(seqlen), int(nheads), int(chunk_size),\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n\ndef _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), ddt=None):\n batch, seqlen, nheads = dt.shape\n _, _, nchunks, chunk_size = ddA.shape\n if dt_bias is not None:\n ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)\n else:\n ddt_bias = None\n if ddt is not None:\n ddt = torch.empty_like(dt)\n else:\n ddt = torch.empty_like(dt)\n dA = torch.empty_like(A, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_bwd_kernel[grid_chunk_cs](\n ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,\n int(batch), int(seqlen), int(nheads), int(chunk_size),\n dt_limit[0], dt_limit[1],\n ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),\n ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n ddt.stride(0), ddt.stride(1), ddt.stride(2),\n dA.stride(0),\n ddt_bias.stride(0) if ddt_bias is not None else 0,\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return ddt, dA, ddt_bias\n", - "description_1": "Use triton language to implement a chunk-based cumulative sum forward kernel (_chunk_cumsum_fwd_kernel) and backward kernel (_chunk_cumsum_bwd_kernel) with support for optional bias and softplus operations. These kernels operate on inputs representing batch, sequence length, heads, and chunks, handling dimensions and strides to perform computations across chunks and heads efficiently. The forward function (_chunk_cumsum_fwd) initializes output tensors and launches the kernel with appropriate meta-parameters, while the backward function (_chunk_cumsum_bwd) computes gradients for inputs and bias.", - "description_2": "Use triton language to implement forward and backward chunk-based cumulative sum kernels that support bias and softplus operations.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n ],\n key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n if not HAS_SEQ_IDX:\n scale = tl.exp(dA_cs_last - dA_cs_m)\n else:\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n\n offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc = tl.dot(b, dstates) * scale[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc += tl.dot(b, dstates)\n b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n acc *= scale[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit\n K_MIN = pid_m * BLOCK_SIZE_M\n cb_ptrs += K_MIN * stride_cb_csize_k\n dout_ptrs += K_MIN * stride_dout_seqlen\n dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n k = tl.multiple_of(k, BLOCK_SIZE_K)\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n mask = k + offs_k[None, :] >= offs_m[:, None]\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(dout_ptr.dtype.element_ty)\n acc += tl.dot(cb, dout)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n dx = acc * dt_m[:, None]\n dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n if HAS_D:\n dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n dx += dout_res * D\n tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if HAS_D:\n dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n if D_HAS_HDIM:\n dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n dD = tl.sum(dout_res * x, axis=0)\n tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n else:\n dD = tl.sum(dout_res * x)\n tl.store(dD_ptr, dD)\n ddt = tl.sum(acc * x, axis=1)\n ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n int(chunk_size), int(headdim), int(dstate),\n int(batch), int(seqlen), int(nheads // ngroups),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a backward kernel for chunk scan operations. The kernel, _chunk_scan_chunk_state_bwd_dx_kernel, takes 60 parameters including pointers to input and output matrices, matrix dimensions, strides, and meta-parameters. It computes the backward pass for a chunk scan operation, handling various configurations and optimizations based on the input parameters.", - "description_2": "Use triton language to implement a backward function, _chunk_scan_chunk_state_bwd_dx, which calls the triton kernel to compute gradients for chunk scan operations. The function takes 9 parameters including input tensors and optional parameters, and returns computed gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n int(dim), int(nchunks), int(seqlen if seq_idx is not None else 0), int(chunk_size if seq_idx is not None else 0),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n", - "description_1": "Use triton language to implement a forward state passing kernel with parameters for matrix pointers, dimensions, strides, and meta-parameters. The kernel computes new states based on input states and cumulative sums, optionally using initial states and sequence indices. The function _state_passing_fwd sets up the kernel execution with appropriate grid and strides, handling optional initial states and sequence indices.", - "description_2": "Use triton language to implement a backward state passing kernel with parameters for matrix pointers, dimensions, strides, and meta-parameters. The kernel computes gradients of states and cumulative sums, optionally using final states and sequence indices. The function _state_passing_bwd sets up the kernel execution with appropriate grid and strides, handling optional final states and sequence indices.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport numpy as np\nimport cupy as cp\nimport torch\n\n@triton.jit\ndef circ_pad(X,\n all_pads_0, all_pads_2, all_pads_4, all_pads_6,\n orig_dims_0, orig_dims_1, orig_dims_2, orig_dims_3,\n Y,\n Y_shape_1, Y_shape_2, Y_shape_3,\n X_len, Y_len, BLOCK_SIZE: tl.constexpr,):\n pid = tl.program_id(0)\n i = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n\n mask_y = i < Y_len\n\n i3 = i % Y_shape_3\n i2 = (i // Y_shape_3) % Y_shape_2\n i1 = (i // Y_shape_3 // Y_shape_2) % Y_shape_1\n i0 = i // Y_shape_3 // Y_shape_2 // Y_shape_1\n\n j0 = (i0 - all_pads_0 + orig_dims_0) % orig_dims_0\n j1 = (i1 - all_pads_2 + orig_dims_1) % orig_dims_1\n j2 = (i2 - all_pads_4 + orig_dims_2) % orig_dims_2\n j3 = (i3 - all_pads_6 + orig_dims_3) % orig_dims_3\n\n load_idx = orig_dims_3 * orig_dims_2 * orig_dims_1 * j0 + orig_dims_3 * orig_dims_2 * j1 + orig_dims_3 * j2 + j3\n mask_x = load_idx < X_len\n\n x = tl.load(X + load_idx, mask=mask_x)\n\n tl.store(Y + i, x, mask=mask_y)\n\ndef call_circ_pad(a_t, c_t, pads, orig_dims, out_dims):\n N = len(orig_dims)\n all_pads = np.zeros((N * 2,), dtype=np.int32)\n orig_dims = np.array(orig_dims, dtype=np.int32)\n out_dims = np.array(out_dims, dtype=np.int32)\n\n for i in range(np.size(pads) // 2):\n out_dims[N - i - 1] += pads[i * 2] + pads[i * 2 + 1]\n all_pads[N * 2 - 2 * i - 2] = pads[i * 2]\n all_pads[N * 2 - 2 * i - 1] = pads[i * 2 + 1]\n\n all_pads = all_pads.tolist()\n orig_dims = orig_dims.tolist()\n out_dims = out_dims.tolist()\n\n blockSize = 256\n numBlocks = tuple([int((np.prod(out_dims) + blockSize - 1) // blockSize)])\n\n circ_pad[numBlocks](a_t,\n all_pads[0], all_pads[2], all_pads[4], all_pads[6],\n orig_dims[0], orig_dims[1], orig_dims[2], orig_dims[3],\n c_t,\n out_dims[1], out_dims[2], out_dims[3],\n int(np.prod(orig_dims)), int(np.prod(out_dims)), BLOCK_SIZE=256\n )\n", - "description_1": "Use triton language to implement a circular padding operation. The kernel 'circ_pad' takes 15 parameters: X (input tensor), all_pads_0, all_pads_2, all_pads_4, all_pads_6 (padding values for each dimension), orig_dims_0, orig_dims_1, orig_dims_2, orig_dims_3 (original dimensions of the input tensor), Y (output tensor), Y_shape_1, Y_shape_2, Y_shape_3 (output tensor shapes), X_len, Y_len (lengths of input and output tensors), and BLOCK_SIZE (block size for parallel execution). The function 'call_circ_pad' is used to set up and launch the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a kernel for circular padding of a 4D tensor, and a function to configure and launch this kernel with specified padding and dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport numpy as np\nimport cupy as cp\n\n@triton.jit\ndef circ_pad(\n X,\n all_pads_0,\n all_pads_2,\n all_pads_4,\n all_pads_6,\n orig_dims_0,\n orig_dims_1,\n orig_dims_2,\n orig_dims_3,\n Y,\n Y_shape_1,\n Y_shape_2,\n Y_shape_3,\n X_len,\n Y_len,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n i = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n\n mask_y = i < Y_len\n\n i3 = i % Y_shape_3\n i2 = (i // Y_shape_3) % Y_shape_2\n i1 = (i // Y_shape_3 // Y_shape_2) % Y_shape_1\n i0 = i // Y_shape_3 // Y_shape_2 // Y_shape_1\n\n j0 = (i0 - all_pads_0 + orig_dims_0) % orig_dims_0\n j1 = (i1 - all_pads_2 + orig_dims_1) % orig_dims_1\n j2 = (i2 - all_pads_4 + orig_dims_2) % orig_dims_2\n j3 = (i3 - all_pads_6 + orig_dims_3) % orig_dims_3\n\n load_idx = (\n orig_dims_3 * orig_dims_2 * orig_dims_1 * j0\n + orig_dims_3 * orig_dims_2 * j1\n + orig_dims_3 * j2\n + j3\n )\n mask_x = load_idx < X_len\n\n x = tl.load(X + load_idx, mask=mask_x)\n\n tl.store(Y + i, x, mask=mask_y)\n\ndef call_circ_pad_kernel(inputs, outputs, input_desc, output_desc, pads, X_shape):\n inp_dtype = trt.nptype(input_desc[0].type)\n\n a_mem = cp.cuda.UnownedMemory(\n inputs[0], volume(input_desc[0].dims) * cp.dtype(inp_dtype).itemsize, self\n )\n c_mem = cp.cuda.UnownedMemory(\n outputs[0],\n volume(output_desc[0].dims) * cp.dtype(inp_dtype).itemsize,\n self,\n )\n\n a_ptr = cp.cuda.MemoryPointer(a_mem, 0)\n c_ptr = cp.cuda.MemoryPointer(c_mem, 0)\n\n a_d = cp.ndarray((volume(input_desc[0].dims)), dtype=inp_dtype, memptr=a_ptr)\n c_d = cp.ndarray((volume(output_desc[0].dims)), dtype=inp_dtype, memptr=c_ptr)\n\n a_t = torch.as_tensor(a_d, device=\"cuda\")\n c_t = torch.as_tensor(c_d, device=\"cuda\")\n\n N = len(X_shape)\n all_pads = np.zeros((N * 2,), dtype=np.int32)\n orig_dims = np.array(X_shape, dtype=np.int32)\n out_dims = np.array(X_shape, dtype=np.int32)\n\n for i in range(np.size(pads) // 2):\n out_dims[N - i - 1] += pads[i * 2] + pads[i * 2 + 1]\n all_pads[N * 2 - 2 * i - 2] = pads[i * 2]\n all_pads[N * 2 - 2 * i - 1] = pads[i * 2 + 1]\n\n all_pads = all_pads.tolist()\n orig_dims = orig_dims.tolist()\n out_dims = out_dims.tolist()\n\n blockSize = 256\n numBlocks = (int((np.prod(out_dims) + blockSize - 1) // blockSize),)\n\n circ_pad[numBlocks](\n a_t,\n all_pads[0],\n all_pads[2],\n all_pads[4],\n all_pads[6],\n orig_dims[0],\n orig_dims[1],\n orig_dims[2],\n orig_dims[3],\n c_t,\n out_dims[1],\n out_dims[2],\n out_dims[3],\n int(np.prod(orig_dims)),\n int(np.prod(out_dims)),\n BLOCK_SIZE=256,\n )\n", - "description_1": "Use triton language to implement a circular padding operation on a 4D tensor. The kernel 'circ_pad' takes 15 parameters: input tensor X, padding values for each dimension, original dimensions of the input tensor, output tensor Y, shape of the output tensor, lengths of input and output tensors, and a block size. The function calculates the indices for loading and storing data with circular padding and uses triton's load and store operations to perform the padding. The function 'call_circ_pad_kernel' prepares the input and output tensors, calculates necessary dimensions and padding, and launches the 'circ_pad' kernel with appropriate parameters.", - "description_2": "Use triton language to create a kernel for circular padding of a 4D tensor, and implement a function to prepare data and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport torch.nn as nn\nimport triton\nimport triton.language as tl\n\n# -----------------------------\n# Triton Kernels\n# -----------------------------\n\n@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,\n BLOCK_SIZE: tl.constexpr\n):\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n input_row_ptr = input_ptr + row_idx * input_row_stride + col_offsets\n output_row_ptr = output_ptr + row_idx * output_row_stride + col_offsets\n\n logits = tl.load(input_row_ptr, mask=mask, other=float('-inf'))\n max_logits = tl.max(logits, axis=0)\n logits = logits - max_logits\n exp_logits = tl.exp(logits)\n sum_exp_logits = tl.sum(exp_logits, axis=0) + 1e-6\n\n softmax_output = exp_logits / sum_exp_logits\n tl.store(output_row_ptr, softmax_output, mask=mask)\n\n@triton.jit\ndef layer_norm_kernel(\n x_ptr, weight_ptr, bias_ptr, y_ptr,\n N, eps: tl.constexpr,\n BLOCK_SIZE: tl.constexpr\n):\n row_idx = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n\n x_offset = x_ptr + row_idx * N + cols\n x = tl.load(x_offset, mask=mask, other=0.0)\n\n mean = tl.sum(x, axis=0) / N\n x_centered = x - mean\n var = tl.sum(x_centered * x_centered, axis=0) / N\n rstd = 1.0 / tl.sqrt(var + eps)\n\n w = tl.load(weight_ptr + cols, mask=mask, other=1.0)\n b = tl.load(bias_ptr + cols, mask=mask, other=0.0)\n\n y = (x_centered * rstd) * w + b\n tl.store(y_ptr + row_idx * N + cols, y, mask=mask)\n\n@triton.jit\ndef gelu_kernel(\n x_ptr, y_ptr, n_elements,\n BLOCK_SIZE: tl.constexpr\n):\n offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n x = tl.load(x_ptr + offsets, mask=mask)\n\n sqrt_2_over_pi = 0.7978845608028654\n coeff = sqrt_2_over_pi * (1 + 0.044715 * x * x)\n y = 0.5 * x * (1 + (x * coeff) / (1 + tl.abs(x * coeff)))\n\n tl.store(y_ptr + offsets, y, mask=mask)\n\n# -----------------------------------\n# Triton-accelerated Launch Functions\n# -----------------------------------\n\nclass TritonLayerNorm(nn.Module):\n def __init__(self, normalized_shape, eps=1e-5):\n super().__init__()\n self.normalized_shape = tuple(normalized_shape) if isinstance(normalized_shape, (tuple, list)) else (normalized_shape,)\n self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n self.bias = nn.Parameter(torch.zeros(self.normalized_shape))\n self.eps = eps\n\n def forward(self, x):\n assert x.shape[-len(self.normalized_shape):] == self.normalized_shape, \"Input shape does not match normalized_shape.\"\n y = torch.empty_like(x)\n x_ = x.reshape(-1, self.normalized_shape[-1])\n y_ = y.reshape(-1, self.normalized_shape[-1])\n M, N = x_.shape\n grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']),)\n layer_norm_kernel[grid](\n x_, self.weight, self.bias, y_,\n N, eps=self.eps,\n BLOCK_SIZE=128\n )\n return y\n\nclass TritonSoftmax(nn.Module):\n def forward(self, x):\n original_shape = x.shape\n if len(original_shape) > 2:\n x = x.view(-1, original_shape[-1])\n x = x.clamp(-100, 100)\n B, N = x.shape\n y = torch.empty_like(x)\n grid = lambda meta: (B,)\n softmax_kernel[grid](\n y, x,\n x.stride(0), y.stride(0), N,\n BLOCK_SIZE=triton.next_power_of_2(N)\n )\n y = y + 1e-8\n y = y / y.sum(dim=-1, keepdim=True)\n return y.view(original_shape)\n\nclass TritonGELU(nn.Module):\n def forward(self, x):\n n_elements = x.numel()\n y = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n gelu_kernel[grid](\n x, y, n_elements,\n BLOCK_SIZE=1024\n )\n return y\n", - "description_1": "Use triton language to implement three custom kernel functions: softmax_kernel (6 parameters: output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE) to perform row-wise softmax operation; layer_norm_kernel (7 parameters: x_ptr, weight_ptr, bias_ptr, y_ptr, N, eps, BLOCK_SIZE) for layer normalization; gelu_kernel (4 parameters: x_ptr, y_ptr, n_elements, BLOCK_SIZE) to apply the GELU activation function. These kernels are accelerated using Triton and integrated into PyTorch modules: TritonLayerNorm, TritonSoftmax, and TritonGELU, each having a forward method that sets up the execution grid and parameters for the corresponding kernel.", - "description_2": "Use triton language to implement custom softmax, layer norm, and GELU kernel functions, then wrap them into PyTorch modules with corresponding forward methods for execution.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Bias,\n Out,\n Lse,\n TMP,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_ob,\n stride_oh,\n stride_om,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == \"vector\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == \"matrix\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n elif EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n elif EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != \"none\":\n if BIAS_TYPE == \"vector\":\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == \"matrix\":\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n elif EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n elif EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n (batch, seqlen_q, nheads, d) = q.shape\n (_, seqlen_k, _, _) = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)\")\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q,\n k,\n v,\n bias,\n o,\n lse,\n tmp,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM,\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1\n )\n return (o, lse, softmax_scale)\n", - "description_1": "Use triton language to implement a forward kernel for FlashAttention. This kernel computes the scaled dot-product attention with optional causal masking and bias, supporting various block sizes and head dimensions up to 128. The kernel requires 36 parameters: input matrices Q, K, V, Bias, output matrix Out, auxiliary matrices Lse and TMP, scaling factor softmax_scale, strides for accessing memory, dimensions for query and key sequences (nheads, seqlen_q, seqlen_k, seqlen_q_rounded), headdim, cache keys (CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K), constant parameters specifying bias type, causality, block dimensions, and evenness flags for the dimensions.", - "description_2": "Use triton language to implement the FlashAttention forward pass, encapsulating it with a Python function to handle setup and execution. The function `_flash_attn_forward` sets up input tensors, checks prerequisites, calculates grid size for Triton, and executes the `_fwd_kernel` for processing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X,\n Y,\n OUT,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_out_row,\n ncols,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X,\n Y,\n DOUT,\n OUT,\n DX,\n DY,\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dout_row,\n stride_out_row,\n stride_dx_row,\n stride_dy_row,\n ncols,\n BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n x.stride(0), y.stride(0), dout.stride(0),\n out.stride(0) if recompute_output else 0,\n dx.stride(0), dy.stride(0),\n N)\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement a forward and backward pass of the SwiGLU activation function. The forward kernel '_swiglu_fwd_kernel' takes 7 arguments: two input tensors X, Y, an output tensor OUT, strides for X, Y, and OUT, the number of columns, and a BLOCK_N constant for block size. It computes the element-wise product of X, its sigmoid, and Y, storing the result in OUT. The backward kernel '_swiglu_bwd_kernel' has 13 arguments: input tensors X, Y, the gradient DOUT, optional tensor OUT, gradient tensors DX, DY, strides for these tensors, the number of columns, BLOCK_N constant, and a flag RECOMPUTE_OUTPUT. It computes the gradient of the SwiGLU activation with respect to X and Y, optionally recomputing the output if needed. The corresponding functions '_swiglu_fwd' and '_swiglu_bwd' prepare and launch these kernels with appropriate grid and block configurations.", - "description_2": "Use triton language to define the forward and backward computations for the SwiGLU activation, optimizing for various block sizes using autotuning, and manage tensor strides and reshaping for GPU execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n", - "description_1": "Use triton language to implement a forward pass of a layer normalization operation. The kernel function '_layer_norm_fwd_1pass_kernel' takes 17 parameters: pointers to input, output, weights, biases, other branch, mean, and 1/std, strides for input, output, and other branch, number of rows and columns in input, epsilon for numerical stability, and several compile-time constants. The function maps program IDs to rows of input and output, computes mean and variance, normalizes the input, applies a linear transformation, and writes the output. The wrapper function '_layer_norm_fwd' prepares the input, output, and other necessary parameters, and launches the kernel with appropriate grid and block sizes.", - "description_2": "Use triton language to implement a forward pass of a layer normalization operation with kernel function '_layer_norm_fwd_1pass_kernel' and wrapper function '_layer_norm_fwd'.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt)\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 45 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 10 parameters to prepare and call the kernel.", - "description_2": "Use triton language to create a kernel for matrix state updates with optional bias and scaling, and a wrapper to manage inputs and call the kernel.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n @triton.jit\n def softplus(dt):\n # Applies the softplus function element-wise.\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n @triton.jit\n def softplus(dt):\n # Applies the softplus function element-wise.\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n", - "description_1": "Use triton language to implement a softplus function kernel that takes one parameter 'dt', a tensor, and applies the softplus function element-wise. The function uses different implementations based on the Triton version.", - "description_2": "Use triton language to create a version-dependent softplus function kernel for element-wise computation on tensors.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n seqlen, chunk_size, k, ngroups if has_groups else 1,\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement two kernels: _bmm_chunk_fwd_kernel and _bmm_chunk_bwd_kernel. The _bmm_chunk_fwd_kernel performs a batched matrix multiplication with optional sequence index masking and causal masking. It takes 24 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. The _bmm_chunk_bwd_kernel computes the gradient of the batched matrix multiplication with respect to one of the input matrices. It takes 23 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. Both kernels are called by their respective wrapper functions _bmm_chunk_fwd and _bmm_chunk_bwd, which handle input preparation and kernel invocation.", - "description_2": "Use triton language to create forward and backward kernels for batched matrix multiplication with optional sequence and causal masking, and implement wrapper functions to manage input preparation and kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange, repeat\nfrom mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel implementation here\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n\nclass ChunkScanFn(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, ngroups, dstate = B.shape\n assert B.shape == (batch, seqlen, ngroups, dstate)\n _, _, nchunks, chunk_size = dt.shape\n assert seqlen == nchunks * chunk_size\n assert C.shape == B.shape\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)\n if B.stride(-1) != 1:\n B = B.contiguous()\n if C.stride(-1) != 1:\n C = C.contiguous()\n if x.stride(-1) != 1 and x.stride(1) != 1:\n x = x.contiguous()\n if z is not None and z.stride(-1) != 1 and z.stride(1) != 1:\n z = z.contiguous()\n if D is not None and D.stride(-1) != 1:\n D = D.contiguous()\n CB = _bmm_chunk_fwd(C, B, chunk_size)\n out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)\n ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)\n return out\n\ndef chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z)\n", - "description_1": "Use triton language to create a forward kernel _chunk_scan_fwd_kernel, that performs chunk-wise scanning of input matrices. This kernel takes pointers to matrices, their dimensions, strides, and meta-parameters like IS_CAUSAL, HAS_D, D_HAS_HDIM, HAS_Z, and HAS_SEQ_IDX. It computes the forward pass using these inputs. The _chunk_scan_fwd function initializes necessary variables and launches the _chunk_scan_fwd_kernel with triton's grid. The ChunkScanFn class manages the forward pass by preparing inputs, calling _chunk_scan_fwd, and saving necessary tensors for the backward pass.", - "description_2": "Implement a Triton kernel function that executes chunk-wise scanning for matrix operations. Utilize triton.autotune to optimize block sizes and handle meta-parameters for conditional logic. In Python, wrap this kernel in a function that prepares inputs, manages memory, and invokes the Triton kernel with appropriate grid dimensions.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = softplus(dt)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n@triton.jit\ndef _chunk_cumsum_bwd_kernel(\n ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,\n ddt_ptr, dA_ptr, ddt_bias_ptr,\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,\n stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,\n stride_dA_head,\n stride_ddt_bias_head,\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk\n ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)\n ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n ddt = ddA * A[:, None] + ddt_out\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt_presoftplus = dt\n dt = softplus(dt)\n clamp_mask = (dt < dt_min) | (dt > dt_max)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)\n ddt = tl.where(clamp_mask, 0.0, ddt)\n if DT_SOFTPLUS:\n ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)\n tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))\n dA = tl.sum(ddA * dt, axis=1)\n tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)\n if HAS_DT_BIAS:\n ddt_bias = tl.sum(ddt, axis=1)\n tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n\ndef _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), ddt=None):\n batch, seqlen, nheads = dt.shape\n _, _, nchunks, chunk_size = ddA.shape\n assert ddA.shape == (batch, nheads, nchunks, chunk_size)\n assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)\n else:\n ddt_bias = None\n if ddt is not None:\n assert ddt.shape == dt.shape\n else:\n ddt = torch.empty_like(dt)\n dA = torch.empty_like(A, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_bwd_kernel[grid_chunk_cs](\n ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,\n batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1],\n ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),\n ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n ddt.stride(0), ddt.stride(1), ddt.stride(2),\n dA.stride(0),\n ddt_bias.stride(0) if ddt_bias is not None else 0,\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return ddt, dA, ddt_bias\n", - "description_1": "Use triton language to implement a forward and backward kernel for chunk-wise cumulative sum operations. The forward kernel (_chunk_cumsum_fwd_kernel) takes 20 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters. It computes the cumulative sum of a matrix with optional bias and softplus activation, storing the result in output pointers. The backward kernel (_chunk_cumsum_bwd_kernel) takes 22 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters. It computes gradients for the cumulative sum operation, storing the results in output pointers. The functions _chunk_cumsum_fwd and _chunk_cumsum_bwd are Python functions that call these kernels with appropriate grid configurations.", - "description_2": "Use triton language to create kernels for computing the forward and backward pass of a chunk-wise cumulative sum operation with optional bias and softplus activation. The forward kernel computes the cumulative sum and stores it, while the backward kernel computes the gradients for the operation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel function for chunk scan backward dx computation\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr, dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n if not HAS_SEQ_IDX:\n scale = tl.exp(dA_cs_last - dA_cs_m)\n else:\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n\n offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc = tl.dot(b, dstates) * scale[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc += tl.dot(b, dstates)\n b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n acc *= scale[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit\n K_MIN = pid_m * BLOCK_SIZE_M\n cb_ptrs += K_MIN * stride_cb_csize_k\n dout_ptrs += K_MIN * stride_dout_seqlen\n dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n k = tl.multiple_of(k, BLOCK_SIZE_K)\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(dout_ptr.dtype.element_ty)\n acc += tl.dot(cb, dout)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n dx = acc * dt_m[:, None]\n dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n if HAS_D:\n dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n dx += dout_res * D\n tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if HAS_D:\n dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n if D_HAS_HDIM:\n dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n dD = tl.sum(dout_res * x, axis=0)\n tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n else:\n dD = tl.sum(dout_res * x)\n tl.store(dD_ptr, dD)\n ddt = tl.sum(acc * x, axis=1)\n ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n# Function to call the above kernel function\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n chunk_size, headdim, dstate,\n batch, seqlen, nheads // ngroups,\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n\n", - "description_1": "Use triton language to implement a kernel for backward pass computation in chunk scan operations. The kernel '_chunk_scan_chunk_state_bwd_dx_kernel' takes pointers to input tensors (such as x, cb, dout, dt, etc.) and performs the gradient calculations. The kernel handles matrix multiplication and accumulation for the specified dimensions and blocks. The associated Python function '_chunk_scan_chunk_state_bwd_dx' sets up the environment and calls the kernel with the necessary grid and block configurations.", - "description_2": "Use triton language to define a kernel for gradient computation in a backward pass. Call the kernel from a Python function with appropriate grid and block configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n # Pointers to matrices\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n # Matrix dimensions\n dim, nchunks, seqlen, chunk_size,\n # Strides\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n # Meta-parameters\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n # Pointers to matrices\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n # Matrix dimensions\n dim, nchunks, seqlen, chunk_size,\n # Strides\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n # Meta-parameters\n CONVERT_STATES: tl.constexpr,\n HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n else:\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\n\ndef _state_passing_bwd(\n states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,\n dstates_dtype=None, states_dtype=None, chunk_size=None\n):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = torch.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert dfinal_states.shape == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n if dfinal_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n if dinitstates is not None else (0, 0, 0)),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n", - "description_1": "Use triton language to implement forward and backward kernels for state passing. The forward kernel _state_passing_fwd_kernel takes 33 parameters including pointers to matrices, matrix dimensions, strides, and meta-parameters, handling the operations of loading, scaling, and storing states iteratively over chunks. The backward kernel _state_passing_bwd_kernel, with 39 parameters, processes gradients similarly, computing updates to the gradients based on the stored output and scaling factors.", - "description_2": "Use triton language to create kernels for state passing that load, process, and store states and gradients using configurable block sizes and meta-parameters for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom .triton_utils.kernels import silu\n\n@triton.jit\ndef quant_fused_matmul_248_kernel(\n a_ptr,\n c_ptr,\n b1_ptr,\n scales1_ptr,\n zeros1_ptr,\n g1_ptr,\n b2_ptr,\n scales2_ptr,\n zeros2_ptr,\n g2_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = zeros1 + 1\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = zeros2 + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b1 = tl.load(b1_ptrs)\n b2 = tl.load(b2_ptrs)\n\n b1 = (b1 >> shifter[:, None]) & maxq\n b1 = (b1 - zeros1) * scales1\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\nclass FusedLlamaMLPForQuantizedModel:\n def __init__(self, gate_proj, down_proj, up_proj):\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.gate_proj = gate_proj\n self.up_proj = up_proj\n self.down_proj = down_proj\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size,)\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n quant_fused_matmul_248_kernel[grid](\n x,\n c,\n self.gate_proj.qweight,\n self.gate_proj.scales,\n self.gate_proj.qzeros,\n self.gate_proj.g_idx,\n self.up_proj.qweight,\n self.up_proj.scales,\n self.up_proj.qzeros,\n self.up_proj.g_idx,\n M,\n N,\n K,\n self.bits,\n self.maxq,\n x.stride(0),\n x.stride(1),\n self.gate_proj.qweight.stride(0),\n self.gate_proj.qweight.stride(1),\n c.stride(0),\n c.stride(1),\n self.gate_proj.scales.stride(0),\n self.gate_proj.qzeros.stride(0),\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to define a kernel 'quant_fused_matmul_248_kernel' with 24 parameters for a fused matrix multiplication and SiLU activation. The kernel takes pointers to input matrices A and B, scales, zeros, group indices, matrix dimensions (M, N, K), bit information, max quantization levels, strides, and block/group sizes as parameters. The kernel calculates: C = silu(A * B1) * (A * B2), where A is (M, K) float16 and B1, B2 are quantized (K//8, N) int32 matrices. The calling function, 'triton_llama_mlp', takes an input tensor 'x', reshapes it, and uses Triton for efficient kernel execution, providing dimensions and tensor strides to match the kernel's requirements.", - "description_2": "Use triton language to implement a fused matrix multiplication with SiLU operation, including custom quantization using scales and zeros, defined by a kernel with extensive tuning configurations and called within a Python class method.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, \n M, N, K, bits, maxq, \n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, \n stride_scales, stride_zeros, \n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, \n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n a_mask = offs_am[:, None] < M\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn\n )\n g_ptrs = g_ptr + offs_k\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros)\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b = tl.load(b_ptrs)\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, \n M, N, K, bits, maxq, \n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, \n stride_scales, stride_zeros, \n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, \n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak)\n a_mask = offs_am[:, None] < M\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn\n )\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n scales = tl.load(scales_ptrs)\n zeros = tl.load(zeros_ptrs)\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b = tl.load(b_ptrs)\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n quant_matmul_248_kernel[grid](\n input, qweight, output, scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,\n input.stride(0), input.stride(1), qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)\n )\n return output\n\n\ndef transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(output_dim, META[\"BLOCK_SIZE_K\"]),\n )\n transpose_quant_matmul_248_kernel[grid](\n input, qweight, output, scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_dim, bits, maxq,\n input.stride(0), input.stride(1), qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)\n )\n return output\n", - "description_1": "Use triton language to implement two kernel functions and their corresponding calling functions: 'quant_matmul_248_kernel' and 'transpose_quant_matmul_248_kernel'. Each kernel function takes 24 parameters, which include pointers to matrices, scalars, and stride information, along with compile-time constants for block sizes and group size. The kernels perform quantized matrix multiplication, supporting different shapes and data types for A, B, C, and the quantization factors. The corresponding Python functions, 'quant_matmul_248' and 'transpose_quant_matmul_248', facilitate the execution of these kernels, managing memory allocation and grid configuration.", - "description_2": "Use triton language to build efficient GPU kernels for quantized matrix multiplication with customizable block sizes and support for quantization parameters. Provide Python wrappers to set up and execute the kernels with proper memory management and grid launch parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton CUDA kernel\n@triton.jit\ndef update_fn_kernel(\n p_ptr,\n grad_ptr,\n exp_avg_ptr,\n lr,\n wd,\n beta1,\n beta2,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n\n mask = offsets < n_elements\n\n # Offsetted pointers\n offset_p_ptr = p_ptr + offsets\n offset_grad_ptr = grad_ptr + offsets\n offset_exp_avg_ptr = exp_avg_ptr + offsets\n\n # Load\n p = tl.load(offset_p_ptr, mask=mask)\n grad = tl.load(offset_grad_ptr, mask=mask)\n exp_avg = tl.load(offset_exp_avg_ptr, mask=mask)\n\n # Stepweight decay\n p = p * (1 - lr * wd)\n\n # Diff between momentum running average and grad\n diff = exp_avg - grad\n\n # Weight update\n update = diff * beta1 + grad\n\n # Torch.sign\n can_update = update != 0\n update_sign = tl.where(update > 0, -lr, lr)\n\n p = p + update_sign * can_update\n\n # Decay the momentum running average coefficient\n exp_avg = diff * beta2 + grad\n\n # Store new params and momentum running average coefficient\n tl.store(offset_p_ptr, p, mask=mask)\n tl.store(offset_exp_avg_ptr, exp_avg, mask=mask)\n\ndef update_fn(\n p: torch.Tensor,\n grad: torch.Tensor,\n exp_avg: torch.Tensor,\n lr: float,\n wd: float,\n beta1: float,\n beta2: float\n):\n assert all([t.is_cuda for t in (p, grad, exp_avg)])\n n_elements = p.numel()\n\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n\n update_fn_kernel[grid](\n p,\n grad,\n exp_avg,\n lr,\n wd,\n beta1,\n beta2,\n n_elements\n )\n", - "description_1": "Use triton language to implement a CUDA kernel that updates parameters and momentum running averages for optimization. The kernel takes 8 parameters: pointers to parameter, gradient, and exponential average tensors, learning rate, weight decay, two beta coefficients, and the number of elements. It computes updates using stepweight decay and momentum, and stores the results back.", - "description_2": "Use triton language to create a CUDA kernel for parameter updates in optimization, involving stepweight decay and momentum calculations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef get_kernel(\n x_ptr,\n index_ptr,\n y_ptr,\n BLOCK_SIZE: tl.constexpr\n):\n index_offsets = tl.arange(0, BLOCK_SIZE)\n index = tl.load(index_ptr + index_offsets)\n # Error on line 10\n x = tl.load(x_ptr + index[None] * BLOCK_SIZE + index[None, :])\n y = tl.store(y_ptr + index[:, None] * BLOCK_SIZE + index[None, :], x)\n\nBLOCK_SIZE = 128\nindex = torch.arange(BLOCK_SIZE, device='cuda', dtype=torch.long)\nx = torch.ones((BLOCK_SIZE, BLOCK_SIZE), device='cuda', dtype=torch.long)\ny = torch.zeros((BLOCK_SIZE, BLOCK_SIZE), device='cuda', dtype=torch.long)\n\nget_kernel[(1,)](x, index, y, BLOCK_SIZE)\nprint(y)\n", - "description_1": "Use triton language to define a kernel 'get_kernel' with four parameters: 'x_ptr' (pointer to input matrix x), 'index_ptr' (pointer to index array), 'y_ptr' (pointer to output matrix y), and 'BLOCK_SIZE' (constant expression for block size). The kernel calculates index offsets, loads indices from memory, loads the matrix x at calculated indices, and stores the result in matrix y. A torch script sets up the input and output matrices and calls the kernel function with appropriate parameters.", - "description_2": "Use triton language to load and store matrix elements at specified indices based on a block size.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom .utils import calculate_settings, MAX_FUSED_SIZE\n\n@triton.jit\ndef _cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]\n \"\"\"\n row_idx = tl.program_id(0)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx\n labels_ptr += row_idx\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = logsumexp - x\n else:\n loss = 0.0\n tl.store(logsumexp_ptr, logsumexp)\n tl.store(loss_ptr, loss)\npass\n\n@triton.jit\ndef _chunked_cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n N_CHUNKS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n 256K vocab divided in 4 chunks\n \"\"\"\n row_idx = tl.program_id(0)\n chunk_idx = tl.program_id(1)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n labels_ptr += row_idx\n\n col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if chunk_idx == 0:\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = -1.0 * x\n else:\n loss = 0.0\n tl.store(loss_ptr, loss)\n pass\n tl.store(logsumexp_ptr, logsumexp)\npass\n\n@triton.jit\ndef _cross_entropy_backward(\n logits_ptr, logits_row_stride,\n dloss_ptr, dloss_row_stride,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n CE_i = -y log(P) = y * (log[sum(exp(x))] - x)\n \"\"\"\n row_idx = tl.program_id(0)\n block_idx = tl.program_id(1)\n\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n dloss_ptr += row_idx * dloss_row_stride\n col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n if label_idx != -100:\n dloss = tl.load(dloss_ptr)\n else:\n dloss = 0.0\n\n x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n logsumexp = tl.load(logsumexp_ptr + row_idx)\n y = tl.exp(x - logsumexp)\n y = tl.where(\n col_offsets == label_idx,\n y - 1.0, # exp(x - logsumexp) - 1\n y, # exp(x - logsumexp)\n )\n\n tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)\npass\n\nclass Fast_CrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(ctx, logits, labels):\n n_rows, vocab_size = logits.shape\n\n div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n n_chunks = div + (mod != 0)\n losses = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n if n_chunks == 1:\n BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n logsumexp = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n _cross_entropy_forward[(n_rows,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n else:\n logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda\")\n\n _chunked_cross_entropy_forward[(n_rows, n_chunks,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n N_CHUNKS = n_chunks,\n BLOCK_SIZE = MAX_FUSED_SIZE,\n num_warps = 32,\n )\n logsumexp = torch.logsumexp(logsumexp, dim = 1)\n losses += logsumexp\n losses.masked_fill_(labels == -100, 0)\n pass\n\n ctx.save_for_backward(logits, logsumexp, labels)\n return losses\n pass\n\n @staticmethod\n def backward(ctx, dlosses):\n logits, logsumexp, labels = ctx.saved_tensors\n n_rows, vocab_size = logits.shape\n\n BLOCK_SIZE = 4096\n div, mod = divmod(vocab_size, BLOCK_SIZE)\n n_blocks = div + (mod != 0)\n\n _cross_entropy_backward[(n_rows, n_blocks,)](\n logits, logits.stride(0),\n dlosses, dlosses.stride(0),\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = 8,\n )\n return logits, None, None,\n pass\npass\n\ndef fast_cross_entropy_loss(logits, labels):\n \"\"\"\n Arguments:\n logits: (batch, seq_len, vocab_size)\n labels: (batch, seq_len,)\n Returns:\n losses: float\n \"\"\"\n batch, seq_len, d = logits.shape\n assert(labels.shape == (batch, seq_len))\n\n loss = Fast_CrossEntropyLoss.apply(\n logits.view(batch*seq_len, d),\n labels.view(-1),\n )\n n_items = torch.count_nonzero(labels != -100)\n return loss.sum() / n_items\npass\n", - "description_1": "Use triton language to implement cross-entropy loss and its backward pass for a given set of logits and labels. The forward function computes the loss using either a single kernel for small vocabularies or a chunked approach for large vocabularies. The backward function computes the gradient of the loss with respect to the logits.", - "description_2": "Use triton language to implement a cross-entropy loss function with forward and backward passes, handling both small and large vocabularies efficiently.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))\n # h = f * up\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)\n f_row = f_row.to(g_row.dtype) # Exact copy from HF\n h_row = f_row * g_row\n\n # Store h\n tl.store(h + offsets, h_row, mask=mask)\n\ndef geglu_exact_forward_kernel(gate, up):\n batch, seq_len, hd = gate.shape\n n_elements = gate.numel()\n out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE=1024)\n return out\n\n@triton.jit\ndef _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)\n f_row = f_partial_row * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n\n t = 0.3989422804014327 # 1/sqrt(2*pi)\n df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)\n\n de_row = dg_row.to(tl.float32) * df_de\n de_row = de_row.to(DW_row.dtype)\n\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\n\ndef geglu_exact_backward_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n\n@triton.jit\ndef _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n s = 0.7978845608028654 # math.sqrt(2 / math.pi)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_row = 0.5 * e_row * (\n tl.math.tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0\n )\n f_row = f_row.to(g_row.dtype) # Exact copy from HF\n h_row = f_row * g_row\n\n tl.store(h + offsets, h_row, mask=mask)\n\ndef geglu_approx_forward_kernel(gate, up):\n batch, seq_len, hd = gate.shape\n n_elements = gate.numel()\n out = torch.empty((batch, seq_len, hd), dtype=gate.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE=1024)\n return out\n\n@triton.jit\ndef _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n s = 0.7978845608028654 # math.sqrt(2 / math.pi)\n a = s * e_row\n b = a * 0.044715 * e_row * e_row\n T = 1.0 + tl.math.tanh(a + b)\n T2 = 0.5 * T\n Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)\n df_de = T2 + Q2\n\n f_row = T2 * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n\n de_row = dg_row.to(tl.float32) * df_de\n de_row = de_row.to(DW_row.dtype)\n\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\n\ndef geglu_approx_backward_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n", - "description_1": "Use triton language to implement four kernels: _exact_forward_kernel, _exact_backward_kernel, _approx_forward_kernel, and _approx_backward_kernel. Each kernel processes data in blocks, using a BLOCK_SIZE parameter to determine the size of each block. The forward kernels (_exact_forward_kernel and _approx_forward_kernel) compute a transformation on input tensors e and g, storing the result in tensor h. The backward kernels (_exact_backward_kernel and _approx_backward_kernel) compute gradients for input tensors DW, e, and g. The functions geglu_exact_forward_kernel, geglu_exact_backward_kernel, geglu_approx_forward_kernel, and geglu_approx_backward_kernel are used to call these kernels with appropriate grid settings.", - "description_2": "Use triton language to create forward and backward kernels for exact and approximate transformations on input tensors, utilizing block processing with a specified BLOCK_SIZE.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _rms_layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n n_cols, eps,\n BLOCK_SIZE : tl.constexpr\n):\n \"\"\"\n Fast RMS Layernorm kernel\n Inspiration from a Triton tutorial:\n https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)\n\n row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n inv_var = tl.math.rsqrt(row_var + eps)\n tl.store(r, inv_var)\n normed = X_row * inv_var\n normed = normed.to(W_row.dtype) # Exact copy from HF\n output = normed * W_row\n tl.store(Y + col_offsets, output, mask = mask)\npass\n\n\n@triton.jit\ndef _gemma_rms_layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n n_cols, eps,\n BLOCK_SIZE : tl.constexpr,\n):\n # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31\n # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33\n # exactly. Essentially all in float32!\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl\n tl.store(r, inv_var)\n normed = X_row * inv_var\n output = normed * (W_row + 1.0)\n\n tl.store(Y + col_offsets, output, mask = mask)\npass\n\n\nclass Fast_RMS_Layernorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, X, W, eps, gemma = False):\n shape = X.shape\n dim = shape[-1]\n X = X.view(-1, dim)\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda\")\n r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward\n fx[(n_rows,)](\n Y, Y.stride(0),\n X, X.stride(0),\n W, W.stride(0),\n r, r.stride(0),\n n_cols, eps,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n ctx.eps = eps\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.GEMMA = gemma\n ctx.save_for_backward(X, W, r)\n return Y.view(*shape)\n pass\n\ndef fast_rms_layernorm(layernorm, X, gemma = False):\n W = layernorm.weight\n eps = layernorm.variance_epsilon\n out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)\n return out\npass\n", - "description_1": "Use triton language to implement forward kernels for RMS Layernorm. The `_rms_layernorm_forward` kernel takes 10 arguments: output tensor Y, Y stride, input tensor X, X stride, weight tensor W, W stride, row variance tensor r, r stride, number of columns n_cols, epsilon eps, and block size BLOCK_SIZE. It performs layer normalization on the input tensor X using weights W and stores the result in Y. The `_gemma_rms_layernorm_forward` is a variant that uses a slightly different normalization approach, multiplying the output by (W + 1.0). The `Fast_RMS_Layernorm` class uses these kernels to provide a function that can be called with PyTorch tensors to compute RMS layernorm with optional GEMMA behavior.", - "description_2": "Use triton language to define two RMS layernorm forward kernels and a PyTorch autograd function to apply these kernels, providing RMS normalization functionality with optional GEMMA modifications.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel that calculates the element-wise product of e and g\n# after applying the sigmoid function to e and stores the result in h.\n@triton.jit\ndef _fg_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_row = e_row * tl.sigmoid(e_row)\n f_row = f_row.to(g_row.dtype)\n h_row = f_row * g_row\n\n tl.store(h + offsets, h_row, mask=mask)\npass\n\n# Wrapper function that sets up and calls the _fg_kernel.\ndef swiglu_fg_kernel(e, g):\n batch, seq_len, hd = e.shape\n n_elements = e.numel()\n h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE=1024)\n return h\npass\n\n# Triton kernel for computing gradients for backpropagation.\n@triton.jit\ndef _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n se_row = tl.sigmoid(e_row)\n f_row = se_row * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))\n de_row = de_row.to(DW_row.dtype)\n\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\npass\n\n# Wrapper function that sets up and calls the _DWf_DW_dfg_kernel.\ndef swiglu_DWf_DW_dfg_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\npass\n", - "description_1": "Use triton language to implement two kernels: one (_fg_kernel) that computes an element-wise product of a tensor e with its sigmoid and another tensor g, storing the result in h; and another (_DWf_DW_dfg_kernel) that computes gradients for e, g, and DW during backpropagation. The former has four parameters: e, g, h, and n_elements, with a constant BLOCK_SIZE for parallel processing. The latter has similar parameters but computes the gradient of the product with respect to its inputs.", - "description_2": "Use triton language to define kernels for computing element-wise operations and gradients involving tensor products and sigmoid activation, suitable for neural network computations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n V_TILES: tl.constexpr = 1,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n\n V_GROUP_SIZE: tl.constexpr = V_TILES * V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N, V // 64),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = (\n losses_ptr + (idx_N + idx_N_group * N_group // N_BLOCK_SIZE) * stride_loss_Nb + idx_V_group * stride_loss_B\n )\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V_TILES):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp((z_j_to_k - m_new[:, None])), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n tl.store(loss_val_ptr, loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n lse_local = -10e5 * torch.ones(N, V // 64, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 64, V // 64, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n lse_global = lse_local.logsumexp(dim=1)\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n lse_global,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if At.requires_grad:\n torch.addmm(\n At_grad,\n x_n_chunk.detach().T,\n z_grad,\n out=At_grad,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return losses.sum() + lse_global.sum() / N\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size=512):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n", - "description_1": "Use triton language to implement a cross-entropy loss calculation with integrated matrix multiplication and softmax function. This consists of two Triton kernels: 'linear_xent_fwd_prep_bwd_kernel_matmul_t' which performs the forward pass and partial backward pass, handling matrix multiplications and softmax calculations; and 'linear_xent_mini_bwd_prologue_kernel' which processes the backward pass by computing gradients for the softmax function. The main function 'LinearCrossEntropyLoss' orchestrates these kernels to compute the loss and gradients efficiently over input data tensors 'x' and 'At', and target labels 'y'.", - "description_2": "Use triton language to create efficient matrix multiplication and cross-entropy loss computation with gradient calculation using two kernels: one for forward pass and initial gradient calculation, and another for processing gradients through a softmax function.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V = tl.program_id(axis=0)\n\n N_offsets = tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n for idx_N in range(N // N_BLOCK_SIZE):\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n local_x_block_ptr = x_block_ptr\n local_A_block_ptr = A_block_ptr\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x Hc\n A_v = tl.load(local_A_block_ptr) # Hc x Vc\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x Hc) @ (Hc x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n local_A_block_ptr = tl.advance(local_A_block_ptr, [H_BLOCK_SIZE, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n local_x_block_ptr = x_block_ptr\n local_A_block_ptr = A_block_ptr\n for idx_H in range(H // H_BLOCK_SIZE):\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_chunk = tl.load(local_x_block_ptr).to(tl.float32) # Nc x Hc\n A_v = tl.load(local_A_block_ptr).to(tl.float32) # Hc x Vc\n\n temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)\n temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n temp_AgradT = temp_Agrad.trans() / N + tl.load(A_grad_block_ptr)\n tl.store(A_grad_block_ptr, temp_AgradT, boundary_check=(0, 1))\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n local_A_block_ptr = tl.advance(local_A_block_ptr, [H_BLOCK_SIZE, 0])\n N_offsets += N_BLOCK_SIZE\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = tl.arange(0, V_BLOCK_SIZE)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n for idx_V in range(V // V_BLOCK_SIZE):\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n local_x_block_ptr = x_block_ptr\n local_A_block_ptr = A_block_ptr\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x Hc\n A_v = tl.load(local_A_block_ptr) # Hc x Vc\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x Hc) @ (Hc x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n local_A_block_ptr = tl.advance(local_A_block_ptr, [H_BLOCK_SIZE, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n local_x_block_ptr = x_block_ptr\n local_A_block_ptr = A_block_ptr\n local_x_grad_block_ptr = x_grad_block_ptr\n for idx_H in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr).to(tl.float32) # Nc x Hc\n A_v = tl.load(local_A_block_ptr).to(tl.float32) # Hc x Vc\n\n temp_xgrad = tl.dot(softmax_z, A_v.trans()) / N\n temp_xgrad -= tl.sum(tl.where(mask, A_v.trans()[None, :, :], 0.0), axis=1) / N\n\n temp_xgrad += tl.load(local_x_grad_block_ptr)\n tl.store(local_x_grad_block_ptr, temp_xgrad, boundary_check=(0, 1))\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n local_x_grad_block_ptr = tl.advance(local_x_grad_block_ptr, [0, H_BLOCK_SIZE])\n local_A_block_ptr = tl.advance(local_A_block_ptr, [H_BLOCK_SIZE, 0])\n\n V_offsets += V_BLOCK_SIZE\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 256 == 0, f\"V is {V}\"\n assert N % 64 == 0, f\"N is {N}\"\n assert H % 64 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, losses, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n print(\"fwd config:\", linear_xent_fwd_kernel_matmul_t.best_config)\n\n ctx.save_for_backward(x, y, At, lse_global)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x, dtype=torch.float32)\n Atgrad = torch.zeros_like(At, dtype=torch.float32)\n\n with torch.cuda.device(x.device.index):\n grid = lambda meta: (triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),)\n linear_xent_bwd_kernel_matmul_t_dA[grid](\n x,\n y,\n At,\n lse_global,\n Atgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n print(\"bwd config dA:\", linear_xent_bwd_kernel_matmul_t_dA.best_config)\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n linear_xent_bwd_kernel_matmul_t_dx[grid](\n x,\n y,\n At,\n lse_global,\n xgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n print(\"bwd config dx:\", linear_xent_bwd_kernel_matmul_t_dx.best_config)\n\n ctx.mark_non_differentiable(y)\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n\nif __name__ == \"__main__\":\n f = 8\n V, N, H = 131072 // 8, 4096 * 4 // 8, 4096 // 8\n\n compute_dtype = torch.float16\n\n y = torch.randint(0, V, (N,), device=device)\n A = torch.randn(V, H, requires_grad=True, device=device, dtype=compute_dtype)\n At = A.clone().detach().T.contiguous()\n At.requires_grad_()\n\n x = 0.01 * A[y].clone().detach() + torch.randn(N, H, device=device, dtype=compute_dtype)\n x.requires_grad_()\n\n simple_bench(lambda: linear_cross_entropy(x, y, At), reference_loss, reference_x_grad, reference_A_grad)\n", - "description_1": "Use triton language to create a linear cross entropy kernel with forward and backward propagation. It takes parameters for input tensors x, y, and At, strides for accessing these tensors, and constants for matrix dimensions and block sizes.", - "description_2": "Implement a triton-based linear cross entropy loss function with both forward and backward operations for optimization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(0 * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n for idx_N in range(N // N_BLOCK_SIZE):\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n for idx_H in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)\n temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n temp_AgradT = temp_Agrad.trans() / N\n tl.store(A_grad_block_ptr, temp_AgradT.to(tl.float16) + tl.load(A_grad_block_ptr))\n\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, -H])\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [-H, 0])\n N_offsets += N_BLOCK_SIZE\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = tl.arange(0, V_BLOCK_SIZE)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0 * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n for idx_V in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for idx_H_1 in range(H // H_BLOCK_SIZE):\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H_1 * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H_1 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n for idx_H in range(H // H_BLOCK_SIZE):\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_v = tl.load(A_block_ptr).trans()\n temp_xgrad = tl.dot(softmax_z, A_v) / N\n temp_xgrad -= tl.sum(tl.where(mask, A_v[None, :, :], 0.0), axis=1) / N\n tl.store(x_grad_block_ptr, tl.load(x_grad_block_ptr) + temp_xgrad.to(tl.float16))\n\n V_offsets += V_BLOCK_SIZE\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100, # ignores all negative integers ...\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 256 == 0, f\"V is {V}\"\n assert N % 64 == 0, f\"N is {N}\"\n assert H % 64 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index): # actually required\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, losses, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n ctx.save_for_backward(x, y, At, lse_global)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x)\n Atgrad = torch.zeros_like(At)\n\n with torch.cuda.device(x.device.index): # actually required\n grid = lambda meta: (triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),)\n linear_xent_bwd_kernel_matmul_t_dA[grid](\n x,\n y,\n At,\n lse_global,\n Atgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n linear_xent_bwd_kernel_matmul_t_dx[grid](\n x,\n y,\n At,\n lse_global,\n xgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n\n ctx.mark_non_differentiable(y)\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement three kernels for forward and backward passes of linear cross-entropy loss computation. The forward kernel takes 13 parameters: 3 pointers for input tensors, 2 pointers for results, 4 integer strides, and 4 integer dimensions/constants. The backward kernels each take 13 parameters: 3 pointers for input tensors, a pointer for intermediate results, a pointer for gradients, 4 integer strides, and 4 integer dimensions/constants. The forward kernel computes the loss and log-sum-exp (lse) values, while the backward kernels compute gradients with respect to the weight matrix (A) and input (x). A PyTorch autograd function utilizes these kernels to compute forward and backward passes.", - "description_2": "Use triton language to implement a linear cross-entropy loss computation with forward and backward passes using three kernels.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(0 * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n for idx_N in range(N // N_BLOCK_SIZE):\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n for idx_H in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)\n temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n temp_AgradT = temp_Agrad.trans() / N\n tl.store(A_grad_block_ptr, temp_AgradT.to(tl.float16) + tl.load(A_grad_block_ptr))\n\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, -H])\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [-H, 0])\n N_offsets += N_BLOCK_SIZE\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n H_GROUP_SIZE: tl.constexpr = 4,\n):\n idx_N = tl.program_id(axis=0)\n idx_H_group = tl.program_id(axis=1)\n\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n H_GROUPS: tl.constexpr = H // (H_GROUP_SIZE * H_BLOCK_SIZE)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = tl.arange(0, V_BLOCK_SIZE)\n H_group_offsets = tl.arange(0, H_GROUP_SIZE)\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE, H_GROUP_SIZE), dtype=tl.float16)\n\n for idx_V in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for idx_H_1 in range(H // H_BLOCK_SIZE):\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H_1 * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H_1 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=((H_GROUP_SIZE * H_BLOCK_SIZE) * idx_H_group, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n for idx_H_in_group in range(H_GROUP_SIZE):\n A_v = tl.load(A_block_ptr).trans()\n x_grad_block = tl.dot(softmax_z, A_v) / N\n x_grad_block -= tl.sum(tl.where(mask, A_v[None, :, :], 0), axis=1) / N\n x_grad_slice = x_grad_block[:, :, None].to(tl.float16)\n\n accum_mask = (idx_H_in_group == H_group_offsets)[None, None, :]\n x_grad_acc += tl.where(accum_mask, x_grad_slice, 0)\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n V_offsets += V_BLOCK_SIZE\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H_group * H_GROUP_SIZE * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_GROUP_SIZE * H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, x_grad_acc.reshape(N_BLOCK_SIZE, H_GROUP_SIZE * H_BLOCK_SIZE))\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 256 == 0, f\"V is {V}\"\n assert N % 64 == 0, f\"N is {N}\"\n assert H % 64 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, losses, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n print(\"fwd config:\", linear_xent_fwd_kernel_matmul_t.best_config)\n\n ctx.save_for_backward(x, y, At, lse_global)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x)\n Atgrad = torch.zeros_like(At)\n\n with torch.cuda.device(x.device.index):\n grid = lambda meta: (\n triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(H, meta[\"H_GROUP_SIZE\"] * meta[\"H_BLOCK_SIZE\"]),\n )\n linear_xent_bwd_kernel_matmul_t_dx[grid](\n x,\n y,\n At,\n lse_global,\n xgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n print(\"bwd config dx:\", linear_xent_bwd_kernel_matmul_t_dx.best_config)\n\n ctx.mark_non_differentiable(y)\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with forward and backward passes. The forward kernel 'linear_xent_fwd_kernel_matmul_t' computes the loss and log-sum-exp values for given inputs and weights. The backward kernels 'linear_xent_bwd_kernel_matmul_t_dA' and 'linear_xent_bwd_kernel_matmul_t_dx' compute the gradients with respect to the weights and inputs, respectively. The function 'LinearCrossEntropyLoss' manages the forward and backward operations, ensuring the tensors are contiguous and compatible with the kernel requirements.", - "description_2": "Use triton language to create a linear cross-entropy loss function with forward and backward kernels for efficient GPU computation.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\", \"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n A_grad_ptr,\n locks_N_ptr,\n locks_V_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n locks_N_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n )\n linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n locks_V_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n )\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n locks_N_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V, idx_N = tl.program_id(axis=0), tl.program_id(axis=1)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n for idx_H in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)\n temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n temp_AgradT = (temp_Agrad.trans() / N).to(tl.float16)\n while tl.atomic_cas(locks_N_ptr + idx_V, 0, 1) == 1:\n pass\n tl.store(A_grad_block_ptr, temp_AgradT + tl.load(A_grad_block_ptr))\n tl.atomic_xchg(locks_N_ptr + idx_V, 0)\n\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n locks_V_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V, idx_N = tl.program_id(axis=0), tl.program_id(axis=1)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for idx_H_1 in range(H // H_BLOCK_SIZE):\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H_1 * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H_1 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n for idx_H in range(H // H_BLOCK_SIZE):\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_v = tl.load(A_block_ptr).trans()\n temp_xgrad = tl.dot(softmax_z, A_v) / N\n temp_xgrad -= tl.sum(tl.where(mask, A_v[None, :, :], 0.0), axis=1) / N\n while tl.atomic_cas(locks_V_ptr + idx_N, 0, 1) == 1:\n pass\n tl.store(x_grad_block_ptr, tl.load(x_grad_block_ptr) + temp_xgrad.to(tl.float16))\n tl.atomic_xchg(locks_V_ptr + idx_N, 0)\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, losses, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n ctx.save_for_backward(x, y, At, lse_global)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x)\n Atgrad = torch.zeros_like(At)\n locks_N = torch.zeros(N // 16, dtype=torch.int32, device=x.device)\n locks_V = torch.zeros(V // 16, dtype=torch.int32, device=x.device)\n\n with torch.cuda.device(x.device.index):\n grid = lambda meta: (triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]), triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]))\n linear_xent_bwd_kernel_matmul_t_dA[grid](\n x,\n y,\n At,\n lse_global,\n Atgrad,\n locks_N,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n grid = lambda meta: (triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]), triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]))\n linear_xent_bwd_kernel_matmul_t_dx[grid](\n x,\n y,\n At,\n lse_global,\n xgrad,\n locks_V,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n\n ctx.mark_non_differentiable(y)\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n\n\nif __name__ == \"__main__\":\n f = 4\n V, N, H = 131072 // f, 4096 * 4 // f, 4096 // f\n\n compute_dtype = torch.float16\n\n y = torch.randint(0, V, (N,), device=torch.device(\"cuda:0\"))\n A = torch.randn(V, H, requires_grad=True, device=torch.device(\"cuda:0\"), dtype=compute_dtype)\n At = A.clone().detach().T.contiguous()\n At.requires_grad_()\n\n x = 0.01 * A[y].clone().detach() + torch.randn(N, H, device=torch.device(\"cuda:0\"), dtype=compute_dtype)\n x.requires_grad_()\n\n def baseline_torch(x, y, A):\n V = A.shape[0]\n return torch.nn.functional.cross_entropy(torch.nn.functional.linear(x, A).view(-1, V).float(), y.view(-1))\n\n loss = baseline_torch(x.float(), y, A.float())\n loss.backward()\n\n reference_A_grad = A.grad.float().clone()\n reference_x_grad = x.grad.float().clone()\n reference_loss = loss.detach().float().clone()\n\n def simple_bench(fn, reference_loss, reference_x_grad, reference_A_grad):\n torch.cuda.synchronize()\n start_event = torch.cuda.Event(enable_timing=True)\n end_event = torch.cuda.Event(enable_timing=True)\n start_event.record()\n loss_triton = fn()\n loss_triton.backward()\n end_event.record()\n torch.cuda.synchronize()\n estimate_ms_bwd = start_event.elapsed_time(end_event)\n print(f\"fwd-bwd : {estimate_ms_bwd}ms\")\n print(f\"fwd error: {torch.dist(loss_triton, reference_loss).item()}\")\n if At.grad is not None:\n A_error = torch.dist(reference_A_grad.T, At.grad).item()\n else:\n A_error = torch.dist(reference_A_grad, A.grad).item()\n print(f\"bwd error: {torch.dist(reference_x_grad, x.grad).item()}, {A_error}\")\n\n simple_bench(lambda: linear_cross_entropy(x, y, At), reference_loss, reference_x_grad, reference_A_grad)\n", - "description_1": "Use triton language to implement linear cross-entropy loss with two main kernels: one for forward pass and another for backward pass. The forward kernel computes the softmax and cross-entropy loss using block pointers for efficient memory access, optimizing matrix multiplication. It takes 15 parameters including pointers to data and constant expressions like block sizes. The backward kernel is split into two sub-kernels: one for calculating the gradient of matrix A (dA) and another for calculating the gradient with respect to x (dx). It similarly uses block pointers and takes 17 parameters, including pointers and constant block sizes.", - "description_2": "Use triton language to compute forward and backward passes for cross-entropy loss in a neural network with multiple configurations for autotuning block sizes. Utilize block pointers for efficient memory access in kernel computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V = tl.program_id(axis=0)\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n\n N_offsets = tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0 * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(0 * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n for idx_N in range(N // N_BLOCK_SIZE):\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x Hc\n A_v = tl.load(A_block_ptr) # Hc x Vc\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x Hc) @ (Hc x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n for idx_H in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x Hc\n temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)\n temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n temp_AgradT = temp_Agrad.trans() / N\n tl.store(A_grad_block_ptr, temp_AgradT.to(tl.float16) + tl.load(A_grad_block_ptr))\n\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [H_BLOCK_SIZE, 0])\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, -H])\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [-H, 0])\n N_offsets += N_BLOCK_SIZE\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 4}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 1,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n x_grad_slice = tl.zeros((N_BLOCK_SIZE, H), tl.float16)\n\n for idx_V in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H),\n order=(1, 0),\n )\n A_slice_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H, V_BLOCK_SIZE),\n order=(1, 0),\n )\n x_chunk = tl.load(x_block_ptr) # Nc x Hc\n A_v_full = tl.load(A_slice_ptr) # Hc x Vc\n\n z_j_to_k = tl.sum(x_chunk[:, :, None] * A_v_full[None, :, :], axis=1).to(tl.float32)\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n temp_xgrad = tl.sum(softmax_z[:, :, None] * A_v_full.trans()[None, :, :], axis=1) / N\n temp_xgrad -= tl.sum(tl.where(mask, A_v_full.trans()[None, :, :], 0.0), axis=1) / N\n temp_xgrad = temp_xgrad.to(tl.float16)\n x_grad_slice += temp_xgrad\n V_offsets += V_BLOCK_SIZE\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, x_grad_slice)\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 256 == 0, f\"V is {V}\"\n assert N % 64 == 0, f\"N is {N}\"\n assert H % 64 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, losses, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n print(\"fwd config:\", linear_xent_fwd_kernel_matmul_t.best_config)\n\n ctx.save_for_backward(x, y, At, lse_global)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x)\n Atgrad = torch.zeros_like(At)\n\n with torch.cuda.device(x.device.index):\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n linear_xent_bwd_kernel_matmul_t_dx[grid](\n x,\n y,\n At,\n lse_global,\n xgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n print(\"bwd config dx:\", linear_xent_bwd_kernel_matmul_t_dx.best_config)\n\n ctx.mark_non_differentiable(y)\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with forward and backward passes. The forward kernel 'linear_xent_fwd_kernel_matmul_t' computes the loss and log-sum-exp values for given inputs x, y, and transposed matrix At. The backward kernels 'linear_xent_bwd_kernel_matmul_t_dA' and 'linear_xent_bwd_kernel_matmul_t_dx' compute the gradients with respect to At and x, respectively. The function 'linear_cross_entropy' serves as a wrapper for the autograd function 'LinearCrossEntropyLoss', which manages the forward and backward computations.", - "description_2": "Use triton language to create a linear cross-entropy loss function with kernels for forward and backward computations, handling inputs x, y, and transposed matrix At, and computing necessary gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 2}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 4}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 8}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=32),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=5\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=6\n ),\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n GROUP_SIZE: tl.constexpr = 32,\n):\n # Function body...\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=3\n ),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n GROUP_SIZE: tl.constexpr = 1,\n):\n # Function body...\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 64}, num_warps=8),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 64}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 64}, num_warps=8, num_stages=3\n ),\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n GROUP_SIZE: tl.constexpr = 16,\n):\n # Function body...\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n # Function body...\n\n @staticmethod\n @torch.inference_mode()\n def backward(ctx, grad_output):\n # Function body...\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size: int = 4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with both forward and backward passes, utilizing multiple Triton kernels for efficient matrix multiplications and gradient calculations in blocks.", - "description_2": "Use triton language to create kernels for forward and backward computations of linear cross-entropy loss, supporting auto-tuning for different block sizes and configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport math\n\n@triton.autotune(\n configs=fwd_configs,\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n m_ptr,\n logit_norm_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n stride_norm_N,\n stride_norm_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n ignore_index: tl.constexpr,\n logit_scale: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE) # type:ignore\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE, GROUP_SIZE)\n\n R = tl.load(reduction_ptr, eviction_policy=\"evict_last\")\n V_GROUP_SIZE: tl.constexpr = V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0), # (0, 1) apparently not faster :<\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n z_j_to_k = z_j_to_k / logit_scale\n if monitoring:\n logit_pow2 = tl.sum(z_j_to_k * z_j_to_k, axis=1)\n norm_val_ptr = logit_norm_ptr + idx_V_group * stride_norm_V + idx_N * stride_norm_N + tl.arange(0, N_BLOCK_SIZE)\n tl.store(norm_val_ptr, logit_pow2 / N)\n m = tl.max(z_j_to_k, 1)\n s = tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n mask = y[:, None] == tl.where(V_range != ignore_index, V_range, -1)[None, :] # Nc x Vc\n loss = -tl.sum(tl.where(mask, z_j_to_k, 0.0)) / R\n\n # save z for later\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty)) # can move +log(1/N) here\n\n zero_lse_constant: tl.constexpr = tl.log(1 / tl.cdiv(V, V_BLOCK_SIZE)) # type: ignore\n lse = tl.where(y != ignore_index, m + tl.log(s), zero_lse_constant)\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128), # fixed to largest number of possible V blocks\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n tl.store(lse_row_ptr, lse[:, None])\n\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n # loss += tl.sum(lse) / N # defered until all blocks are done\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n\n if monitoring:\n m_val_ptr = m_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(m_val_ptr, tl.maximum(tl.load(m_val_ptr), tl.max(m, 0)))\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n reduction_ptr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0) // SPLIT_V\n idx_H = tl.program_id(axis=1)\n idx_V_tile = tl.program_id(axis=0) % SPLIT_V\n\n num_idx_N = tl.num_programs(0) - (triton.cdiv(V, V_BLOCK_SIZE) * SPLIT_N) # type: ignore\n num_idx_H = tl.num_programs(1)\n idx_N, idx_H = tl.swizzle2d(idx_N, idx_H, num_idx_N // SPLIT_V, num_idx_H, GROUP_SIZE) # type:ignore\n\n V_split_offset = idx_V_tile * tl.cdiv(V, SPLIT_V)\n\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, V_split_offset),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, V_split_offset),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = V_split_offset + tl.arange(0, V_BLOCK_SIZE)\n R = tl.load(reduction_ptr, eviction_policy=\"evict_last\")\n\n y = tl.load(y_ptr + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE), eviction_policy=\"evict_last\")\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else x_grad_ptr.type.element_ty\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(V, V_BLOCK_SIZE * SPLIT_V)):\n mask = y[:, None] == v_range[None, :]\n A_v = tl.load(A_t_block_ptr, eviction_policy=\"evict_first\") # Hc x Vc\n z_j_to_k = tl.load(z_block_ptr)\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n if z_regularization > 0:\n softmax_z += 2.0 * z_regularization * lse[:, None] * softmax_z\n z_grad = softmax_z - tl.where(mask, 1.0, 0.0) # 1/N, 0 if log(1/N) moved\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(A_v.type.element_ty)\n\n # xgrad\n x_grad_acc = tl.dot(valid_z_grad, A_v.trans(), x_grad_acc, out_dtype=acc_dtype)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range += V_BLOCK_SIZE\n\n if SPLIT_V == 1:\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, (x_grad_acc / R / logit_scale).to(x_grad_ptr.type.element_ty))\n # not divided here if 1/N moved\n else:\n row_n = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n x_grad_simple_ptr = x_grad_ptr + row_n[:, None] * stride_x_N + row_h[None, :] * stride_x_H\n tl.atomic_add(x_grad_simple_ptr, (x_grad_acc / R / logit_scale).to(x_grad_ptr.type.element_ty))\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n entropy_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_V = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) // SPLIT_N\n idx_H = tl.program_id(axis=1)\n idx_N_tile = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) % SPLIT_N\n\n num_idx_V, num_idx_H = tl.num_programs(0) - (N_group // N_BLOCK_SIZE * SPLIT_V), tl.num_programs(1)\n idx_V, idx_H = tl.swizzle2d(idx_V, idx_H, num_idx_V // SPLIT_N, num_idx_H, GROUP_SIZE) # type:ignore\n\n N_split_offset = idx_N_tile * tl.cdiv(N_group, SPLIT_N)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + N_split_offset, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(N_split_offset, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = N_split_offset + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n R = tl.load(reduction_ptr, eviction_policy=\"evict_last\")\n logit_entropy = 0.0\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else A_grad_ptr.type.element_ty\n A_grad_acc = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(N_group, N_BLOCK_SIZE * SPLIT_N)):\n y = tl.load(y_ptr + idx_N_group * N_group + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + N_range, eviction_policy=\"evict_last\")\n mask = y[:, None] == V_range[None, :]\n\n x_chunk = tl.load(x_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr)\n logprobs = z_j_to_k - lse[:, None]\n softmax_z = logprobs.exp()\n if monitoring:\n logit_entropy += tl.sum(tl.where(y == ignore_index, 0.0, tl.sum(-softmax_z * logprobs, axis=1)))\n if z_regularization > 0:\n softmax_z += 2.0 * z_regularization * lse[:, None] * softmax_z\n z_grad = softmax_z - tl.where(mask, 1.0, 0.0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(x_ptr.type.element_ty)\n\n A_grad_acc = tl.dot(x_chunk.trans(), valid_z_grad, A_grad_acc, out_dtype=acc_dtype)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n\n entropy_val_ptr = entropy_ptr + idx_H * stride_ent_H + idx_V * stride_ent_V\n if SPLIT_N == 1:\n A_grad_T_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n if idx_N_group > 0:\n tl.store(\n A_grad_T_block_ptr,\n tl.load(A_grad_T_block_ptr) + (A_grad_acc / R / logit_scale).to(A_grad_ptr.type.element_ty),\n )\n tl.store(entropy_val_ptr, tl.load(entropy_val_ptr) + logit_entropy / R)\n else:\n tl.store(A_grad_T_block_ptr, (A_grad_acc / R / logit_scale).to(A_grad_ptr.type.element_ty))\n if monitoring:\n tl.store(entropy_val_ptr, logit_entropy / R)\n else:\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n row_v = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n A_grad_T_simple_ptr = A_grad_ptr + row_h[:, None] * stride_A_H + row_v[None, :] * stride_A_V\n tl.atomic_add(A_grad_T_simple_ptr, (A_grad_acc / R / logit_scale).to(A_grad_ptr.type.element_ty))\n if monitoring:\n tl.atomic_add(entropy_val_ptr, logit_entropy / R)\n\n\n@triton.autotune(\n configs=bwd_configs,\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr,\n y_ptr,\n x_ptr,\n A_t_ptr,\n x_grad,\n At_grad,\n lse_global,\n logit_entropy_local,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 128, # type: ignore\n N_BLOCK_SIZE: tl.constexpr = 128, # type: ignore\n H_BLOCK_SIZE: tl.constexpr = 128, # type: ignore\n GROUP_SIZE: tl.constexpr = 32, # type: ignore\n SPLIT_N: tl.constexpr = 2, # type: ignore\n SPLIT_V: tl.constexpr = 2, # type: ignore\n):\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE, GROUP_SIZE, SPLIT_N, SPLIT_V)\n\n idx_NV = tl.program_id(axis=0)\n if idx_NV < (N_group // N_BLOCK_SIZE * SPLIT_V):\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits_ptr,\n y_ptr,\n A_t_ptr,\n x_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n reduction_ptr,\n logit_scale,\n z_regularization,\n fp32_grad_accumulators,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits_ptr,\n y_ptr,\n x_ptr,\n At_grad,\n lse_global,\n logit_entropy_local,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring,\n logit_scale,\n z_regularization,\n fp32_grad_accumulators,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n", - "description_1": "Use triton language to implement a fused forward and backward pass for linear cross-entropy. It involves two primary kernels: linear_xent_fwd_prep_bwd_kernel_matmul_t for forward computation and gradient preparation, and linear_xent_bwd_dispatcher to handle the gradient computation of inputs and weights.", - "description_2": "Use triton language to create optimized kernels for efficient computation of linear cross-entropy loss with integrated forward and backward passes, leveraging auto-tuning for performance enhancement.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=fwd_configs,\n key=[\"V\", \"N\", \"H\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n warmup=100,\n rep=500,\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE)\n\n V_GROUP_SIZE: tl.constexpr = V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option=\"zero\")\n A_v = tl.load(A_block_ptr, boundary_check=(0, 1), padding_option=\"zero\")\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range, mask=N_range < N, other=ignore_index)\n\n reduction = tl.load(reduction_ptr)\n mask = y[:, None] == tl.where(V_range != ignore_index, V_range, -1)[None, :]\n loss = -tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / reduction\n\n tl.store(z_block_ptr, (z_j_to_k + tl.log(1 / reduction)).to(z_nv_ptr.type.element_ty), boundary_check=(0, 1))\n\n m = tl.max(z_j_to_k, 1)\n zero_lse_constant: tl.constexpr = tl.log(1 / tl.cdiv(V, V_BLOCK_SIZE))\n lse = tl.where(y != ignore_index, tl.log(tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)) + m, zero_lse_constant)\n\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n tl.store(lse_row_ptr, lse[:, None], boundary_check=(0,))\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0) // SPLIT_V\n idx_H = tl.program_id(axis=1)\n idx_V_tile = tl.program_id(axis=0) % SPLIT_V\n\n num_idx_N, num_idx_H = tl.num_programs(0) - (triton.cdiv(V, V_BLOCK_SIZE) * SPLIT_N), tl.num_programs(1)\n idx_N, idx_H = tl.swizzle2d(idx_N, idx_H, num_idx_N // SPLIT_V, num_idx_H, GROUP_SIZE)\n\n V_split_offset = idx_V_tile * tl.cdiv(V, SPLIT_V)\n\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, V_split_offset),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, V_split_offset),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = V_split_offset + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE), eviction_policy=\"evict_last\")\n reduction = tl.load(reduction_ptr)\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else x_grad_ptr.type.element_ty\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(V, V_BLOCK_SIZE * SPLIT_V)):\n mask = y[:, None] == V_range[None, :]\n A_v = tl.load(A_t_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n if z_regularization > 0:\n softmax_z += 2.0 * z_regularization * lse[:, None] * softmax_z\n z_grad = softmax_z - tl.where(mask, 1 / reduction, 0.0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(A_v.type.element_ty)\n\n x_grad_acc = tl.dot(valid_z_grad, A_v.trans(), x_grad_acc, out_dtype=acc_dtype)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range += V_BLOCK_SIZE\n\n if SPLIT_V == 1:\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, x_grad_acc.to(x_grad_ptr.type.element_ty))\n else:\n row_n = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n x_grad_simple_ptr = x_grad_ptr + row_n[:, None] * stride_x_N + row_h[None, :] * stride_x_H\n tl.atomic_add(x_grad_simple_ptr, x_grad_acc.to(x_grad_ptr.type.element_ty))\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_V = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) // SPLIT_N\n idx_H = tl.program_id(axis=1)\n idx_N_tile = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) % SPLIT_N\n\n num_idx_V, num_idx_H = tl.num_programs(0) - (N_group // N_BLOCK_SIZE * SPLIT_V), tl.num_programs(1)\n idx_V, idx_H = tl.swizzle2d(idx_V, idx_H, num_idx_V // SPLIT_N, num_idx_H, GROUP_SIZE)\n\n N_split_offset = idx_N_tile * tl.cdiv(N_group, SPLIT_N)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + N_split_offset, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(N_split_offset, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = N_split_offset + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n reduction = tl.load(reduction_ptr)\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else A_grad_ptr.type.element_ty\n A_grad_acc = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(N_group, N_BLOCK_SIZE * SPLIT_N)):\n y = tl.load(y_ptr + idx_N_group * N_group + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + N_range, eviction_policy=\"evict_last\")\n mask = y[:, None] == V_range[None, :]\n\n x_chunk = tl.load(x_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n if z_regularization > 0:\n softmax_z += 2.0 * z_regularization * lse[:, None] * softmax_z\n z_grad = softmax_z - tl.where(mask, 1 / reduction, 0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(x_ptr.type.element_ty)\n\n A_grad_acc = tl.dot(x_chunk.trans(), valid_z_grad, A_grad_acc, out_dtype=acc_dtype)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n\n if SPLIT_N == 1:\n A_grad_T_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n if idx_N_group > 0:\n tl.store(\n A_grad_T_block_ptr,\n tl.load(A_grad_T_block_ptr) + A_grad_acc.to(A_grad_ptr.type.element_ty),\n )\n else:\n tl.store(A_grad_T_block_ptr, A_grad_acc.to(A_grad_ptr.type.element_ty))\n else:\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n row_v = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n A_grad_T_simple_ptr = A_grad_ptr + row_h[:, None] * stride_A_H + row_v[None, :] * stride_A_V\n tl.atomic_add(A_grad_T_simple_ptr, A_grad_acc.to(A_grad_ptr.type.element_ty))\n\n@triton.autotune(\n configs=bwd_configs,\n key=[\"V\", \"N\", \"H\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n warmup=100,\n rep=500,\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr,\n y_ptr,\n x_ptr,\n A_t_ptr,\n x_grad,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 128,\n N_BLOCK_SIZE: tl.constexpr = 128,\n H_BLOCK_SIZE: tl.constexpr = 128,\n GROUP_SIZE: tl.constexpr = 32,\n SPLIT_N: tl.constexpr = 2,\n SPLIT_V: tl.constexpr = 2,\n):\n idx_NV = tl.program_id(axis=0)\n if idx_NV < (N_group // N_BLOCK_SIZE * SPLIT_V):\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits_ptr,\n y_ptr,\n A_t_ptr,\n x_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization,\n fp32_grad_accumulators,\n reduction_ptr,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits_ptr,\n y_ptr,\n x_ptr,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization,\n fp32_grad_accumulators,\n reduction_ptr,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n", - "description_1": "Use triton language to implement a series of kernels and dispatcher for computing forward and backward passes of a linear cross-entropy operation with matrix multiplication and epilogues. The forward kernel `linear_xent_fwd_kernel_matmul_t` takes 27 parameters including pointers to input and output tensors, stride values, reduction pointer, and block size parameters for efficient computation. The backward pass involves two kernels `linear_xent_bwd_kernel_matmul_t_epilogue_dx` and `linear_xent_bwd_kernel_matmul_t_epilogue_dA`, each requiring 28 parameters to compute gradients with respect to input and weights respectively. The dispatcher function `linear_xent_bwd_dispatcher` oversees the execution of backward kernels based on program ids and splits.", - "description_2": "Use triton language to develop optimized kernels and a dispatcher for linear cross-entropy with matrix multiplication, which support both forward and backward passes. These kernels and the dispatcher handle tensor pointers, strides, and block sizes for efficient parallel computation on GPUs.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=16, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n ],\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr, y_ptr, A_t_ptr, z_nv_ptr, losses_ptr, lse_ptr, m_ptr, logit_norm_ptr,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V, stride_z_N, stride_z_V,\n stride_lse_N, stride_lse_B, stride_loss_Nb, stride_loss_B, stride_norm_N,\n stride_norm_V, reduction_ptr, monitoring: tl.constexpr, ignore_index: tl.constexpr,\n logit_scale: tl.constexpr, idx_N_group, N_group: tl.constexpr, V: tl.constexpr, N: tl.constexpr,\n H: tl.constexpr, V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n pass # Kernel code removed for brevity\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr, y_ptr, A_t_ptr, x_grad_ptr, lse_ptr, stride_x_N, stride_x_H, stride_A_H, stride_A_V,\n stride_z_N, stride_z_V, reduction_ptr, logit_scale: tl.constexpr, z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr, ignore_index: tl.constexpr, idx_N_group,\n N_group: tl.constexpr, V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr, SPLIT_N: tl.constexpr, SPLIT_V: tl.constexpr,\n):\n pass # Kernel code removed for brevity\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr, y_ptr, x_ptr, A_grad_ptr, lse_ptr, entropy_ptr, stride_x_N, stride_x_H,\n stride_A_H, stride_A_V, stride_z_N, stride_z_V, stride_ent_H, stride_ent_V,\n reduction_ptr, monitoring: tl.constexpr, logit_scale: tl.constexpr,\n z_regularization: tl.constexpr, fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr, idx_N_group, N_group: tl.constexpr, V: tl.constexpr,\n N: tl.constexpr, H: tl.constexpr, V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr, GROUP_SIZE: tl.constexpr, SPLIT_N: tl.constexpr, SPLIT_V: tl.constexpr,\n):\n pass # Kernel code removed for brevity\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8, num_stages=1,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8, num_stages=3,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 2},\n num_warps=8, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=16, num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 8},\n num_warps=16,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 8},\n num_warps=8,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8, num_stages=2,\n ),\n ],\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr, y_ptr, x_ptr, A_t_ptr, x_grad, At_grad, lse_global, logit_entropy_local,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V, stride_z_N, stride_z_V,\n stride_ent_H, stride_ent_V, reduction_ptr, monitoring: tl.constexpr, logit_scale: tl.constexpr,\n z_regularization: tl.constexpr, fp32_grad_accumulators: tl.constexpr, ignore_index: tl.constexpr,\n idx_N_group, N_group: tl.constexpr, V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 128, # type: ignore\n N_BLOCK_SIZE: tl.constexpr = 128, # type: ignore\n H_BLOCK_SIZE: tl.constexpr = 128, # type: ignore\n GROUP_SIZE: tl.constexpr = 32, # type: ignore\n SPLIT_N: tl.constexpr = 2, # type: ignore\n SPLIT_V: tl.constexpr = 2, # type: ignore\n):\n pass # Kernel code removed for brevity\n", - "description_1": "Use triton language to implement several kernels for a linear cross-entropy operation. The main kernel `linear_xent_fwd_prep_bwd_kernel_matmul_t` computes forward preparation for backpropagation in a linear operation involving matrix multiplication and cross-entropy loss. It uses inputs like feature pointers, logit pointers, and several stride values. Additional kernels like `linear_xent_bwd_kernel_matmul_t_epilogue_dx` and `linear_xent_bwd_kernel_matmul_t_epilogue_dA` compute gradient propagation for the inputs and weights. These kernels work in tandem with the dispatcher `linear_xent_bwd_dispatcher` to correctly route the gradient calculations based on program IDs and other configurations. The dispatcher also helps in splitting the tasks among the triton kernel grid.", - "description_2": "Use triton language to build a linear cross-entropy forward and backward propagation system with a focus on optimizing for memory access patterns and computational efficiency using block-based matrix operations. This involves configuring kernel execution parameters to utilize multiple stages and warps within triton's CUDA-like environment.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 2}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 4}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 8}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=32),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=5\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=6\n ),\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE) # type:ignore\n\n V_GROUP_SIZE: tl.constexpr = V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m = tl.max(z_j_to_k, 1)\n s = tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n mask = y[:, None] == V_range[None, :] # Nc x Vc\n loss = -tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n # save z for later\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty)) # can move +log(1/N) here\n\n lse = m + tl.log(s)\n\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128), # fixed to worst case number assuming max(V_TILES)\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n # loss += tl.sum(lse) / N # defered until all blocks are done\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_H = tl.program_id(axis=1)\n idx_V = 0\n\n num_idx_N, num_idx_H = tl.num_programs(0) - (V // V_BLOCK_SIZE), tl.num_programs(1)\n idx_N, idx_H = tl.swizzle2d(idx_N, idx_H, num_idx_N, num_idx_H, GROUP_SIZE) # type:ignore\n\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = 0 + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE), eviction_policy=\"evict_last\")\n\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), x_grad_ptr.type.element_ty)\n for _ in range(V // V_BLOCK_SIZE):\n mask = y[:, None] == v_range[None, :]\n A_v = tl.load(A_t_block_ptr, eviction_policy=\"evict_first\") # Hc x Vc\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)).to(A_t_ptr.type.element_ty) # 1/N, 0 if log(1/N) moved\n\n # xgrad\n x_grad_acc = tl.dot(z_grad, A_v.trans(), x_grad_acc, out_dtype=x_grad_ptr.type.element_ty)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range += V_BLOCK_SIZE\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, (x_grad_acc / N).to(x_grad_ptr.type.element_ty)) # not divided here if 1/N moved\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_V = tl.program_id(axis=0) - N_group // N_BLOCK_SIZE\n idx_H = tl.program_id(axis=1)\n\n num_idx_V, num_idx_H = tl.num_programs(0) - (N_group // N_BLOCK_SIZE), tl.num_programs(1)\n idx_V, idx_H = tl.swizzle2d(idx_V, idx_H, num_idx_V, num_idx_H, GROUP_SIZE) # type:ignore\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_grad_acc = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), A_grad_ptr.type.element_ty)\n for _ in range(N_group // N_BLOCK_SIZE):\n y = tl.load(y_ptr + idx_N_group * N_group + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + N_range, eviction_policy=\"evict_last\")\n mask = y[:, None] == V_range[None, :]\n\n x_chunk = tl.load(x_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)).to(x_ptr.type.element_ty)\n\n A_grad_acc = tl.dot(x_chunk.trans(), z_grad, A_grad_acc, out_dtype=A_grad_ptr.type.element_ty)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n\n A_grad_T_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n if idx_N_group > 0:\n tl.store(A_grad_T_block_ptr, tl.load(A_grad_T_block_ptr) + (A_grad_acc / N).to(A_grad_ptr.type.element_ty))\n else:\n tl.store(A_grad_T_block_ptr, (A_grad_acc / N).to(A_grad_ptr.type.element_ty))\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8),\n # Configurations with V_BLOCK_SIZE = 128, GROUP_SIZE = 32\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=4\n ),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32},\n num_warps=16,\n num_stages=3,\n ),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32}, num_warps=4),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3\n ),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32},\n num_warps=16,\n num_stages=3,\n ),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3\n ),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=4\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32},\n num_warps=16,\n num_stages=3,\n ),\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits,\n y,\n x,\n At,\n x_grad,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_NV = tl.program_id(axis=0)\n if idx_NV < N_group // N_BLOCK_SIZE:\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits,\n y,\n At,\n x_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits,\n y,\n x,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n )\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n N_group = min(N, N_chunk_size)\n\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n\n At_grad = torch.zeros_like(At)\n x_grad = torch.empty_like(x)\n\n lse_sum = 0.0\n lse_local = -10e5 * torch.ones(N_group, V // 128, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N_group // 64, V // 128, dtype=torch.float32, device=x.device)\n logits = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n with torch.inference_mode():\n\n fwd_grid = lambda meta: (triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]), triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]))\n bwd_grid_dx_dA = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]) + triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]),\n )\n\n for idx_N_group in range(math.ceil(N / N_group)):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n logits,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n lse_global = lse_local.logsumexp(dim=1)\n lse_sum += lse_global.sum() / N\n\n if x.requires_grad or At.requires_grad:\n linear_xent_bwd_dispatcher[bwd_grid_dx_dA](\n logits,\n y,\n x,\n At,\n x_grad,\n At_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return lse_sum + losses.sum()\n\n @staticmethod\n @torch.inference_mode()\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size: int = 2048):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n\n", - "description_1": "Use triton language to implement a linear cross entropy loss function, including forward and backward passes. The forward pass computes the loss using a matrix multiplication approach with input tensors x and y, and transposed weights At. The backward pass calculates gradients with respect to x and At using a dispatcher kernel. Autotuning is used for optimal performance.", - "description_2": "Use triton language to implement a linear cross entropy loss function with forward and backward computation, optimized with autotuning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8),\n # Additional configurations\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr, y_ptr, A_t_ptr, z_nv_ptr, losses_ptr, lse_ptr,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V, stride_z_N, stride_z_V,\n stride_lse_N, stride_lse_B, stride_loss_Nb, stride_loss_B, idx_N_group, \n N_group: tl.constexpr, V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n # Triton kernel implementation\n pass # Implementation logic goes here\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr, y_ptr, A_t_ptr, x_grad_ptr, lse_ptr,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V,\n stride_z_N, stride_z_V, idx_N_group, \n N_group: tl.constexpr, V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr, SPLIT_V: tl.constexpr,\n):\n # Triton kernel implementation\n pass # Implementation logic goes here\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr, y_ptr, x_ptr, A_grad_ptr, lse_ptr,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V,\n stride_z_N, stride_z_V, idx_N_group, \n N_group: tl.constexpr, V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr, SPLIT_N: tl.constexpr,\n):\n # Triton kernel implementation\n pass # Implementation logic goes here\n\n\nbwd_configs = []\nfor num_stages in [2]:\n for warps in [4]:\n for v_block in [128]:\n for n_block in [128]:\n for h_block in [128]:\n for group in [32]:\n for split in [2]:\n bwd_configs.append(\n triton.Config(\n {\n \"V_BLOCK_SIZE\": v_block,\n \"N_BLOCK_SIZE\": n_block,\n \"H_BLOCK_SIZE\": h_block,\n \"GROUP_SIZE\": group,\n \"SPLIT_NV\": split,\n },\n num_warps=warps,\n num_stages=num_stages,\n )\n )\n\n\n@triton.autotune(\n configs=bwd_configs,\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits, y, x, At, x_grad, At_grad, lse_global,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V, stride_z_N, stride_z_V,\n idx_N_group, N_group, V, N, H,\n V_BLOCK_SIZE: tl.constexpr, N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr, SPLIT_NV: tl.constexpr,\n):\n idx_NV = tl.program_id(axis=0)\n if idx_NV < N_group // N_BLOCK_SIZE:\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits, y, At, x_grad, lse_global,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V, stride_z_N, stride_z_V,\n idx_N_group, N_group, V, N, H, V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE,\n GROUP_SIZE, SPLIT_V=SPLIT_NV,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits, y, x, At_grad, lse_global,\n stride_x_N, stride_x_H, stride_A_H, stride_A_V, stride_z_N, stride_z_V,\n idx_N_group, N_group, V, N, H, V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE,\n GROUP_SIZE, SPLIT_N=SPLIT_NV,\n )\n", - "description_1": "Use triton language to implement a forward and backward pass for cross-entropy loss with a linear layer, involving multiple kernels and configurations to efficiently compute the loss and its gradients.", - "description_2": "Use triton language to perform efficient block matrix operations for computing linear layer cross-entropy loss and its gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n # Additional configs...\n ],\n key=[\"V\", \"N\", \"H\"],\n prune_configs_by={\n \"early_config_prune\": lambda configs, named_args: [\n config\n for config in configs\n if config.kwargs[\"H_BLOCK_SIZE\"] <= named_args[\"x_ptr\"].shape[1]\n ],\n },\n warmup=100,\n rep=500,\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m = tl.max(z_j_to_k, 1)\n s = tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n mask = y[:, None] == V_range[None, :]\n loss = -tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n lse = m + tl.log(s)\n\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0) // SPLIT_V\n idx_H = tl.program_id(axis=1)\n idx_V_tile = tl.program_id(axis=0) % SPLIT_V\n\n num_idx_N, num_idx_H = tl.num_programs(0) - (triton.cdiv(V, V_BLOCK_SIZE) * SPLIT_N), tl.num_programs(1)\n idx_N, idx_H = tl.swizzle2d(idx_N, idx_H, num_idx_N // SPLIT_V, num_idx_H, GROUP_SIZE)\n\n V_split_offset = idx_V_tile * tl.cdiv(V, SPLIT_V)\n\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, V_split_offset),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, V_split_offset),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = V_split_offset + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE), eviction_policy=\"evict_last\")\n\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), x_grad_ptr.type.element_ty)\n for _ in range(0, tl.cdiv(V, V_BLOCK_SIZE * SPLIT_V)):\n mask = y[:, None] == v_range[None, :]\n A_v = tl.load(A_t_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)).to(A_t_ptr.type.element_ty)\n\n x_grad_acc = tl.dot(z_grad, A_v.trans(), x_grad_acc, out_dtype=x_grad_ptr.type.element_ty)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range += V_BLOCK_SIZE\n\n if SPLIT_V == 1:\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, (x_grad_acc / N).to(x_grad_ptr.type.element_ty))\n else:\n row_n = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n x_grad_simple_ptr = x_grad_ptr + row_n[:, None] * stride_x_N + row_h[None, :] * stride_x_H\n tl.atomic_add(x_grad_simple_ptr, (x_grad_acc / N).to(x_grad_ptr.type.element_ty))\n\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_V = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) // SPLIT_N\n idx_H = tl.program_id(axis=1)\n idx_N_tile = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) % SPLIT_N\n\n num_idx_V, num_idx_H = tl.num_programs(0) - (N_group // N_BLOCK_SIZE * SPLIT_V), tl.num_programs(1)\n idx_V, idx_H = tl.swizzle2d(idx_V, idx_H, num_idx_V // SPLIT_N, num_idx_H, GROUP_SIZE)\n\n N_split_offset = idx_N_tile * tl.cdiv(N_group, SPLIT_N)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + N_split_offset, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(N_split_offset, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = N_split_offset + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_grad_acc = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), A_grad_ptr.type.element_ty)\n for _ in range(0, tl.cdiv(N_group, N_BLOCK_SIZE * SPLIT_N)):\n y = tl.load(y_ptr + idx_N_group * N_group + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + N_range, eviction_policy=\"evict_last\")\n mask = y[:, None] == V_range[None, :]\n\n x_chunk = tl.load(x_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)).to(x_ptr.type.element_ty)\n\n A_grad_acc = tl.dot(x_chunk.trans(), z_grad, A_grad_acc, out_dtype=A_grad_ptr.type.element_ty)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n\n if SPLIT_N == 1:\n A_grad_T_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n if idx_N_group > 0:\n tl.store(A_grad_T_block_ptr, tl.load(A_grad_T_block_ptr) + (A_grad_acc / N).to(A_grad_ptr.type.element_ty))\n else:\n tl.store(A_grad_T_block_ptr, (A_grad_acc / N).to(A_grad_ptr.type.element_ty))\n else:\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n row_v = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n A_grad_T_simple_ptr = A_grad_ptr + row_h[:, None] * stride_A_H + row_v[None, :] * stride_A_V\n tl.atomic_add(A_grad_T_simple_ptr, (A_grad_acc / N).to(A_grad_ptr.type.element_ty))\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=3,\n ),\n # Additional configs...\n ],\n key=[\"V\", \"N\", \"H\"],\n prune_configs_by={\n \"early_config_prune\": lambda configs, named_args: [\n config\n for config in configs\n if config.kwargs[\"H_BLOCK_SIZE\"] <= named_args[\"x_ptr\"].shape[1]\n ],\n },\n warmup=100,\n rep=500,\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr,\n y_ptr,\n x_ptr,\n A_t_ptr,\n x_grad,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_NV = tl.program_id(axis=0)\n if idx_NV < (N_group // N_BLOCK_SIZE * SPLIT_V):\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits_ptr,\n y_ptr,\n A_t_ptr,\n x_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits_ptr,\n y_ptr,\n x_ptr,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n with torch.cuda.device(x.device.index):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n N_group = min(N, N_chunk_size)\n\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n\n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n lse_sum = 0.0\n lse_local = -10e5 * torch.ones(N_group, V // 128, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N_group // 64, V // 128, dtype=torch.float32, device=x.device)\n logits = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n with torch.inference_mode():\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n bwd_grid_dx_dA = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]) * meta[\"SPLIT_V\"]\n + triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]) * meta[\"SPLIT_N\"],\n triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]),\n )\n\n for idx_N_group in range(math.ceil(N / N_group)):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n logits,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n V_BLOCK_SIZE = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n buffer_extent = V // V_BLOCK_SIZE\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n lse_sum += lse_global.sum() / N\n\n if x.requires_grad or At.requires_grad:\n linear_xent_bwd_dispatcher[bwd_grid_dx_dA](\n logits,\n y,\n x,\n At,\n x_grad,\n At_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return (\n lse_sum + losses.sum(),\n linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config,\n linear_xent_bwd_dispatcher.best_config,\n )\n\n @staticmethod\n @torch.inference_mode()\n def backward(ctx, grad_output, void, void2):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size: int = 4096):\n out, fwd_config, bwd_config = LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n linear_cross_entropy.chosen_fwd_configs.append(fwd_config)\n linear_cross_entropy.chosen_bwd_configs.append(bwd_config)\n return out\n\n\nlinear_cross_entropy.chosen_fwd_configs = []\nlinear_cross_entropy.chosen_bwd_configs = []\n", - "description_1": "Use triton language to implement a cross-entropy loss calculation with its forward and backward passes. The forward kernel 'linear_xent_fwd_prep_bwd_kernel_matmul_t' takes 22 parameters: pointers to input data, strides, constants, and block size configurations. It performs matrix multiplication followed by cross-entropy loss computation. The backward pass is handled by 'linear_xent_bwd_dispatcher', which calls separate epilogues for gradients of x and A, depending on whether the current thread is responsible for N or V dimension. Both forward and backward passes are designed to optimize memory access patterns using block pointers and swizzled indices.", - "description_2": "Use triton language to develop an optimized matrix multiplication-based cross-entropy function with autotuning to select optimal configurations, addressing both forward and backward passes while utilizing efficient memory strategies.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=4, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=16, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2),\n ],\n key=[\"V\", \"N\", \"H\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE)\n \n V_GROUP_SIZE: tl.constexpr = V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n reduction = tl.load(reduction_ptr)\n mask = y[:, None] == tl.where(V_range != ignore_index, V_range, -1)[None, :] # Nc x Vc\n loss = -tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / reduction\n\n tl.store(z_block_ptr, (z_j_to_k + tl.log(1 / reduction)).to(z_nv_ptr.type.element_ty))\n\n m = tl.max(z_j_to_k, 1)\n zero_lse_constant: tl.constexpr = tl.log(1 / tl.cdiv(V, V_BLOCK_SIZE))\n lse = tl.where(y != ignore_index, tl.log(tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)) + m, zero_lse_constant)\n\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n fp32_grad_accumulators: bool = False\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n z_regularization=0.0,\n N_chunk_size: int = 4096,\n ):\n with torch.cuda.device(x.device.index):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n N_group = min(N, N_chunk_size)\n\n assert N % 64 == 0\n assert V % 128 == 0\n assert H % 64 == 0\n\n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n lse_sum = torch.zeros((1,), dtype=torch.float32, device=x.device)\n lse_local = -10e5 * torch.ones(N_group, V // 128, dtype=torch.float32, device=x.device)\n\n losses = torch.zeros(N_group // 64, V // 128, dtype=torch.float32, device=x.device)\n logits = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n with torch.inference_mode():\n reduction = (y != ignore_index).sum()\n if reduction == 0:\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return losses.sum()\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group in range(math.ceil(N / N_group)):\n linear_xent_fwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n logits,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n reduction,\n ignore_index=ignore_index,\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n V_BLOCK_SIZE = linear_xent_fwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n buffer_extent = V // V_BLOCK_SIZE\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n lse_sum += (lse_global.sum() + z_regularization * lse_global.pow(2).sum()) / reduction\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return lse_sum + losses.sum()\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, z_regularization=0.0, N_chunk_size: int = 4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, z_regularization, N_chunk_size)\n", - "description_1": "Use triton language to implement a kernel for forward propagation of linear cross-entropy loss. The kernel computes the logits by matrix multiplication of input and weights, applies softmax cross entropy with reduction and ignore index, and stores intermediate results for backward pass. The LinearCrossEntropyLoss function orchestrates calling the triton kernel with appropriate parameters.", - "description_2": "Use triton language to implement a forward linear cross-entropy loss computation using matrix multiplication and softmax, storing necessary outputs for backward computation. Integrate this with PyTorch's autograd functionality.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=4, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=3),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=8, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=16, num_stages=2),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2),\n ],\n key=[\"V\", \"N\", \"H\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n warmup=100,\n rep=500,\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE)\n\n V_GROUP_SIZE: tl.constexpr = V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n y = tl.load(y_ptr + idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE))\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n mask = y[:, None] == tl.where(V_range != ignore_index, V_range, -1)[None, :]\n loss = -tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n tl.store(z_block_ptr, (z_j_to_k + tl.log(1 / N)).to(z_nv_ptr.type.element_ty))\n\n m = tl.max(z_j_to_k, 1)\n zero_lse_constant: tl.constexpr = tl.log(1 / tl.cdiv(V, V_BLOCK_SIZE))\n lse = tl.where(y != ignore_index, tl.log(tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)) + m, zero_lse_constant)\n\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0) // SPLIT_V\n idx_H = tl.program_id(axis=1)\n idx_V_tile = tl.program_id(axis=0) % SPLIT_V\n\n num_idx_N, num_idx_H = tl.num_programs(0) - (triton.cdiv(V, V_BLOCK_SIZE) * SPLIT_N), tl.num_programs(1)\n idx_N, idx_H = tl.swizzle2d(idx_N, idx_H, num_idx_N // SPLIT_V, num_idx_H, GROUP_SIZE)\n\n V_split_offset = idx_V_tile * tl.cdiv(V, SPLIT_V)\n\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, V_split_offset),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, V_split_offset),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = V_split_offset + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE), eviction_policy=\"evict_last\")\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else x_grad_ptr.type.element_ty\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(V, V_BLOCK_SIZE * SPLIT_V)):\n mask = y[:, None] == V_range[None, :]\n A_v = tl.load(A_t_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n z_grad = softmax_z - tl.where(mask, 1 / N, 0.0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(A_v.type.element_ty)\n\n x_grad_acc = tl.dot(valid_z_grad, A_v.trans(), x_grad_acc, out_dtype=acc_dtype)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range += V_BLOCK_SIZE\n\n if SPLIT_V == 1:\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, x_grad_acc.to(x_grad_ptr.type.element_ty))\n else:\n row_n = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n x_grad_simple_ptr = x_grad_ptr + row_n[:, None] * stride_x_N + row_h[None, :] * stride_x_H\n tl.atomic_add(x_grad_simple_ptr, x_grad_acc.to(x_grad_ptr.type.element_ty))\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_V = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) // SPLIT_N\n idx_H = tl.program_id(axis=1)\n idx_N_tile = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) % SPLIT_N\n\n num_idx_V, num_idx_H = tl.num_programs(0) - (N_group // N_BLOCK_SIZE * SPLIT_V), tl.num_programs(1)\n idx_V, idx_H = tl.swizzle2d(idx_V, idx_H, num_idx_V // SPLIT_N, num_idx_H, GROUP_SIZE)\n\n N_split_offset = idx_N_tile * tl.cdiv(N_group, SPLIT_N)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + N_split_offset, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(N_split_offset, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = N_split_offset + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else A_grad_ptr.type.element_ty\n A_grad_acc = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(N_group, N_BLOCK_SIZE * SPLIT_N)):\n y = tl.load(y_ptr + idx_N_group * N_group + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + N_range, eviction_policy=\"evict_last\")\n mask = y[:, None] == V_range[None, :]\n\n x_chunk = tl.load(x_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr, eviction_policy=\"evict_last\")\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n z_grad = softmax_z - tl.where(mask, 1 / N, 0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(x_ptr.type.element_ty)\n\n A_grad_acc = tl.dot(x_chunk.trans(), valid_z_grad, A_grad_acc, out_dtype=acc_dtype)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n\n if SPLIT_N == 1:\n A_grad_T_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n if idx_N_group > 0:\n tl.store(\n A_grad_T_block_ptr,\n tl.load(A_grad_T_block_ptr) + A_grad_acc.to(A_grad_ptr.type.element_ty),\n )\n else:\n tl.store(A_grad_T_block_ptr, A_grad_acc.to(A_grad_ptr.type.element_ty))\n else:\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n row_v = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n A_grad_T_simple_ptr = A_grad_ptr + row_h[:, None] * stride_A_H + row_v[None, :] * stride_A_V\n tl.atomic_add(A_grad_T_simple_ptr, A_grad_acc.to(A_grad_ptr.type.element_ty))\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=3,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=3,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 2},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=16,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 8},\n num_warps=16,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 8},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8,\n num_stages=2,\n ),\n ],\n key=[\"V\", \"N\", \"H\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n warmup=100,\n rep=500,\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr,\n y_ptr,\n x_ptr,\n A_t_ptr,\n x_grad,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n reduction_ptr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 128,\n N_BLOCK_SIZE: tl.constexpr = 128,\n H_BLOCK_SIZE: tl.constexpr = 128,\n GROUP_SIZE: tl.constexpr = 32,\n SPLIT_N: tl.constexpr = 2,\n SPLIT_V: tl.constexpr = 2,\n):\n idx_NV = tl.program_id(axis=0)\n if idx_NV < (N_group // N_BLOCK_SIZE * SPLIT_V):\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits_ptr,\n y_ptr,\n A_t_ptr,\n x_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization,\n fp32_grad_accumulators,\n reduction_ptr,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits_ptr,\n y_ptr,\n x_ptr,\n At_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n z_regularization,\n fp32_grad_accumulators,\n reduction_ptr,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n fp32_grad_accumulators: bool = False\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n z_regularization=0.0,\n N_chunk_size: int = 4096,\n ):\n with torch.cuda.device(x.device.index):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n N_group = min(N, N_chunk_size)\n\n assert N % 64 == 0\n assert V % 128 == 0\n assert H % 64 == 0\n\n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n lse_sum = torch.zeros((1,), dtype=torch.float32, device=x.device)\n lse_local = -10e5 * torch.ones(N_group, V // 128, dtype=torch.float32, device=x.device)\n\n losses = torch.zeros(N_group // 64, V // 128, dtype=torch.float32, device=x.device)\n logits = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n with torch.inference_mode():\n reduction = N\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n bwd_grid_dx_dA = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]) * meta[\"SPLIT_V\"]\n + triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]) * meta[\"SPLIT_N\"],\n triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]),\n )\n\n for idx_N_group in range(math.ceil(N / N_group)):\n linear_xent_fwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n logits,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n reduction,\n ignore_index=ignore_index,\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n V_BLOCK_SIZE = linear_xent_fwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n\n buffer_extent = V // V_BLOCK_SIZE\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n lse_sum += lse_global.sum() / reduction\n\n if x.requires_grad or At.requires_grad:\n linear_xent_bwd_dispatcher[bwd_grid_dx_dA](\n logits,\n y,\n x,\n At,\n x_grad,\n At_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n z_regularization,\n LinearCrossEntropyLoss.fp32_grad_accumulators,\n reduction,\n ignore_index=ignore_index,\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return lse_sum + losses.sum()\n\n @staticmethod\n @torch.inference_mode()\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None, None\n\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, z_regularization=0.0, N_chunk_size: int = 4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, z_regularization, N_chunk_size)\n\n\nif __name__ == \"__main__\":\n f = 1\n V, N, H = 32768 * f, 4096 * f, 1024 * f\n\n compute_dtype = torch.float16\n\n y = torch.randint(0, V, (N,), device=device)\n A = torch.randn(V, H, requires_grad=True, device=device, dtype=compute_dtype)\n At = A.clone().detach().T.contiguous()\n At.requires_grad_()\n\n x = (0.1 * A[y].clone().detach() + torch.randn(N, H, device=device, dtype=compute_dtype)) * 1\n x.requires_grad_()\n z_reg = 0.0\n\n A_ref = A.clone().detach()\n\n loss = baseline_torch(x.float(), y, A.float(), ignore_index=5, z_regularization=z_reg)\n loss.backward()\n\n reference_A_grad = A.grad.float().clone()\n reference_x_grad = x.grad.float().clone()\n reference_loss = loss.detach().float().clone()\n\n z_ref = F.linear(x, A).view(-1, V).float().detach()\n m_ref = z_ref.max(dim=1)[0]\n s_ref = (z_ref - m_ref[:, None]).exp().sum(dim=1)\n\n print(reference_loss)\n\n simple_bench(\n lambda: linear_cross_entropy(x, y, At, ignore_index=5, z_regularization=z_reg),\n reference_loss,\n reference_x_grad,\n reference_A_grad,\n )\n\n simple_bench(lambda: torch.compile(baseline_torch)(x, y, A), reference_loss, reference_x_grad, reference_A_grad)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss with a forward kernel and a backward kernel. The forward kernel (linear_xent_fwd_kernel_matmul_t) takes 24 parameters including pointers to input, target, transposed weight, and several stride values. The backward dispatcher (linear_xent_bwd_dispatcher) also takes 24 parameters to manage the computation of gradients with respect to inputs and weights. It manages both backward kernels for input and weights separately.", - "description_2": "Use triton language to implement and autotune linear cross-entropy loss calculation with both forward and backward kernels, handling the computation of loss and gradients efficiently on a GPU.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"loss_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n loss_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.atomic_add(loss_ptr, loss)\n tl.store(lse_ptr + offsets, lse)\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"sz_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_prologue(\n sz_ptr,\n x_ptr,\n A_t_ptr,\n lse_global_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_sz_N,\n stride_sz_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n lse = tl.load(lse_global_ptr + offsets)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n sz_block_ptr = tl.make_block_ptr(\n base=sz_ptr,\n shape=(N, V),\n strides=(stride_sz_N, stride_sz_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x Hc\n A_v = tl.load(A_block_ptr) # Hc x Vc\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x Hc) @ (Hc x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n tl.store(sz_block_ptr, softmax_z.to(tl.float16))\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\", \"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue(\n sz_ptr,\n x_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_sz_N,\n stride_sz_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n idx_NV = tl.program_id(axis=1)\n if idx_NV < (N // N_BLOCK_SIZE):\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n sz_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_sz_N,\n stride_sz_V,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n sz_ptr,\n x_ptr,\n y_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_sz_N,\n stride_sz_V,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n )\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n sz_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_sz_N,\n stride_sz_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_H = tl.program_id(axis=0)\n idx_N = tl.program_id(axis=1)\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n sz_block_ptr = tl.make_block_ptr(\n base=sz_ptr,\n shape=(N, V),\n strides=(stride_sz_N, stride_sz_V),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_offsets)\n\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), tl.float32)\n for idx_V in range(V // V_BLOCK_SIZE):\n mask = (y[:, None] == V_offsets[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n A_v = tl.load(A_t_block_ptr).trans() # Hc x Vc\n sz = tl.load(sz_block_ptr)\n\n # xgrad\n x_grad_acc = tl.dot(sz, A_v, x_grad_acc)\n x_grad_acc -= tl.sum(tl.where(mask, A_v[None, :, :], 0.0), axis=1)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n sz_block_ptr = tl.advance(sz_block_ptr, [0, V_BLOCK_SIZE])\n V_offsets += V_BLOCK_SIZE\n\n tl.store(x_grad_block_ptr, x_grad_acc / N)\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n sz_ptr,\n x_ptr,\n y_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_sz_N,\n stride_sz_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_H = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1) - (N // N_BLOCK_SIZE)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(0, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_t_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n sz_block_ptr = tl.make_block_ptr(\n base=sz_ptr,\n shape=(N, V),\n strides=(stride_sz_N, stride_sz_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_offsets = tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_grad_acc = tl.zeros((V_BLOCK_SIZE, H_BLOCK_SIZE), tl.float32)\n for idx_N in range(N // N_BLOCK_SIZE):\n y = tl.load(y_ptr + N_offsets)\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n x_chunk = tl.load(x_block_ptr)\n sz = tl.load(sz_block_ptr).trans()\n\n A_grad_acc = tl.dot(sz, x_chunk, A_grad_acc)\n A_grad_acc -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n sz_block_ptr = tl.advance(sz_block_ptr, [N_BLOCK_SIZE, 0])\n N_offsets += N_BLOCK_SIZE\n\n tl.store(A_t_grad_block_ptr, A_grad_acc.trans() / N)\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 16 == 0, f\"V is {V}\"\n assert N % 16 == 0, f\"N is {N}\"\n assert H % 16 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n loss = torch.zeros(1, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, loss, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n ctx.save_for_backward(x, y, At, lse_global)\n return loss\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x, dtype=torch.float32)\n Atgrad = torch.zeros_like(At, dtype=torch.float32)\n\n with torch.cuda.device(x.device.index):\n sz = torch.empty((N, V), dtype=torch.float16, device=x.device)\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]), triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]))\n linear_xent_bwd_kernel_matmul_t_prologue[grid](\n sz,\n x,\n At,\n lse_global,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n sz.stride(0),\n sz.stride(1),\n V,\n N,\n H,\n )\n grid = lambda meta: (\n triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]),\n triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]) + triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n linear_xent_bwd_kernel_matmul_t_epilogue[grid](\n sz,\n x,\n y,\n At,\n xgrad,\n Atgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n sz.stride(0),\n sz.stride(1),\n V,\n N,\n H,\n )\n\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement a forward and backward kernel for linear cross-entropy computation. The forward kernel computes the matrix product of inputs x and weight matrix A_t, applies softmax, and calculates the loss. It requires pointers to inputs, output loss, and intermediate results, and block size parameters to divide computation across grid blocks. The backward prologue kernel computes softmax derivatives, storing them in an intermediate buffer, while the backward epilogue kernel computes gradients for both the input and weight matrix, requiring similar pointer and block parameters.", - "description_2": "Use triton language to implement a linear cross-entropy forward kernel that calculates matrix multiplication, softmax, and loss, and a backward kernel that computes input and weight matrix gradients using block-wise processing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"loss_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n loss_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.atomic_add(loss_ptr, loss)\n tl.store(lse_ptr + offsets, lse)\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\", \"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n x_grad_ptr,\n locks_N_ptr,\n locks_V_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_x_grad_N,\n stride_x_grad_H,\n stride_A_grad_H,\n stride_A_grad_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + offsets)\n lse = tl.load(lse_global_ptr + offsets)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_grad_N, stride_x_grad_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_grad_H, stride_A_grad_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n\n local_x_block_ptr = x_block_ptr\n local_A_block_ptr = A_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr) # Nc x Hc\n A_v = tl.load(local_A_block_ptr) # Hc x Vc\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x Hc) @ (Hc x Vc)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n local_A_block_ptr = tl.advance(local_A_block_ptr, [H_BLOCK_SIZE, 0])\n\n mask = (y[:, None] == v_range[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n # the reason for the double loop\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr).to(tl.float32) # Nc x Hc\n A_v = tl.load(A_block_ptr).to(tl.float32) # Hc x Vc\n\n # xgrad\n temp_xgrad = tl.dot(softmax_z, A_v.trans())\n temp_xgrad -= tl.sum(tl.where(mask, A_v.trans()[None, :, :], 0.0), axis=1)\n\n # Lock in V direction for x accumulation\n # tl.atomic_add(x_grad_block_ptr, temp_xgrad)\n while tl.atomic_cas(locks_V_ptr + idx_N, 0, 1) == 1:\n pass\n temp_xgrad = temp_xgrad / N + tl.load(x_grad_block_ptr)\n tl.store(x_grad_block_ptr, temp_xgrad)\n tl.atomic_xchg(locks_V_ptr + idx_N, 0)\n\n # Agrad\n temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)\n temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n temp_Agrad = temp_Agrad.trans() # to T\n\n # Lock in N direction for A accumulation\n # tl.atomic_add(A_grad_block_ptr, temp_Agrad)\n while tl.atomic_cas(locks_N_ptr + idx_V, 0, 1) == 1:\n pass\n temp_Agrad = temp_Agrad / N + tl.load(A_grad_block_ptr)\n\n tl.store(A_grad_block_ptr, temp_Agrad)\n tl.atomic_xchg(locks_N_ptr + idx_V, 0)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n x_grad_block_ptr = tl.advance(x_grad_block_ptr, [0, H_BLOCK_SIZE])\n\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n A_grad_block_ptr = tl.advance(A_grad_block_ptr, [H_BLOCK_SIZE, 0])\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 16 == 0, f\"V is {V}\"\n assert N % 16 == 0, f\"N is {N}\"\n assert H % 16 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n loss = torch.zeros(1, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, loss, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n print(\"fwd config:\", linear_xent_fwd_kernel_matmul_t.best_config)\n\n ctx.save_for_backward(x, y, At, lse_global)\n return loss\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x, dtype=torch.float32)\n Atgrad = torch.zeros_like(At, dtype=torch.float32)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]), triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]))\n locks_N = torch.zeros(N // 16, dtype=torch.int32, device=x.device)\n locks_V = torch.zeros(V // 16, dtype=torch.int32, device=x.device)\n\n with torch.cuda.device(x.device.index):\n linear_xent_bwd_kernel_matmul_t[grid](\n x,\n y,\n At,\n lse_global,\n Atgrad,\n xgrad,\n locks_N,\n locks_V,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n xgrad.stride(0),\n xgrad.stride(1),\n Atgrad.stride(0),\n Atgrad.stride(1),\n V=V,\n N=N,\n H=H,\n )\n print(\"bwd config:\", linear_xent_bwd_kernel_matmul_t.best_config)\n\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n\n", - "description_1": "Use triton language to implement a cross-entropy loss with linear transformation forward and backward kernels. The forward kernel 'linear_xent_fwd_kernel_matmul_t' takes 15 parameters: input pointers for x, y, transposed A, loss, and lse, strides for x and A, constants V, N, H, and block sizes for V, N, and H. The backward kernel 'linear_xent_bwd_kernel_matmul_t' takes 18 parameters: input pointers for x, y, transposed A, global lse, A gradient, x gradient, and locks for N and V, strides for x, A, x gradient, and A gradient, constants V, N, H, and block sizes for V, N, and H.", - "description_2": "Use triton language to define and call the 'LinearCrossEntropyLoss' function, which uses the above kernels to compute forward and backward passes for cross-entropy loss with a linear transformation of input x and target y with transposed matrix At.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 128}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 128}, num_warps=8),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\", \"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n # save z for later\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n # Reset and advance pointers for next step\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx_N + idx_N_group * N_group // N_BLOCK_SIZE, loss)\n tl.store(lse_ptr + N_range, lse)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=16),\n ],\n key=[\"V\", \"N\"],\n restore_value=[\"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE)\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At = At.contiguous()\n A_grad = torch.zeros_like(At.T)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n grid = lambda meta: (triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),)\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n lse_global,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n lse_global,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if x.requires_grad:\n x_grad[N_group * idx_N_group : x_n_chunk.shape[0] * (idx_N_group + 1)] = z_grad @ At.T\n\n if At.requires_grad:\n torch.addmm(\n A_grad,\n z_grad.T,\n x_n_chunk,\n out=A_grad,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, A_grad.T.to(At.dtype))\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None\n\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with forward and backward passes. The forward kernel 'linear_xent_fwd_prep_bwd_kernel_matmul_t' takes 19 parameters: pointers to input tensors, strides, and block sizes, and computes the forward pass of the loss. The backward kernel 'linear_xent_mini_bwd_prologue_kernel' takes 10 parameters: pointers to input tensors, strides, and block sizes, and computes the gradient of the loss. The 'LinearCrossEntropyLoss' class wraps these kernels and provides a PyTorch-compatible interface with forward and backward methods.", - "description_2": "Use triton language to create a linear cross-entropy loss function with both forward and backward kernels, and integrate it with PyTorch's autograd system.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n N_offset,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n offsets = N_offset + idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n # save z for later\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n # Reset and advance pointers for next step\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\"],\n reset_to_zero=[\"z_grad_ptr\"],\n)\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n z_grad_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n N_offset,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_grad_block_ptr = tl.make_block_ptr(\n base=z_grad_ptr,\n shape=(N, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = N_offset + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_grad_block_ptr, z_grad.to(tl.float16))\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n N_offset,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n idx_H = tl.program_id(axis=1)\n idx_V = 0\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(N_offset + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = N_offset + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = 0 + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), tl.float32)\n for idx_V in range(V // V_BLOCK_SIZE):\n mask = (y[:, None] == v_range[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n A_v = tl.load(A_t_block_ptr).trans() # Hc x Vc\n z_j_to_k = tl.load(z_block_ptr)\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n # xgrad\n x_grad_acc = tl.dot(softmax_z, A_v, x_grad_acc)\n x_grad_acc -= tl.sum(tl.where(mask, A_v[None, :, :], 0.0), axis=1)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range += V_BLOCK_SIZE\n\n tl.store(x_grad_block_ptr, (x_grad_acc / N).to(tl.float16))\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At = At.contiguous()\n A_grad = torch.zeros_like(At.T)\n x_grad = torch.zeros_like(x)\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n fwd_grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n bwd_grid_dx = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]), triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]))\n bwd_grid_dA = lambda meta: (triton.cdiv(N, meta[\"V_BLOCK_SIZE\"]), triton.cdiv(V, meta[\"H_BLOCK_SIZE\"]))\n\n for idx, x_n_chunk in enumerate(x.split(N_chunk_size)):\n x_input = x_n_chunk.contiguous()\n\n z_nv = torch.empty((N_chunk_size, V), device=x.device, dtype=torch.float32)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv,\n losses,\n lse_global,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv.stride(0),\n z_nv.stride(1),\n N_offset=idx * N_chunk_size,\n V=V,\n N=N_chunk_size,\n H=H,\n )\n if x.requires_grad:\n linear_xent_bwd_kernel_matmul_t_epilogue_dx[bwd_grid_dx](\n z_nv,\n y,\n At,\n x_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv.stride(0),\n z_nv.stride(1),\n idx * N_chunk_size,\n V,\n N_chunk_size,\n H,\n )\n\n if At.requires_grad:\n torch.addmm(\n A_grad,\n z_nv.T.half(),\n x_input,\n out=A_grad,\n )\n\n print(\"fwd config:\", linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config)\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, A_grad.T)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement three kernel functions, and an associated PyTorch Function for calculating cross-entropy loss. The first kernel (linear_xent_fwd_prep_bwd_kernel_matmul_t) computes forward pass with matrix multiplication and computes partial backward data for backward pass. The second kernel (linear_xent_mini_bwd_prologue_kernel) prepares gradient data for a mini-batch using softmax. The third kernel (linear_xent_bwd_kernel_matmul_t_epilogue_dx) computes the backward pass gradients for input data using matrix multiplication. The PyTorch Function handles batching and orchestrates the forward and backward passes, with caching for gradients. Each function takes a number of arguments that include pointers to data, strides, and block sizes as compile-time constants for optimization.", - "description_2": "Use triton language to create optimized kernels for cross-entropy loss computation and implement them in PyTorch for efficient forward and backward passes using matrix multiplication and softmax.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n # save z for later\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n # Reset and advance pointers for next step\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx_N + idx_N_group * N_group // N_BLOCK_SIZE, loss)\n tl.store(lse_ptr + N_range, lse)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n idx_H = tl.program_id(axis=1)\n idx_V = 0\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = 0 + tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), tl.float32)\n for idx_V in range(V // V_BLOCK_SIZE):\n mask = (y[:, None] == v_range[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1\n A_v = tl.load(A_t_block_ptr).trans() # Hc x Vc\n z_j_to_k = tl.load(z_block_ptr)\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n # xgrad\n x_grad_acc = tl.dot(softmax_z, A_v, x_grad_acc)\n x_grad_acc -= tl.sum(tl.where(mask, A_v[None, :, :], 0.0), axis=1)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range += V_BLOCK_SIZE\n\n tl.store(x_grad_block_ptr, (x_grad_acc / N).to(x_grad_ptr.type.element_ty))\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V = tl.program_id(axis=0)\n idx_H = tl.program_id(axis=1)\n idx_N = 0\n\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(V, H),\n strides=(stride_A_V, stride_A_H),\n offsets=(idx_V * V_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(V_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = idx_N_group * N_group + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_grad_acc = tl.zeros((V_BLOCK_SIZE, H_BLOCK_SIZE), tl.float32)\n for idx_N in range(N // N_BLOCK_SIZE):\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n mask = (y[:, None] == V_range[None, :])[:, :, None] # type: ignore\n\n x_chunk = tl.load(x_block_ptr)\n z_j_to_k = tl.load(z_block_ptr)\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n A_grad_acc = tl.dot(softmax_z.trans(), x_chunk, A_grad_acc)\n A_grad_acc -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n tl.store(A_grad_block_ptr, (A_grad_acc / N).to(A_grad_ptr.type.element_ty))\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100, # code ignores all negative integers right now\n N_chunk_size: int = 4096, # N_chunk_size x V is the maximal memory peak\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n # x = x.contiguous()\n # y = y.contiguous()\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At = At.contiguous()\n A_grad = torch.zeros_like(At.T)\n x_grad = torch.zeros_like(x)\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n z_nv = torch.empty((N_chunk_size, V), device=x.device, dtype=torch.float32)\n\n N_group = min(N, N_chunk_size)\n\n fwd_grid = lambda meta: (triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),)\n bwd_grid_dx = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]),\n )\n bwd_grid_dA = lambda meta: (\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index): # actually required\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv,\n losses,\n lse_global,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv.stride(0),\n z_nv.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n if x.requires_grad:\n linear_xent_bwd_kernel_matmul_t_epilogue_dx[bwd_grid_dx](\n z_nv,\n y,\n At,\n x_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv.stride(0),\n z_nv.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n if At.requires_grad:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA[bwd_grid_dA](\n z_nv,\n y,\n x,\n A_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv.stride(0),\n z_nv.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n\n print(\"fwd config:\", linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config)\n print(\"dx config:\", linear_xent_bwd_kernel_matmul_t_epilogue_dx.best_config)\n print(\"dA config:\", linear_xent_bwd_kernel_matmul_t_epilogue_dA.best_config)\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, A_grad.T)\n # print(losses.max())\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None\n\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with forward and backward passes. The forward kernel computes the loss and log-sum-exp values, while the backward kernels compute gradients with respect to input and weight matrices. The kernels are optimized using triton's autotune feature with various block size configurations.", - "description_2": "Use triton language to create a linear cross-entropy loss function with optimized forward and backward kernels, utilizing autotuning for performance.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 128, \"V_TILES\": 1}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 128, \"V_TILES\": 1}, num_warps=4),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 128, \"V_TILES\": 1}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\", \"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n V_TILES: tl.constexpr = 4,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n V_GROUP_SIZE: tl.constexpr = V_TILES * V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N, V // 16),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = (\n losses_ptr + (idx_N + idx_N_group * N_group // N_BLOCK_SIZE) * stride_loss_Nb + idx_V_group * stride_loss_B\n )\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V_TILES):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n s_update = tl.sum(tl.exp((z_j_to_k - m_new[:, None])), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n mask = y[:, None] == V_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n tl.store(loss_val_ptr, loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=16),\n ],\n key=[\"V\", \"N\"],\n restore_value=[\"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At = At.contiguous()\n A_grad = torch.zeros_like(At.T)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n lse_local = torch.zeros(N, V // 16, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, V // 16, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n chosen_tiles = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_TILES\"]\n chosen_block = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n buffer_extent = V // chosen_block // chosen_tiles\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n lse_global,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if x.requires_grad:\n x_grad[N_group * idx_N_group : x_n_chunk.shape[0] * (idx_N_group + 1)] = z_grad @ At.T\n\n if At.requires_grad:\n torch.addmm(\n A_grad,\n z_grad.T,\n x_n_chunk,\n out=A_grad,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, A_grad.T.to(At.dtype))\n return losses.sum() + lse_global.sum() / N\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size=4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n\n", - "description_1": "Use triton language to implement a cross-entropy loss with linear transformation, employing forward and backward kernel functions. The forward kernel (`linear_xent_fwd_prep_bwd_kernel_matmul_t`) computes dot products, manages data pointers, and stores results for losses and local softmax exponentiation. It requires 29 parameters: three pointers for input data (`x_ptr`, `y_ptr`, `A_t_ptr`), three pointers for output data (`z_nv_ptr`, `losses_ptr`, `lse_ptr`), nine stride parameters for data access (`stride_x_N`, `stride_x_H`, `stride_A_H`, `stride_A_V`, `stride_z_N`, `stride_z_V`, `stride_lse_N`, `stride_lse_B`, `stride_loss_Nb`, `stride_loss_B`), an index parameter (`idx_N_group`), and several configuration constants (`N_group`, `V`, `N`, `H`, `V_BLOCK_SIZE`, `N_BLOCK_SIZE`, `H_BLOCK_SIZE`, `V_TILES`). The backward kernel (`linear_xent_mini_bwd_prologue_kernel`) computes gradients for softmax and requires 11 parameters: three pointers for data (`z_nv_ptr`, `y_ptr`, `lse_ptr`), two stride parameters for data access (`stride_z_N`, `stride_z_V`), an index parameter (`idx_N_group`), and configuration constants (`N_group`, `V`, `N`, `V_BLOCK_SIZE`, `N_BLOCK_SIZE`).", - "description_2": "Use triton language to create kernels that handle forward and backward computation of linear transformation with cross-entropy loss. The forward kernel processes dot products and stores intermediate results for both loss and local exponentiation. The backward kernel computes gradients necessary for updating weights.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n V_TILES: tl.constexpr = 1,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n\n V_GROUP_SIZE: tl.constexpr = V_TILES * V_BLOCK_SIZE\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 16),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V_TILES):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp((z_j_to_k - m_new[:, None])), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n tl.store(loss_val_ptr, loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE))\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At = At.contiguous()\n A_grad = torch.zeros_like(At.T)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n loss = torch.zeros(1, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n lse_local = -10e5 * torch.ones(N_group, V // 16, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N_group // 16, V // 16, dtype=torch.float32, device=x.device)\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n chosen_tiles = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_TILES\"]\n chosen_block = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n buffer_extent = V // chosen_block // chosen_tiles\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n lse_global,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if x.requires_grad:\n x_grad[N_group * idx_N_group : x_n_chunk.shape[0] * (idx_N_group + 1)] = z_grad @ At.T\n\n if At.requires_grad:\n torch.addmm(\n A_grad,\n z_grad.T,\n x_n_chunk,\n out=A_grad,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, A_grad.T.to(At.dtype))\n return loss\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size: int = 4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with two kernels: one for forward and backward preparation (linear_xent_fwd_prep_bwd_kernel_matmul_t) and another for backward prologue (linear_xent_mini_bwd_prologue_kernel). The forward function takes 5 inputs: x (input tensor), y (target tensor), At (transposed weight matrix), ignore_index (index to ignore), and N_chunk_size (chunk size for processing). The backward function computes gradients for x and At.", - "description_2": "Use triton language to create a linear cross-entropy loss function with forward and backward kernels, processing inputs x, y, and At, and computing gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n V_TILES: tl.constexpr = 1,\n GROUP_SIZE: tl.constexpr = 1,\n):\n # Kernel logic for forward pass and preparation for backward pass\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n GROUP_SIZE: tl.constexpr = 1,\n):\n # Kernel logic for calculating gradients with respect to input x\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n GROUP_SIZE: tl.constexpr = 16,\n):\n # Kernel logic for calculating gradients with respect to weights A\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n # Forward pass implementation\n \n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n \n bwd_grid_dx = lambda meta: (triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]), triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]))\n bwd_grid_dA = lambda meta: (triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]), triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]))\n\n for idx_N_group in range(math.ceil(N / N_group)):\n with torch.cuda.device(x.device.index): # actually required\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n logits,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n # Compute global log-sum-exp and accumulate the loss\n lse_global = lse_local.logsumexp(dim=1)\n loss += losses.sum() + lse_global.sum() / N\n\n if x.requires_grad:\n linear_xent_bwd_kernel_matmul_t_epilogue_dx[bwd_grid_dx](\n logits,\n y,\n At,\n x_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n if At.requires_grad:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA[bwd_grid_dA](\n logits,\n y,\n x,\n At_grad,\n lse_global,\n x_grad.stride(0),\n x_grad.stride(1),\n At.stride(0),\n At.stride(1),\n logits.stride(0),\n logits.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n \n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return loss\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size: int = 4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss calculation with the forward and backward propagation kernels. This involves three main triton kernels: linear_xent_fwd_prep_bwd_kernel_matmul_t for the forward pass and preparing backward computation, linear_xent_bwd_kernel_matmul_t_epilogue_dx for computing the gradient with respect to input x, and linear_xent_bwd_kernel_matmul_t_epilogue_dA for computing the gradient with respect to weights A. The function linear_cross_entropy calls these kernels, handling memory management and grid configuration.", - "description_2": "Use triton language to create and integrate kernels for forward and backward propagation of a linear cross-entropy loss, optimizing memory and compute resources.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\", \"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n V_TILES: tl.constexpr = 1, # type: ignore\n GROUP_SIZE: tl.constexpr = 1, # type: ignore\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n # idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE) # type:ignore\n\n V_GROUP_SIZE: tl.constexpr = V_TILES * V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 16), # fixed to worst case number assuming max(V_TILES)\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V_TILES):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp((z_j_to_k - m_new[:, None])), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :] # Nc x Vc\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n # save z for later\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n # Reset and advance pointers for next step\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n # loss += tl.sum(lse) / N # defered until all blocks are done\n tl.store(loss_val_ptr, loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=8),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 1024, \"N_BLOCK_SIZE\": 16}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64}, num_warps=16),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=16),\n ],\n key=[\"V\", \"N\"],\n restore_value=[\"z_nv_ptr\"], # or reset_to_zero? does this have measurable consequences?\n)\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n # tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE)\n # tl.static_assert(N % N_BLOCK_SIZE == 0)\n # tl.static_assert(V % V_BLOCK_SIZE == 0)\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100, # code ignores all negative integers right now\n N_chunk_size: int = 4096, # N_chunk_size x V is the maximal memory peak\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n loss = torch.zeros(1, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n lse_local = -10e5 * torch.ones(N_group, V // 128, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N_group // 64, V // 128, dtype=torch.float32, device=x.device)\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index): # actually required\n\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n chosen_tiles = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_TILES\"]\n chosen_block = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n buffer_extent = V // chosen_block // chosen_tiles\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n loss += losses.sum() + lse_global.sum() / N\n losses.zero_()\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n lse_global,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if x.requires_grad:\n x_grad[N_group * idx_N_group : x_n_chunk.shape[0] * (idx_N_group + 1)] = z_grad @ At.T\n\n if At.requires_grad:\n torch.addmm(\n At_grad,\n x_n_chunk.T,\n z_grad,\n out=At_grad,\n )\n\n print(\"fwd config:\", linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config)\n print(\"prologue config:\", linear_xent_mini_bwd_prologue_kernel.best_config)\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return loss\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size: int = 4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n", - "description_1": "Use triton language to implement a forward and backward pass for a linear cross-entropy loss. The first kernel 'linear_xent_fwd_prep_bwd_kernel_matmul_t' computes the forward pass and prepares for the backward pass by storing intermediate results. The second kernel 'linear_xent_mini_bwd_prologue_kernel' computes the partial backward pass for a subset of data. 'LinearCrossEntropyLoss' class wraps these kernels for easy use in a PyTorch-like interface, supporting both forward and backward computations. The 'forward' function accepts 5 parameters: 'x' (input tensor of shape N x H), 'y' (target labels of shape N), 'At' (transposed weight matrix of shape H x V), 'ignore_index' (label to ignore during loss computation), and 'N_chunk_size' (chunk size for processing).", - "description_2": "Use triton language to compute linear cross-entropy loss and its gradient, handling large vocabularies by chunking input data. Implement two main Triton kernels for forward and partial backward operations and wrap them using a PyTorch function class to integrate with autograd.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n V_TILES: tl.constexpr = 1,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n\n V_GROUP_SIZE: tl.constexpr = V_TILES * V_BLOCK_SIZE\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N, V // 64),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n loss_val_ptr = (\n losses_ptr + (idx_N + idx_N_group * N_group // N_BLOCK_SIZE) * stride_loss_Nb + idx_V_group * stride_loss_B\n )\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V_TILES):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp((z_j_to_k - m_new[:, None])), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n tl.store(loss_val_ptr, loss)\n tl.store(lse_row_ptr, lse[:, None])\n\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n lse_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n lse = tl.load(lse_ptr + N_range)\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n if ignore_index >= 0:\n y[y == ignore_index] = -100\n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n lse_local = -10e5 * torch.ones(N, V // 64, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 64, V // 64, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n lse_local,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n lse_local.stride(0),\n lse_local.stride(1),\n losses.stride(0),\n losses.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n chosen_tiles = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_TILES\"]\n chosen_block = linear_xent_fwd_prep_bwd_kernel_matmul_t.best_config.kwargs[\"V_BLOCK_SIZE\"]\n buffer_extent = V // chosen_block // chosen_tiles\n lse_global = lse_local[:, :buffer_extent].logsumexp(dim=1)\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n lse_global,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if x.requires_grad:\n x_grad[N_group * idx_N_group : x_n_chunk.shape[0] * (idx_N_group + 1)] = z_grad @ At.T\n\n if At.requires_grad:\n torch.addmm(\n At_grad,\n x_n_chunk.T,\n z_grad,\n out=At_grad,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return losses.sum() + lse_global.sum() / N\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size=4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n\n", - "description_1": "Use triton language to create a linear cross-entropy loss kernel and its backpropagation prologue kernel for efficient matrix operations. The forward kernel computes the forward pass of the linear transformation and cross-entropy loss. It takes in pointers to input tensors (x, y, A_t) and output tensors (z_nv, losses, lse), strides for each tensor, group and block sizes for the computation. The backward kernel (prologue) prepares gradients for backpropagation, processing the tensor z_nv with the softmax function, and adjusting gradients for the inputs. The LinearCrossEntropyLoss class manages forward and backward passes using these kernels with PyTorch compatibility.", - "description_2": "Use triton language to implement efficient forward and backward passes for a linear transformation followed by a cross-entropy loss using Triton kernels for large matrix operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"V_TILES\": 1}, num_warps=8),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"sumexp_ptr\", \"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n sumexp_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n V_TILES: tl.constexpr = 4,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n \n V_GROUP_SIZE: tl.constexpr = V_TILES * V_BLOCK_SIZE\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n sumexp_row_ptr = sumexp_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e6)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V_TILES):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp((z_j_to_k - m_new[:, None])), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == V_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n m = m_new\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n A_block_ptr = tl.advance(A_block_ptr, [-H, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n V_range = V_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n sum_exp = tl.exp(lse).to(sumexp_ptr.type.element_ty)\n\n tl.atomic_add(losses_ptr + idx_N, loss)\n tl.atomic_add(sumexp_row_ptr, sum_exp)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128}, num_warps=8),\n ],\n key=[\"V\", \"N\"],\n restore_value=[\"z_nv_ptr\"],\n)\n@triton.jit\ndef linear_xent_mini_bwd_prologue_kernel(\n z_nv_ptr,\n y_ptr,\n sumexp_ptr,\n stride_z_N,\n stride_z_V,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V = tl.program_id(axis=1)\n \n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + idx_N_group * N_group + N_range)\n lse = tl.log(tl.load(sumexp_ptr + N_range))\n z_j_to_k = tl.load(z_block_ptr)\n\n mask = y[:, None] == v_range[None, :]\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n z_grad = (softmax_z - tl.where(mask, 1.0, 0.0)) / N\n\n tl.store(z_block_ptr, z_grad.to(z_nv_ptr.type.element_ty))\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n N_chunk_size: int = 4096,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n \n if ignore_index >= 0:\n y[y == ignore_index] = -100\n \n At_grad = torch.zeros_like(At)\n x_grad = torch.zeros_like(x)\n\n N_group = min(N, N_chunk_size)\n\n loss = 0.0\n sumexp = torch.empty(N_group, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N_group // 16, dtype=torch.float32, device=x.device)\n z_nv_and_grad = torch.empty((N_group, V), device=x.device, dtype=torch.float32)\n\n fwd_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"] * meta[\"V_TILES\"]),\n )\n prologue_grid = lambda meta: (\n triton.cdiv(N_group, meta[\"N_BLOCK_SIZE\"]),\n triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]),\n )\n\n for idx_N_group, x_n_chunk in enumerate(x.split(N_group)):\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_prep_bwd_kernel_matmul_t[fwd_grid](\n x,\n y,\n At,\n z_nv_and_grad,\n losses,\n sumexp,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n H=H,\n )\n loss += losses.sum() + 40 + sumexp.log().sum() / N\n if x.requires_grad or At.requires_grad:\n linear_xent_mini_bwd_prologue_kernel[prologue_grid](\n z_nv_and_grad,\n y,\n sumexp,\n z_nv_and_grad.stride(0),\n z_nv_and_grad.stride(1),\n idx_N_group=idx_N_group,\n N_group=N_group,\n V=V,\n N=N,\n )\n z_grad = z_nv_and_grad.to(x.dtype)\n\n if x.requires_grad:\n x_grad[N_group * idx_N_group : x_n_chunk.shape[0] * (idx_N_group + 1)] = z_grad @ At.T\n\n if At.requires_grad:\n torch.addmm(\n At_grad,\n x_n_chunk.T,\n z_grad,\n out=At_grad,\n )\n\n ctx.mark_non_differentiable(y)\n ctx.save_for_backward(x_grad, At_grad.to(At.dtype))\n return loss\n\n @staticmethod\n def backward(ctx, grad_output):\n x_grad, At_grad = ctx.saved_tensors\n\n return x_grad * grad_output, None, At_grad * grad_output, None, None\n\n\ndef linear_cross_entropy(x, y, At, ignore_index=-100, N_chunk_size=4096):\n return LinearCrossEntropyLoss.apply(x, y, At, ignore_index, N_chunk_size)\n", - "description_1": "Use triton language to implement two kernels for forward and backward computation of a linear cross-entropy loss. The `linear_xent_fwd_prep_bwd_kernel_matmul_t` kernel handles the forward pass and prepares data for the backward pass. It computes a matrix multiplication of inputs and weights, calculates the maximum and sum of exponentials for stable softmax computation, and stores intermediate results. The `linear_xent_mini_bwd_prologue_kernel` kernel computes the gradient for the backward pass using stored logits and labels. These kernels are called in a `LinearCrossEntropyLoss` class which applies these operations in its forward and backward methods.", - "description_2": "Use triton language to create kernels for the forward and backward passes of a linear cross-entropy loss function, utilizing matrix multiplication and softmax operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_cross_entropy_fwd_bwd_kernel(\n output_loss_ptr,\n output_logit_grad_ptr,\n input_logit_ptr,\n input_targ_ptr,\n input_divisor_ptr,\n output_loss_stride,\n output_logit_grad_stride,\n input_logit_stride,\n input_targ_stride,\n n_cols,\n ignore_index: tl.constexpr,\n requires_grad: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n # Get pointers to current row for all inputs/outputs\n row_idx = tl.program_id(0)\n logit_grad_row_start_ptr = output_logit_grad_ptr + row_idx * output_logit_grad_stride\n logit_row_start_ptr = input_logit_ptr + row_idx * input_logit_stride\n targ_ptr = input_targ_ptr + row_idx * input_targ_stride\n loss_ptr = output_loss_ptr + row_idx * output_loss_stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n logit_row_ptrs = logit_row_start_ptr + col_offsets\n logit_grad_row_ptrs = logit_grad_row_start_ptr + col_offsets\n\n # Load data into SRAM\n logit_row_unnormalized = tl.load(logit_row_ptrs, mask=col_offsets < n_cols, other=float(\"-Inf\"))\n targ = tl.load(targ_ptr)\n divisor = tl.load(input_divisor_ptr)\n\n # Normalize logits and compute some useful intermediate values\n logit_row = logit_row_unnormalized - tl.max(\n logit_row_unnormalized, axis=0\n ) # Subtract max value for numerical stability\n exp_logit_row = tl.exp(logit_row)\n sum_exp_logit_row = tl.sum(exp_logit_row, axis=0)\n\n # Compute loss\n log_sum_exp_logit_row = tl.log(sum_exp_logit_row)\n logit_gt_logit = tl.sum(tl.where(targ == col_offsets, logit_row, 0.0))\n loss = log_sum_exp_logit_row - logit_gt_logit\n loss = loss / divisor\n loss = tl.where(targ == ignore_index, 0.0, loss)\n tl.store(loss_ptr, loss)\n\n # Compute gradients\n if requires_grad:\n targ_one_hot = tl.where(targ == col_offsets, 1.0, 0.0)\n grad = exp_logit_row / sum_exp_logit_row - targ_one_hot\n grad = grad / divisor\n grad = tl.where(targ == ignore_index, 0.0, grad)\n tl.store(logit_grad_row_ptrs, grad, mask=col_offsets < n_cols)\n\n\nclass FusedCrossEntropyLossFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n in_feat: torch.Tensor,\n proj_weight: torch.Tensor,\n targ: torch.Tensor,\n n_loop_iters: int,\n ignore_index: int,\n reduction: str,\n ):\n n_tokens = in_feat.shape[0]\n n_classes = proj_weight.shape[0]\n\n NUM_WARPS = 16\n BLOCK_SIZE = triton.next_power_of_2(n_classes)\n\n loss = torch.empty(n_tokens, dtype=in_feat.dtype, device=in_feat.device)\n dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else in_feat.dtype\n\n if proj_weight.requires_grad:\n grad_proj_weight = torch.zeros_like(proj_weight, dtype=dtype)\n else:\n grad_proj_weight = None\n\n if in_feat.requires_grad:\n grad_in_feat = torch.zeros_like(in_feat)\n else:\n grad_in_feat = None\n\n divisor = (\n (targ != ignore_index).sum().to(dtype)\n if reduction == \"mean\"\n else torch.ones(1, dtype=dtype, device=in_feat.device)\n )\n\n proj_weight_cast = proj_weight.to(dtype)\n\n loop_chunk_size = triton.cdiv(n_tokens, n_loop_iters)\n logits_chunk_cast = torch.zeros((loop_chunk_size, n_classes), dtype=dtype, device=in_feat.device)\n for i, in_feat_chunk in enumerate(torch.split(in_feat, loop_chunk_size)):\n token_start_idx = i * loop_chunk_size\n token_end_idx = (i + 1) * loop_chunk_size\n\n in_feat_chunk = in_feat_chunk.to(dtype)\n\n torch.matmul(in_feat_chunk, proj_weight_cast.T, out=logits_chunk_cast)\n logits_chunk = logits_chunk_cast.float()\n\n loss_chunk = loss[token_start_idx:token_end_idx]\n targ_chunk = targ[token_start_idx:token_end_idx]\n\n n_tokens_chunk = logits_chunk.shape[0]\n grad_logits_chunk = logits_chunk\n fused_cross_entropy_fwd_bwd_kernel[(n_tokens_chunk,)](\n loss_chunk,\n grad_logits_chunk,\n logits_chunk,\n targ_chunk,\n divisor,\n loss_chunk.stride(0),\n grad_logits_chunk.stride(0),\n logits_chunk.stride(0),\n targ_chunk.stride(0),\n n_classes,\n ignore_index,\n requires_grad=in_feat.requires_grad or proj_weight.requires_grad,\n num_warps=NUM_WARPS,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n\n grad_logits_chunk = grad_logits_chunk.to(dtype)\n\n if in_feat.requires_grad:\n grad_in_feat[token_start_idx:token_end_idx] = grad_logits_chunk @ proj_weight_cast\n\n if proj_weight.requires_grad:\n torch.addmm(\n grad_proj_weight,\n grad_logits_chunk.T,\n in_feat_chunk,\n out=grad_proj_weight,\n )\n\n loss = loss.sum()\n\n ctx.in_feat_requires_grad = in_feat.requires_grad\n ctx.proj_weight_requires_grad = proj_weight.requires_grad\n\n if proj_weight.requires_grad and in_feat.requires_grad:\n ctx.save_for_backward(grad_in_feat, grad_proj_weight)\n elif proj_weight.requires_grad and not in_feat.requires_grad:\n ctx.save_for_backward(grad_proj_weight)\n elif not proj_weight.requires_grad and in_feat.requires_grad:\n ctx.save_for_backward(grad_in_feat)\n\n return loss\n\n @staticmethod\n def backward(ctx, grad_output):\n if ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad:\n grad_in_feat, grad_proj_weight = ctx.saved_tensors\n elif not ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad:\n (grad_proj_weight,) = ctx.saved_tensors\n elif ctx.in_feat_requires_grad and not ctx.proj_weight_requires_grad:\n (grad_in_feat,) = ctx.saved_tensors\n\n grad_in_feat *= grad_output\n grad_proj_weight *= grad_output\n\n return grad_in_feat, grad_proj_weight, None, None, None, None\n", - "description_1": "Use triton language to implement a fused cross entropy forward and backward kernel for loss and gradient computation given pointers to output and input memory, strides for accessing the data, and other constants for handling batch size and numerical stability. This involves computing normalized logits, calculating cross entropy loss, and optionally computing gradients, efficiently utilizing the GPU.", - "description_2": "Use triton language to implement a kernel that computes fused forward and backward passes for cross entropy loss, managing data pointers and computational requirements directly on the GPU.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"losses_ptr\", \"lse_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n losses_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(axis=0)\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, 0),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + offsets)\n\n m = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32) - float(10e5)\n s = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)\n loss = 0.0\n\n for _ in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n local_x_block_ptr = x_block_ptr\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(local_x_block_ptr)\n A_v = tl.load(A_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n m_new = tl.maximum(m, tl.max(z_j_to_k, 1))\n\n s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)\n s = s * tl.exp(m - m_new) + s_update\n\n mask = y[:, None] == v_range[None, :]\n loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N\n\n m = m_new\n A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])\n v_range = v_range + V_BLOCK_SIZE\n\n lse = m + tl.log(s)\n loss += tl.sum(lse) / N\n tl.store(losses_ptr + idx, loss)\n tl.store(lse_ptr + offsets, lse)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 256, \"H_BLOCK_SIZE\": 128}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"A_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dA(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n A_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_V = tl.program_id(axis=0)\n idx_H_grad = tl.program_id(axis=1)\n\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = tl.arange(0, N_BLOCK_SIZE)\n V_offsets = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n\n A_fwd_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_grad_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H_grad * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_fwd_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(0 * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n x_bwd_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(0 * N_BLOCK_SIZE, idx_H_grad * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n AgradT = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), tl.float16)\n\n for idx_N in range(N // N_BLOCK_SIZE):\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_fwd_block_ptr)\n A_v = tl.load(A_fwd_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k)\n\n x_fwd_block_ptr = tl.advance(x_fwd_block_ptr, [0, H_BLOCK_SIZE])\n A_fwd_block_ptr = tl.advance(A_fwd_block_ptr, [H_BLOCK_SIZE, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n x_chunk_bwd = tl.load(x_bwd_block_ptr)\n AgradT += (tl.dot(x_chunk_bwd.trans(), softmax_z) / N).to(tl.float16)\n AgradT -= (tl.sum(tl.where(mask, x_chunk_bwd[:, None, :], 0.0), axis=0).trans() / N).to(tl.float16)\n\n x_bwd_block_ptr = tl.advance(x_bwd_block_ptr, [N_BLOCK_SIZE, 0])\n x_fwd_block_ptr = tl.advance(x_fwd_block_ptr, [N_BLOCK_SIZE, -H])\n A_fwd_block_ptr = tl.advance(A_fwd_block_ptr, [-H, 0])\n N_offsets += N_BLOCK_SIZE\n\n tl.store(A_grad_block_ptr, AgradT)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 32}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 32, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 32, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 64, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64}),\n triton.Config({\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 256}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 16, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 512}),\n triton.Config({\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 128}),\n triton.Config({\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 16, \"H_BLOCK_SIZE\": 16}),\n ],\n key=[\"V\", \"N\", \"H\"],\n reset_to_zero=[\"x_grad_ptr\"],\n)\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_dx(\n x_ptr,\n y_ptr,\n A_t_ptr,\n lse_global_ptr,\n x_grad_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 16,\n N_BLOCK_SIZE: tl.constexpr = 16,\n H_BLOCK_SIZE: tl.constexpr = 16,\n):\n idx_N = tl.program_id(axis=0)\n idx_H_grad = tl.program_id(axis=1)\n\n tl.static_assert(N % N_BLOCK_SIZE == 0)\n tl.static_assert(V % V_BLOCK_SIZE == 0)\n tl.static_assert(H % H_BLOCK_SIZE == 0)\n\n N_offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_offsets = tl.arange(0, V_BLOCK_SIZE)\n\n y = tl.load(y_ptr + N_offsets)\n lse = tl.load(lse_global_ptr + N_offsets)\n\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, idx_H_grad * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N * N_BLOCK_SIZE, 0 * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_fwd_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0 * H_BLOCK_SIZE, 0 * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n A_bwd_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H_grad * H_BLOCK_SIZE, 0 * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n x_grad = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), tl.float16)\n\n for idx_V in range(V // V_BLOCK_SIZE):\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for idx_H_1 in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr)\n A_v_fwd = tl.load(A_fwd_block_ptr)\n\n z_j_to_k = tl.dot(x_chunk, A_v_fwd, z_j_to_k)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_fwd_block_ptr = tl.advance(A_fwd_block_ptr, [H_BLOCK_SIZE, 0])\n\n mask = (y[:, None] == V_offsets[None, :])[:, :, None]\n softmax_z = (z_j_to_k - lse[:, None]).exp().to(tl.float16)\n\n A_v = tl.load(A_bwd_block_ptr).trans()\n x_grad += (tl.dot(softmax_z, A_v) / N).to(tl.float16)\n x_grad -= (tl.sum(tl.where(mask, A_v[None, :, :], 0.0), axis=1) / N).to(tl.float16)\n\n A_bwd_block_ptr = tl.advance(A_bwd_block_ptr, [0, V_BLOCK_SIZE])\n A_fwd_block_ptr = tl.advance(A_fwd_block_ptr, [-H, V_BLOCK_SIZE])\n x_block_ptr = tl.advance(x_block_ptr, [0, -H])\n V_offsets += V_BLOCK_SIZE\n tl.store(x_grad_block_ptr, x_grad)\n\n\nclass LinearCrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n x,\n y,\n At,\n ignore_index=-100,\n ):\n N, H = x.shape\n H_A, V = At.shape\n assert H_A == H\n assert y.shape == (N,)\n x = x.contiguous()\n y = y.contiguous()\n At = At.contiguous()\n\n assert V % 16 == 0, f\"V is {V}\"\n assert N % 16 == 0, f\"N is {N}\"\n assert H % 16 == 0, f\"H is {H}\"\n\n lse_global = torch.zeros(N, dtype=torch.float32, device=x.device)\n losses = torch.zeros(N // 16, dtype=torch.float32, device=x.device)\n\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]),)\n\n with torch.cuda.device(x.device.index):\n linear_xent_fwd_kernel_matmul_t[grid](\n x, y, At, losses, lse_global, x.stride(0), x.stride(1), At.stride(0), At.stride(1), V=V, N=N, H=H\n )\n\n ctx.save_for_backward(x, y, At, lse_global)\n\n return losses.sum()\n\n @staticmethod\n def backward(ctx, grad_output):\n x, y, At, lse_global = ctx.saved_tensors\n N, H = x.shape\n _, V = At.shape\n\n xgrad = torch.zeros_like(x)\n Atgrad = torch.zeros_like(At)\n\n with torch.cuda.device(x.device.index):\n grid = lambda meta: (triton.cdiv(V, meta[\"V_BLOCK_SIZE\"]), triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]))\n linear_xent_bwd_kernel_matmul_t_dA[grid](\n x,\n y,\n At,\n lse_global,\n Atgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n grid = lambda meta: (triton.cdiv(N, meta[\"N_BLOCK_SIZE\"]), triton.cdiv(H, meta[\"H_BLOCK_SIZE\"]))\n linear_xent_bwd_kernel_matmul_t_dx[grid](\n x,\n y,\n At,\n lse_global,\n xgrad,\n x.stride(0),\n x.stride(1),\n At.stride(0),\n At.stride(1),\n V=V,\n N=N,\n H=H,\n )\n\n ctx.mark_non_differentiable(y)\n return xgrad * grad_output, None, Atgrad * grad_output, None\n\n\ndef linear_cross_entropy(x, y, At):\n return LinearCrossEntropyLoss.apply(x, y, At)\n", - "description_1": "Use triton language to implement a linear cross-entropy loss function with forward and backward passes. The forward kernel 'linear_xent_fwd_kernel_matmul_t' computes the loss and log-sum-exp for given inputs and weights. The backward kernels 'linear_xent_bwd_kernel_matmul_t_dA' and 'linear_xent_bwd_kernel_matmul_t_dx' compute the gradients with respect to the weights and inputs, respectively. The function 'linear_cross_entropy' serves as a wrapper for these operations, using the 'LinearCrossEntropyLoss' class to manage the forward and backward passes.", - "description_2": "Use triton language to create a linear cross-entropy loss function with forward and backward kernels for efficient GPU computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=fwd_configs,\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n reset_to_zero=[\"z_nv_ptr\", \"logit_norm_ptr\", \"lse_ptr\", \"m_ptr\", \"losses_ptr\"],\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n m_ptr,\n logit_norm_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n stride_norm_N,\n stride_norm_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n ignore_index: tl.constexpr,\n logit_scale: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n # Kernel logic here...\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n reduction_ptr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n # Kernel logic here...\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n entropy_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n # Kernel logic here...\n\n@triton.autotune(\n configs=bwd_configs,\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n },\n reset_to_zero=[\"x_grad\", \"At_grad\", \"logit_entropy_local\"],\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr,\n y_ptr,\n x_ptr,\n A_t_ptr,\n x_grad,\n At_grad,\n lse_global,\n logit_entropy_local,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 128,\n N_BLOCK_SIZE: tl.constexpr = 128,\n H_BLOCK_SIZE: tl.constexpr = 128,\n GROUP_SIZE: tl.constexpr = 32,\n SPLIT_N: tl.constexpr = 2,\n SPLIT_V: tl.constexpr = 2,\n):\n # Dispatcher logic here...\n\nclass LinearXentImplementation(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n x_in,\n y,\n At,\n ignore_index=-100,\n z_regularization: float = 0.0,\n logit_scale: float = 1.0,\n N_chunk_size: int = 4096,\n monitoring: bool = True,\n ):\n # Forward logic here...\n\n @staticmethod\n def backward(ctx, grad_output, void0, void1, void2, void3):\n x_grad, At_grad = ctx.saved_tensors\n return x_grad.mul_(grad_output), None, At_grad.mul_(grad_output), None, None, None, None, None\n\ndef linear_cross_entropy(\n x,\n y,\n At,\n ignore_index=-100,\n z_regularization: float = 0.0,\n logit_scale: float = 1.0,\n N_chunk_size: int = 4096,\n monitoring: bool = False,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n return LinearXentImplementation.apply(\n x, y, At, ignore_index, z_regularization, logit_scale, N_chunk_size, monitoring\n )\n\nclass LinearCrossEntropyLoss(torch.nn.Linear):\n def __init__(\n self,\n in_features: int,\n out_features: int,\n device=None,\n dtype=None,\n ignore_index: int = -100,\n logit_scale: float = 1.0,\n z_regularization: float = 0.0,\n N_chunk_size: int = 4096,\n init_method=None,\n ):\n factory_kwargs = {\"device\": device, \"dtype\": dtype}\n torch.nn.Module.__init__(self)\n\n self.in_features = in_features\n self.out_features = out_features\n self.weight = torch.nn.Parameter(torch.empty((in_features, out_features), **factory_kwargs))\n\n self.logit_scale = logit_scale\n self.ignore_index = ignore_index\n self.z_regularization = z_regularization\n self.N_chunk_size = N_chunk_size\n\n self.monitoring = False\n self.latest_metrics = {}\n self.init_method = init_method\n\n self.reset_parameters()\n\n def reset_parameters(self) -> None:\n if self.init_method is not None:\n self.init_method(self.weight)\n else:\n std = math.sqrt(1 / self.in_features)\n torch.nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)\n\n def forward(self, x, y):\n loss, z_reg, logit_max, logit_ent, logit_norm = LinearXentImplementation.apply(\n x,\n y,\n self.weight,\n self.ignore_index,\n self.z_regularization,\n self.logit_scale,\n self.N_chunk_size,\n self.monitoring,\n )\n if self.monitoring:\n metrics = {\n \"logit_norm\": logit_norm,\n \"logit_max\": logit_max,\n \"logit_entropy\": logit_ent,\n \"z_value\": z_reg,\n }\n self.latest_metrics = metrics\n return loss\n", - "description_1": "Use triton language to implement forward and backward passes for a linear cross-entropy loss function, where the kernel computes the loss and gradients efficiently for large input sizes and supports autotuning for optimal performance. The implementation includes three main kernels: a forward kernel that also prepares for the backward pass, a backward kernel for gradient calculation with respect to inputs, and another for weight gradients, along with a dispatcher to manage kernel execution based on input parameters.", - "description_2": "Use triton language to create a linear cross-entropy loss layer with forward and backward kernels optimized for large-scale input, and implement autotuning for performance efficiency.", - "difficulty": 5 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\nfrom paddle import Tensor\n\n@triton.jit\ndef _causal_conv1d_varlen_states(\n X,\n CU_SEQLENS,\n STATES,\n state_len,\n dim,\n stride_x_seqlen,\n stride_x_dim,\n stride_states_batch,\n stride_states_seqlen,\n stride_states_dim,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n batch_idx = tl.program_id(2)\n STATES += batch_idx * stride_states_batch\n end_idx = tl.load(CU_SEQLENS + batch_idx + 1)\n start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)\n rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)\n cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)\n x = tl.load(\n X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,\n mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),\n other=0,\n )\n rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)\n tl.store(\n STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,\n x,\n mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim),\n )\n\ndef causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:\n \"\"\"\n Forward pass only, does not support backward pass.\n\n Parameters:\n x: (total_tokens, dim)\n cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.\n state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.\n If some of those elements belong to a different sequence, the value of the states will be zero.\n Return:\n states: (batch, dim, state_len)\n \"\"\"\n _, dim = x.shape\n batch = cu_seqlens.shape[0] - 1\n cu_seqlens = cu_seqlens.contiguous()\n states = paddle.empty([batch, state_len, dim], dtype=x.dtype).transpose([0, 2, 1])\n BLOCK_M = min(triton.next_power_of_2(state_len), 16)\n BLOCK_N = min(triton.next_power_of_2(dim), 256)\n grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)\n _causal_conv1d_varlen_states[grid](\n x,\n cu_seqlens,\n states,\n state_len,\n dim,\n x.strides[0],\n x.strides[1],\n states.strides[0],\n states.strides[2],\n states.strides[1],\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n )\n return states\n", - "description_1": "Use triton language to implement a causal 1D convolution with variable length states. The kernel function '_causal_conv1d_varlen_states' takes 11 parameters: X (input tensor), CU_SEQLENS (cumulative sequence lengths), STATES (output tensor), state_len (length of the state), dim (dimension of the input), stride_x_seqlen, stride_x_dim, stride_states_batch, stride_states_seqlen, stride_states_dim (stride values for memory access), and two block sizes BLOCK_M and BLOCK_N. The function 'causal_conv1d_varlen_states' prepares the input and output tensors, calculates grid dimensions, and launches the Triton kernel.", - "description_2": "Use triton language to perform a causal 1D convolution operation on input data with variable sequence lengths, storing results in a state tensor. The operation is optimized using Triton's parallel execution capabilities.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\n\n@triton.jit\ndef liger_cross_entropy_kernel(\n X_ptr,\n X_stride,\n Y_ptr,\n Y_stride,\n loss_ptr,\n loss_stride,\n n_cols,\n n_non_ignore,\n ignore_index,\n label_smoothing: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n # Kernel for computing cross entropy loss and input gradients.\n program_id = tl.program_id(0).to(tl.int64)\n\n Y_ptr += program_id * Y_stride\n y = tl.load(Y_ptr)\n\n X_ptr += program_id * X_stride\n\n if y == ignore_index:\n for i in range(0, n_cols, BLOCK_SIZE):\n X_offsets = i + tl.arange(0, BLOCK_SIZE)\n tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)\n return\n\n loss_ptr += program_id * loss_stride\n\n m = float(\"-inf\")\n d = 0.0\n ori_X_y = tl.load(X_ptr + y)\n\n scaled_x_sum = 0.0\n eps = label_smoothing / n_cols\n\n for i in range(0, n_cols, BLOCK_SIZE):\n X_offsets = i + tl.arange(0, BLOCK_SIZE)\n X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float(\"-inf\"))\n block_max = tl.max(X_block)\n if label_smoothing > 0:\n scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))\n m_new = tl.maximum(m, block_max)\n d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))\n m = m_new\n\n for i in range(0, n_cols, BLOCK_SIZE):\n X_offsets = i + tl.arange(0, BLOCK_SIZE)\n X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float(\"-inf\"))\n X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)\n tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)\n\n tl.debug_barrier()\n\n loss = -(ori_X_y - m - tl.log(d))\n\n if label_smoothing > 0:\n smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))\n loss = loss * (1 - label_smoothing) + smooth_loss\n\n X_y = tl.load(X_ptr + y)\n X_y += -(1 - label_smoothing) / (n_non_ignore)\n\n tl.store(loss_ptr, loss)\n tl.store(X_ptr + y, X_y)\n\n@triton.jit\ndef element_mul_kernel(\n X_ptr,\n X_stride,\n grad_output_ptr,\n n_cols,\n BLOCK_SIZE: tl.constexpr,\n):\n # Kernel for element-wise multiplication of a tensor.\n program_id = tl.program_id(0).to(tl.int64)\n\n X_ptr += program_id * X_stride\n\n grad_output = tl.load(grad_output_ptr)\n\n for i in range(0, n_cols, BLOCK_SIZE):\n X_offsets = i + tl.arange(0, BLOCK_SIZE)\n X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)\n tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)\n\ndef cross_entropy_forward(_input, target, ignore_index, label_smoothing):\n BT, V = _input.shape\n n_rows = BT\n\n BLOCK_SIZE = min(65536 // 2, triton.next_power_of_2(V))\n\n loss_1d = paddle.zeros(n_rows, dtype=_input.dtype)\n\n n_non_ignore = (target != ignore_index).sum().item()\n\n if _input.strides[-1] != 1:\n _input = _input.contiguous()\n if target.strides[-1] != 1:\n target = target.contiguous()\n\n liger_cross_entropy_kernel[(n_rows,)](\n X_ptr=_input,\n X_stride=_input.strides[-2],\n Y_ptr=target,\n Y_stride=target.strides[-1],\n loss_ptr=loss_1d,\n loss_stride=loss_1d.strides[-1],\n n_cols=V,\n n_non_ignore=n_non_ignore,\n ignore_index=ignore_index,\n label_smoothing=label_smoothing,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=32,\n )\n\n loss = paddle.sum(loss_1d) / n_non_ignore\n return loss, _input\n\ndef cross_entropy_backward(_input, grad_output):\n if paddle.equal(grad_output, paddle.to_tensor(1.0, dtype=grad_output.dtype)):\n pass\n else:\n BT, V = _input.shape\n n_rows = BT\n BLOCK_SIZE = min(65536 // 2, triton.next_power_of_2(V))\n\n element_mul_kernel[(n_rows,)](\n _input,\n _input.strides[-2],\n grad_output,\n V,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=32,\n )\n\n return _input\n", - "description_1": "Use triton language to create a cross entropy kernel with 10 parameters: pointers to input and target tensors, strides, a pointer for loss storage, column count, non-ignore count, ignore index, label smoothing constant, and block size for operations. Use another kernel for element-wise multiplication with 5 parameters: input tensor pointer, stride, output gradient pointer, column count, and block size.", - "description_2": "Use triton language to compute cross entropy loss and gradients with 10 parameters, and perform element-wise multiplication with 5 parameters.", - "difficulty": 3 - }, - { - "code": "import paddle\nimport triton\nimport triton.language as tl\n\ndef is_hip():\n return triton.runtime.driver.active.get_current_target().backend == \"hip\"\n\n@triton.jit\ndef _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M: tl.constexpr,\n HEAD_DIM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n STAGE: tl.constexpr,\n offs_m: tl.constexpr,\n offs_n: tl.constexpr,\n N_CTX: tl.constexpr,\n fp8_v: tl.constexpr,\n):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n else:\n lo, hi = 0, N_CTX\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n qk = tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_block_ptr)\n if fp8_v:\n p = p.to(tl.float8e5)\n else:\n p = p.to(tl.float16)\n acc = tl.dot(p, v, acc)\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\nconfigs = [\n triton.Config({\"BLOCK_M\": BM, \"BLOCK_N\": BN}, num_stages=s, num_warps=w)\n for BM in [64, 128]\n for BN in [32, 64]\n for s in ([1] if is_hip() else [3, 4, 7])\n for w in [4, 8]\n]\n\ndef keep(conf):\n BLOCK_M = conf.kwargs[\"BLOCK_M\"]\n BLOCK_N = conf.kwargs[\"BLOCK_N\"]\n if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:\n return False\n return True\n\n@triton.autotune(list(filter(keep, configs)), key=[\"N_CTX\", \"HEAD_DIM\"])\n@triton.jit\ndef _attn_fwd(\n Q,\n K,\n V,\n sm_scale,\n M,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n HEAD_DIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n STAGE: tl.constexpr,\n):\n tl.static_assert(BLOCK_N <= HEAD_DIM)\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, HEAD_DIM),\n order=(1, 0),\n )\n v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, HEAD_DIM),\n order=v_order,\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(HEAD_DIM, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(HEAD_DIM, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, HEAD_DIM),\n order=(1, 0),\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)\n qk_scale = sm_scale\n qk_scale *= 1.44269504\n q = tl.load(Q_block_ptr)\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M,\n HEAD_DIM,\n BLOCK_N,\n 4 - STAGE,\n offs_m,\n offs_n,\n N_CTX,\n V.dtype.element_ty == tl.float8e5,\n )\n if STAGE & 2:\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n qk_scale,\n BLOCK_M,\n HEAD_DIM,\n BLOCK_N,\n 2,\n offs_m,\n offs_n,\n N_CTX,\n V.dtype.element_ty == tl.float8e5,\n )\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\nclass fused_attention(paddle.autograd.PyLayer):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, causal, sm_scale: float = 0.5):\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n assert HEAD_DIM_K in {16, 32, 64, 128, 256}\n o = paddle.empty_like(q)\n stage = 3 if causal else 1\n extra_kern_args = {}\n if is_hip():\n waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2\n extra_kern_args = {\"waves_per_eu\": waves_per_eu, \"allow_flush_denorm\": True}\n\n grid = lambda args: (triton.cdiv(q.shape[2], args[\"BLOCK_M\"]), q.shape[0] * q.shape[1], 1)\n M = paddle.empty((q.shape[0], q.shape[1], q.shape[2]), dtype=paddle.float32)\n _attn_fwd[grid](\n q,\n k,\n v,\n sm_scale,\n M,\n o,\n q.strides[0],\n q.strides[1],\n q.strides[2],\n q.strides[3],\n k.strides[0],\n k.strides[1],\n k.strides[2],\n k.strides[3],\n v.strides[0],\n v.strides[1],\n v.strides[2],\n v.strides[3],\n o.strides[0],\n o.strides[1],\n o.strides[2],\n o.strides[3],\n q.shape[0],\n q.shape[1],\n N_CTX=q.shape[2],\n HEAD_DIM=HEAD_DIM_K,\n STAGE=stage,\n **extra_kern_args,\n )\n\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.HEAD_DIM = HEAD_DIM_K\n ctx.causal = causal\n return o\n\nattention = fused_attention.apply\n", - "description_1": "Use triton language to implement a fused attention mechanism with forward and backward passes. The forward pass computes the attention output using input tensors Q, K, V, and a scaling factor. The backward pass computes gradients for Q, K, and V. The implementation includes kernel functions for the forward pass (_attn_fwd and _attn_fwd_inner) and a PyLayer class (fused_attention) to integrate with PaddlePaddle's autograd system.", - "description_2": "Use triton language to implement a fused attention mechanism with forward and backward passes, integrating with PaddlePaddle's autograd system.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\nfrom ..utils import calculate_settings\n\n@triton.jit\ndef _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n program_id = tl.program_id(0)\n\n # locate start index\n a += program_id * stride\n b += program_id * stride\n c += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b + col_offsets, mask=mask, other=0)\n\n # tanh approximation form of GELU is computed with:\n # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))\n sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)\n a_cubed = a_row * a_row * a_row\n tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)\n tanh_result = tl.tanh(tanh_arg)\n geglu_a = 0.5 * a_row * (1 + tanh_result)\n c_row = geglu_a * b_row\n tl.store(c + col_offsets, c_row, mask=mask)\n\n\n@triton.jit\ndef _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n program_id = tl.program_id(0)\n\n # locate start index\n dc += program_id * stride\n a += program_id * stride\n b += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dc_row = tl.load(dc + col_offsets, mask=mask, other=0)\n a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b + col_offsets, mask=mask, other=0)\n\n # recomputation to save memory\n sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)\n a_cubed = a_row * a_row * a_row\n tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)\n tanh_result = tl.tanh(tanh_arg)\n geglu_a = 0.5 * a_row * (1 + tanh_result)\n\n db_row = dc_row * geglu_a\n\n # Gradient w.r.t. a can be computed with:\n # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))\n # where z = sqrt(2/pi) * (a + 0.044715 * a^3)\n term1 = 0.5 * (1 + tanh_result)\n tanh_sq = tanh_result * tanh_result\n term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))\n da_row = dc_row * b_row * (term1 + term2)\n\n tl.store(a + col_offsets, da_row, mask=mask)\n tl.store(b + col_offsets, db_row, mask=mask)\n\n\ndef geglu_forward(a, b):\n ori_shape = a.shape\n\n n_cols = ori_shape[-1]\n a = a.reshape([-1, n_cols])\n b = b.reshape([-1, n_cols])\n c = paddle.empty_like(a)\n n_rows = a.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _geglu_tanh_forward_kernel[(n_rows,)](\n a,\n b,\n c,\n c.strides[-2],\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return a, b, c.reshape(ori_shape)\n\n\ndef geglu_backward(a, b, dc):\n ori_shape = dc.shape\n n_cols = ori_shape[-1]\n dc = dc.reshape([-1, n_cols])\n n_rows = dc.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _geglu_tanh_backward_kernel[(n_rows,)](\n dc,\n a,\n b,\n dc.strides[-2],\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n\n return a.reshape(ori_shape), b.reshape(ori_shape)\n", - "description_1": "Use triton language to implement forward and backward kernels for a GEGLU operation. The forward kernel computes the GEGLU activation using a tanh approximation, and the backward kernel computes gradients for input tensors a and b. Both kernels involve memory access through triton's load and store operations with masks for handling varying column lengths. The forward function reshapes inputs, calculates necessary configuration settings, and launches the forward kernel. The backward function handles the reshaping of gradient tensor dc, recalculates configuration settings, and invokes the backward kernel.", - "description_2": "Use triton language to implement GEGLU activation and its backward operation using tanh approximation, involving memory handling with triton load/store operations and configuration settings for kernel execution.", - "difficulty": 2 - }, - { - "code": "import paddle\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_N\": 32}),\n triton.Config({\"BLOCK_N\": 64}),\n triton.Config({\"BLOCK_N\": 128}),\n triton.Config({\"BLOCK_N\": 256}),\n triton.Config({\"BLOCK_N\": 512}),\n triton.Config({\"BLOCK_N\": 1024}),\n ],\n key=[\"ncols\"],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.0).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.0).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_fwd(xy, out=None):\n if xy.strides[-1] != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape([-1, xy.shape[-1]])\n x, y = xy.chunk(2, axis=-1)\n if out is None:\n out = paddle.empty_like(x)\n else:\n out = out.reshape([-1, out.shape[-1]])\n assert out.shape == x.shape\n assert out.strides[-1] == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META[\"BLOCK_N\"]))\n _swiglu_fwd_kernel[grid](x, y, out, x.strides[0], y.strides[0], out.strides[0], N)\n return out.reshape([*batch_shape, out.shape[-1]])\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_N\": 32}),\n triton.Config({\"BLOCK_N\": 64}),\n triton.Config({\"BLOCK_N\": 128}),\n triton.Config({\"BLOCK_N\": 256}),\n triton.Config({\"BLOCK_N\": 512}),\n triton.Config({\"BLOCK_N\": 1024}),\n ],\n key=[\"ncols\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.0).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.0).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.0).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.strides[-1] != 1:\n xy = xy.contiguous()\n if dout.strides[-1] != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape([-1, xy.shape[-1]])\n x, y = xy.chunk(2, axis=-1)\n dout = dout.reshape([-1, dout.shape[-1]])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = paddle.empty_like(xy)\n else:\n dxy = dxy.reshape([-1, dxy.shape[-1]])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, axis=-1)\n assert dx.strides[-1] == 1\n assert dy.strides[-1] == 1\n if recompute_output:\n if out is None:\n out = paddle.empty_like(x)\n else:\n out = out.reshape([-1, out.shape[-1]])\n assert out.shape == x.shape\n assert out.strides[-1] == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META[\"BLOCK_N\"]))\n _swiglu_bwd_kernel[grid](\n x,\n y,\n dout,\n out if recompute_output else None,\n dx,\n dy,\n x.strides[0],\n y.strides[0],\n dout.strides[0],\n out.strides[0] if recompute_output else 0,\n dx.strides[0],\n dy.strides[0],\n N,\n )\n if not recompute_output:\n return dxy.reshape([*batch_shape, dxy.shape[-1]])\n else:\n return dxy.reshape([*batch_shape, dxy.shape[-1]]), out.reshape([*batch_shape, out.shape[-1]])\n", - "description_1": "Use triton language to implement a forward and backward kernel for the SwiGLU activation function. The forward kernel (_swiglu_fwd_kernel) takes 7 parameters: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, and ncols. It computes the element-wise product of X and Y after applying the sigmoid function to X, storing the result in OUT. The backward kernel (_swiglu_bwd_kernel) takes 14 parameters: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, and BLOCK_N. It computes the gradients of X and Y with respect to the output gradient DOUT, storing them in DX and DY, and optionally recomputes the output if RECOMPUTE_OUTPUT is true.", - "description_2": "Use triton language to create a forward kernel for computing the SwiGLU activation and a backward kernel for computing its gradients. The forward kernel should compute the element-wise product of inputs after applying a sigmoid function, and the backward kernel should compute the gradients with respect to the inputs.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\n\nconfigs_autotune = [\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n]\n\ndef config_prune(configs):\n warp_size = 32\n max_block_sz = 1024\n max_num_warps = max_block_sz // warp_size\n pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]\n return pruned_configs\n\npruned_configs_autotune = config_prune(configs_autotune)\n\n@triton.autotune(\n configs=pruned_configs_autotune,\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n@triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, X1, W1, B1, Y1, RESIDUAL_OUT, ROWSCALE, SEEDS, DROPOUT_MASK, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row,\n M, N, eps, dropout_p, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr,\n STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr,\n HAS_W1: tl.constexpr, HAS_B1: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n if HAS_X1:\n X1 += row * stride_x1_row\n if HAS_W1:\n Y1 += row * stride_y1_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n x *= rowscale\n if HAS_DROPOUT:\n keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n if HAS_X1:\n x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n x1 *= rowscale\n if HAS_DROPOUT:\n keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)\n x += x1\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n if HAS_W1:\n w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n if HAS_B1:\n b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n tl.store(Y1 + cols, y1, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None, dropout_p=0.0,\n rowscale=None, out_dtype=None, residual_dtype=None, is_rms_norm=False, return_dropout_mask=False,\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.strides[-1] == 1\n if residual is not None:\n assert residual.strides[-1] == 1\n assert tuple(residual.shape) == (M, N)\n assert weight.shape[0] == N\n assert weight.strides[-1] == 1\n if bias is not None:\n assert bias.strides[-1] == 1\n assert bias.shape[0] == N\n if x1 is not None:\n assert x1.shape == x.shape\n assert rowscale is None\n assert x1.strides[-1] == 1\n if weight1 is not None:\n assert weight1.shape[0] == N\n assert weight1.strides[-1] == 1\n if bias1 is not None:\n assert bias1.shape[0] == N\n assert bias1.strides[-1] == 1\n if rowscale is not None:\n assert rowscale.is_contiguous()\n assert rowscale.shape[0] == M\n y = paddle.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.strides[-1] == 1\n if weight1 is not None:\n y1 = paddle.empty_like(y)\n assert y1.strides[-1] == 1\n else:\n y1 = None\n if (\n residual is not None\n or (residual_dtype is not None and residual_dtype != x.dtype)\n or dropout_p > 0.0\n or rowscale is not None\n or x1 is not None\n ):\n residual_out = paddle.empty(M, N, dtype=residual_dtype if residual_dtype is not None else x.dtype)\n assert residual_out.strides[-1] == 1\n else:\n residual_out = None\n mean = paddle.empty((M,), dtype=paddle.float32) if not is_rms_norm else None\n rstd = paddle.empty((M,), dtype=paddle.float32)\n if dropout_p > 0.0:\n seeds = paddle.randint(2**32, (M if x1 is None else 2 * M,), dtype=paddle.int64)\n else:\n seeds = None\n if return_dropout_mask and dropout_p > 0.0:\n dropout_mask = paddle.empty(M if x1 is None else 2 * M, N, dtype=paddle.bool)\n else:\n dropout_mask = None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds, dropout_mask,\n mean, rstd, x.strides[0], y.strides[0], residual.strides[0] if residual is not None else 0,\n residual_out.strides[0] if residual_out is not None else 0, x1.strides[0] if x1 is not None else 0,\n y1.strides[0] if y1 is not None else 0, M, N, eps, dropout_p, is_rms_norm, BLOCK_N,\n residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0,\n dropout_mask is not None, rowscale is not None,\n )\n if dropout_mask is not None and x1 is not None:\n dropout_mask, dropout_mask1 = dropout_mask.chunk(2, axis=0)\n else:\n dropout_mask1 = None\n return (\n y, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1,\n )\n", - "description_1": "Use triton language to implement a fused layer normalization kernel with support for dropout, residual connections, and optional secondary inputs. The kernel computes the mean and variance for normalization, applies dropout if specified, and performs the normalization and linear transformation using weights and biases. The forward function manages input reshaping, output allocation, and kernel invocation.", - "description_2": "Use triton language to create a layer normalization kernel that supports dropout and residuals, and a forward function to handle input/output management and kernel execution.", - "difficulty": 3 - }, - { - "code": "import math\nimport paddle\nimport paddle.device\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_forward_kernel(\n Y_ptr, # pointer to output, shape (n_rows, n_cols)\n Y_row_stride, # stride of each row in output\n X_ptr, # pointer to input, shape (n_rows, n_cols)\n X_row_stride, # stride of each row in input\n W_ptr, # pointer to weights, shape (n_cols,)\n W_row_stride, # stride of each row in weights\n B_ptr, # pointer to bias, shape (n_cols,)\n B_row_stride, # stride of each row in bias\n Mean_ptr, # pointer to mean, shape (n_rows,)\n Mean_row_stride, # stride of each row in mean\n RSTD_ptr, # pointer to rstd, shape (n_rows,)\n RSTD_row_stride, # stride of each row in rstd\n n_cols,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n \"\"\"\n References: https://arxiv.org/abs/1607.06450\n https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y_ptr += row_idx * Y_row_stride\n X_ptr += row_idx * X_row_stride\n Mean_ptr += row_idx * Mean_row_stride\n RSTD_ptr += row_idx * RSTD_row_stride\n\n X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)\n W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)\n B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)\n\n mean = tl.sum(X_row, axis=0) / n_cols\n var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols\n rstd = rsqrt(var + eps)\n\n tl.store(Mean_ptr, mean)\n tl.store(RSTD_ptr, rstd)\n\n Y_row = (X_row - mean) * rstd * W_row + B_row\n\n tl.store(Y_ptr + col_offsets, Y_row, mask=mask)\n\n\n@triton.jit\ndef _layer_norm_backward_kernel(\n X_ptr, # pointer to input, shape (n_rows, n_cols)\n W_ptr, # pointer to weights, shape (n_cols,)\n Mean_ptr, # pointer to mean, shape (n_rows,)\n RSTD_ptr, # pointer to rstd, shape (n_rows,)\n DX_ptr, # pointer to input grad, shape (n_rows, n_cols)\n DW_ptr, # pointer to weights grad, shape (n_cols,)\n DB_ptr, # pointer to bias grad, shape (n_cols,)\n DY_ptr, # pointer to output grad, shape (n_rows, n_cols)\n stride_x, # stride of each row in input\n stride_dx, # stride of each row in input grad\n stride_dw, # stride of each row in weights grad\n stride_db, # stride of each row in bias grad\n stride_dy, # stride of each row in output grad\n n_rows,\n n_cols,\n rows_per_program: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n dtype: tl.constexpr,\n):\n \"\"\"\n References: https://arxiv.org/abs/1607.06450\n https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md\n https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py\n \"\"\"\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n row_end = min((row_block_id + 1) * rows_per_program, n_rows)\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < n_cols\n\n dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n\n X_ptr += row_start * stride_x\n Mean_ptr += row_start\n RSTD_ptr += row_start\n DX_ptr += row_start * stride_dx\n DY_ptr += row_start * stride_dy\n\n for _ in range(row_start, row_end):\n x = tl.load(X_ptr + cols, mask=mask, other=0.0)\n w = tl.load(W_ptr + cols, mask=mask, other=0.0)\n dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)\n mean = tl.load(Mean_ptr)\n rstd = tl.load(RSTD_ptr)\n\n x_hat = (x - mean) * rstd\n wdy = w * dy\n c1 = tl.sum(x_hat * wdy, axis=0) / n_cols\n c2 = tl.sum(wdy, axis=0) / n_cols\n dx = (wdy - (x_hat * c1 + c2)) * rstd\n tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)\n\n dw_row += dy * x_hat\n db_row += dy\n\n X_ptr += stride_x\n Mean_ptr += 1\n RSTD_ptr += 1\n DX_ptr += stride_dx\n DY_ptr += stride_dy\n\n tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)\n tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)\n\n\ndef layer_norm_forward(X, W, B, eps):\n shape = X.shape\n dim = shape[-1]\n X = X.reshape([-1, dim])\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n Y = paddle.empty((n_rows, n_cols), dtype=X.dtype)\n Mean = paddle.empty((n_rows,), dtype=X.dtype)\n RSTD = paddle.empty((n_rows,), dtype=X.dtype)\n assert (\n X.shape[1] == W.shape[0]\n ), f\"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}\"\n\n _layer_norm_forward_kernel[(n_rows,)](\n Y,\n Y.strides[0],\n X,\n X.strides[0],\n W,\n W.strides[0],\n B,\n B.strides[0],\n Mean,\n Mean.strides[0],\n RSTD,\n RSTD.strides[0],\n n_cols,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return Y.reshape(shape), X, Mean, RSTD, BLOCK_SIZE, num_warps\n\n\ndef layer_norm_backward(dY, X, W, B, Mean, RSTD):\n shape = dY.shape\n dim = shape[-1]\n dY = dY.reshape([-1, dim])\n n_rows, n_cols = dY.shape\n\n DX = paddle.empty((n_rows, n_cols), dtype=X.dtype)\n sm_count = paddle.device.cuda.get_device_properties().multi_processor_count\n _DW = paddle.empty((sm_count, n_cols), dtype=W.dtype)\n _DB = paddle.empty((sm_count, n_cols), dtype=W.dtype)\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n if n_cols > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n\n rows_per_program = math.ceil(n_rows / sm_count)\n grid = (sm_count,)\n triton_dtype = tl.float32 if X.dtype == paddle.float32 else tl.bfloat16\n _layer_norm_backward_kernel[grid](\n X,\n W,\n Mean,\n RSTD,\n DX,\n _DW,\n _DB,\n dY,\n X.strides[0],\n DX.strides[0],\n _DW.strides[0],\n _DB.strides[0],\n dY.strides[0],\n n_rows,\n n_cols,\n rows_per_program,\n BLOCK_SIZE=BLOCK_SIZE,\n dtype=triton_dtype,\n )\n\n DW = _DW.sum(axis=0).cast(W.dtype)\n DB = _DB.sum(axis=0).cast(W.dtype)\n\n DX = DX.reshape(shape)\n return DX, DW, DB\n", - "description_1": "Use triton language to implement layer normalization forward and backward kernels. The forward kernel has 13 parameters: pointers to input, output, weights, bias, mean, and rstd, strides for each, number of columns, epsilon, and block size. The backward kernel has 18 parameters: pointers to input, weights, mean, rstd, input grad, weights grad, bias grad, output grad, strides for each, number of rows and columns, rows per program, block size, and data type. There are also two functions to call these kernels: layer_norm_forward with 4 parameters (X, W, B, eps) and layer_norm_backward with 6 parameters (dY, X, W, B, Mean, RSTD).", - "description_2": "Use triton language to create kernels for layer normalization. Implement forward kernel with parameters for pointers, strides, columns, epsilon, and block size. Implement backward kernel with parameters for pointers, strides, rows, columns, program rows, block size, and type.", - "difficulty": 3 - }, - { - "code": "import paddle\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.strides[-1] == 1\n if z is not None:\n assert z.strides[-1] == 1\n assert tuple(z.shape) == (M, N)\n assert weight.shape[0] == N\n assert weight.strides[-1] == 1\n if bias is not None:\n assert bias.strides[-1] == 1\n assert bias.shape[0] == N\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = paddle.empty_like(x)\n assert out.strides[-1] == 1\n mean = paddle.empty((ngroups * M,), dtype=paddle.float32) if not is_rms_norm else None\n rstd = paddle.empty((ngroups * M,), dtype=paddle.float32)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n _layer_norm_fwd_1pass_kernel[grid](\n x,\n out,\n weight,\n bias,\n z,\n mean,\n rstd,\n x.strides[0],\n out.strides[0],\n z.strides[0] if z is not None else 0,\n M,\n group_size,\n eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps,\n )\n return out, mean, rstd\n\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DZ, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_z_row,\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dz_row,\n stride_dw_row,\n stride_db_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row_block_id = tl.program_id(0)\n group = tl.program_id(1)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row + group * N\n if HAS_Z:\n Z += row_start * stride_z_row + group * N\n DZ += row_start * stride_dz_row + group * N\n DY += row_start * stride_dy_row + group * N\n DX += row_start * stride_dx_row + group * N\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:\n B += group * N\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32)\n x_og = x\n x = x_og * z * tl.sigmoid(z)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32)\n z_sigmoid = tl.sigmoid(z)\n y = xhat * w + b if HAS_BIAS else xhat * w\n if RECOMPUTE_OUTPUT:\n tl.store(Y + cols, y * z * z_sigmoid, mask=mask)\n dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))\n tl.store(DZ + cols, dz, mask=mask)\n dy *= z * z_sigmoid\n else:\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n c1 = tl.sum(xhat * wdy, axis=0) / N\n if not IS_RMS_NORM:\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n dx = (wdy - xhat * c1) * rstd\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if HAS_Z and not NORM_BEFORE_GATE:\n z_sigmoid = tl.sigmoid(z)\n dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))\n tl.store(DZ + cols, dz, mask=mask)\n dx *= z * z_sigmoid\n # Write dx\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_Z:\n Z += stride_z_row\n DZ += stride_dz_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n z=None,\n group_size=None,\n norm_before_gate=True,\n is_rms_norm=False,\n recompute_output=False,\n dz=None,\n out=None,\n):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.strides[-1] == 1\n assert dy.strides[-1] == 1\n assert tuple(dy.shape) == (M, N)\n if z is not None:\n assert z.strides[-1] == 1\n assert tuple(z.shape) == (M, N)\n assert weight.shape[0] == N\n assert weight.strides[-1] == 1\n if bias is not None:\n assert bias.strides[-1] == 1\n assert bias.shape[0] == N\n # allocate output\n dx = paddle.empty_like(x)\n if dz is not None:\n assert z is not None\n assert dz.shape == z.shape\n assert dz.strides[-1] == 1\n else:\n dz = paddle.empty_like(z) if z is not None else None\n if recompute_output:\n if out is None:\n out = paddle.empty_like(x)\n assert out.shape == x.shape\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n sm_count = paddle.device.cuda.get_device_properties(paddle.get_device()).multi_processor_count\n # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs\n # would limit the occupancy.\n nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)\n _dw = paddle.empty((nrow_groups, N), dtype=paddle.float32)\n _db = paddle.empty((nrow_groups, N), dtype=paddle.float32) if bias is not None else None\n rows_per_program = math.ceil(M / nrow_groups)\n grid = (nrow_groups, ngroups)\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n z,\n out if recompute_output else None,\n dy,\n dx,\n _dw,\n _db,\n dz,\n mean,\n rstd,\n x.strides[0],\n z.strides[0] if z is not None else 0,\n 0 if not recompute_output else out.strides[0],\n dy.strides[0],\n dx.strides[0],\n dz.strides[0] if dz is not None else 0,\n _dw.strides[0],\n _db.strides[0] if _db is not None else 0,\n M,\n group_size,\n eps,\n rows_per_program,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps,\n )\n dw = _dw.sum(0).cast(weight.dtype)\n db = _db.sum(0).cast(bias.dtype) if bias is not None else None\n return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)\n", - "description_1": "Use triton language to implement a layer normalization forward and backward pass. The forward kernel (_layer_norm_fwd_1pass_kernel) takes 17 parameters: pointers to input, output, weights, biases, other branch, mean, and 1/std, strides for input, output, and other branch, number of rows and columns in input, epsilon for numerical stability, and several compile-time constants. The backward kernel (_layer_norm_bwd_kernel) takes 28 parameters: pointers to input, weights, biases, other branch, output, output gradient, input gradient, partial sums of weights and biases gradients, other branch gradient, mean, 1/std, strides for various tensors, number of rows and columns in input, epsilon, rows per program, and several compile-time constants. The forward function (_layer_norm_fwd) prepares data and calls the forward kernel, while the backward function (_layer_norm_bwd) prepares data and calls the backward kernel.", - "description_2": "Use triton language to create a layer normalization operation with both forward and backward passes, handling optional bias and additional branch inputs, and supporting RMS normalization.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=4, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 16}, num_warps=4, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=3\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 512, \"N_BLOCK_SIZE\": 64, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 32}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=8, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 64, \"GROUP_SIZE\": 64}, num_warps=16, num_stages=2\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 256, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 1}, num_warps=8, num_stages=2\n ),\n ],\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n prune_configs_by={\n \"early_config_prune\": lambda configs, named_args: configs,\n },\n)\n@triton.jit\ndef linear_xent_fwd_prep_bwd_kernel_matmul_t(\n x_ptr,\n y_ptr,\n A_t_ptr,\n z_nv_ptr,\n losses_ptr,\n lse_ptr,\n m_ptr,\n logit_norm_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_lse_N,\n stride_lse_B,\n stride_loss_Nb,\n stride_loss_B,\n stride_norm_N,\n stride_norm_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n ignore_index: tl.constexpr,\n logit_scale: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0)\n idx_V_group = tl.program_id(axis=1)\n num_idx_N, num_idx_V_group = tl.num_programs(0), tl.num_programs(1)\n idx_N, idx_V_group = tl.swizzle2d(idx_N, idx_V_group, num_idx_N, num_idx_V_group, GROUP_SIZE) # type:ignore\n tl.static_print(N_group, V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE, GROUP_SIZE, monitoring)\n R = tl.load(reduction_ptr, eviction_policy=\"evict_last\")\n V_GROUP_SIZE: tl.constexpr = V_BLOCK_SIZE\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, 0),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n A_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(0, idx_V_group * V_GROUP_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group * V_GROUP_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)\n for _ in range(H // H_BLOCK_SIZE):\n x_chunk = tl.load(x_block_ptr) # Nc x H\n A_v = tl.load(A_block_ptr) # Vc x H\n\n z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)\n\n x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])\n A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])\n\n z_j_to_k = z_j_to_k * logit_scale\n if monitoring:\n logit_pow2 = tl.sum(z_j_to_k * z_j_to_k, axis=1)\n norm_val_ptr = (\n logit_norm_ptr + idx_V_group * stride_norm_V + idx_N * stride_norm_N + tl.arange(0, N_BLOCK_SIZE)\n )\n tl.store(norm_val_ptr, logit_pow2 / N)\n m = tl.max(z_j_to_k, 1)\n s = tl.sum(tl.exp((z_j_to_k - m[:, None])), axis=1)\n\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V_group * V_GROUP_SIZE + tl.arange(0, V_BLOCK_SIZE)\n y = tl.load(y_ptr + N_range)\n\n mask = y[:, None] == tl.where(V_range != ignore_index, V_range, -1)[None, :]\n loss = -tl.sum(tl.where(mask, z_j_to_k, 0.0)) / R\n\n tl.store(z_block_ptr, z_j_to_k.to(z_nv_ptr.type.element_ty))\n\n zero_lse_constant: tl.constexpr = tl.log(1 / tl.cdiv(V, V_BLOCK_SIZE))\n lse = tl.where(y != ignore_index, m + tl.log(s), zero_lse_constant)\n lse_row_ptr = tl.make_block_ptr(\n base=lse_ptr,\n shape=(N_group, V // 128),\n strides=(stride_lse_N, stride_lse_B),\n offsets=(idx_N * N_BLOCK_SIZE, idx_V_group),\n block_shape=(N_BLOCK_SIZE, 1),\n order=(1, 0),\n )\n tl.store(lse_row_ptr, lse[:, None])\n\n loss_val_ptr = losses_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(loss_val_ptr, tl.load(loss_val_ptr) + loss)\n\n if monitoring:\n m_val_ptr = m_ptr + idx_N * stride_loss_Nb + idx_V_group * stride_loss_B\n tl.store(m_val_ptr, tl.maximum(tl.load(m_val_ptr), tl.max(m, 0)))\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n z_nv_ptr,\n y_ptr,\n A_t_ptr,\n x_grad_ptr,\n lse_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n reduction_ptr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_N = tl.program_id(axis=0) // SPLIT_V\n idx_H = tl.program_id(axis=1)\n idx_V_tile = tl.program_id(axis=0) % SPLIT_V\n\n num_idx_N = tl.num_programs(0) - (triton.cdiv(V, V_BLOCK_SIZE) * SPLIT_N)\n num_idx_H = tl.num_programs(1)\n idx_N, idx_H = tl.swizzle2d(idx_N, idx_H, num_idx_N // SPLIT_V, num_idx_H, GROUP_SIZE)\n\n V_split_offset = idx_V_tile * tl.cdiv(V, SPLIT_V)\n\n A_t_block_ptr = tl.make_block_ptr(\n base=A_t_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, V_split_offset),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(idx_N * N_BLOCK_SIZE, V_split_offset),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n N_range = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n v_range = V_split_offset + tl.arange(0, V_BLOCK_SIZE)\n R = tl.load(reduction_ptr, eviction_policy=\"evict_last\")\n\n y = tl.load(y_ptr + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE), eviction_policy=\"evict_last\")\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else x_grad_ptr.type.element_ty\n x_grad_acc = tl.zeros((N_BLOCK_SIZE, H_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(V, V_BLOCK_SIZE * SPLIT_V)):\n mask = y[:, None] == v_range[None, :]\n A_v = tl.load(A_t_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr)\n softmax_z = (z_j_to_k - lse[:, None]).exp()\n\n if z_regularization > 0:\n softmax_z += 2.0 * z_regularization * lse[:, None] * softmax_z\n z_grad = softmax_z - tl.where(mask, 1.0, 0.0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(A_v.type.element_ty)\n\n x_grad_acc = tl.dot(valid_z_grad, A_v.trans(), x_grad_acc, out_dtype=acc_dtype)\n\n A_t_block_ptr = tl.advance(A_t_block_ptr, [0, V_BLOCK_SIZE])\n z_block_ptr = tl.advance(z_block_ptr, [0, V_BLOCK_SIZE])\n v_range += V_BLOCK_SIZE\n\n if SPLIT_V == 1:\n x_grad_block_ptr = tl.make_block_ptr(\n base=x_grad_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + idx_N * N_BLOCK_SIZE, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n tl.store(x_grad_block_ptr, (x_grad_acc / R * logit_scale).to(x_grad_ptr.type.element_ty))\n else:\n row_n = idx_N_group * N_group + idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n x_grad_simple_ptr = x_grad_ptr + row_n[:, None] * stride_x_N + row_h[None, :] * stride_x_H\n tl.atomic_add(x_grad_simple_ptr, (x_grad_acc / R * logit_scale).to(x_grad_ptr.type.element_ty))\n\n@triton.jit()\ndef linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n z_nv_ptr,\n y_ptr,\n x_ptr,\n A_grad_ptr,\n lse_ptr,\n entropy_ptr,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group: tl.constexpr,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr,\n N_BLOCK_SIZE: tl.constexpr,\n H_BLOCK_SIZE: tl.constexpr,\n GROUP_SIZE: tl.constexpr,\n SPLIT_N: tl.constexpr,\n SPLIT_V: tl.constexpr,\n):\n idx_V = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) // SPLIT_N\n idx_H = tl.program_id(axis=1)\n idx_N_tile = (tl.program_id(axis=0) - N_group // N_BLOCK_SIZE * SPLIT_V) % SPLIT_N\n\n num_idx_V, num_idx_H = tl.num_programs(0) - (N_group // N_BLOCK_SIZE * SPLIT_V), tl.num_programs(1)\n idx_V, idx_H = tl.swizzle2d(idx_V, idx_H, num_idx_V // SPLIT_N, num_idx_H, GROUP_SIZE)\n\n N_split_offset = idx_N_tile * tl.cdiv(N_group, SPLIT_N)\n\n x_block_ptr = tl.make_block_ptr(\n base=x_ptr,\n shape=(N, H),\n strides=(stride_x_N, stride_x_H),\n offsets=(idx_N_group * N_group + N_split_offset, idx_H * H_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),\n order=(1, 0),\n )\n\n z_block_ptr = tl.make_block_ptr(\n base=z_nv_ptr,\n shape=(N_group, V),\n strides=(stride_z_N, stride_z_V),\n offsets=(N_split_offset, idx_V * V_BLOCK_SIZE),\n block_shape=(N_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(1, 0),\n )\n\n N_range = N_split_offset + tl.arange(0, N_BLOCK_SIZE)\n V_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n R = tl.load(reduction_ptr, eviction_policy=\"evict_last\")\n logit_entropy = 0.0\n\n acc_dtype = tl.float32 if fp32_grad_accumulators else A_grad_ptr.type.element_ty\n A_grad_acc = tl.zeros((H_BLOCK_SIZE, V_BLOCK_SIZE), acc_dtype)\n for _ in range(0, tl.cdiv(N_group, N_BLOCK_SIZE * SPLIT_N)):\n y = tl.load(y_ptr + idx_N_group * N_group + N_range, eviction_policy=\"evict_last\")\n lse = tl.load(lse_ptr + N_range, eviction_policy=\"evict_last\")\n mask = y[:, None] == V_range[None, :]\n\n x_chunk = tl.load(x_block_ptr, eviction_policy=\"evict_first\")\n z_j_to_k = tl.load(z_block_ptr)\n logprobs = z_j_to_k - lse[:, None]\n softmax_z = logprobs.exp()\n if monitoring:\n logit_entropy += tl.sum(tl.where(y == ignore_index, 0.0, tl.sum(-softmax_z * logprobs, axis=1)))\n if z_regularization > 0:\n softmax_z += 2.0 * z_regularization * lse[:, None] * softmax_z\n z_grad = softmax_z - tl.where(mask, 1.0, 0.0)\n valid_z_grad = tl.where((y == ignore_index)[:, None], 0.0, z_grad).to(x_ptr.type.element_ty)\n\n A_grad_acc = tl.dot(x_chunk.trans(), valid_z_grad, A_grad_acc, out_dtype=acc_dtype)\n\n x_block_ptr = tl.advance(x_block_ptr, [N_BLOCK_SIZE, 0])\n z_block_ptr = tl.advance(z_block_ptr, [N_BLOCK_SIZE, 0])\n N_range += N_BLOCK_SIZE\n\n entropy_val_ptr = entropy_ptr + idx_H * stride_ent_H + idx_V * stride_ent_V\n if SPLIT_N == 1:\n A_grad_T_block_ptr = tl.make_block_ptr(\n base=A_grad_ptr,\n shape=(H, V),\n strides=(stride_A_H, stride_A_V),\n offsets=(idx_H * H_BLOCK_SIZE, idx_V * V_BLOCK_SIZE),\n block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),\n order=(0, 1),\n )\n if idx_N_group > 0:\n tl.store(\n A_grad_T_block_ptr,\n tl.load(A_grad_T_block_ptr) + (A_grad_acc / R * logit_scale).to(A_grad_ptr.type.element_ty),\n )\n tl.store(entropy_val_ptr, tl.load(entropy_val_ptr) + logit_entropy / R)\n else:\n tl.store(A_grad_T_block_ptr, (A_grad_acc / R * logit_scale).to(A_grad_ptr.type.element_ty))\n if monitoring:\n tl.store(entropy_val_ptr, logit_entropy / R)\n else:\n row_h = idx_H * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)\n row_v = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)\n A_grad_T_simple_ptr = A_grad_ptr + row_h[:, None] * stride_A_H + row_v[None, :] * stride_A_V\n tl.atomic_add(A_grad_T_simple_ptr, (A_grad_acc / R * logit_scale).to(A_grad_ptr.type.element_ty))\n if monitoring:\n tl.atomic_add(entropy_val_ptr, logit_entropy / R)\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=1,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=3,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 32, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 128, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 1},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 2},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 16, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=16,\n num_stages=2,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 8},\n num_warps=16,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 8},\n num_warps=8,\n ),\n triton.Config(\n {\"V_BLOCK_SIZE\": 128, \"N_BLOCK_SIZE\": 128, \"H_BLOCK_SIZE\": 256, \"GROUP_SIZE\": 64, \"SPLIT_N\": 1, \"SPLIT_V\": 4},\n num_warps=8,\n num_stages=2,\n ),\n ],\n key=[\"V\", \"N\", \"H\", \"monitoring\"],\n prune_configs_by={\n \"early_config_prune\": lambda configs, named_args: configs,\n },\n)\n@triton.jit()\ndef linear_xent_bwd_dispatcher(\n logits_ptr,\n y_ptr,\n x_ptr,\n A_t_ptr,\n x_grad,\n At_grad,\n lse_global,\n logit_entropy_local,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring: tl.constexpr,\n logit_scale: tl.constexpr,\n z_regularization: tl.constexpr,\n fp32_grad_accumulators: tl.constexpr,\n ignore_index: tl.constexpr,\n idx_N_group,\n N_group,\n V: tl.constexpr,\n N: tl.constexpr,\n H: tl.constexpr,\n V_BLOCK_SIZE: tl.constexpr = 128,\n N_BLOCK_SIZE: tl.constexpr = 128,\n H_BLOCK_SIZE: tl.constexpr = 128,\n GROUP_SIZE: tl.constexpr = 32,\n SPLIT_N: tl.constexpr = 2,\n SPLIT_V: tl.constexpr = 2,\n):\n idx_NV = tl.program_id(axis=0)\n tl.static_print(V_BLOCK_SIZE, N_BLOCK_SIZE, H_BLOCK_SIZE, GROUP_SIZE, SPLIT_N, SPLIT_V, monitoring)\n if idx_NV < (N_group // N_BLOCK_SIZE * SPLIT_V):\n linear_xent_bwd_kernel_matmul_t_epilogue_dx(\n logits_ptr,\n y_ptr,\n A_t_ptr,\n x_grad,\n lse_global,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n reduction_ptr,\n logit_scale,\n z_regularization,\n fp32_grad_accumulators,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n else:\n linear_xent_bwd_kernel_matmul_t_epilogue_dA(\n logits_ptr,\n y_ptr,\n x_ptr,\n At_grad,\n lse_global,\n logit_entropy_local,\n stride_x_N,\n stride_x_H,\n stride_A_H,\n stride_A_V,\n stride_z_N,\n stride_z_V,\n stride_ent_H,\n stride_ent_V,\n reduction_ptr,\n monitoring,\n logit_scale,\n z_regularization,\n fp32_grad_accumulators,\n ignore_index,\n idx_N_group,\n N_group,\n V,\n N,\n H,\n V_BLOCK_SIZE,\n N_BLOCK_SIZE,\n H_BLOCK_SIZE,\n GROUP_SIZE,\n SPLIT_N,\n SPLIT_V,\n )\n", - "description_1": "Use triton language to implement forward and backward kernels for a linear cross-entropy computation with support for autotuning and gradient accumulation.", - "description_2": "Use triton language to develop optimized kernels for linear cross-entropy including forward, backward gradients for inputs and weights, with autotuning capabilities.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# softplus kernel\n@triton.jit\ndef softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\n\n@triton.jit\ndef softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n", - "description_1": "Use triton language to implement a softplus function kernel that takes a tensor as input, applies the softplus transformation element-wise, and returns the result.", - "description_2": "Use triton language to create a kernel for the softplus activation function, handling inputs element-wise and returning the modified output.", - "difficulty": 2 - }, - { - "code": "import paddle\nimport triton\nimport triton.language as tl\nfrom .math import rsqrt\nfrom ..utils import calculate_settings, custom_bwd, custom_fwd, ensure_contiguous\n\n_CASTING_MODE_NONE = tl.constexpr(-1)\n_CASTING_MODE_LLAMA = tl.constexpr(0)\n_CASTING_MODE_GEMMA = tl.constexpr(1)\n\n@triton.jit\ndef _rms_norm_forward_kernel(\n Y_ptr, Y_row_stride, X_ptr, X_row_stride, W_ptr, W_row_stride,\n RSTD_ptr, RSTD_row_stride, n_cols, eps, offset, casting_mode: tl.constexpr, BLOCK_SIZE: tl.constexpr,\n):\n \"\"\"\n y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y_ptr += row_idx * Y_row_stride\n X_ptr += row_idx * X_row_stride\n RSTD_ptr += row_idx * RSTD_row_stride\n\n X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)\n X_row_dtype = X_row.dtype\n W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)\n\n if casting_mode == _CASTING_MODE_LLAMA:\n X_row = X_row.to(tl.float32)\n\n if casting_mode == _CASTING_MODE_GEMMA:\n W_row = W_row.to(tl.float32)\n X_row = X_row.to(tl.float32)\n\n mean_square = tl.sum(X_row * X_row, axis=0) / n_cols\n rstd = rsqrt(mean_square + eps)\n tl.store(RSTD_ptr, rstd)\n\n X_row = X_row * rstd\n\n if casting_mode == _CASTING_MODE_LLAMA:\n X_row = X_row.to(X_row_dtype)\n\n Y_row = X_row * (offset + W_row)\n tl.store(Y_ptr + col_offsets, Y_row, mask=mask)\n\n@triton.jit\ndef _rms_norm_backward_kernel(\n dY_ptr, dY_row_stride, X_ptr, X_row_stride, W_ptr, W_row_stride,\n RSTD_ptr, RSTD_row_stride, dW_ptr, dW_row_stride, n_cols, offset, casting_mode: tl.constexpr, BLOCK_SIZE: tl.constexpr,\n):\n \"\"\"\n dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x].\n multiplication, whileas dot means dot product dw = sum(dy * (x / RMS)).\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dY_ptr += row_idx * dY_row_stride\n X_ptr += row_idx * X_row_stride\n RSTD_ptr += row_idx * RSTD_row_stride\n dW_ptr += row_idx * dW_row_stride\n\n dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)\n X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)\n W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)\n original_x_dtype = X_row.dtype\n\n rstd_row = tl.load(RSTD_ptr)\n W_row = W_row + offset\n X_row = X_row.to(tl.float32)\n\n if casting_mode == _CASTING_MODE_LLAMA:\n m = (dY_row * W_row).to(tl.float32)\n\n elif casting_mode == _CASTING_MODE_GEMMA:\n dY_row, W_row = dY_row.to(tl.float32), W_row.to(tl.float32)\n\n m = dY_row * W_row\n dX_row = rstd_row * m\n dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)\n\n if casting_mode == _CASTING_MODE_LLAMA:\n dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)\n else:\n dW_row = dY_row * (X_row * rstd_row)\n\n tl.store(dY_ptr + col_offsets, dX_row, mask=mask)\n tl.store(dW_ptr + col_offsets, dW_row, mask=mask)\n\n_str_to_casting_mode = {\n \"llama\": _CASTING_MODE_LLAMA.value,\n \"gemma\": _CASTING_MODE_GEMMA.value,\n \"none\": _CASTING_MODE_NONE.value,\n}\n\ndef rms_norm_forward(X, W, eps, offset, casting_mode):\n if not isinstance(casting_mode, int):\n assert casting_mode in _str_to_casting_mode, f\"Invalid casting mode: {casting_mode}\"\n casting_mode = _str_to_casting_mode[casting_mode]\n else:\n assert casting_mode in _str_to_casting_mode.values(), f\"Invalid casting mode: {casting_mode}\"\n\n shape = X.shape\n dim = shape[-1]\n X = X.reshape([-1, dim])\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n Y = paddle.empty((n_rows, n_cols), dtype=X.dtype)\n rstd_dtype = paddle.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype\n RSTD = paddle.empty((n_rows,), dtype=rstd_dtype)\n\n assert X.shape[1] == W.shape[0], \"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]\"\n\n _rms_norm_forward_kernel[(n_rows,)](\n Y, Y.strides[0], X, X.strides[0], W, W.strides[0],\n RSTD, RSTD.strides[0], n_cols, eps, offset, casting_mode,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps,\n )\n return Y.reshape(shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode\n\ndef rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):\n shape = dY.shape\n dim = shape[-1]\n dY = dY.reshape([-1, dim])\n n_rows, n_cols = dY.shape\n dW = paddle.empty_like(\n X, dtype=(paddle.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),\n )\n\n _rms_norm_backward_kernel[(n_rows,)](\n dY, dY.strides[0], X, X.strides[0], W, W.strides[0],\n RSTD, RSTD.strides[0], dW, dW.strides[0], n_cols, offset, casting_mode,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps,\n )\n dX = dY.reshape(shape)\n dW = dW.sum(axis=0).cast(W.dtype)\n return dX, dW\n", - "description_1": "Use triton language to implement RMSNorm forward and backward kernels. The forward kernel (_rms_norm_forward_kernel) has 12 parameters: output pointer (Y_ptr), row stride of output (Y_row_stride), input pointer (X_ptr), row stride of input (X_row_stride), weight pointer (W_ptr), row stride of weight (W_row_stride), RSTD pointer (RSTD_ptr), row stride of RSTD (RSTD_row_stride), number of columns (n_cols), epsilon (eps), offset, casting mode (casting_mode), and block size (BLOCK_SIZE). The backward kernel (_rms_norm_backward_kernel) also has 14 parameters: gradient of output pointer (dY_ptr), row stride of gradient of output (dY_row_stride), input pointer (X_ptr), row stride of input (X_row_stride), weight pointer (W_ptr), row stride of weight (W_row_stride), RSTD pointer (RSTD_ptr), row stride of RSTD (RSTD_row_stride), gradient of weight pointer (dW_ptr), row stride of gradient of weight (dW_row_stride), number of columns (n_cols), offset, casting mode (casting_mode), and block size (BLOCK_SIZE). These functions are used in rms_norm_forward and rms_norm_backward, which manage memory and launch the kernels.", - "description_2": "Use triton language to perform RMS normalization with customizable forward and backward kernels, handling different data types and casting modes efficiently.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _triton_rope(\n q_ptr,\n q_row_stride,\n k_ptr,\n k_row_stride,\n cos,\n cos_row_stride,\n sin,\n sin_row_stride,\n sl,\n bs: tl.constexpr,\n n_qh: tl.constexpr,\n n_kh: tl.constexpr,\n hd: tl.constexpr,\n pad_n_qh: tl.constexpr,\n pad_n_kh: tl.constexpr,\n pad_hd: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n BACKWARD_PASS: tl.constexpr = False,\n):\n # q size: (bsz, seq_len, num_q_heads, head_dim)\n # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)\n # k size: (bsz, seq_len, num_kv_heads, head_dim)\n # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)\n\n # cos size: (1, seq_len, head_dim)\n # stride: (seq_len * head_dim, head_dim, 1)\n pid = tl.program_id(0)\n\n # locate start address\n q_ptr = q_ptr + pid * q_row_stride\n k_ptr = k_ptr + pid * k_row_stride\n\n # ####################################################################\n # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position\n # m of this program instance\n # ####################################################################\n\n # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which\n # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension\n # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index\n # and pid % sl to get the sequence index.\n # 2. We only need the left half of cos and sin matrix because the right half is just\n # a clone of the left half.\n cos_row_idx = pid % (sl)\n cos = cos + cos_row_idx * cos_row_stride\n sin = sin + cos_row_idx * sin_row_stride\n cos_offsets = tl.arange(0, pad_hd // 2)\n cos_mask = cos_offsets < hd // 2\n cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)\n sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)\n\n # ####################################################################\n # Load the left and right half of q and k for the current\n # program instance (i.e. for the current token) separately\n # ####################################################################\n # left half of the head\n first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]\n first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]\n first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)\n first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)\n q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)\n k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)\n\n # right half of the head\n second_half_q_offsets = first_half_q_offsets + (hd // 2)\n second_half_k_offsets = first_half_k_offsets + (hd // 2)\n second_q_mask = first_q_mask\n second_k_mask = first_k_mask\n q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)\n k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)\n\n if not BACKWARD_PASS:\n # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]\n new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row\n tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)\n new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row\n tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)\n\n new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row\n tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)\n new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row\n tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)\n else:\n # with some math, we can get:\n # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]\n new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row\n tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)\n new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row\n tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)\n\n new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row\n tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)\n new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row\n tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)\n\n\ndef rope_forward(q, k, cos, sin):\n perm = list(range(q.shape))\n perm[2], perm[1] = perm[1], perm[2]\n # transpose it back to the physical shape because Triton looks at the physical storage\n # note: q and k are incontiguous before the transformation and will become contiguous after transpose\n q = q.transpose(perm)\n k = k.transpose(perm)\n\n batch_size, seq_len, n_q_head, head_dim = q.shape\n n_kv_head = k.shape[2]\n pad_hd = triton.next_power_of_2(head_dim)\n pad_n_q_head = triton.next_power_of_2(n_q_head)\n pad_n_kv_head = triton.next_power_of_2(n_kv_head)\n BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)\n\n n_row = batch_size * seq_len\n\n # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous\n q = q.contiguous()\n k = k.contiguous()\n cos = cos.contiguous()\n sin = sin.contiguous()\n\n _triton_rope[(n_row,)](\n q,\n q.strides[1],\n k,\n k.strides[1],\n cos,\n cos.strides[-2],\n sin,\n sin.strides[-2],\n seq_len,\n batch_size,\n n_q_head,\n n_kv_head,\n head_dim,\n pad_n_q_head,\n pad_n_kv_head,\n pad_hd,\n BLOCK_SIZE=BLOCK_SIZE,\n BACKWARD_PASS=False,\n )\n return q.transpose(perm), k.transpose(perm), cos, sin\n\n\ndef rope_backward(dq, dk, cos, sin):\n perm = list(range(dq.shape))\n perm[2], perm[1] = perm[1], perm[2]\n dq = dq.transpose(perm)\n dk = dk.transpose(perm)\n\n batch_size, seq_len, n_q_head, head_dim = dq.shape\n n_kv_head = dk.shape[2]\n pad_hd = triton.next_power_of_2(head_dim)\n pad_n_q_head = triton.next_power_of_2(n_q_head)\n pad_n_kv_head = triton.next_power_of_2(n_kv_head)\n BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)\n\n n_row = batch_size * seq_len\n\n # ensure dq and dk are contiguous\n dq = dq.contiguous()\n dk = dk.contiguous()\n\n # backward is similar to forward except swapping few ops\n _triton_rope[(n_row,)](\n dq,\n dq.strides[1],\n dk,\n dk.strides[1],\n cos,\n cos.strides[-2],\n sin,\n sin.strides[-2],\n seq_len,\n batch_size,\n n_q_head,\n n_kv_head,\n head_dim,\n pad_n_q_head,\n pad_n_kv_head,\n pad_hd,\n BLOCK_SIZE=BLOCK_SIZE,\n BACKWARD_PASS=True,\n )\n return dq.transpose(perm), dk.transpose(perm)\n", - "description_1": "Use triton language to implement a rotary positional embedding (RoPE) operation. The kernel function '_triton_rope' takes 18 parameters: q_ptr, q_row_stride, k_ptr, k_row_stride, cos, cos_row_stride, sin, sin_row_stride, sl, bs, n_qh, n_kh, hd, pad_n_qh, pad_n_kh, pad_hd, BLOCK_SIZE, and BACKWARD_PASS. It performs a transformation on the input query and key tensors using cosine and sine matrices. The 'rope_forward' function calls this kernel with 16 parameters: q, k, cos, sin, seq_len, batch_size, n_q_head, n_kv_head, head_dim, pad_n_q_head, pad_n_kv_head, pad_hd, BLOCK_SIZE, and BACKWARD_PASS set to False. The 'rope_backward' function calls the kernel with similar parameters but with BACKWARD_PASS set to True.", - "description_2": "Use triton language to create a kernel for rotary positional embedding, transforming input tensors with cosine and sine matrices. Implement forward and backward functions to call this kernel with appropriate parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt)\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert tuple(x.shape) == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert tuple(A.shape) == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert tuple(B.shape) == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert tuple(D.shape) == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert tuple(dt_bias.shape) == (nheads, dim)\n out = paddle.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META[\"BLOCK_SIZE_M\"]), batch, nheads)\n z_strides = (z.strides[0], z.strides[1], z.strides[2]) if z is not None else (0, 0, 0)\n BLOCK_SIZE_M, num_warps = (\n (32, 4)\n if dstate <= 16\n else ((16, 4) if dstate <= 32 else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))))\n )\n tie_hdim = A.strides[-1] == 0 and A.strides[-2] == 0 and dt.strides[-1] == 0 and dt_bias.strides[-1] == 0\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.strides[0], state.strides[1], state.strides[2], state.strides[3],\n x.strides[0], x.strides[1], x.strides[2],\n dt.strides[0], dt.strides[1], dt.strides[2],\n *(dt_bias.strides[0], dt_bias.strides[1]) if dt_bias is not None else 0,\n A.strides[0], A.strides[1], A.strides[2],\n B.strides[0], B.strides[1], B.strides[2],\n C.strides[0], C.strides[1], C.strides[2],\n *(D.strides[0], D.strides[1]) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.strides[0], out.strides[1], out.strides[2],\n dt_softplus, tie_hdim, BLOCK_SIZE_M, num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a kernel function '_selective_scan_update_kernel' with 56 parameters for updating state matrices with optional bias and scaling, and a wrapper function 'selective_state_update' with 9 parameters to prepare and call the kernel.", - "description_2": "Use triton language to create a kernel for matrix state updates with optional bias and scaling, and a wrapper to manage inputs and call the kernel.", - "difficulty": 4 - }, - { - "code": "import math\nimport paddle\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 32, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 32, \"BLOCK_SIZE_K\": 32}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_SIZE_M\": 32, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 32}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=2),\n ],\n key=[\"chunk_size\", \"K\", \"IS_CAUSAL\"],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr, seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n # Kernel code...\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_CS\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_CS\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_CS\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_CS\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_CS\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 32, \"BLOCK_SIZE_CS\": 32}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 32, \"BLOCK_SIZE_CS\": 32}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_SIZE_M\": 32, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_CS\": 32}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 64, \"BLOCK_SIZE_CS\": 32}, num_stages=4, num_warps=2),\n ],\n key=[\"chunk_size\", \"K\"],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr, seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr, HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n # Kernel code...\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert tuple(seq_idx.shape) == (batch, seqlen)\n if a.strides[-1] != 1 and a.strides[1] != 1:\n a = a.contiguous()\n if b.strides[-1] != 1 and b.strides[1] != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = paddle.empty(\n (batch, nchunks, chunk_size, chunk_size)\n if not has_groups\n else (batch, nchunks, ngroups, chunk_size, chunk_size),\n dtype=out_dtype,\n )\n dot_dtype = (\n tl.bfloat16\n if a.dtype == paddle.bfloat16 or b.dtype == paddle.bfloat16\n else (tl.float16 if a.dtype == paddle.float16 or b.dtype == paddle.float16 else tl.float32)\n )\n grid = lambda META: (\n triton.cdiv(chunk_size, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(chunk_size, META[\"BLOCK_SIZE_N\"]),\n batch,\n nchunks if not has_groups else nchunks * ngroups,\n )\n _bmm_chunk_fwd_kernel[grid](\n a,\n b,\n out,\n seq_idx,\n seqlen,\n chunk_size,\n k,\n ngroups if has_groups else 1,\n a.strides[0],\n a.strides[1],\n 0 if not has_groups else a.strides[2],\n a.strides[-1],\n b.strides[0],\n b.strides[1],\n 0 if not has_groups else b.strides[2],\n b.strides[-1],\n out.strides[0],\n out.strides[1],\n 0 if not has_groups else out.strides[2],\n out.strides[-2],\n out.strides[-1],\n *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.strides[-1] != 1 and a.strides[-2] != 1:\n a = a.contiguous()\n if dout.strides[-1] != 1 and dout.strides[-2] != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert tuple(residual.shape) == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.strides[-1] != 1 and residual.strides[1] != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.strides[-1] == 1 or out.strides[1] == 1\n else:\n out = paddle.empty_like(a)\n dot_dtype = (\n tl.bfloat16\n if a.dtype == paddle.bfloat16 or dout.dtype == paddle.bfloat16\n else (tl.float16 if a.dtype == paddle.float16 or dout.dtype == paddle.float16 else tl.float32)\n )\n grid = lambda META: (\n triton.cdiv(chunk_size, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(k, META[\"BLOCK_SIZE_N\"]),\n batch,\n nchunks if not has_groups else nchunks * ngroups,\n )\n residual_strides = (\n (residual.strides[0], residual.strides[1], 0 if not has_groups else residual.strides[2], residual.strides[-1])\n if residual is not None\n else (0, 0, 0, 0)\n )\n _bmm_chunk_bwd_kernel[grid](\n a,\n dout,\n out,\n residual,\n seqlen,\n chunk_size,\n k,\n ngroups if has_groups else 1,\n a.strides[0],\n a.strides[1],\n 0 if not has_groups else a.strides[2],\n a.strides[-1],\n dout.strides[0],\n dout.strides[1],\n 0 if not has_groups else dout.strides[2],\n dout.strides[-2],\n dout.strides[-1],\n out.strides[0],\n out.strides[1],\n 0 if not has_groups else out.strides[2],\n out.strides[-1],\n residual_strides[0],\n residual_strides[1],\n residual_strides[2],\n residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement two matrix multiplication kernels for forward and backward propagation. The forward kernel _bmm_chunk_fwd_kernel takes 27 parameters, including pointers to input matrices, matrix dimensions, and stride information, as well as meta-parameters like causality and block sizes. The backward kernel _bmm_chunk_bwd_kernel takes 24 parameters and performs gradient computations for matrix operations. The forward function _bmm_chunk_fwd calls _bmm_chunk_fwd_kernel and manages input/output preparation, while the backward function _bmm_chunk_bwd calls _bmm_chunk_bwd_kernel for gradient calculation, both handling optional grouping in matrices.", - "description_2": "Use triton language to create optimized kernels for batch matrix multiplication with support for chunked processing and optional causality, handling inputs with and without grouping dimensions, and enabling both forward and backward computations using autotuned kernel configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\nimport math\nfrom einops import rearrange, repeat\n\nTRITON_22 = True # Assuming Triton version is 2.2.0 or above for this context\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE_M\": 64, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32}, num_stages=4, num_warps=4),\n ],\n key=[\"chunk_size\", \"hdim\", \"dstate\", \"IS_CAUSAL\"],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n # Pointers to matrices\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr,\n seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n # Matrix dimensions\n chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio,\n # Strides\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m,\n stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head,\n stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head,\n stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head,\n stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head,\n stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head,\n stride_states_hdim, stride_states_dstate, stride_D_head,\n # Meta-parameters\n IS_CAUSAL: tl.constexpr, HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel implementation\n # (omitted here due to length and assuming it is given correctly in the context)\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n # Function to launch the forward kernel\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert tuple(C.shape) == (batch, seqlen, ngroups, dstate)\n assert tuple(cb.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads\n assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size)\n assert tuple(dA_cumsum.shape) == (batch, nheads, nchunks, chunk_size)\n assert tuple(states.shape) == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert tuple(seq_idx.shape) == (batch, seqlen)\n # Allocates output.\n out = paddle.empty([batch, seqlen, nheads, headdim], dtype=x.dtype)\n if z is not None:\n out_x = paddle.empty([batch, seqlen, nheads, headdim], dtype=x.dtype)\n assert out_x.strides == out.strides\n else:\n out_x = None\n grid = lambda META: (\n triton.cdiv(chunk_size, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(headdim, META[\"BLOCK_SIZE_N\"]),\n batch * nchunks,\n nheads,\n )\n z_strides = (z.strides[0], z.strides[1], z.strides[2], z.strides[3]) if z is not None else (0, 0, 0, 0)\n _chunk_scan_fwd_kernel[grid](\n cb,\n x,\n z,\n out,\n out_x,\n dt,\n dA_cumsum,\n seq_idx,\n C,\n states,\n D,\n chunk_size,\n headdim,\n dstate,\n batch,\n seqlen,\n nheads // ngroups,\n cb.strides[0],\n cb.strides[1],\n cb.strides[2],\n cb.strides[3],\n cb.strides[4],\n x.strides[0],\n x.strides[1],\n x.strides[2],\n x.strides[3],\n z_strides[0],\n z_strides[1],\n z_strides[2],\n z_strides[3],\n out.strides[0],\n out.strides[1],\n out.strides[2],\n out.strides[3],\n dt.strides[0],\n dt.strides[2],\n dt.strides[1],\n dt.strides[3],\n dA_cumsum.strides[0],\n dA_cumsum.strides[2],\n dA_cumsum.strides[1],\n dA_cumsum.strides[3],\n *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),\n C.strides[0],\n C.strides[1],\n C.strides[2],\n C.strides[3],\n states.strides[0],\n states.strides[1],\n states.strides[2],\n states.strides[3],\n states.strides[4],\n D.strides[0] if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n\ndef chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n \"\"\"\n prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1.\n\n Argument:\n B: (batch, seqlen, ngroups, dstate)\n C: (batch, seqlen, ngroups, dstate)\n x: (batch, seqlen, nheads, headdim)\n dt: (batch, nheads, nchunks, chunk_size)\n dA_cumsum: (batch, nheads, nchunks, chunk_size)\n prev_states: (batch, nchunks, nheads, headdim, dstate)\n D: (nheads, headdim) or (nheads,)\n z: (batch, seqlen, nheads, headdim)\n Return:\n out: (batch, seqlen, nheads, headdim)\n \"\"\"\n return _chunk_scan_fwd(B, C, x, dt, dA_cumsum, prev_states, D, z)\n", - "description_1": "Use triton language to implement a forward kernel that performs a block-scan operation on input matrices, enabling efficient computation of attention scores with optional dropout and state initialization. The kernel supports configurations such as block sizes and tuning for optimal performance. Additionally, implement a corresponding Python function to launch this kernel using input matrices and specified configurations.", - "description_2": "Use triton language to create a block-scan kernel for computing attention scores efficiently, with support for configurations and optimizations. Implement a Python interface to execute this kernel with given inputs.", - "difficulty": 4 - }, - { - "code": "import math\nimport paddle\nimport triton\nimport triton.language as tl\n\ndef init_to_zero(names):\n return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE_H\": 1}),\n triton.Config({\"BLOCK_SIZE_H\": 2}),\n triton.Config({\"BLOCK_SIZE_H\": 4}),\n triton.Config({\"BLOCK_SIZE_H\": 8}),\n triton.Config({\"BLOCK_SIZE_H\": 16}),\n triton.Config({\"BLOCK_SIZE_H\": 32}),\n triton.Config({\"BLOCK_SIZE_H\": 64}),\n ],\n key=[\"chunk_size\", \"nheads\"],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, batch, seqlen, nheads, chunk_size,\n dt_min, dt_max, stride_dt_batch, stride_dt_seqlen, stride_dt_head, stride_A_head, stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = softplus(dt)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n batch, seqlen, nheads = dt.shape\n assert A.shape[0] == nheads\n if dt_bias is not None:\n assert dt_bias.shape[0] == nheads\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32)\n dA_cumsum = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META[\"BLOCK_SIZE_H\"]))\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum, batch, seqlen, nheads, chunk_size,\n dt_limit[0], dt_limit[1], dt.strides[0], dt.strides[1], dt.strides[2], A.strides[0],\n dt_bias.strides[0] if dt_bias is not None else 0,\n dt_out.strides[0], dt_out.strides[2], dt_out.strides[1], dt_out.strides[3],\n dA_cumsum.strides[0], dA_cumsum.strides[2], dA_cumsum.strides[1], dA_cumsum.strides[3],\n dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n", - "description_1": "Use triton language to implement a cumulative sum forward operation for a matrix. The kernel function _chunk_cumsum_fwd_kernel takes pointers to input and output matrices, matrix dimensions, strides, and meta-parameters. It performs operations like matrix loading, bias addition, softplus transformation, clamping, and cumulative summation. The _chunk_cumsum_fwd function sets up the grid and invokes the kernel to perform the operation, handling necessary parameters like batch, sequence length, number of heads, chunk size, and optional bias.", - "description_2": "Use triton language to implement a kernel that performs cumulative summation over chunks of a matrix, applying optional bias and softplus transformations.", - "difficulty": 4 - }, - { - "code": "import paddle\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nimport paddle.nn.functional as F\n\nTRITON_22 = True # Assuming Triton version is 2.2.0 or higher\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64},\n num_stages=3,\n num_warps=8,\n pre_hook=lambda nargs: [nargs[\"ddt_ptr\"].zero_() if nargs[\"ddt_ptr\"] is not None else None],\n ),\n # Additional configurations omitted for brevity\n ],\n key=[\"chunk_size\", \"hdim\", \"dstate\"],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, b_ptr, dstates_ptr, dx_ptr, ddt_ptr, dD_ptr,\n chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel implementation\n # The kernel computes the backward pass for a chunked scan operation\n # involving state updates and matrix multiplications.\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert tuple(B.shape) == (batch, seqlen, ngroups, dstate)\n assert tuple(CB.shape) == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert tuple(dt.shape) == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert tuple(dstates.shape) == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert tuple(seq_idx.shape) == (batch, seqlen)\n if D is not None:\n assert tuple(D.shape) == (nheads, headdim) or D.shape[0] == nheads\n assert D.strides[-1] == 1\n BLOCK_SIZE_min = 32\n dD = paddle.empty(\n [triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1],\n dtype=paddle.float32,\n )\n else:\n dD = None\n dD_strides = (\n (dD.strides[0], dD.strides[1], dD.strides[2], dD.strides[3], dD.strides[4])\n if D is not None\n else (0, 0, 0, 0, 0)\n )\n if dx is None:\n dx = paddle.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = paddle.empty([batch, nheads, nchunks, chunk_size], dtype=paddle.float32)\n grid_dx = lambda META: (\n triton.cdiv(chunk_size, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(headdim, META[\"BLOCK_SIZE_N\"]),\n batch * nchunks,\n nheads,\n )\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x,\n CB,\n dout,\n dt,\n dA_cumsum,\n seq_idx,\n D,\n B,\n dstates,\n dx,\n ddt,\n dD,\n chunk_size,\n headdim,\n dstate,\n batch,\n seqlen,\n nheads // ngroups,\n x.strides[0],\n x.strides[1],\n x.strides[2],\n x.strides[3],\n CB.strides[0],\n CB.strides[1],\n CB.strides[2],\n CB.strides[-1],\n CB.strides[-2],\n dout.strides[0],\n dout.strides[1],\n dout.strides[2],\n dout.strides[3],\n dt.strides[0],\n dt.strides[2],\n dt.strides[1],\n dt.strides[3],\n dA_cumsum.strides[0],\n dA_cumsum.strides[2],\n dA_cumsum.strides[1],\n dA_cumsum.strides[3],\n *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),\n D.strides[0] if D is not None else 0,\n B.strides[0],\n B.strides[1],\n B.strides[2],\n B.strides[3],\n dstates.strides[0],\n dstates.strides[1],\n dstates.strides[2],\n dstates.strides[3],\n dstates.strides[4],\n dx.strides[0],\n dx.strides[1],\n dx.strides[2],\n dx.strides[3],\n ddt.strides[0],\n ddt.strides[2],\n ddt.strides[1],\n ddt.strides[3],\n dD_strides[1],\n dD_strides[2],\n dD_strides[3],\n dD_strides[0],\n dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22,\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(axis=(0, 1, 2)).cast(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.cast(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a backward pass kernel for a chunked scan operation. The kernel should handle state updates and matrix multiplications, taking into account various matrix dimensions, strides, and meta-parameters.", - "description_2": "Use triton language to create a wrapper function that sets up the grid and calls the backward pass kernel for a chunked scan operation. Ensure the function handles input and output tensors correctly, including optional parameters like D and seq_idx.", - "difficulty": 4 - }, - { - "code": "import paddle\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 64}),\n triton.Config({\"BLOCK_SIZE\": 128}),\n triton.Config({\"BLOCK_SIZE\": 256}),\n triton.Config({\"BLOCK_SIZE\": 512}),\n triton.Config({\"BLOCK_SIZE\": 1024}),\n triton.Config({\"BLOCK_SIZE\": 2048}),\n ],\n key=[\"dim\"],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n HAS_INITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 64}),\n triton.Config({\"BLOCK_SIZE\": 128}),\n triton.Config({\"BLOCK_SIZE\": 256}),\n triton.Config({\"BLOCK_SIZE\": 512}),\n triton.Config({\"BLOCK_SIZE\": 1024}),\n triton.Config({\"BLOCK_SIZE\": 2048}),\n ],\n key=[\"dim\"],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n CONVERT_STATES: tl.constexpr, HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += (\n pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n )\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(\n tl.float32\n )\n else:\n dstates = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks)\n if initial_states is not None:\n assert tuple(initial_states.shape) == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert tuple(seq_idx.shape) == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = paddle.empty((batch, nchunks, nheads, dim), dtype=out_dtype)\n final_states = paddle.empty((batch, nheads, dim), dtype=paddle.float32)\n grid = lambda META: (triton.cdiv(dim, META[\"BLOCK_SIZE\"]), batch, nheads)\n _state_passing_fwd_kernel[grid](\n states,\n out,\n final_states,\n dA_chunk_cumsum,\n initial_states,\n seq_idx,\n dim,\n nchunks,\n seqlen if seq_idx is not None else 0,\n chunk_size if seq_idx is not None else 0,\n states.strides[0],\n states.strides[1],\n states.strides[2],\n states.strides[3],\n out.strides[0],\n out.strides[1],\n out.strides[2],\n out.strides[3],\n final_states.strides[0],\n final_states.strides[1],\n final_states.strides[2],\n dA_chunk_cumsum.strides[0],\n dA_chunk_cumsum.strides[2],\n dA_chunk_cumsum.strides[1],\n *(\n (initial_states.strides[0], initial_states.strides[1], initial_states.strides[2])\n if initial_states is not None\n else (0, 0, 0)\n ),\n *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\n\ndef _state_passing_bwd(\n states,\n dA_chunk_cumsum,\n dout,\n dfinal_states=None,\n seq_idx=None,\n has_initial_states=None,\n dstates_dtype=None,\n states_dtype=None,\n chunk_size=None,\n):\n batch, nchunks, nheads, dim = states.shape\n assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks)\n assert tuple(dout.shape) == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert tuple(seq_idx.shape) == (batch, seqlen)\n dstates = paddle.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = paddle.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.strides == states.strides\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = paddle.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert tuple(dfinal_states.shape) == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = paddle.empty([batch, nheads, nchunks, n_blocks], dtype=paddle.float32)\n grid = lambda META: (triton.cdiv(dim, META[\"BLOCK_SIZE\"]), batch, nheads)\n _state_passing_bwd_kernel[grid](\n dout,\n states,\n dA_chunk_cumsum,\n dfinal_states,\n seq_idx,\n dstates,\n ddA_chunk_cumsum,\n dinitstates,\n states_converted,\n dim,\n nchunks,\n seqlen if seq_idx is not None else 0,\n chunk_size if seq_idx is not None else 0,\n dout.strides[0],\n dout.strides[1],\n dout.strides[2],\n dout.strides[3],\n states.strides[0],\n states.strides[1],\n states.strides[2],\n states.strides[3],\n dA_chunk_cumsum.strides[0],\n dA_chunk_cumsum.strides[2],\n dA_chunk_cumsum.strides[1],\n *(\n (dfinal_states.strides[0], dfinal_states.strides[1], dfinal_states.strides[2])\n if dfinal_states is not None\n else (0, 0, 0)\n ),\n *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),\n dstates.strides[0],\n dstates.strides[1],\n dstates.strides[2],\n dstates.strides[3],\n ddA_chunk_cumsum.strides[0],\n ddA_chunk_cumsum.strides[2],\n ddA_chunk_cumsum.strides[1],\n *(\n (dinitstates.strides[0], dinitstates.strides[1], dinitstates.strides[2])\n if dinitstates is not None\n else (0, 0, 0)\n ),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(axis=-1).cast(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (\n (dstates, ddA_chunk_cumsum, dinitstates)\n if states_dtype is None\n else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n )\n", - "description_1": "Use triton language to implement forward and backward kernels for state passing. The forward kernel (_state_passing_fwd_kernel) has 28 parameters, handling pointers, dimensions, strides, and meta-parameters for block-wise state passing computation. The backward kernel (_state_passing_bwd_kernel) involves 37 parameters, managing gradients and optional conversions with similar types of inputs and meta-parameters. The functions _state_passing_fwd and _state_passing_bwd encapsulate kernel calls, orchestrating grid setups and invoking respective triton kernels with necessary stride and sequence index calculations for matrix manipulations.", - "description_2": "Use triton language to create kernels for state passing computations, both forward and backward, utilizing block sizes and matrix pointers. Implement functions that prepare and invoke these kernels, managing stride and dimension details for data processing.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport paddle\nfrom ..utils import calculate_settings\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\n@triton.jit\ndef _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n program_id = tl.program_id(0)\n a_ptr += program_id * stride\n b_ptr += program_id * stride\n c_ptr += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)\n c_row = silu(a_row) * b_row\n tl.store(c_ptr + col_offsets, c_row, mask=mask)\n\n@triton.jit\ndef _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):\n program_id = tl.program_id(0)\n dc_ptr += program_id * stride\n a_ptr += program_id * stride\n b_ptr += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)\n a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)\n\n sig_a = tl.sigmoid(a_row)\n silu_a = a_row * sig_a\n db_row = dc_row * silu_a\n da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row\n\n tl.store(a_ptr + col_offsets, da_row, mask=mask)\n tl.store(b_ptr + col_offsets, db_row, mask=mask)\n\ndef swiglu_forward(a, b):\n ori_shape = a.shape\n n_cols = ori_shape[-1]\n a = a.reshape([-1, n_cols])\n b = b.reshape([-1, n_cols])\n c = paddle.empty_like(a)\n n_rows = a.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _swiglu_forward_kernel[(n_rows,)](\n a,\n b,\n c,\n c.strides[-2],\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return a, b, c.reshape(ori_shape)\n\ndef swiglu_backward(a, b, dc):\n ori_shape = dc.shape\n n_cols = ori_shape[-1]\n dc = dc.reshape([-1, n_cols])\n n_rows = dc.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _swiglu_backward_kernel[(n_rows,)](\n dc,\n a,\n b,\n dc.strides[-2],\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return a.reshape(ori_shape), b.reshape(ori_shape)\n", - "description_1": "Use triton language to implement a SiLU activation function and a SWIGLU operation with forward and backward kernels. The forward kernel (_swiglu_forward_kernel) takes pointers to input tensors a and b, an output tensor c, a stride, the number of columns, and a block size. It computes the element-wise product of the SiLU activation of a and b, storing the result in c. The backward kernel (_swiglu_backward_kernel) takes pointers to the gradient tensor dc, input tensors a and b, a stride, the number of columns, and a block size. It computes the gradients with respect to a and b using recomputation to save memory, storing the results back in a and b.", - "description_2": "Use triton language to create a SWIGLU operation with forward and backward kernels, where the forward kernel computes the element-wise product of the SiLU activation of two input tensors, and the backward kernel computes the gradients with respect to the inputs.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport numpy as np\n\n@triton.jit\ndef garbage_pad_ragged_acts_kernel(\n ragged_acts_ptr,\n ragged_acts_offset_per_seq_ptr,\n n_ctx_per_seq_ptr,\n padded_acts_ptr,\n BLOCK_SIZE: tl.constexpr,\n n_ctx_max: tl.constexpr,\n):\n seq_idx = tl.program_id(axis=0)\n ctx_idx = tl.program_id(axis=1)\n\n ragged_acts_offset_ptr = ragged_acts_offset_per_seq_ptr + seq_idx\n ragged_acts_offset = tl.load(ragged_acts_offset_ptr)\n\n n_ctx_in_this_seq_ptr = n_ctx_per_seq_ptr + seq_idx\n n_ctx_in_this_seq = tl.load(n_ctx_in_this_seq_ptr)\n ctx_idx_too_large_mask = ctx_idx < n_ctx_in_this_seq\n\n ragged_acts_offsets = ragged_acts_offset + tl.arange(0, BLOCK_SIZE)\n\n acts = tl.load(ragged_acts_ptr + ragged_acts_offsets, mask=ctx_idx_too_large_mask)\n\n padded_acts_offset = n_ctx_max * seq_idx * BLOCK_SIZE\n\n tl.store(padded_acts_ptr + padded_acts_offset, acts, mask=ctx_idx_too_large_mask)\n\n\nclass RaggedActivations:\n def __init__(self, raw_tensor: torch.Tensor, n_ctx_per_seq: list):\n self.raw_tensor = raw_tensor\n self.n_ctx_per_seq = n_ctx_per_seq\n\n def triton_to_garbage_padded(self) -> torch.Tensor:\n n_seqs = len(self.n_ctx_per_seq)\n n_ctx_max = max(self.n_ctx_per_seq)\n\n ragged_acts = self.raw_tensor\n d_model = ragged_acts.shape[-1]\n padded_acts = torch.empty(\n n_seqs, n_ctx_max, d_model, dtype=ragged_acts.dtype, device=\"cuda\"\n )\n\n assert d_model >= 128, f\"bad {d_model=}\"\n assert d_model <= 8 * 1024, f\"bad {d_model=}\"\n assert d_model % 32 == 0, f\"bad {d_model=}\"\n\n n_ctx_per_seq = self.n_ctx_per_seq\n ragged_acts_offset_per_seq = get_acts_offset_per_seq(n_ctx_per_seq)\n\n grid_2d = (n_seqs, n_ctx_max)\n\n garbage_pad_ragged_acts_kernel[grid_2d](\n ragged_acts,\n torch.tensor(ragged_acts_offset_per_seq, device=\"cuda\"),\n torch.tensor(self.n_ctx_per_seq, device=\"cuda\"),\n padded_acts,\n BLOCK_SIZE=d_model,\n n_ctx_max=n_ctx_max,\n )\n return padded_acts\n\n\ndef get_acts_offset_per_seq(n_ctx_per_seq):\n n_ctx_per_seq_shifted = np.array([0] + n_ctx_per_seq[:-1])\n ragged_acts_offset_per_seq = n_ctx_per_seq_shifted.cumsum(axis=0)\n return ragged_acts_offset_per_seq\n", - "description_1": "Use triton language to implement a kernel that pads ragged sequences with garbage data. The kernel 'garbage_pad_ragged_acts_kernel' takes 6 parameters: ragged_acts_ptr (pointer to the ragged activations), ragged_acts_offset_per_seq_ptr (pointer to offsets for each sequence), n_ctx_per_seq_ptr (pointer to the number of contexts per sequence), padded_acts_ptr (pointer to the output padded activations), BLOCK_SIZE (constant expression for block size), and n_ctx_max (constant expression for maximum context length). The kernel processes each sequence and context index, loads the ragged activations, and stores them into the padded activations tensor, applying a mask to handle out-of-bounds accesses. The 'RaggedActivations' class provides a method 'triton_to_garbage_padded' to invoke this kernel, which prepares the necessary data and launch grid, and returns the padded activations.", - "description_2": "Use triton language to create a kernel that pads sequences with garbage data, handling out-of-bounds accesses with a mask, and provide a class method to invoke this kernel and return the padded result.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _kernel(\n A, B, C, M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n):\n # matrix multiplication\n pid = tl.program_id(0)\n\n # Determine the number of blocks in the grid\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n pid_m = pid // grid_n\n pid_n = pid % grid_n\n\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(K, 0, -BLOCK_K):\n\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n\n acc += tl.dot(a, b)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n tl.store(C, acc, mask=mask)\n\n\ndef matmul(a, b):\n device = a.device\n # handle non-contiguous inputs if necessary\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n\n # checks constraints\n assert a.shape[1] == b.shape[0], f\"incompatible dimensions, {a.shape=} {b.shape=}\"\n\n M, K = a.shape\n _, N = b.shape\n\n # allocates output\n c = torch.empty((M, N), device=device, dtype=a.dtype)\n\n # launch kernel\n def grid(META):\n return (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n\n _kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n )\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication kernel. The kernel function '_kernel' takes 14 parameters: three matrices A, B, C, three integers M, N, K representing the dimensions of the matrices, six stride values for the matrices, and three block size constants BLOCK_M, BLOCK_N, BLOCK_K. The function performs matrix multiplication using a block-wise approach and stores the result in matrix C. The 'matmul' function is a wrapper that prepares the input matrices, checks their dimensions, allocates the output matrix, and launches the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a matrix multiplication operation with block-wise computation, handling input matrices' strides and dimensions, and storing the result in an output matrix.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n# Triton kernel for matrix multiplication\n@triton.autotune(\n configs=get_fast_dev_configs(),\n key=[\"n_ctx_q\", \"n_ctx_k\", \"d_model\"],\n prune_configs_by={\n \"early_config_prune\": early_config_prune,\n \"perf_model\": estimate_matmul_time,\n \"top_k\": 10,\n },\n)\n@triton.jit\ndef _kernel(\n q_ptr, k_ptr, scores_ptr,\n n_ctx_q,\n n_ctx_k, # N\n d_model,\n stride_ctx_q, stride_ctx_k,\n stride_d, # Stride along the d_model_per_head dim\n stride_out_q, stride_out_k,\n BLOCK_Q: tl.constexpr,\n BLOCK_K: tl.constexpr,\n BLOCK_D: tl.constexpr,\n):\n pid = tl.program_id(0)\n grid_k = (n_ctx_k + BLOCK_K - 1) // BLOCK_K\n pid_q = pid // grid_k\n pid_k = pid % grid_k\n\n rq = pid_q * BLOCK_Q + tl.arange(0, BLOCK_Q)\n rq = tl.max_contiguous(tl.multiple_of(rq % n_ctx_q, BLOCK_Q), BLOCK_Q)\n\n rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)\n rk = tl.max_contiguous(tl.multiple_of(rk % n_ctx_k, BLOCK_K), BLOCK_K)\n\n acc_tile = tl.zeros((BLOCK_Q, BLOCK_K), dtype=tl.float32)\n rd = tl.arange(0, BLOCK_D)\n\n q_ptr_tile = q_ptr + (rq[:, None] * stride_ctx_q + rd[None, :] * stride_d)\n k_ptr_tile = k_ptr + (rd[:, None] * stride_d + rk[None, :] * stride_ctx_k)\n\n for d_max_offset in range(d_model, 0, -BLOCK_D):\n q_tile = tl.load(q_ptr_tile, mask=rd[None, :] < d_max_offset, other=0.0)\n k_tile = tl.load(k_ptr_tile, mask=rd[:, None] < d_max_offset, other=0.0)\n acc_tile += tl.dot(q_tile, k_tile)\n q_ptr_tile += BLOCK_D * stride_d\n k_ptr_tile += BLOCK_D * stride_d\n\n acc_tile = acc_tile.to(scores_ptr.dtype.element_ty)\n\n rq = pid_q * BLOCK_Q + tl.arange(0, BLOCK_Q)\n rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)\n\n scores_offset_tile = rq[:, None] * stride_out_q + rk[None, :] * stride_out_k\n scores_ptr_tile = scores_ptr + scores_offset_tile\n\n mask = (rq < n_ctx_q)[:, None] & (rk < n_ctx_k)[None, :]\n tl.store(scores_ptr_tile, acc_tile, mask=mask)\n\n# Function to call the Triton kernel\ndef qk_dotprod(query, key):\n device = query.device\n\n if query.stride(0) > 1 and query.stride(1) > 1:\n query = query.contiguous()\n if key.stride(0) > 1 and key.stride(1) > 1:\n key = key.contiguous()\n\n n_ctx_q, d_model = query.shape\n n_ctx_k, d_model_k = key.shape\n assert d_model == d_model_k, f\"{query.shape=} {key.shape=}\"\n\n scores_out = torch.empty((n_ctx_q, n_ctx_k), device=device, dtype=query.dtype)\n stride_d = query.stride(1)\n assert stride_d == key.stride(1), f\"{stride_d=}, {key.stride(1)=}\"\n\n def grid(META):\n return (\n triton.cdiv(n_ctx_q, META[\"BLOCK_Q\"])\n * triton.cdiv(n_ctx_k, META[\"BLOCK_K\"]),\n )\n\n _kernel[grid](\n query,\n key,\n scores_out,\n n_ctx_q,\n n_ctx_k,\n d_model,\n query.stride(0),\n key.stride(0),\n stride_d,\n scores_out.stride(0),\n scores_out.stride(1),\n )\n return scores_out\n", - "description_1": "Use triton language to perform a matrix multiplication kernel with 14 parameters. The kernel takes pointers to query and key matrices, an output scores pointer, context dimensions, model dimension, and strides for contexts and model dimension. It uses block sizes defined by BLOCK_Q, BLOCK_K, and BLOCK_D to process input matrices and calculate their dot product. The function qk_dotprod prepares these matrices and calls the kernel on them.", - "description_2": "Use triton language to create a kernel function that performs block-wise matrix multiplication between two input matrices and outputs the result, and provide a Python function that prepares the inputs and invokes this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for computing the dot product of query and key matrices\n@triton.jit\ndef _qk_dotprod_kernel(\n q_ptr, k_ptr, scores_ptr,\n pid_to_in_q_token_offset_ptr, pid_to_in_k_token_offset_ptr,\n pid_to_out_q_block_ptr, pid_to_out_k_block_ptr, pid_to_out_seq_idx_ptr,\n max_n_ctx_q_across_seqs, max_n_ctx_k_across_seqs, d_head,\n stride_ctx_q, stride_ctx_k, stride_out_q, stride_out_k, stride_out_seq,\n total_ctx_q_across_all_seqs, total_ctx_k_across_all_seqs,\n BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_D: tl.constexpr,\n):\n pid = tl.program_id(0)\n\n out_q_block = tl.load(pid_to_out_q_block_ptr + pid)\n out_k_block = tl.load(pid_to_out_k_block_ptr + pid)\n out_seq_idx = tl.load(pid_to_out_seq_idx_ptr + pid)\n in_q_token_offset = tl.load(pid_to_in_q_token_offset_ptr + pid)\n in_k_token_offset = tl.load(pid_to_in_k_token_offset_ptr + pid)\n\n rq = in_q_token_offset + tl.arange(0, BLOCK_Q)\n rk = in_k_token_offset + tl.arange(0, BLOCK_K)\n\n q_ctx_in_bounds = rq < total_ctx_q_across_all_seqs\n k_ctx_in_bounds = rk < total_ctx_k_across_all_seqs\n\n acc_tile = tl.zeros((BLOCK_Q, BLOCK_K), dtype=tl.float32)\n\n rd = tl.arange(0, BLOCK_D)\n\n q_ptr_tile = q_ptr + (rq[:, None] * stride_ctx_q + rd[None, :])\n k_ptr_tile = k_ptr + (rd[:, None] + rk[None, :] * stride_ctx_k)\n\n for d_max_offset in range(d_head, 0, -BLOCK_D):\n q_tile = tl.load(\n q_ptr_tile,\n mask=(rd[None, :] < d_max_offset) & q_ctx_in_bounds[:, None],\n other=0.0,\n )\n k_tile = tl.load(\n k_ptr_tile,\n mask=(rd[:, None] < d_max_offset) & k_ctx_in_bounds[None, :],\n other=0.0,\n )\n\n acc_tile += tl.dot(q_tile, k_tile)\n\n q_ptr_tile += BLOCK_D\n k_ptr_tile += BLOCK_D\n\n rq_out = out_q_block * BLOCK_Q + tl.arange(0, BLOCK_Q)\n rk_out = out_k_block * BLOCK_K + tl.arange(0, BLOCK_K)\n\n scores_offset_tile = (\n rq_out[:, None] * stride_out_q\n + rk_out[None, :] * stride_out_k\n + out_seq_idx * stride_out_seq\n )\n scores_ptr_tile = scores_ptr + scores_offset_tile\n\n mask = (rq_out < max_n_ctx_q_across_seqs)[:, None] & (\n rk_out < max_n_ctx_k_across_seqs\n )[None, :]\n\n acc_tile = acc_tile.to(scores_ptr.dtype.element_ty)\n tl.store(scores_ptr_tile, acc_tile, mask=mask)\n\n\ndef ragged_single_seq_qk_dotprod(\n query: torch.Tensor, key: torch.Tensor, lut\n) -> torch.Tensor:\n assert query.ndim == 2 and key.ndim == 2\n device = query.device\n\n if query.stride(0) > 1 and query.stride(1) > 1:\n query = query.contiguous()\n if key.stride(0) > 1 and key.stride(1) > 1:\n key = key.contiguous()\n\n n_ctx_q, d_head = query.shape\n n_ctx_k, d_head_k = key.shape\n assert d_head == d_head_k, f\"{query.shape=} {key.shape=}\"\n\n scores_out = torch.empty((1, n_ctx_q, n_ctx_k), device=device, dtype=query.dtype)\n\n assert query.stride(1) == 1, f\"{query.stride(1)}\"\n assert key.stride(1) == 1, f\"{key.stride(1)}\"\n\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query,\n k_ptr=key,\n scores_ptr=scores_out,\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n max_n_ctx_q_across_seqs=n_ctx_q,\n max_n_ctx_k_across_seqs=n_ctx_k,\n d_head=d_head,\n stride_ctx_q=query.stride(0),\n stride_ctx_k=key.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=n_ctx_q,\n total_ctx_k_across_all_seqs=n_ctx_k,\n )\n return scores_out.reshape((n_ctx_q, n_ctx_k))\n\n\ndef ragged_qk_dotprod(\n query, key, lut\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n assert query.n_seqs == key.n_seqs\n\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n", - "description_1": "Use triton language to implement a kernel function '_qk_dotprod_kernel' that performs a matrix multiplication of query and key matrices with specific block sizes and accumulates the results. The kernel takes 20 parameters: pointers to query, key, and scores tensors, pointers to lookup tables for token offsets and block indices, integers for context sizes and strides, and block sizes as constexpr. The function 'ragged_single_seq_qk_dotprod' calls this kernel for a single sequence, taking 3 parameters: query tensor, key tensor, and a lookup table. The function 'ragged_qk_dotprod' calls this kernel for multiple sequences, taking 3 parameters: query activations, key activations, and a lookup table.", - "description_2": "Use triton language to create a kernel for batched matrix multiplication with custom block sizes and offsets, and implement functions to call this kernel for single and multiple sequence scenarios.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef k_mean(X, Mean, Var, stride, N, BLOCK_SIZE_N: tl.constexpr):\n \"\"\"\n Fused layernorm kernel over a 3d tensor.\n The layer norm is applied over the last dimension.\n Compute\n y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta\n \"\"\"\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n\n # Move to this row\n x_ptrs = X + row * stride + cols\n x = tl.load(x_ptrs, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n\n # Compute variance\n x_mean = tl.sum(x, axis=0) / N\n x_zm = x - x_mean\n x_zm = tl.where(cols < N, x_zm, 0.0)\n x_var = tl.sum(x_zm * x_zm, axis=0) / N\n tl.store(Mean + row, x_mean)\n tl.store(Var + row, x_var)\n\ndef stats(x: torch.Tensor):\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n\n # heuristics for number of warps.\n num_warps = min(max(BLOCK_SIZE_N // 256, 1), 8)\n\n mean = torch.zeros((M,)).cuda()\n var = torch.zeros((M,)).cuda()\n\n # enqueue kernel\n k_mean[(M,)](\n x_arg, mean, var,\n x_arg.stride(0),\n N,\n num_warps=num_warps,\n BLOCK_SIZE_N=BLOCK_SIZE_N\n )\n\n return mean.reshape(x.shape[:-1]), var.reshape(x.shape[:-1])\n", - "description_1": "Use triton language to create a fused layernorm kernel named k_mean that operates on a 3D tensor with parameters: X (input tensor), Mean (output mean), Var (output variance), stride (memory stride of X), N (size of the last dimension of X), and BLOCK_SIZE_N (block size for loading data). This kernel computes the mean and variance across the last dimension of X. The kernel is invoked by a function named stats which reshapes the input tensor to 2D, calculates parameters like block size and number of warps, and calls the triton kernel.", - "description_2": "Use triton language to create a kernel that computes the mean and variance over the last dimension of a 3D tensor. Invoke this kernel using a wrapper function that prepares input dimensions and manages GPU resources.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom xformers.components import Activation\n\n_kAlpha = math.sqrt(2.0 / math.pi)\n\n# A Triton implementation of the most used activations\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef relu(x):\n \"\"\"\n ReLU_ activation function\n\n .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html\n \"\"\"\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n@triton.jit\ndef relu_grad(x):\n # ReLU is different from other activations\n # in that it does not require the input to retrospectively compute its gradient\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n@triton.jit\ndef squared_relu(x):\n \"\"\"\n Squared ReLU activation, as proposed in the Primer_ paper.\n\n .. _Primer: https://arxiv.org/abs/2109.08668\n \"\"\"\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n@triton.jit\ndef leaky_relu(x):\n \"\"\"\n LeakyReLU_ activation\n\n .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html\n \"\"\"\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n\n return tl.where(x >= 0, max_grad, min_grad)\n\n@triton.jit\ndef gelu(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit\n\n .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf\n \"\"\"\n return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))\n\n@triton.jit\ndef gelu_grad(x):\n # CREDITS: Fast implementation proposed in\n # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * (\n (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)\n ) + 0.5 * (1 + tanh_out)\n\n@triton.jit\ndef smelu(x):\n \"\"\"\n SmeLU_ activation - Smooth ReLU with beta=2.0\n\n .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf\n \"\"\"\n zero = 0.0\n four = 4.0\n two = 2.0\n beta = two.to(x.dtype)\n\n output = (x + beta) * (x + beta) / (four.to(x.dtype) * beta)\n relu = tl.where(x >= beta, x, zero.to(x.dtype))\n return tl.where(tl.abs(x) <= beta, output, relu)\n\n@triton.jit\ndef smelu_grad(x):\n zero = 0.0\n one = 1.0\n two = 2.0\n beta = two.to(x.dtype)\n\n grad = (beta + x) / (two.to(x.dtype) * beta)\n relu_grad = tl.where(x >= beta, one.to(x.dtype), zero.to(x.dtype))\n return tl.where(tl.abs(x) <= beta, grad, relu_grad)\n", - "description_1": "Use triton language to implement several activation functions: tanh, relu, relu_grad, squared_relu, squared_relu_grad, leaky_relu, leaky_relu_grad, gelu, gelu_grad, smelu, and smelu_grad. Each of these kernels takes a single argument x, which is the input tensor, and returns the transformed output tensor. The relu and relu_grad functions handle the ReLU activation and its gradient respectively. Similarly, squared_relu and squared_relu_grad handle the Squared ReLU activation. Leaky ReLU and its gradient are implemented in leaky_relu and leaky_relu_grad. The gelu function implements the Gaussian Error Linear Unit activation, with gelu_grad providing its gradient. Lastly, smelu and smelu_grad implement the Smooth ReLU activation and its gradient.", - "description_2": "Use triton language to create kernels for common activation functions including ReLU, Leaky ReLU, GeLU, and SmeLU, along with their gradients.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n_configs = [\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n]\n\n\n@triton.jit\ndef _get_4_bin_masks(seed_ptr, rand_offsets, p):\n seed = tl.load(seed_ptr)\n rand1, rand2, rand3, rand4 = tl.randint4x(seed, rand_offsets)\n\n threshold = (4294967296.0 * p).to(tl.int32)\n rand_mask1 = rand1 > threshold\n rand_mask2 = rand2 > threshold\n rand_mask3 = rand3 > threshold\n rand_mask4 = rand4 > threshold\n\n return rand_mask1, rand_mask2, rand_mask3, rand_mask4\n\n\n@triton.jit\ndef _random_prune_and_scale(x, rand_mask, p, p_scale):\n zero = 0.0\n keep = tl.reshape(rand_mask, x.shape)\n x = tl.where(keep, (x * p_scale).to(x.dtype), zero.to(x.dtype))\n return x\n\n\n@triton.jit\ndef tile_random_drop(\n x_ptrs,\n y_ptrs,\n block_mask,\n use_bias,\n bias,\n rand_mask,\n p,\n p_scale,\n ACTIVATION,\n):\n x = tl.load(x_ptrs, mask=block_mask, other=0.0)\n\n if use_bias:\n x += bias\n\n if ACTIVATION:\n x = ACTIVATION(x)\n\n output = _random_prune_and_scale(x, rand_mask, p, p_scale)\n\n tl.store(y_ptrs, output, mask=block_mask)\n\n\n@triton.heuristics({\"SIZE_RAND_BLOCK\": lambda args: args[\"BLOCK_N\"] * args[\"BLOCK_M\"]})\n@triton.autotune(\n configs=_configs,\n key=[\"M\", \"N\", \"is_fp16\"],\n)\n@triton.jit\ndef k_dropout_fw(\n Y, X, BIAS, SEEDS,\n stride,\n M, N,\n p,\n is_fp16,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SIZE_RAND_BLOCK: tl.constexpr,\n USE_BIAS: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n row_id = tl.program_id(axis=0)\n rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M)\n\n col_id = tl.program_id(axis=1)\n cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)\n seed = SEEDS + col_id\n\n x_ptrs = X + rows[:, None] * stride + cols[None, :]\n y_ptrs = Y + rows[:, None] * stride + cols[None, :]\n\n rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + row_id * BLOCK_M * 4\n rand_mask1, rand_mask2, rand_mask3, rand_mask4 = _get_4_bin_masks(seed, rand_offsets, p)\n\n col_mask = cols[None, :] < N\n p_scale = 1 / (1 - p)\n\n if USE_BIAS:\n b_ptrs = BIAS + cols[None, :]\n bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.)\n else:\n bias = x_ptrs\n\n for i in range(4):\n if i == 0:\n rand_mask = rand_mask1\n elif i == 1:\n rand_mask = rand_mask2\n elif i == 2:\n rand_mask = rand_mask3\n else:\n rand_mask = rand_mask4\n\n block_mask = (rows[:, None] < M) & col_mask\n tile_random_drop(x_ptrs, y_ptrs, block_mask, USE_BIAS, bias, rand_mask, p, p_scale, ACTIVATION)\n\n rows += BLOCK_M\n x_ptrs += BLOCK_M * stride\n y_ptrs += BLOCK_M * stride\n\n\n@triton.heuristics({\"SIZE_RAND_BLOCK\": lambda args: args[\"BLOCK_N\"] * args[\"BLOCK_M\"]})\n@triton.autotune(\n configs=_configs,\n key=[\"M\", \"N\", \"is_fp16\"],\n)\n@triton.jit\ndef k_dropout_bw(\n GRAD_IN, GRAD_BIAS, GRAD_OUT,\n INPUTS, BIAS, SEEDS,\n stride_grad, stride_inputs,\n M, N,\n p,\n is_fp16,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SIZE_RAND_BLOCK: tl.constexpr,\n TRAINABLE_BIAS: tl.constexpr,\n USE_BIAS: tl.constexpr,\n ACTIVATION_GRAD: tl.constexpr,\n):\n row_id = tl.program_id(axis=0)\n rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M)\n\n col_id = tl.program_id(axis=1)\n cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)\n seed = SEEDS + col_id\n\n grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :]\n grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :]\n input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :]\n\n rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + row_id * BLOCK_M * 4\n rand_mask1, rand_mask2, rand_mask3, rand_mask4 = _get_4_bin_masks(seed, rand_offsets, p)\n\n grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)\n col_mask = cols[None, :] < N\n p_scale = 1 / (1 - p)\n\n if USE_BIAS:\n b_ptrs = BIAS + cols[None, :]\n bias = tl.load(b_ptrs, mask=col_mask, other=0.)\n\n for i in range(4):\n if i == 0:\n rand_mask = rand_mask1\n elif i == 1:\n rand_mask = rand_mask2\n elif i == 2:\n rand_mask = rand_mask3\n else:\n rand_mask = rand_mask4\n\n block_mask = (rows[:, None] < M) & col_mask\n grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.)\n\n if ACTIVATION_GRAD:\n inputs = tl.load(input_ptrs, mask=block_mask, other=0.)\n if USE_BIAS:\n inputs += bias\n\n act_grad = ACTIVATION_GRAD(inputs).to(grad_out.dtype)\n grad_out *= act_grad\n\n output = _random_prune_and_scale(grad_out, rand_mask, p, p_scale)\n\n tl.store(grad_in_ptrs, output, mask=block_mask)\n\n if TRAINABLE_BIAS:\n grad_bias += tl.sum(output, axis=0)\n\n rows += BLOCK_M\n grad_out_ptrs += BLOCK_M * stride_grad\n input_ptrs += BLOCK_M * stride_inputs\n grad_in_ptrs += BLOCK_M * stride_grad\n\n if TRAINABLE_BIAS:\n grad_bias_ptr = GRAD_BIAS + row_id * N + cols\n tl.store(grad_bias_ptr, grad_bias, mask=cols < N)\n", - "description_1": "Use triton language to implement a dropout operation on input tensors with both forward and backward passes. The forward pass kernel, `k_dropout_fw`, applies dropout by generating random binary masks, optionally adding bias, applying activation functions, and storing the results. It takes parameters for input/output tensor pointers, bias pointers, seeds, dimensions, and dropout probability. The backward pass kernel, `k_dropout_bw`, computes gradients with similar logic, taking parameters for gradient tensors, input tensors, dimensions, dropout probability, and meta-parameters for autotuning.", - "description_2": "Use triton language to implement a dropout operation with random binary mask generation, optional bias addition, and activation in forward pass; compute gradients in backward pass with similar logic, supporting autotuning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for backward operation\n@triton.jit\ndef kernel_bw(\n GRAD_ACT, GRAD_OUT, ACT_INPUTS,\n N,\n stride_gom, stride_aim,\n BLOCK_N: tl.constexpr,\n EVEN_N: tl.constexpr,\n ACTIVATION_GRAD: tl.constexpr,\n):\n \"\"\"\n Go over all the activation inputs, compute the corresponding gradient\n \"\"\"\n pid_m, pid_n = tl.program_id(axis=0), tl.program_id(axis=1)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n act_input_ptrs = ACT_INPUTS + pid_m * stride_aim + rn\n\n if EVEN_N:\n act_in = tl.load(act_input_ptrs)\n else:\n act_in = tl.load(act_input_ptrs, mask=rn < N, other=0.0)\n\n grad_act = ACTIVATION_GRAD(act_in)\n\n grad_out_ptrs = GRAD_OUT + pid_m * stride_gom + rn\n if EVEN_N:\n grad_out = tl.load(grad_out_ptrs)\n else:\n grad_out = tl.load(grad_out_ptrs, mask=rn < N)\n\n grad_act *= grad_out\n\n grad_act_ptrs = GRAD_ACT + pid_m * stride_gom + rn\n tl.store(grad_act_ptrs, grad_act, mask=rn < N)\n\n\ndef fused_matmul_backward(\n grad_out: torch.Tensor,\n inputs: torch.Tensor,\n act_in: Optional[torch.Tensor],\n weight: torch.Tensor,\n trainable_weight: bool,\n trainable_bias: bool,\n activation_grad=None,\n):\n \"\"\"\n Compute grad_in = activation^-1(grad_out) @ weight.transpose()\n \"\"\"\n\n if not grad_out.is_contiguous():\n grad_out = grad_out.contiguous()\n\n grad_out_ = grad_out if grad_out.ndim == 2 else grad_out.flatten(0, 1)\n inputs_ = inputs if inputs.ndim == 2 else inputs.flatten(0, 1)\n\n assert grad_out_.shape[1] == weight.shape[0], \"Incompatible dimensions in between grad_out and weight\"\n\n M, N = grad_out_.shape\n\n if activation_grad is not None:\n grad_act = torch.empty_like(grad_out_)\n\n if act_in is None:\n act_in = grad_out_\n\n grid = lambda META: (M, triton.cdiv(N, META[\"BLOCK_N\"])) # noqa\n\n kernel_bw[grid](\n grad_act, grad_out_, act_in, # data ptrs\n N, # shapes\n grad_act.stride(0), act_in.stride(0), # strides\n ACTIVATION_GRAD=activation_grad, # optional fused activation\n )\n\n grad_out_ = grad_act\n\n grad_in = grad_out_ @ weight\n grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None\n grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None\n\n return grad_in.reshape_as(inputs), grad_weight, grad_bias\n", - "description_1": "Use triton language to define a kernel `kernel_bw` that computes gradients for activation functions. It has 10 parameters: three pointers to input tensors (GRAD_ACT, GRAD_OUT, ACT_INPUTS), an integer N for matrix dimensions, two stride values (stride_gom, stride_aim), and three compile-time constants (BLOCK_N, EVEN_N, ACTIVATION_GRAD). Another function `fused_matmul_backward` orchestrates the Triton kernel call with gradient and input handling. It takes 7 arguments including PyTorch tensors and optional activation gradient.", - "description_2": "Use triton language to define a backward kernel that calculates activation gradients and use PyTorch to call this kernel in a fused matrix multiplication backward operation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel_fma(\n OUT, ACT_INPUTS, INPUT, WEIGHT, bias,\n M, N, K,\n stride_om, stride_im,\n stride_wn,\n BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n BIAS: tl.constexpr,\n SAVE_ACT_INPUTS: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Bias has shape (N,)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n\n This kernel will consolidate over K\n \"\"\"\n pid = tl.program_id(axis=0)\n\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_M\n GROUP_M = min(num_pid_m - first_pid_m, GROUP_M)\n\n pid_m = first_pid_m + (pid % GROUP_M)\n pid_n = (pid % num_pid_in_group) // GROUP_M\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n\n input_ptrs = INPUT + rm[:, None] * stride_im\n weight_ptrs = WEIGHT + rn[None, :] * stride_wn\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n mask_rn = rn < N\n mask_rm = rm < M\n\n for i in range(0, K, BLOCK_K):\n rk = tl.arange(0, BLOCK_K) + i\n a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0)\n w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0)\n\n acc += tl.dot(a, w)\n\n if SAVE_ACT_INPUTS:\n act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :]\n tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])\n\n if ACTIVATION:\n acc = ACTIVATION(acc)\n\n out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :]\n tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])\n\n\ndef fused_matmul(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n activation=None,\n save_act_inputs: bool = False\n):\n \"\"\"\n Compute e = activation(x @ weight + bias).\n This wrapper kicks the `kernel_fma` Triton kernel\n \"\"\"\n if not x.is_contiguous():\n x = x.contiguous()\n\n x_ = x if x.ndim == 2 else x.flatten(0, 1)\n\n assert (\n x_.shape[1] == weight.shape[1]\n ), f\"Incompatible dimensions in between inputs and weight, {x_.shape} - {weight.shape}\"\n assert bias is None or bias.is_contiguous()\n assert (\n bias is None or bias.shape[0] == weight.shape[0]\n ), \"Incompatible dimensions in between weight and bias\"\n assert weight.is_contiguous()\n\n M, K = x_.shape\n N, K = weight.shape\n\n outputs = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_inputs = torch.empty_like(outputs) if save_act_inputs else x\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n BLOCK_K = 32 if K < 1024 else 64\n\n kernel_fma[grid](\n outputs, act_inputs, x_, weight,\n bias if bias is not None else x,\n M, N, K,\n outputs.stride(0), x_.stride(0),\n weight.stride(0),\n ACTIVATION=activation,\n BIAS=bias is not None,\n GROUP_M=8,\n BLOCK_K=BLOCK_K,\n SAVE_ACT_INPUTS=save_act_inputs\n )\n\n outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N)\n\n return outputs, act_inputs if save_act_inputs else None\n", - "description_1": "Use triton language to implement a kernel function 'kernel_fma' that performs matrix multiplication with optional bias addition and activation. The kernel takes pointers to input matrices, their dimensions, strides, and meta-parameters for block sizes and operations. It computes the output matrix by iterating over blocks of the input matrices, performing dot products, and optionally applying bias and activation. The 'fused_matmul' function wraps this kernel, preparing input tensors and launching the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a matrix multiplication kernel with optional bias and activation, and a wrapper function to prepare inputs and launch the kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Fused layernorm kernel over a 3d tensor\n@triton.jit\ndef layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n \"\"\"\n Arguments:\n 1. X: Input tensor\n 2. Y: Output tensor\n 3. W: Weight for affine transformation\n 4. B: Bias for affine transformation\n 5. M: Mean storage\n 6. V: Variance storage\n 7. stride: Stride size\n 8. N: Number of elements in the last dimension\n 9. eps: Small epsilon value for numerical stability\n 10. affine: Boolean indicating whether to apply affine transformation\n 11. BLOCK_SIZE_N: Block size for the last dimension\n \"\"\"\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n\n # Move to this row\n x_ptrs = X + row * stride + cols\n x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)\n\n # Compute mean and variance\n mean = tl.sum(x, axis=0) / N\n x_zm = tl.where(mask, x - mean, 0.0)\n tl.store(M + row, mean)\n\n x_var = tl.sum(x_zm * x_zm, axis=0) / N\n rstd = 1.0 / tl.sqrt(x_var + eps)\n\n # Normalize, optionally affine\n y = x_zm * rstd\n tl.store(V + row, rstd)\n\n if affine:\n w = tl.load(W + cols, mask=mask, other=1.0)\n b = tl.load(B + cols, mask=mask, other=0.0)\n y = y * w + b\n\n y_ptrs = Y + row * stride + cols\n tl.store(y_ptrs, y, mask=mask)\n\n# Backward pass: DX + partial DW + partial DB\n@triton.jit\ndef layer_norm_bwd_dx_fused(\n DX, DY, DW, DB,\n X, W, M, V,\n Lock, stride, N,\n affine: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"\n Arguments:\n 1. DX: Gradient for input tensor\n 2. DY: Gradient for output tensor\n 3. DW: Weight gradient\n 4. DB: Bias gradient\n 5. X: Input tensor\n 6. W: Weight for affine transformation\n 7. M: Mean storage\n 8. V: Variance storage\n 9. Lock: Lock for synchronization\n 10. stride: Stride size\n 11. N: Number of elements in the last dimension\n 12. affine: Boolean indicating whether affine is applied\n 13. GROUP_SIZE_M: Group size for the rows\n 14. BLOCK_SIZE_N: Block size for the last dimension\n \"\"\"\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n\n x_ptrs = X + row * stride + cols\n dy_ptrs = DY + row * stride + cols\n\n x = tl.load(x_ptrs, mask=mask, other=0)\n dy = tl.load(dy_ptrs, mask=mask, other=0)\n mean = tl.load(M + row)\n rstd = tl.load(V + row)\n\n xhat = (x - mean) * rstd\n\n if affine:\n w = tl.load(W + cols, mask=mask, other=0)\n wdy = w * dy\n else:\n wdy = dy\n\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n mean1 = tl.sum(xhat * wdy, axis=0) / N\n mean2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * mean1 + mean2)) * rstd\n\n dx_ptrs = DX + row * stride + cols\n tl.store(dx_ptrs, dx, mask=mask)\n\n if affine:\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = dy.to(w.dtype)\n\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n\n dw_ptrs = DW + lock_id * N + cols\n db_ptrs = DB + lock_id * N + cols\n\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(dw_ptrs, mask=mask, other=0.)\n partial_db += tl.load(db_ptrs, mask=mask, other=0.)\n\n tl.store(dw_ptrs, partial_dw, mask=mask)\n tl.store(db_ptrs, partial_db, mask=mask)\n\n tl.atomic_xchg(Lock, 0)\n\n# Backward pass: total DW + total DB\n@triton.jit\ndef layer_norm_bwd_dwdb(\n DW, DB, FINAL_DW, FINAL_DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr\n):\n \"\"\"\n Arguments:\n 1. DW: Weight gradient\n 2. DB: Bias gradient\n 3. FINAL_DW: Final accumulated weight gradient\n 4. FINAL_DB: Final accumulated bias gradient\n 5. M: Number of elements in the first dimension\n 6. N: Number of elements in the last dimension\n 7. BLOCK_SIZE_M: Block size for the first dimension\n 8. BLOCK_SIZE_N: Block size for the last dimension\n \"\"\"\n pid = tl.program_id(0)\n\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n mask_cols = cols < N\n\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n offs = rows[:, None] * N + cols[None, :]\n mask_rm = rows < M\n\n dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)\n db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)\n\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n\n tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols)\n tl.store(FINAL_DB + cols, sum_db, mask=mask_cols)\n", - "description_1": "Use triton language to implement a fused layer normalization forward and backward kernels. The forward kernel normalizes a 3D tensor across the last dimension and applies an optional affine transformation. The backward kernels compute gradients for inputs and partial sums for weight and bias updates, followed by accumulation of these partial sums into final gradients.", - "description_2": "Use triton language to implement layer normalization with forward and backward passes, including affine transformation and gradient accumulation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/\n# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html\n\n# autotune: Triton will test out these configurations, and automatically pick the fastest one.\n# heuristic: add arguments to the kernel call automatically given some heuristics. These arguments are passed in \"meta\"\n# fmt: off\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"K\"],\n)\n@triton.heuristics(values={\"depth\": lambda args: triton.next_power_of_2(args[\"K\"]), \"is_fp16\": lambda args: args[\"Y\"].dtype == torch.float16})\n@triton.jit\ndef _softmax(\n Y, X, M,\n stride_ym, stride_yn,\n stride_xm, stride_xn,\n stride_mn,\n K,\n # Meta-params\n depth: tl.constexpr,\n causal: tl.constexpr,\n use_mask: tl.constexpr,\n is_fp16: tl.constexpr,\n log: tl.constexpr,\n):\n # fmt: om\n\n \"\"\"\n Fused softmax kernel over a 3d tensor.\n The softmax is applied over the last dimension, meaning that this is equivalent to torch.softmax(tensor, dim=-1)\n\n Note, if the last dimension is large, say 128K elements, the kernel compile time can shot up to many minutes when\n the kernel is run for the first time.\n \"\"\"\n\n m = tl.program_id(0)\n n = tl.program_id(1)\n\n # col indices\n k = tl.arange(0, depth)\n\n # the memory address of all the elements that we want to load can be computed as follows\n x_ptrs = X + m * stride_xm + n * stride_xn + k\n\n # load input data; pad out-of-bounds elements with 0\n io_mask = k < K\n\n # Causal - 1: skip on the loads directly\n if causal:\n io_mask = io_mask & (k <= n)\n\n x = tl.load(x_ptrs, mask=io_mask, other=float(\"-inf\"))\n\n # Causal - 2: enforce correctness over a couple of misloaded values\n if causal:\n off = float(\"-inf\")\n off = off.to(x.dtype) # type: ignore\n x = tl.where(k > n, off, x)\n\n if use_mask:\n mask_ptrs = M + n * stride_mn + k\n add_mask = tl.load(mask_ptrs, io_mask, other=float(\"-inf\"))\n x += add_mask\n\n # compute numerically-stable softmax\n z = x - tl.max(x, axis=0)\n\n if is_fp16:\n # tl.exp() crashes on fp16 values\n # See https://github.com/openai/triton/issues/241\n z = z.to(tl.float32)\n\n num = tl.exp(z)\n denom = tl.sum(num, axis=0)\n\n if log:\n y = z - tl.log(denom)\n else:\n y = num / denom\n\n # write back to Y.\n # we only write once, hence the \"fused\" softmax naming\n y_ptrs = Y + m * stride_ym + n * stride_yn + k\n\n # technically we could write only the lower triangular matrix in the causal case\n # but this is deemed to error prone\n tl.store(y_ptrs, y, mask=k < K)\n\n\n# fmt: off\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n ],\n key=[\"K\"],\n)\n@triton.heuristics(values={\"is_fp16\": lambda args: args[\"GradIn\"].dtype == torch.float16})\n@triton.jit\ndef _softmax_backward(\n GradIn, GradOut, Out,\n stride_bm, stride_bn,\n stride_gm, stride_gn,\n stride_om, stride_on,\n K,\n # meta-params\n depth: tl.constexpr,\n causal: tl.constexpr,\n is_fp16: tl.constexpr,\n log: tl.constexpr,\n):\n # fmt: on\n\n \"\"\"\n Compute the softmax gradients.\n ..Note: Not autotuning for now because this would lead to broken accumulated gradients\n \"\"\"\n\n m = tl.program_id(0)\n n = tl.program_id(1)\n\n # col indices\n k = tl.arange(0, depth)\n\n # the memory address of all the elements that we want to load can be computed as follows\n grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k\n out_ptrs = Out + m * stride_om + n * stride_on + k\n\n # load input data; pad out-of-bounds elements with 0\n io_mask = k < K\n\n # Causal - 1: skip on the loads directly\n if causal:\n io_mask = io_mask & (k <= n)\n\n g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0))\n o = tl.load(out_ptrs, mask=io_mask, other=float(0))\n\n # Causal - 2: enforce correctness over a couple of misloaded values\n if causal:\n zero = float(0)\n zero = zero.to(g.dtype) # type: ignore\n g = tl.where(k > n, zero, g)\n o = tl.where(k > n, zero, o)\n\n if log:\n s = tl.sum(g, 0)\n if is_fp16:\n o = o.to(tl.float32)\n grad_in = g - tl.exp(o) * s\n else:\n # Step 1: Compute the intermediate sum used for the gradient\n s = tl.sum(g * o, 0)\n\n # Step 2: Compute the gradients\n grad_in = o * (g - s)\n\n # write back to the input gradients\n # technically we could write only the lower triangular matrix in the causal case\n # but this is deemed to error prone\n grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k\n tl.store(grad_in_ptrs, grad_in, mask=k < K)\n", - "description_1": "Use triton language to implement a fused softmax kernel and its backward pass for a 3D tensor. The softmax is applied over the last dimension. The kernel is autotuned for different configurations and uses heuristics to determine meta-parameters like depth and data type. The forward kernel (_softmax) takes 13 parameters: output tensor Y, input tensor X, mask tensor M, strides for Y, X, and M, dimension size K, and meta-parameters for depth, causality, mask usage, data type, and log softmax. The backward kernel (_softmax_backward) takes 12 parameters: gradient input GradIn, gradient output GradOut, output tensor Out, strides for GradIn, GradOut, and Out, dimension size K, and meta-parameters for depth, causality, data type, and log softmax.", - "description_2": "Use triton language to create a fused softmax operation with forward and backward kernels, optimized with autotuning and heuristics for 3D tensors.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Sum a 2d tensor over the first (strided) dimension.\n@triton.jit\ndef k_sum_0(\n Y, X,\n stride_xm,\n M, N,\n is_fp16,\n # META-params\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n \"\"\"\n Sum a 2d tensor over the first (strided) dimension.\n This extracts some speed through a parallel sum across the second dimension\n \"\"\"\n # partial row indices. We'll reduce over this dimension\n m = tl.arange(0, BLOCK_M)\n\n # To get some extra parallelization, we handle several columns in the same thread block\n rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # the memory address of all the elements that we want to load can be computed as follows\n x_ptrs = X + m[:, None] * stride_xm + rn[None, :]\n x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n tiles = M // BLOCK_M\n if M % BLOCK_M > 0:\n tiles += 1\n\n col_mask = (rn[None, :] < N)\n\n for _ in range(tiles):\n # load input data; pad out-of-bounds elements with 0\n # NOTE: make sure to accumulate in fp32 to prevent a trivial overflow\n mask = (m[:, None] < M) & col_mask\n x = tl.load(x_ptrs, mask=mask, other=0.0)\n x_sum += tl.sum(x, 0)\n\n # move the load pointer\n x_ptrs += BLOCK_M * stride_xm\n m += BLOCK_M # update the mask check\n\n tl.store(Y + rn, x_sum, mask=rn < N)\n", - "description_1": "Use triton language to define a kernel function `k_sum_0` that computes the sum of a 2D tensor over the first (strided) dimension using parallelization. The kernel takes 8 parameters: output tensor `Y`, input tensor `X`, stride of the first dimension `stride_xm`, number of rows `M`, number of columns `N`, boolean `is_fp16` to indicate if the operation is on half precision, and block sizes `BLOCK_M` and `BLOCK_N` for dimensions. The kernel performs a load, compute, and store sequence over the dimensions to accumulate the sum.", - "description_2": "Use triton language to implement a parallelized summation of a strided 2D tensor's first dimension, allowing for block-wise computation to enhance performance.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nclass ForeachKernel:\n def __init__(self):\n self.blocking_2d = False\n self.block_size_1d = 1024\n self.block_size_2d = 32\n self.num_warps = 8\n self.sub_kernels = []\n self.x_block_count = 0\n\n def get_block_size(self):\n return self.block_size_2d if self.blocking_2d else self.block_size_1d\n\n def codegen_pid_range(self, code, x_elems):\n num_x_blocks = (x_elems + self.get_block_size() - 1) // self.get_block_size()\n upper_bound_x_pid = self.x_block_count + num_x_blocks\n lower_bound_x_pid = self.x_block_count\n\n if self.x_block_count == 0:\n cond = \"if\"\n else:\n cond = \"elif\"\n\n x_pid_bounds_check = (\n f\"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}\"\n )\n code.append(f\"{cond} {x_pid_bounds_check}:\")\n\n self.x_block_count += num_x_blocks\n\n def codegen_kernel(self, name=None):\n code = []\n\n code.append(\"@triton.jit\")\n code.append(f\"def {name or 'kernel'}(x):\")\n\n code.append(\" xpid = tl.program_id(0)\")\n if self.blocking_2d:\n code.append(\" ypid = tl.program_id(1)\")\n code.append(f\" XBLOCK: tl.constexpr = {self.block_size_2d}\")\n code.append(f\" YBLOCK: tl.constexpr = {self.block_size_2d}\")\n else:\n code.append(f\" XBLOCK: tl.constexpr = {self.block_size_1d}\")\n\n for sub_kernel in self.sub_kernels:\n self.codegen_pid_range(code, int(sub_kernel.numels[0]))\n code.append(\" pass\")\n\n code.append(\"else:\")\n code.append(\" pass\")\n\n return \"\\n\".join(code)\n\n def call_kernel(self, code, name: str):\n call_args_str = \"x\"\n stream_name = \"stream\"\n code.append(\n f\"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})\"\n )\n\n# Example usage\nkernel = ForeachKernel()\nkernel_code = kernel.codegen_kernel(\"example_kernel\")\nprint(kernel_code)\n", - "description_1": "Use triton language to define a kernel with a single argument 'x'. The kernel uses program IDs to determine execution blocks and includes a placeholder for sub-kernel execution. The kernel is called with a grid configuration and a stream.", - "description_2": "Use triton language to create a kernel that processes data in blocks, using program IDs for block management, and execute it with specified grid and stream.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = tl.program_id(0)\n block_start = pid * 1024\n offsets = block_start + tl.arange(0, 1024)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_add_kernel(X, Y, Z, N):\n grid = lambda meta: (triton.cdiv(N, 1024),)\n add_kernel[grid](X, Y, Z, N)\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = torch.empty(1024, device='cuda')\nN = X.numel()\ncall_add_kernel(X, Y, Z, N)\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that takes four parameters: X, Y, Z, and N. The kernel performs element-wise addition of two input tensors X and Y, storing the result in tensor Z. The parameter N specifies the number of elements to process. The kernel is launched with a grid size determined by the lambda function, which divides N by 1024 to determine the number of blocks needed.", - "description_2": "Use triton language to define a kernel that performs element-wise addition of two input tensors and stores the result in an output tensor, with the number of elements specified as a parameter.", - "difficulty": 2 - }, - { - "code": "import triton\n\n# Example kernel function\n@triton.jit\ndef example_kernel(X, Y, Z):\n # Example computation\n idx = triton.program_id(0)\n if idx < X.size(0):\n Z[idx] = X[idx] + Y[idx]\n\ndef call_example_kernel(x, y, z):\n # Assumed that x, y, z are triton allocated tensors\n example_kernel[(1,)](x, y, z)\n", - "description_1": "Use triton language to define a kernel function 'example_kernel' which takes three arguments: X, Y, Z. It performs element-wise addition of arrays X and Y, storing the result in array Z. This is executed for a single program_id. The 'call_example_kernel' function calls the kernel with a grid of size 1.", - "description_2": "Use triton language to create a kernel that performs element-wise addition on two input arrays.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n\n# Example usage\nx = torch.tensor([1.0, 2.0, 3.0], device='cuda')\ny = torch.tensor([4.0, 5.0, 6.0], device='cuda')\nz = torch.empty_like(x)\ncall_example_kernel(x, y, z, block_size=1024)\n", - "description_1": "Use triton language to define a kernel named 'example_kernel' with three parameters X, Y, Z, and a block size. The kernel performs operations on these parameters. A function 'call_example_kernel' is used to invoke this kernel with specific inputs and block size.", - "description_2": "Use triton language to create a kernel for element-wise operations on tensors with a specified block size.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n equal |= a_isnan and b_isnan\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n equal |= a_isnan and b_isnan\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n init,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(\n scratch_base, block_value, index, combine_fn, init\n):\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n\n return exclusive_prefix\n\n@triton.jit\ndef frexp(x):\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n", - "description_1": "Use triton language to define various operations such as promote to tensor, check if type is floating, product accumulation, product reduction along an axis, minimum and maximum functions, minimum and maximum with index, Welford reduction operations, device assert, random integer generation, any operation, bucketize search, packing and unpacking values and flags, exclusive scan with decoupled lookback methods, and frexp function.", - "description_2": "Use triton language to define various mathematical and logical operations, including tensor promotion, floating type checks, reduction operations, min/max operations with indices, Welford reductions, device asserts, random number generation, value packing/unpacking, scanning operations, and frexp function.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch.utils._triton import has_triton\n\nif has_triton():\n @triton.jit\n def _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n ):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n def sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n ):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\n def _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None\n ):\n f_name = \"_scaled_dot_product_attention\"\n check(\n not is_causal,\n f\"{f_name}(): is_causal == True is not supported.\"\n )\n check(\n attn_mask is not None,\n f\"{f_name}(): attn_mask == None is not supported.\"\n )\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\"\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\"\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\n", - "description_1": "Use triton language to implement a sampled matrix multiplication kernel and a scaled dot product attention function. The kernel performs matrix multiplication on sparse matrices using block sparse row (BSR) format, and the attention function applies a scaled dot product attention mechanism using the kernel.", - "description_2": "Use triton language to create a kernel for sampled matrix multiplication with BSR format and implement a scaled dot product attention function utilizing this kernel.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n# A simple add kernel which adds two input arrays element-wise.\n@triton.jit\ndef add_kernel(\n in_ptr0, # Pointer to the first input array\n in_ptr1, # Pointer to the second input array\n out_ptr, # Pointer to the output array\n n_elements, # Total number of elements to process\n BLOCK_SIZE: \"tl.constexpr\", # The size of each block of data processed by a single program instance\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# A kernel with optional parameters to add two arrays.\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0, # Pointer to the first input array\n in_ptr1, # Pointer to the second input array\n out_ptr, # Pointer to the output array\n n_elements, # Total number of elements to process\n ARGS_PASSED: \"tl.constexpr\", # Optional arguments passed as a string, determines computation logic\n BLOCK_SIZE: \"tl.constexpr\", # The size of each block of data processed by a single program instance\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned add kernel for optimized performance.\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0, # Pointer to the first input array\n in_ptr1, # Pointer to the second input array\n out_ptr, # Pointer to the output array\n n_elements, # Total number of elements to process\n BLOCK_SIZE: \"tl.constexpr\", # The size of each block of data processed by a single program instance\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to implement a series of kernels. These include: 'add_kernel' for element-wise addition of two input arrays, with four inputs (two pointers to input arrays, one pointer for output array, and integer for total elements) and one constant BLOCK_SIZE for block processing. 'add_kernel_with_optional_param' extends the add kernel by allowing optional parameter ARGS_PASSED, which alters computation behavior. It maintains the same parameter set with the addition of the ARGS_PASSED constant. The 'add_kernel_autotuned' kernel is configured for optimal performance, using automatic tuning of execution parameters such as num_stages and num_warps, while keeping the same function signature as 'add_kernel'.", - "description_2": "Use triton language to create a simple element-wise addition kernel 'add_kernel' with 4 inputs and 1 constant. Also, implement 'add_kernel_with_optional_param' that adjusts behavior based on a constant parameter, and 'add_kernel_autotuned' for performance optimization, with the same interface as 'add_kernel'.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch as th\nfrom torch import Tensor\nfrom torch.autograd.function import Function\n\n_kAlpha = math.sqrt(2 / math.pi)\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef gelu_forward(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit\n \"\"\"\n return 0.5 * x * (1 + tanh(_kAlpha * x * (1 + 0.044715 * x * x)))\n\n@triton.jit\ndef gelu_backward(x):\n x2 = x * x\n tanh_ = tanh(_kAlpha * x * (1 + 0.044715 * x2))\n dx = 0.5 * (x * (1 - tanh_ * tanh_) * (0.1070322244089 * x2 + 0.797884560802865) + tanh_ + 1)\n return dx\n\n@triton.jit\ndef geglu_forward_kernel(x_ptr, y_ptr, N, C, C2, BLK_C: tl.constexpr, BLK_N: tl.constexpr):\n pid_n = tl.program_id(0)\n pid_c = tl.program_id(1)\n offs_n = pid_n * BLK_N + tl.arange(0, BLK_N)\n offs_c = pid_c * BLK_C + tl.arange(0, BLK_C)\n mask_n = offs_n < N\n mask_c = offs_c < C2\n mask = mask_n[:, None] & mask_c[None, :]\n\n x_ptrs = x_ptr + offs_n[:, None] * C + offs_c[None, :]\n x1 = tl.load(x_ptrs, mask=mask)\n x2 = tl.load(x_ptrs + C2, mask=mask)\n y = x1 * gelu_forward(x2)\n\n y_ptrs = y_ptr + offs_n[:, None] * C2 + offs_c[None, :]\n tl.store(y_ptrs, y, mask=mask)\n\n@triton.jit\ndef geglu_backward_kernel(x_ptr, dx_ptr, dy_ptr, N, C, C2, BLK_C: tl.constexpr, BLK_N: tl.constexpr):\n pid_n = tl.program_id(0)\n pid_c = tl.program_id(1)\n offs_n = pid_n * BLK_N + tl.arange(0, BLK_N)\n offs_c = pid_c * BLK_C + tl.arange(0, BLK_C)\n mask_n = offs_n < N\n mask_c = offs_c < C2\n mask = mask_n[:, None] & mask_c[None, :]\n\n x_ptrs = x_ptr + offs_n[:, None] * C + offs_c[None, :]\n x1 = tl.load(x_ptrs, mask=mask)\n x2 = tl.load(x_ptrs + C2, mask=mask)\n\n dy_ptrs = dy_ptr + offs_n[:, None] * C2 + offs_c[None, :]\n dy = tl.load(dy_ptrs, mask=mask)\n\n # x * F.gelu(gates)\n dx1 = dy * gelu_forward(x2)\n dx2 = dy * x1\n\n # F.gelu(gates)\n dx2 *= gelu_backward(x2)\n\n dx_ptrs = dx_ptr + offs_n[:, None] * C + offs_c[None, :]\n tl.store(dx_ptrs, dx1, mask=mask)\n tl.store(dx_ptrs + C2, dx2, mask=mask)\n\nclass GEGLUFunction(Function):\n @staticmethod\n def forward(ctx, x: Tensor):\n \"\"\"\n - x: ... c, contiguous\n \"\"\"\n N, C = cummul(*x.shape[:-1]), x.size(-1)\n C2 = C >> 1\n y = x.new_empty(*x.shape[:-1], C2)\n\n BLK_C = max(8, min(1024, triton.next_power_of_2(C2)))\n BLK_N = max(1, 1024 // BLK_C)\n grid = lambda meta: (triton.cdiv(N, meta[\"BLK_N\"]), triton.cdiv(C2, meta[\"BLK_C\"]))\n geglu_forward_kernel[grid](x, y, N, C, C2, BLK_C=BLK_C, BLK_N=BLK_N)\n\n ctx.save_for_backward(x)\n return y\n\n @staticmethod\n def backward(ctx, dy: Tensor):\n \"\"\"\n - dy: ... c // 2, contiguous\n \"\"\"\n (x,) = ctx.saved_tensors # ... c\n N, C = cummul(*x.shape[:-1]), x.size(-1)\n C2 = C >> 1\n dx = th.empty_like(x) # ... c\n\n BLK_C = max(8, min(1024, triton.next_power_of_2(C2)))\n BLK_N = max(1, 1024 // BLK_C)\n grid = lambda meta: (triton.cdiv(N, meta[\"BLK_N\"]), triton.cdiv(C2, meta[\"BLK_C\"]))\n\n geglu_backward_kernel[grid](x, dx, dy, N, C, C2, BLK_C=BLK_C, BLK_N=BLK_N)\n return dx\n\ndef geglu(x: Tensor):\n \"\"\"\n input:\n - x: ... c\n \"\"\"\n C = x.size(-1)\n assert C & 0x01 == 0, x.shape\n\n if not x.is_contiguous():\n x = x.contiguous()\n\n return GEGLUFunction.apply(x)\n\nclass GEGLU(nn.Module):\n def forward(self, x: Tensor):\n return geglu(x)\n", - "description_1": "Use triton language to implement GEGLU activation function with forward and backward passes. The forward pass computes the GEGLU activation using a Triton kernel that processes input tensor x of shape (..., c) where c is even, splitting it into two halves and applying the GeLU activation to the second half. The backward pass computes the gradient of the input tensor using another Triton kernel. The kernels use block sizes BLK_C and BLK_N to divide the computation across a grid of threads.", - "description_2": "Use triton language to create a GEGLU activation function with efficient forward and backward computations using Triton kernels, handling input tensors with even last dimension.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch as th\nfrom torch import Tensor\nfrom torch.autograd.function import Function\n\n_kAlpha = math.sqrt(2 / math.pi)\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef gelu_forward(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit\n \"\"\"\n return 0.5 * x * (1 + tanh(_kAlpha * x * (1 + 0.044715 * x * x)))\n\n@triton.jit\ndef gelu_backward(x):\n x2 = x * x\n tanh_ = tanh(_kAlpha * x * (1 + 0.044715 * x2))\n dx = 0.5 * (x * (1 - tanh_ * tanh_) * (0.1070322244089 * x2 + 0.797884560802865) + tanh_ + 1)\n return dx\n\n@triton.jit\ndef geglu_forward_kernel(x_ptr, y_ptr, N, C, C2, BLK_C: tl.constexpr, BLK_N: tl.constexpr):\n pid_n = tl.program_id(0)\n pid_c = tl.program_id(1)\n offs_n = pid_n * BLK_N + tl.arange(0, BLK_N)\n offs_c = pid_c * BLK_C + tl.arange(0, BLK_C)\n mask_n = offs_n < N\n mask_c = offs_c < C2\n mask = mask_n[:, None] & mask_c[None, :]\n\n x_ptrs = x_ptr + offs_n[:, None] * C + offs_c[None, :]\n x1 = tl.load(x_ptrs, mask=mask)\n x2 = tl.load(x_ptrs + C2, mask=mask)\n y = x1 * gelu_forward(x2)\n\n y_ptrs = y_ptr + offs_n[:, None] * C2 + offs_c[None, :]\n tl.store(y_ptrs, y, mask=mask)\n\n@triton.jit\ndef geglu_backward_kernel(x_ptr, dx_ptr, dy_ptr, N, C, C2, BLK_C: tl.constexpr, BLK_N: tl.constexpr):\n pid_n = tl.program_id(0)\n pid_c = tl.program_id(1)\n offs_n = pid_n * BLK_N + tl.arange(0, BLK_N)\n offs_c = pid_c * BLK_C + tl.arange(0, BLK_C)\n mask_n = offs_n < N\n mask_c = offs_c < C2\n mask = mask_n[:, None] & mask_c[None, :]\n\n x_ptrs = x_ptr + offs_n[:, None] * C + offs_c[None, :]\n x1 = tl.load(x_ptrs, mask=mask)\n x2 = tl.load(x_ptrs + C2, mask=mask)\n\n dy_ptrs = dy_ptr + offs_n[:, None] * C2 + offs_c[None, :]\n dy = tl.load(dy_ptrs, mask=mask)\n\n dx1 = dy * gelu_forward(x2)\n dx2 = dy * x1\n dx2 *= gelu_backward(x2)\n\n dx_ptrs = dx_ptr + offs_n[:, None] * C + offs_c[None, :]\n tl.store(dx_ptrs, dx1, mask=mask)\n tl.store(dx_ptrs + C2, dx2, mask=mask)\n\nclass GEGLUFunction(Function):\n @staticmethod\n def forward(ctx, x: Tensor):\n \"\"\"\n - x: ... c, contiguous\n \"\"\"\n N, C = cummul(*x.shape[:-1]), x.size(-1)\n C2 = C >> 1\n y = x.new_empty(*x.shape[:-1], C2)\n\n BLK_C = max(8, min(1024, triton.next_power_of_2(C2)))\n BLK_N = max(1, 1024 // BLK_C)\n grid = lambda meta: (triton.cdiv(N, meta[\"BLK_N\"]), triton.cdiv(C2, meta[\"BLK_C\"]))\n geglu_forward_kernel[grid](x, y, N, C, C2, BLK_C=BLK_C, BLK_N=BLK_N)\n\n ctx.save_for_backward(x)\n return y\n\n @staticmethod\n def backward(ctx, dy: Tensor):\n \"\"\"\n - dy: ... c // 2, contiguous\n \"\"\"\n (x,) = ctx.saved_tensors # ... c\n N, C = cummul(*x.shape[:-1]), x.size(-1)\n C2 = C >> 1\n dx = th.empty_like(x) # ... c\n\n BLK_C = max(8, min(1024, triton.next_power_of_2(C2)))\n BLK_N = max(1, 1024 // BLK_C)\n grid = lambda meta: (triton.cdiv(N, meta[\"BLK_N\"]), triton.cdiv(C2, meta[\"BLK_C\"]))\n\n geglu_backward_kernel[grid](x, dx, dy, N, C, C2, BLK_C=BLK_C, BLK_N=BLK_N)\n return dx\n\ndef geglu(x: Tensor):\n \"\"\"\n input:\n - x: ... c\n \"\"\"\n C = x.size(-1)\n assert C & 0x01 == 0, x.shape\n\n if not x.is_contiguous():\n x = x.contiguous()\n\n return GEGLUFunction.apply(x)\n\nclass GEGLU(nn.Module):\n def forward(self, x: Tensor):\n return geglu(x)\n", - "description_1": "Use triton language to define and implement a GEGLU activation function with forward and backward pass kernels. The kernels perform operations on tensor pointers for efficient element-wise activation and gradient computation. The forward pass applies GELU activation to the input tensor split in half, while the backward pass computes gradients for both split tensors using chain rule derivatives.", - "description_2": "Use triton language to implement GEGLU activation function with forward/backward kernels for tensor processing.", - "difficulty": 3 - }, - { - "code": "import torch as th\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef seqlen_to_index_kernel(seqlen_ptr, idx_ptr, BLK: tl.constexpr):\n pid = tl.program_id(0)\n i = tl.load(seqlen_ptr + pid)\n j = tl.load(seqlen_ptr + pid + 1)\n idx = tl.arange(0, BLK)\n tl.store(idx_ptr + i + idx, idx, mask=idx < (j - i))\n\ndef seqlen_to_index(seqlen: Tensor, max_seqlen: int):\n \"\"\"Convert seqlen into index.\"\"\"\n assert seqlen[0].item() == 0\n\n B = seqlen.size(0) - 1\n idx = seqlen.new_empty(seqlen[-1].item(), dtype=th.int64)\n BLK = triton.next_power_of_2(max_seqlen)\n seqlen_to_index_kernel[(B,)](seqlen, idx, BLK)\n return idx\n\n@triton.jit\ndef seqlen_to_batch_index_kernel(seqlen_ptr, idx_ptr, BLK: tl.constexpr):\n pid = tl.program_id(0)\n i = tl.load(seqlen_ptr + pid)\n j = tl.load(seqlen_ptr + pid + 1)\n idx = tl.arange(0, BLK)\n tl.store(idx_ptr + i + idx, pid, mask=idx < (j - i))\n\ndef seqlen_to_batch_index(seqlen: Tensor, max_seqlen: int):\n \"\"\"Convert seqlen into batch index.\"\"\"\n assert seqlen[0].item() == 0\n\n B = seqlen.size(0) - 1\n idx = seqlen.new_empty(seqlen[-1].item(), dtype=th.int64)\n BLK = triton.next_power_of_2(max_seqlen)\n seqlen_to_batch_index_kernel[(B,)](seqlen, idx, BLK)\n return idx\n", - "description_1": "Use triton language to implement two kernels: one that maps sequence lengths to sequential indices and another that maps sequence lengths to batch indices. Each kernel takes pointers to sequence lengths and an index buffer, and a compile-time constant for block size. The caller functions prepare the indices tensor and determine the grid size based on batch size, then launch the corresponding kernel.", - "description_2": "Use triton language to write kernels for converting sequence lengths into indices and batch indices, including necessary calling functions to manage and execute these kernels with grid setup.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for matrix multiplication\n@triton.jit\ndef matmul_kernel(A, B, C, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n # Triton kernel code for matrix multiplication\n pass\n\n# Function to call the Triton kernel\ndef call_matmul_kernel(A, B, C, M, N, K):\n # Define block sizes\n BLOCK_SIZE_M = 128\n BLOCK_SIZE_N = 128\n BLOCK_SIZE_K = 32\n\n # Launch the Triton kernel\n matmul_kernel[(M, N)](A, B, C, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with parameters A, B, C (input matrices), M, N, K (dimensions), and BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K (block sizes). The kernel performs matrix multiplication and is called using the function call_matmul_kernel.", - "description_2": "Use triton language to create a matrix multiplication kernel and a function to call it, with specified input matrices and dimensions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M,\n N, K, bits, maxq, stride_am, stride_ak, stride_bk,\n stride_bn, stride_cm, stride_cn, stride_scales,\n stride_zeros, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8\n # times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk +\n offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit\n # word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused\n # in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] *\n stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs +\n g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(\n a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit\n # values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[\n None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits,\n maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,\n stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8\n # times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk +\n offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit\n # word from B\n scales_ptrs = scales_ptr + offs_n[\n None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits\n ) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused\n # in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(\n a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit\n # values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[\n None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n \"\"\"matmul248 function with matmul_248_kernel.\"\"\"\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]),\n device=input.device,\n dtype=torch.float16)\n grid = lambda META: ( # noqa: E731\n triton.cdiv( # noqa: E731\n input.shape[0], META['BLOCK_SIZE_M']) * triton. # noqa: E731\n cdiv( # noqa: E731\n qweight.shape[1], META['BLOCK_SIZE_N']), ) # noqa: E731\n matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1],\n input.shape[1], bits, maxq, input.stride(0),\n input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0),\n output.stride(1), scales.stride(0),\n qzeros.stride(0))\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n \"\"\"transpose_matmul248 function with transpose_matmul_248_kernel.\"\"\"\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim),\n device=input.device,\n dtype=torch.float16)\n grid = lambda META: ( # noqa: E731\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) # noqa: E731\n * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) # noqa: E731\n transpose_matmul_248_kernel[grid](input, qweight, output, scales,\n qzeros, g_idx, input.shape[0],\n qweight.shape[1], output_dim,\n bits, maxq, input.stride(0),\n input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0),\n output.stride(1), scales.stride(0),\n qzeros.stride(0))\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: 'matmul_248_kernel' and 'transpose_matmul_248_kernel'. The 'matmul_248_kernel' computes the matrix multiplication C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). It uses additional parameters for scaling and zero-point adjustments. The 'transpose_matmul_248_kernel' performs a similar operation but computes C = A x B where A is a float16 matrix of shape (M, N) and C is a float16 matrix of shape (M, K). Both kernels are optimized for specific block sizes and group sizes, and they handle bit-level operations for quantized matrices.", - "description_2": "Use triton language to create optimized matrix multiplication kernels for quantized matrices, handling bit-level operations and using block and group sizes for performance.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel to compute bias\n@triton.jit\ndef bias_kernel(out, weights, stride_om, stride_on, stride_wn,\n N: tl.constexpr, M: tl.constexpr, NH: tl.constexpr, \n BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_NH: tl.constexpr, \n BIDIRECTIONAL: tl.constexpr, NUM_BUCKETS: tl.constexpr, \n MAX_DISTANCE: tl.constexpr, GROUP_SIZE_M: tl.constexpr):\n \n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n\n # Compute relative positions\n relative_positions = offs_n[None, :]-offs_m[:, None]\n\n # Compute bucket indices based on relative positions\n relative_buckets = tl.zeros_like(relative_positions)\n num_buckets = NUM_BUCKETS\n if BIDIRECTIONAL:\n num_buckets //= 2\n relative_buckets += (relative_positions > 0).to(tl.int32) * num_buckets\n relative_positions = tl.abs(relative_positions)\n else:\n relative_positions = tl.maximum(-relative_positions, tl.zeros_like(relative_positions))\n\n # Half of the buckets are for exact increments in positions\n max_exact = num_buckets // 2\n is_small = relative_positions < max_exact\n\n # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance\n relative_position_if_large = max_exact + (\n tl.log(relative_positions.to(tl.float32) / max_exact)\n / tl.log(MAX_DISTANCE / max_exact)\n * (num_buckets - max_exact)\n ).to(tl.int32)\n relative_position_if_large = tl.minimum(relative_position_if_large, num_buckets - 1)\n\n relative_buckets += tl.where(is_small, relative_positions, relative_position_if_large)\n\n for i in range(0, NH, BLOCK_NH):\n offs_nh = i + tl.arange(0, BLOCK_NH)\n bucket_offs = relative_buckets[:, :, None] * stride_wn + offs_nh[None, None, :]\n\n # Retrieve bias values from weights tensor\n bias_ptrs = weights + bucket_offs # (BLOCK_M, BLOCK_N, BLOCK_NH)\n bias_values = tl.load(bias_ptrs)\n\n out_offs = (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)[:, :, None] + offs_nh[None, None, :]\n out_ptrs = out + out_offs\n\n o_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)[:, :, None] & (offs_nh[None, None, :] < NH)\n\n # Store bias values in the output tensor\n tl.store(out_ptrs, bias_values, mask=o_mask)\n\n# Kernel to compute bias gradient\n@triton.jit\ndef bias_kernel_backward(\n d_weights, d_out, weights, stride_om, stride_on, stride_wn,\n N: tl.constexpr, M: tl.constexpr, NH: tl.constexpr, \n BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_NH: tl.constexpr, \n BIDIRECTIONAL: tl.constexpr, NUM_BUCKETS: tl.constexpr, \n MAX_DISTANCE: tl.constexpr, GROUP_SIZE_M: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n\n relative_positions = offs_m[:, None] - offs_n[None, :]\n\n relative_buckets = tl.zeros_like(relative_positions)\n num_buckets = NUM_BUCKETS\n if BIDIRECTIONAL:\n num_buckets //= 2\n relative_buckets += (relative_positions > 0).to(tl.int32) * num_buckets\n relative_positions = tl.abs(relative_positions)\n else:\n relative_positions = tl.maximum(-relative_positions, tl.zeros_like(relative_positions))\n\n # Half of the buckets are for exact increments in positions\n max_exact = num_buckets // 2\n is_small = relative_positions < max_exact\n\n # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance\n relative_position_if_large = max_exact + (\n tl.log(relative_positions.to(tl.float32) / max_exact)\n / tl.log(MAX_DISTANCE / max_exact)\n * (num_buckets - max_exact)\n ).to(tl.int32)\n relative_position_if_large = tl.minimum(relative_position_if_large, num_buckets - 1)\n\n relative_buckets += tl.where(is_small, relative_positions, relative_position_if_large)\n\n for i in range(0, NH, BLOCK_NH):\n offs_nh = i + tl.arange(0, BLOCK_NH)\n bucket_offs = relative_buckets[:, :, None] * stride_wn + offs_nh[None, None, :]\n\n d_out_ptrs = d_out + (offs_m[:, None] * stride_om + offs_n[None, :] * stride_on)[:, :, None] + offs_nh[None, None, :]\n o_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)[:, :, None] & (offs_nh[None, None, :] < NH)\n d_out_values = tl.load(d_out_ptrs, mask=o_mask, other=0.0)\n\n d_weights_ptrs = d_weights + bucket_offs\n tl.atomic_add(d_weights_ptrs, d_out_values, mask=relative_buckets[:, :, None] < NUM_BUCKETS)\n\n# Function to compute bias using Triton kernels\nclass BiasOp(torch.autograd.Function):\n @staticmethod\n def forward(ctx, weights, M, N, NH, BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE, dtype=torch.float16):\n ctx.save_for_backward(weights)\n ctx.M, ctx.N, ctx.NH = M, N, NH\n ctx.BIDIRECTIONAL, ctx.NUM_BUCKETS, ctx.MAX_DISTANCE = BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE\n ctx.dtype = dtype\n\n out = torch.empty((M, N, NH), device=weights.device, dtype=dtype)\n # Config\n BLOCK_SIZE_N = 32\n BLOCK_SIZE_M = 32\n BLOCK_SIZE_H = 16\n\n # Launch forward kernel\n grid = (triton.cdiv(N, BLOCK_SIZE_N) * triton.cdiv(M, BLOCK_SIZE_M),)\n bias_kernel[grid](\n out,\n weights,\n out.stride(0), out.stride(1), weights.stride(0),\n N, M, NH,\n BLOCK_SIZE_N, BLOCK_SIZE_M, BLOCK_SIZE_H,\n BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE, out.stride(1)\n )\n\n return out\n\n @staticmethod\n def backward(ctx, d_out):\n weights, = ctx.saved_tensors\n M, N, NH = ctx.M, ctx.N, ctx.NH\n BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE = ctx.BIDIRECTIONAL, ctx.NUM_BUCKETS, ctx.MAX_DISTANCE\n dtype = ctx.dtype\n\n d_weights = torch.zeros_like(weights)\n\n # Config\n BLOCK_SIZE_N = 32\n BLOCK_SIZE_M = 32\n BLOCK_SIZE_H = 16\n\n # Launch backward kernel\n grid = (triton.cdiv(N, BLOCK_SIZE_N) * triton.cdiv(M, BLOCK_SIZE_M),)\n bias_kernel_backward[grid](\n d_weights,\n d_out,\n weights,\n d_out.stride(0), d_out.stride(1), weights.stride(0),\n N, M, NH,\n BLOCK_SIZE_N, BLOCK_SIZE_M, BLOCK_SIZE_H,\n BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE, d_out.stride(1)\n )\n\n return d_weights, None, None, None, None, None, None, None\n\ndef triton_compute_bias(weights, M, N, NH, BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE, dtype=torch.float16):\n # Check constraints\n assert weights.shape == (NUM_BUCKETS, NH), \"Incorrect shape of weights tensor\"\n assert weights.is_contiguous(), \"Weights tensor must be contiguous\"\n assert N > 0 and M > 0 and NH > 0, \"Invalid dimensions\"\n assert BIDIRECTIONAL in [True, False], \"BIDIRECTIONAL must be a boolean\"\n assert NUM_BUCKETS > 0, \"NUM_BUCKETS must be positive\"\n assert MAX_DISTANCE > 0, \"MAX_DISTANCE must be positive\"\n return BiasOp.apply(weights, M, N, NH, BIDIRECTIONAL, NUM_BUCKETS, MAX_DISTANCE, dtype)\n", - "description_1": "Use triton language to implement a bias computation kernel and its backward pass. The forward kernel 'bias_kernel' takes 14 parameters: out (output tensor), weights (weights tensor), stride_om, stride_on, stride_wn (stride values), N, M, NH (dimensions), BLOCK_N, BLOCK_M, BLOCK_NH (block sizes), BIDIRECTIONAL (boolean for directionality), NUM_BUCKETS (number of buckets), MAX_DISTANCE (maximum distance), and GROUP_SIZE_M (group size for M dimension). The backward kernel 'bias_kernel_backward' takes similar parameters with an additional d_weights (gradient of weights) and d_out (gradient of output). The 'BiasOp' class encapsulates the forward and backward operations using these kernels, and 'triton_compute_bias' is a helper function to apply the 'BiasOp' with input validation.", - "description_2": "Use triton language to create a forward and backward kernel for bias computation with configurable parameters, and integrate them into a PyTorch autograd function for easy use.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, B, sm_scale,\n L, O,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vn, stride_vk,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_bz, stride_bh, stride_bm, stride_bn,\n Z, H, M, N, P_SEQ,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,\n DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n):\n # Triton kernel for forward pass of FlashAttention\n input_dtype = Q.dtype.element_ty\n start_m = tl.program_id(0)\n off_h = tl.program_id(1)\n off_z = tl.program_id(2)\n\n log2e: tl.constexpr = 1.4426950408889634\n\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_kz + off_h * stride_kh\n V += off_z * stride_vz + off_h * stride_vh\n O += off_z * stride_oz + off_h * stride_oh\n if HAS_BIAS:\n B += off_z * stride_bz + off_h * stride_bh\n L += (off_z * H + off_h) * M\n\n offs_m_base = tl.arange(0, BLOCK_M)\n offs_m = start_m * BLOCK_M + offs_m_base\n offs_n_base = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n\n q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok)\n l_ptrs = L + offs_m\n\n m_i = tl.full([BLOCK_M], value=-float(\"inf\"), dtype=tl.float32)\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n mask_m = offs_m < M\n if DIVISIBLE_M:\n q = tl.load(q_ptrs, cache_modifier=\".cg\")\n else:\n q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=\".cg\")\n\n if BLOCK_DMODEL < 128:\n I = tl.where(offs_k[:, None] == offs_k,\n tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),\n tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))\n q = tl.dot(q, I).to(input_dtype)\n\n if IS_CAUSAL:\n hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)\n if LARGER_M:\n hi = tl.maximum(0, hi)\n else:\n hi = N\n\n offs_n_init = offs_n_base\n k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn)\n v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n if HAS_BIAS:\n bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn)\n\n for start_n in range(0, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n offs_n = start_n + offs_n_base\n\n mask_n = offs_n < N\n if DIVISIBLE_N:\n k = tl.load(k_ptrs, cache_modifier=\".cg\")\n v = tl.load(v_ptrs, cache_modifier=\".cg\")\n else:\n k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=\".cg\")\n v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=\".cg\")\n\n if HAS_BIAS:\n if DIVISIBLE_M and DIVISIBLE_N:\n b = tl.load(bias_ptrs)\n else:\n b = tl.load(bias_ptrs, mask_m[:, None] & mask_n[None, :])\n\n s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n s += tl.dot(q, k) * sm_scale\n if HAS_BIAS:\n s += b\n\n if not DIVISIBLE_N:\n s = tl.where(mask_n[None, :], s, float(\"-inf\"))\n if IS_CAUSAL:\n causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]\n s = tl.where(causal_mask, s, float(\"-inf\"))\n\n m_i_new = tl.maximum(m_i, tl.max(s, 1))\n alpha = tl.math.exp2((m_i - m_i_new)*log2e)\n p = tl.math.exp2((s - m_i_new[:, None])*log2e)\n\n acc *= alpha[:, None]\n acc += tl.dot(p.to(input_dtype), v)\n\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vn\n if HAS_BIAS:\n bias_ptrs += BLOCK_N * stride_bn\n\n if IS_CAUSAL and LARGER_M:\n is_empty_line = (offs_m + P_SEQ) < 0\n acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))\n l = tl.where(is_empty_line, float(\"-inf\"), m_i + tl.log(l_i))\n else:\n acc = acc * (1.0 / l_i[:, None])\n l = m_i + tl.log(l_i)\n\n if DIVISIBLE_M:\n tl.store(l_ptrs, l, cache_modifier=\".cg\")\n tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=\".cg\")\n else:\n tl.store(l_ptrs, l, mask=mask_m, cache_modifier=\".cg\")\n tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=\".cg\")\n\ndef flash_attn_v2_fwd(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):\n B, H, M, D = q.shape\n N = k.shape[2]\n P_SEQ = N - M\n larger_m = M > N\n\n bias_batch_stride = bias.stride(0) if bias is not None else 0\n bias_heads_stride = bias.stride(1) if bias is not None else 0\n if bias is not None:\n if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):\n bias_batch_stride = 0\n if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):\n bias_heads_stride = 0\n\n divisible_m = M % BLOCK_M == 0\n divisible_n = N % BLOCK_N == 0\n grid = (triton.cdiv(M, BLOCK_M), H, B)\n o = torch.empty_like(q)\n L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)\n\n with torch.cuda.device(q.device.index):\n _fwd_kernel[grid](\n q, k, v, bias, sm_scale,\n L, o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n bias_batch_stride, bias_heads_stride,\n bias.stride(2) if bias is not None else 0,\n bias.stride(3) if bias is not None else 0,\n B, H, M, N, P_SEQ,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,\n IS_CAUSAL=causal, LARGER_M=larger_m,\n DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,\n HAS_BIAS=(bias is not None),\n num_warps=num_warps, num_stages=num_stages,\n )\n\n return o, L\n", - "description_1": "Use triton language to implement a forward kernel for FlashAttention. The kernel takes 28 parameters: Q, K, V, B, sm_scale, L, O, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_ok, stride_bz, stride_bh, stride_bm, stride_bn, Z, H, M, N, P_SEQ, BLOCK_M, BLOCK_DMODEL, BLOCK_N, IS_CAUSAL, LARGER_M, DIVISIBLE_M, DIVISIBLE_N, HAS_BIAS. It computes the attention output O and log-sum-exp L for a given set of queries Q, keys K, values V, and optional bias B, using a specified scaling factor sm_scale. The kernel supports causal masking and handles different block sizes and divisibility conditions.", - "description_2": "Use triton language to implement a function flash_attn_v2_fwd that calls the forward kernel _fwd_kernel. The function takes 10 parameters: q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages. It prepares the input tensors and grid configuration for the kernel, handles optional bias strides, and manages device context. The function returns the attention output and log-sum-exp tensors.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, B, sm_scale,\n L, ml, O,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vn, stride_vk,\n stride_oz, stride_oh, stride_os, stride_om,\n stride_lz, stride_lh, stride_ls, stride_lm,\n stride_bz, stride_bh, stride_bm, stride_bn,\n Z, H, M, N, P_SEQ,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n DIVISIBLE_N: tl.constexpr,\n HAS_BIAS: tl.constexpr, NUM_SPLITS:tl.constexpr\n):\n input_dtype = Q.dtype.element_ty\n # -- grid id --\n off_s = tl.program_id(0)\n off_h = tl.program_id(1)\n off_z = tl.program_id(2)\n\n n_per_split = N//NUM_SPLITS\n split_n_start = off_s*n_per_split\n split_n_end = N if off_s+1 == NUM_SPLITS else split_n_start+n_per_split\n\n log2e: tl.constexpr = 1.4426950408889634\n\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_kz + off_h * stride_kh\n V += off_z * stride_vz + off_h * stride_vh\n O += off_z * stride_oz + off_h * stride_oh + off_s*stride_os\n if HAS_BIAS:\n B += off_z * stride_bz + off_h * stride_bh\n L += off_z * stride_lz + off_h * stride_lh + off_s*stride_ls \n ml += off_z * stride_lz + off_h * stride_lh + off_s*stride_ls \n\n offs_m = tl.arange(0, BLOCK_M)\n offs_n_base = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n\n q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) \n o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :]) \n l_ptrs = L + offs_m\n ml_ptrs = ml + offs_m\n\n m_i = tl.full([BLOCK_M], value=-float(\"inf\"), dtype=tl.float32)\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n mask_m = offs_m < M\n q = tl.load(q_ptrs, cache_modifier=\".cg\")\n\n if BLOCK_DMODEL < 128:\n I = tl.where(offs_k[:, None] == offs_k,\n tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),\n tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))\n q = tl.dot(q, I).to(input_dtype)\n \n offs_n_init = offs_n_base+split_n_start\n k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) \n v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) \n if HAS_BIAS:\n bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn) \n\n for start_n in range(split_n_start, split_n_end, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n offs_n = start_n + offs_n_base\n\n mask_n = offs_n < N\n if DIVISIBLE_N:\n k = tl.load(k_ptrs, cache_modifier=\".cg\")\n v = tl.load(v_ptrs, cache_modifier=\".cg\")\n else:\n k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=\".cg\")\n v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=\".cg\")\n\n if HAS_BIAS:\n b = tl.load(bias_ptrs, mask_m[:, None] & mask_n[None, :])\n\n s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n s += tl.dot(q, k) * sm_scale\n if HAS_BIAS:\n s += b\n\n if not DIVISIBLE_N:\n s = tl.where(mask_n[None, :], s, float(\"-inf\"))\n\n m_i_new = tl.maximum(m_i, tl.max(s, 1))\n alpha = tl.math.exp2((m_i - m_i_new)*log2e)\n p = tl.math.exp2((s - m_i_new[:, None])*log2e)\n\n acc *= alpha[:, None]\n acc += tl.dot(p.to(input_dtype), v)\n\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vn\n if HAS_BIAS:\n bias_ptrs += BLOCK_N * stride_bn\n\n acc = acc * (1.0 / l_i[:, None])\n l = l_i\n tl.store(l_ptrs, l, mask=mask_m, cache_modifier=\".cg\")\n tl.store(ml_ptrs, m_i, mask=mask_m, cache_modifier=\".cg\")\n\n tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=\".cg\")\n\ndef flash_attn_v2_fwd(q, k, v, bias, causal, sm_scale, NUM_SPLITS, BLOCK_M, BLOCK_N, num_warps, num_stages):\n B, H, M, D = q.shape\n N = k.shape[2]\n P_SEQ = N - M\n\n bias_batch_stride = bias.stride(0) if bias is not None else 0\n bias_heads_stride = bias.stride(1) if bias is not None else 0\n if bias is not None:\n if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):\n bias_batch_stride = 0\n if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):\n bias_heads_stride = 0\n\n divisible_n = N % BLOCK_N == 0\n grid = (NUM_SPLITS, H, B)\n o = torch.empty_like(q)\n L = torch.zeros((B, H, NUM_SPLITS, M), device=q.device, dtype=torch.float32)\n ml = torch.zeros((B, H, NUM_SPLITS, M), device=q.device, dtype=torch.float32)\n so = torch.empty((B, H, NUM_SPLITS, M, D), device=q.device, dtype=torch.float32)\n with torch.cuda.device(q.device.index):\n _fwd_kernel[grid](\n q, k, v, bias, sm_scale,\n L, ml, so,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n so.stride(0), so.stride(1), so.stride(2), so.stride(3),\n L.stride(0), L.stride(1), L.stride(2), L.stride(3),\n bias_batch_stride, bias_heads_stride,\n bias.stride(2) if bias is not None else 0,\n bias.stride(3) if bias is not None else 0,\n B, H, M, N, P_SEQ,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,\n DIVISIBLE_N=divisible_n,\n HAS_BIAS=(bias is not None),\n NUM_SPLITS = NUM_SPLITS,\n num_warps=num_warps, num_stages=num_stages,\n )\n ml = ml.squeeze(-1)\n L = L.squeeze(-1)\n so = so.squeeze(-2)\n a_max = torch.max(ml, dim=-1, keepdim=True).values\n alpha = torch.exp(ml-a_max)\n max_log_scores_ = torch.log(alpha*L)\n weights = torch.softmax(max_log_scores_, dim=-1)\n res = torch.sum(weights.unsqueeze(-1) * so, dim=-2, keepdim=True)\n return res, L, ml\n", - "description_1": "Use triton language to implement a forward kernel (_fwd_kernel) for flash attention mechanism, which takes 36 parameters: Q, K, V, B for input tensors, sm_scale for scaling factor, L, ml, O for output tensors, strides for Q, K, V, O, L, B as inputs for tensor shapes, Z, H, M, N, P_SEQ for dimensions, BLOCK_M, BLOCK_DMODEL, BLOCK_N as constexpr for blocking, DIVISIBLE_N, HAS_BIAS, NUM_SPLITS as constexpr for configuration, and updates attention scores and accumulated values. The kernel is invoked by flash_attn_v2_fwd function, which takes 10 parameters: q, k, v, bias, causal, sm_scale, NUM_SPLITS, BLOCK_M, BLOCK_N, num_warps, num_stages for processing tensor shapes, stride values, and bias conditions, returning computed results.", - "description_2": "Use triton language to implement a forward kernel (_fwd_kernel) for a flash attention mechanism and invoke it using a wrapper function flash_attn_v2_fwd with necessary tensor parameters and configuration values.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_with_bias_calculation(\n Q, K, V, BW, sm_scale,\n L, O,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vn, stride_vk,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_wn,\n Z, H, M, N, P_SEQ,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,\n DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,\n HAS_BIAS: tl.constexpr, NUM_BUCKETS: tl.constexpr, MAX_DISTANCE: tl.constexpr\n):\n input_dtype = Q.dtype.element_ty\n start_m = tl.program_id(0)\n off_h = tl.program_id(1)\n off_z = tl.program_id(2)\n\n log2e: tl.constexpr = 1.4426950408889634\n\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_kz + off_h * stride_kh\n V += off_z * stride_vz + off_h * stride_vh\n O += off_z * stride_oz + off_h * stride_oh\n\n L += (off_z * H + off_h) * M\n\n offs_m_base = tl.arange(0, BLOCK_M)\n offs_m = start_m * BLOCK_M + offs_m_base\n offs_n_base = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n\n q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok)\n l_ptrs = L + offs_m\n\n m_i = tl.full([BLOCK_M], value=-float(\"inf\"), dtype=tl.float32)\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n mask_m = offs_m < M\n if DIVISIBLE_M:\n q = tl.load(q_ptrs, cache_modifier=\".cg\")\n else:\n q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=\".cg\")\n\n if BLOCK_DMODEL < 128:\n I = tl.where(offs_k[:, None] == offs_k,\n tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),\n tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))\n q = tl.dot(q, I).to(input_dtype)\n\n if IS_CAUSAL:\n hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)\n if LARGER_M:\n hi = tl.maximum(0, hi)\n else:\n hi = N\n\n offs_n_init = offs_n_base\n k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn)\n v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n\n for start_n in range(0, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n offs_n = start_n + offs_n_base\n\n mask_n = offs_n < N\n if DIVISIBLE_N:\n k = tl.load(k_ptrs, cache_modifier=\".cg\")\n v = tl.load(v_ptrs, cache_modifier=\".cg\")\n else:\n k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=\".cg\")\n v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=\".cg\")\n\n if HAS_BIAS:\n relative_positions = offs_n[None, :] - offs_m[:, None]\n relative_buckets = tl.zeros_like(relative_positions)\n num_buckets = NUM_BUCKETS\n if not IS_CAUSAL:\n num_buckets //= 2\n relative_buckets += (relative_positions > 0).to(tl.int32) * num_buckets\n relative_positions = tl.abs(relative_positions)\n else:\n relative_positions = tl.maximum(-relative_positions, tl.zeros_like(relative_positions))\n\n max_exact = num_buckets // 2\n is_small = relative_positions < max_exact\n\n relative_position_if_large = max_exact + (\n tl.log(relative_positions.to(tl.float32) / max_exact)\n / tl.log(MAX_DISTANCE / max_exact)\n * (num_buckets - max_exact)\n ).to(tl.int32)\n relative_position_if_large = tl.minimum(relative_position_if_large, num_buckets - 1)\n\n relative_buckets += tl.where(is_small, relative_positions, relative_position_if_large)\n\n bucket_offs = relative_buckets * stride_wn + off_h\n bias_ptrs = BW + bucket_offs\n bias_values = tl.load(bias_ptrs)\n\n s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n s += tl.dot(q, k) * sm_scale\n if HAS_BIAS:\n s += bias_values\n\n if not DIVISIBLE_N:\n s = tl.where(mask_n[None, :], s, float(\"-inf\"))\n if IS_CAUSAL:\n causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]\n s = tl.where(causal_mask, s, float(\"-inf\"))\n\n m_i_new = tl.maximum(m_i, tl.max(s, 1))\n alpha = tl.math.exp2((m_i - m_i_new) * log2e)\n p = tl.math.exp2((s - m_i_new[:, None]) * log2e)\n\n acc *= alpha[:, None]\n acc += tl.dot(p.to(input_dtype), v)\n\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vn\n\n if IS_CAUSAL and LARGER_M:\n is_empty_line = (offs_m + P_SEQ) < 0\n acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))\n l = tl.where(is_empty_line, float(\"-inf\"), m_i + tl.log(l_i))\n else:\n acc = acc * (1.0 / l_i[:, None])\n l = m_i + tl.log(l_i)\n\n if DIVISIBLE_M:\n tl.store(l_ptrs, l, cache_modifier=\".cg\")\n tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=\".cg\")\n else:\n tl.store(l_ptrs, l, mask=mask_m, cache_modifier=\".cg\")\n tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=\".cg\")\n\ndef flash_attn_v2_fwd_bias(q, k, v, bias_weights, causal, sm_scale, BLOCK_M, BLOCK_N,\n NUM_BUCKETS, MAX_DISTANCE, num_warps, num_stages):\n\n B, H, M, D = q.shape\n N = k.shape[2]\n P_SEQ = N - M\n larger_m = M > N\n\n divisible_m = M % BLOCK_M == 0\n divisible_n = N % BLOCK_N == 0\n\n has_bias = (bias_weights is not None)\n\n bw_stride = bias_weights.stride(0) if has_bias else 0\n grid = (triton.cdiv(M, BLOCK_M), H, B)\n o = torch.empty_like(q)\n L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)\n with torch.cuda.device(q.device.index):\n _fwd_kernel_with_bias_calculation[grid](\n q, k, v, bias_weights, sm_scale,\n L, o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n bw_stride,\n B, H, M, N, P_SEQ,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,\n IS_CAUSAL=causal, LARGER_M=larger_m,\n DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,\n HAS_BIAS=has_bias,\n NUM_BUCKETS=NUM_BUCKETS, MAX_DISTANCE=MAX_DISTANCE,\n num_warps=num_warps, num_stages=num_stages,\n )\n\n return o, L\n\nclass FlashAttentionBias(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, bias_weights, causal, sm_scale, NUM_BUCKETS, MAX_DISTANCE):\n Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]\n\n assert Dq == Dk == Dv\n assert Dk in {16, 32, 64, 128}\n\n B, H, M, D = q.shape\n N = k.shape[2]\n\n if sm_scale is None:\n sm_scale = 1. / math.sqrt(D)\n\n config = get_fwd_config(B, H, M, N, D, causal)\n BLOCK_M, BLOCK_N, num_stages, num_warps = config\n\n o, L = flash_attn_v2_fwd_bias(q, k, v, bias_weights, causal, sm_scale, BLOCK_M, BLOCK_N,\n NUM_BUCKETS, MAX_DISTANCE, num_warps, num_stages)\n\n ctx.save_for_backward(q, k, v, bias_weights, o, L)\n ctx.NUM_BUCKETS = NUM_BUCKETS\n ctx.MAX_DISTANCE = MAX_DISTANCE\n ctx.sm_scale = sm_scale\n ctx.causal = causal\n\n return o\n\n @staticmethod\n def backward(ctx, do, *ignored):\n q, k, v, bias_weights, o, L = ctx.saved_tensors\n sm_scale = ctx.sm_scale\n causal = ctx.causal\n NUM_BUCKETS = ctx.NUM_BUCKETS\n MAX_DISTANCE = ctx.MAX_DISTANCE\n\n B, H, M, D = q.shape\n N = k.shape[2]\n\n if sm_scale is None:\n sm_scale = 1. / math.sqrt(D)\n\n config = get_bwd_config(B, H, M, N, D, causal)\n BLOCK_M, BLOCK_N, num_stages, num_warps = config\n\n dq, dk, dv, db = flash_attn_v2_bwd_bias(o, do, q, k, v, bias_weights, \n L, causal, sm_scale, \n BLOCK_M, BLOCK_N, \n NUM_BUCKETS, MAX_DISTANCE,\n num_warps, num_stages)\n\n return dq, dk, dv, db, None, None, None, None, None, None\n\ndef flash_attention_with_fusing_bias(q, k, v, bias, causal=False, sm_scale=None, NUM_BUCKETS=32, MAX_DISTANCE=128):\n return FlashAttentionBias.apply(q, k, v, bias, causal, sm_scale, NUM_BUCKETS, MAX_DISTANCE)\n", - "description_1": "Use triton language to implement a forward and backward pass for Flash Attention with bias. The forward pass takes 18 parameters: Q, K, V (the queries, keys, and values tensors), BW (bias weights), sm_scale (scale for softmax), L and O (output tensors), various strides, dimensions, and block sizes, and constants for configuration. The function calculates attention scores with bias, applies softmax, and stores results in L and O. The backward function computes gradients for Q, K, V, and the bias.", - "description_2": "Use triton language to implement a Flash Attention kernel with bias, allowing for efficient forward and backward passes using Q, K, V, bias weights, and scaling factors, considering specific configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_kAlpha = math.sqrt(2.0 / math.pi)\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef relu(x):\n \"\"\"\n ReLU_ activation function\n\n .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html\n \"\"\"\n return tl.where(x >= 0, x, 0.0)\n\n@triton.jit\ndef relu_grad(x):\n return tl.where(x >= 0, 1.0, 0.0)\n\n@triton.jit\ndef gelu(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit\n\n .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf\n \"\"\"\n return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))\n\n@triton.jit\ndef gelu_grad(x):\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)\n\n@triton.jit\ndef gated_matmul_fwd(out, input, w1, w2, act_input_1, act_input_2,\n M, N, K, stride_om, stride_im, stride_wn,\n dtype: tl.constexpr, BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n USE_GELU: tl.constexpr, SAVE_ACTIVATION_INPUTS: tl.constexpr,\n IS_EVEN_MNK: tl.constexpr):\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n Input has shape (M, K)\n Weight 1 has shape (K, N)\n Weight 2 has shape (K, N)\n Output has shape (M, N)\n \"\"\"\n pid = tl.program_id(0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_M\n GROUP_M = min(num_pid_m - first_pid_m, GROUP_M)\n\n pid_m = first_pid_m + (pid % GROUP_M)\n pid_n = (pid % num_pid_in_group) // GROUP_M\n\n input_block_ptr = tl.make_block_ptr(\n base=input,\n shape=(M, K),\n strides=(stride_im, 1),\n offsets=(pid_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_K),\n order=(1, 0),\n )\n\n w1_block_ptr = tl.make_block_ptr(\n base=w1,\n shape=(K, N),\n strides=(1, stride_wn),\n offsets=(0, pid_n * BLOCK_N),\n block_shape=(BLOCK_K, BLOCK_N),\n order=(0, 1),\n )\n\n w2_block_ptr = tl.make_block_ptr(\n base=w2,\n shape=(K, N),\n strides=(1, stride_wn),\n offsets=(0, pid_n * BLOCK_N),\n block_shape=(BLOCK_K, BLOCK_N),\n order=(0, 1),\n )\n\n acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n for i in range(0, K, BLOCK_K):\n if IS_EVEN_MNK:\n x = tl.load(input_block_ptr)\n w1_blk = tl.load(w1_block_ptr)\n w2_blk = tl.load(w2_block_ptr)\n else:\n x = tl.load(input_block_ptr, boundary_check=(0, 1))\n w1_blk = tl.load(w1_block_ptr, boundary_check=(0, 1))\n w2_blk = tl.load(w2_block_ptr, boundary_check=(0, 1))\n\n acc1 += tl.dot(x, w1_blk)\n acc2 += tl.dot(x, w2_blk)\n\n input_block_ptr = tl.advance(input_block_ptr, (0, BLOCK_K))\n w1_block_ptr = tl.advance(w1_block_ptr, (BLOCK_K, 0))\n w2_block_ptr = tl.advance(w2_block_ptr, (BLOCK_K, 0))\n\n if SAVE_ACTIVATION_INPUTS:\n act_in_1_ptrs = tl.make_block_ptr(\n base=act_input_1,\n shape=(M, N),\n strides=(stride_om, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n act_in_2_ptrs = tl.make_block_ptr(\n base=act_input_2,\n shape=(M, N),\n strides=(stride_om, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n if IS_EVEN_MNK:\n tl.store(act_in_1_ptrs, acc1.to(dtype))\n tl.store(act_in_2_ptrs, acc2.to(dtype))\n else:\n tl.store(act_in_1_ptrs, acc1.to(dtype), boundary_check=(0, 1))\n tl.store(act_in_2_ptrs, acc2.to(dtype), boundary_check=(0, 1))\n\n if USE_GELU:\n acc1 = gelu(acc1)\n else:\n acc1 = relu(acc1)\n\n acc = acc1 * acc2\n\n out_ptrs = tl.make_block_ptr(\n base=out,\n shape=(M, N),\n strides=(stride_om, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n if IS_EVEN_MNK:\n tl.store(out_ptrs, acc.to(dtype))\n else:\n tl.store(out_ptrs, acc.to(dtype), boundary_check=(0, 1))\n\n@triton.jit\ndef gated_matmul_bwd_ygrad(dout, y1_grad, y2_grad, act_input_1, act_input_2, M, N, stride_dom,\n dtype: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n USE_GELU: tl.constexpr, IS_EVEN_MNK: tl.constexpr):\n \"\"\"\n Kernel for backward gated MLP\n Ref :\n y2_grad = torch.mul(gelu(x @ w1), dout)\n y1_grad = torch.mul(gelu_grad(x @ w1) * (x @ w2), dout)\n \"\"\"\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n\n actin_1_block_ptr = tl.make_block_ptr(\n base=act_input_1,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n actin_2_block_ptr = tl.make_block_ptr(\n base=act_input_2,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n dout_block_ptr = tl.make_block_ptr(\n base=dout,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n if IS_EVEN_MNK:\n dout_blk = tl.load(dout_block_ptr)\n actin_1_blk = tl.load(actin_1_block_ptr)\n actin_2_blk = tl.load(actin_2_block_ptr)\n else:\n dout_blk = tl.load(dout_block_ptr, boundary_check=(0, 1))\n actin_1_blk = tl.load(actin_1_block_ptr, boundary_check=(0, 1))\n actin_2_blk = tl.load(actin_2_block_ptr, boundary_check=(0, 1))\n\n if USE_GELU:\n actin_act = gelu(actin_1_blk)\n actin_act_grad = gelu_grad(actin_1_blk)\n else:\n actin_act = relu(actin_1_blk)\n actin_act_grad = relu_grad(actin_1_blk)\n\n actin_act *= dout_blk\n actin_act_grad *= actin_2_blk\n actin_act_grad *= dout_blk\n\n y1_grad_ptrs = tl.make_block_ptr(\n base=y1_grad,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n y2_grad_ptrs = tl.make_block_ptr(\n base=y2_grad,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n if IS_EVEN_MNK:\n tl.store(y1_grad_ptrs, actin_act_grad.to(dtype))\n tl.store(y2_grad_ptrs, actin_act.to(dtype))\n else:\n tl.store(y1_grad_ptrs, actin_act_grad.to(dtype), boundary_check=(0, 1))\n tl.store(y2_grad_ptrs, actin_act.to(dtype), boundary_check=(0, 1))\n\n@triton.jit\ndef gated_matmul_bwd_input(w1, w2, y1_grad, y2_grad, din, M, N, K, stride_dom, stride_im, stride_wn,\n dtype: tl.constexpr, BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_EVEN_MNK: tl.constexpr):\n \"\"\"\n Kernel for backward gated MLP\n Ref :\n x_grad = torch.matmul(y2_grad, w2.t()) + torch.matmul(y1_grad, w1.t())\n \"\"\"\n pid = tl.program_id(0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_k = tl.cdiv(K, BLOCK_K)\n num_pid_in_group = GROUP_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_M\n GROUP_M = min(num_pid_m - first_pid_m, GROUP_M)\n\n pid_m = first_pid_m + (pid % GROUP_M)\n pid_k = (pid % num_pid_in_group) // GROUP_M\n\n y1_grad_block_ptr = tl.make_block_ptr(\n base=y1_grad,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n y2_grad_block_ptr = tl.make_block_ptr(\n base=y2_grad,\n shape=(M, N),\n strides=(stride_dom, 1),\n offsets=(pid_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n\n w1_block_ptr = tl.make_block_ptr(\n base=w1,\n shape=(N, K),\n strides=(stride_wn, 1),\n offsets=(0, pid_k * BLOCK_K),\n block_shape=(BLOCK_N, BLOCK_K),\n order=(1, 0),\n )\n\n w2_block_ptr = tl.make_block_ptr(\n base=w2,\n shape=(N, K),\n strides=(stride_wn, 1),\n offsets=(0, pid_k * BLOCK_K),\n block_shape=(BLOCK_N, BLOCK_K),\n order=(1, 0),\n )\n\n acc_dx = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)\n\n for i in range(0, N, BLOCK_N):\n if IS_EVEN_MNK:\n w1_blk = tl.load(w1_block_ptr)\n w2_blk = tl.load(w2_block_ptr)\n y1_grad_blk = tl.load(y1_grad_block_ptr)\n y2_grad_blk = tl.load(y2_grad_block_ptr)\n else:\n w1_blk = tl.load(w1_block_ptr, boundary_check=(0, 1))\n w2_blk = tl.load(w2_block_ptr, boundary_check=(0, 1))\n y1_grad_blk = tl.load(y1_grad_block_ptr, boundary_check=(0, 1))\n y2_grad_blk = tl.load(y2_grad_block_ptr, boundary_check=(0, 1))\n\n acc_dx += tl.dot(y2_grad_blk, w2_blk)\n acc_dx += tl.dot(y1_grad_blk, w1_blk)\n\n w1_block_ptr = tl.advance(w1_block_ptr, (BLOCK_N, 0))\n w2_block_ptr = tl.advance(w2_block_ptr, (BLOCK_N, 0))\n y1_grad_block_ptr = tl.advance(y1_grad_block_ptr, (0, BLOCK_N))\n y2_grad_block_ptr = tl.advance(y2_grad_block_ptr, (0, BLOCK_N))\n\n dx_ptrs = tl.make_block_ptr(\n base=din,\n shape=(M, K),\n strides=(stride_im, 1),\n offsets=(pid_m * BLOCK_M, pid_k * BLOCK_K),\n block_shape=(BLOCK_M, BLOCK_K),\n order=(1, 0),\n )\n\n if IS_EVEN_MNK:\n tl.store(dx_ptrs, acc_dx.to(dtype))\n else:\n tl.store(dx_ptrs, acc_dx.to(dtype), boundary_check=(0, 1))\n\n@triton.jit\ndef gated_matmul_bwd_weights(input, y1_grad, y2_grad, dw1, dw2, M, N, K, stride_dom, stride_im, stride_wn,\n dtype: tl.constexpr, BLOCK_M: tl.constexpr, GROUP_N: tl.constexpr,\n BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_EVEN_MNK: tl.constexpr):\n \"\"\"\n Kernel for backward gated MLP\n Ref :\n w1_grad = torch.matmul(y1_grad.t(), x)\n w2_grad = torch.matmul(y2_grad.t(), x)\n \"\"\"\n pid = tl.program_id(0)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_k = tl.cdiv(K, BLOCK_K)\n num_pid_in_group = GROUP_N * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_n = group_id * GROUP_N\n GROUP_N = min(num_pid_n - first_pid_n, GROUP_N)\n\n pid_n = first_pid_n + (pid % GROUP_N)\n pid_k = (pid % num_pid_in_group) // GROUP_N\n\n y1_grad_block_ptr = tl.make_block_ptr(\n base=y1_grad,\n shape=(N, M),\n strides=(1, stride_dom),\n offsets=(pid_n * BLOCK_N, 0),\n block_shape=(BLOCK_N, BLOCK_M),\n order=(0, 1),\n )\n\n y2_grad_block_ptr = tl.make_block_ptr(\n base=y2_grad,\n shape=(N, M),\n strides=(1, stride_dom),\n offsets=(pid_n * BLOCK_N, 0),\n block_shape=(BLOCK_N, BLOCK_M),\n order=(0, 1),\n )\n\n input_block_ptr = tl.make_block_ptr(\n base=input,\n shape=(M, K),\n strides=(stride_im, 1),\n offsets=(0, pid_k * BLOCK_K),\n block_shape=(BLOCK_M, BLOCK_K),\n order=(1, 0),\n )\n\n acc_dw1 = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)\n acc_dw2 = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)\n\n for i in range(0, M, BLOCK_M):\n if IS_EVEN_MNK:\n y1grad_blk = tl.load(y1_grad_block_ptr)\n y2grad_blk = tl.load(y2_grad_block_ptr)\n x = tl.load(input_block_ptr)\n else:\n y1grad_blk = tl.load(y1_grad_block_ptr, boundary_check=(0, 1))\n y2grad_blk = tl.load(y2_grad_block_ptr, boundary_check=(0, 1))\n x = tl.load(input_block_ptr, boundary_check=(0, 1))\n\n acc_dw1 += tl.dot(y1grad_blk, x)\n acc_dw2 += tl.dot(y2grad_blk, x)\n\n y1_grad_block_ptr = tl.advance(y1_grad_block_ptr, (0, BLOCK_M))\n y2_grad_block_ptr = tl.advance(y2_grad_block_ptr, (0, BLOCK_M))\n input_block_ptr = tl.advance(input_block_ptr, (BLOCK_M, 0))\n\n dw1_ptrs = tl.make_block_ptr(\n base=dw1,\n shape=(N, K),\n strides=(stride_wn, 1),\n offsets=(pid_n * BLOCK_N, pid_k * BLOCK_K),\n block_shape=(BLOCK_N, BLOCK_K),\n order=(1, 0),\n )\n\n dw2_ptrs = tl.make_block_ptr(\n base=dw2,\n shape=(N, K),\n strides=(stride_wn, 1),\n offsets=(pid_n * BLOCK_N, pid_k * BLOCK_K),\n block_shape=(BLOCK_N, BLOCK_K),\n order=(1, 0),\n )\n\n if IS_EVEN_MNK:\n tl.store(dw1_ptrs, acc_dw1.to(dtype))\n tl.store(dw2_ptrs, acc_dw2.to(dtype))\n else:\n tl.store(dw1_ptrs, acc_dw1.to(dtype), boundary_check=(0, 1))\n tl.store(dw2_ptrs, acc_dw2.to(dtype), boundary_check=(0, 1))\n", - "description_1": "Use triton language to implement a gated matrix multiplication with forward and backward passes using ReLU and GeLU activations. The kernels handle block-wise matrix operations and include gradient calculations for inputs and weights.", - "description_2": "Use triton language to create a gated matrix multiplication with activation functions for efficient forward and backward operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rms_layer_norm_fwd_fused(\n X, \n Y, \n W, \n RMS,\n stride,\n N,\n eps, \n BLOCK_SIZE: tl.constexpr\n ):\n row = tl.program_id(axis=0)\n\n Y += row*stride\n X += row*stride\n\n mean = 0\n mean_ = tl.zeros([BLOCK_SIZE], dtype = tl.float32)\n\n for i in range(0, N, BLOCK_SIZE):\n offset = i + tl.arange(0, BLOCK_SIZE)\n mask = offset BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n rms_layer_norm_fwd_fused[(M, )]( #\n x_arg, y, weight, rms, #\n x_arg.stride(0), N, eps, #\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)\n ctx.save_for_backward(x, weight, rms)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, rms = ctx.saved_tensors\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')\n _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)\n dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _rms_layer_norm_bwd_dx_fused[(M, )]( #\n dx, dy, _dw, x, w, rms, locks, #\n x_arg.stride(0), N, #\n BLOCK_SIZE_N=ctx.BLOCK_SIZE, #\n GROUP_SIZE_M=GROUP_SIZE_M, #\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n _rms_layer_norm_bwd_dwdb[grid](\n _dw, dw, min(GROUP_SIZE_M, M), N, #\n BLOCK_SIZE_M=32, #\n BLOCK_SIZE_N=128, num_ctas=1)\n return dx, dw, None\n", - "description_1": "Use triton language to implement a fused RMS layer normalization forward and backward pass. The forward kernel 'rms_layer_norm_fwd_fused' takes 8 parameters: X (input tensor), Y (output tensor), W (weights), RMS (output for RMS values), stride (stride for input tensor), N (number of elements per row), eps (epsilon for numerical stability), and BLOCK_SIZE (block size for computation). The backward pass consists of two kernels: '_rms_layer_norm_bwd_dx_fused' and '_rms_layer_norm_bwd_dwdb'. '_rms_layer_norm_bwd_dx_fused' computes the gradient with respect to the input and partial weight gradients, taking 11 parameters: DX (input gradient), DY (output gradient), DW (partial weight gradient), X (input), W (weights), RMS (RMS values), Lock (lock for atomic operations), stride, N, GROUP_SIZE_M, and BLOCK_SIZE_N. '_rms_layer_norm_bwd_dwdb' accumulates the partial weight gradients, taking 6 parameters: DW (partial weight gradient), FINAL_DW (final weight gradient), M (group size), N (number of columns), BLOCK_SIZE_M, and BLOCK_SIZE_N.", - "description_2": "Use triton language to implement a fused RMS layer normalization with forward and backward kernels. The forward kernel computes the RMS normalization and applies weights, while the backward kernels compute gradients for inputs and weights.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_with_bias_calculation(\n Q, K, V, BW, sm_scale,\n L, O,\n cu_seqlens_q, cu_seqlens_k, mid_batch, mid_start,\n stride_qz, stride_qh, stride_qk,\n stride_kz, stride_kh, stride_kk,\n stride_vz, stride_vh, stride_vk,\n stride_oz, stride_oh, stride_ok,\n stride_wn,\n Z, H, M, N,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n HAS_BIAS: tl.constexpr, NUM_BUCKETS: tl.constexpr, MAX_DISTANCE: tl.constexpr\n):\n # Forward kernel implementation for computing attention with bias\n # Arguments:\n # - Q, K, V: query, key, and value tensors\n # - BW: bias weights\n # - sm_scale: scaling factor for softmax\n # - L, O: tensors for intermediate computations\n # - cu_seqlens_q, cu_seqlens_k: cumulative sequence lengths for queries and keys\n # - mid_batch, mid_start: batch and start indices\n # - Various stride and offset values\n # - BLOCK_M, BLOCK_DMODEL, BLOCK_N: block sizes for M, D_MODEL, and N\n # - IS_CAUSAL, HAS_BIAS, NUM_BUCKETS, MAX_DISTANCE: compile-time constants\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO,\n Delta,\n cu_seqlens_q, mid_batch, mid_start,\n stride_oz, stride_oh, stride_ok,\n stride_doz, stride_doh, stride_dok,\n stride_dz, stride_dh,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n # Backward preprocess kernel for computing delta values\n # Arguments:\n # - Out, DO: output and output gradient tensors\n # - Delta: tensor to store delta values\n # - cu_seqlens_q: cumulative sequence lengths for queries\n # - mid_batch, mid_start: batch and start indices\n # - Various stride values\n # - BLOCK_M, D_HEAD: block size for M and head dimension\n\n@triton.jit\ndef _bwd_kv_bias_kernel(\n Q, K, V, BW, sm_scale, DO,\n DK, DV, DB,\n L,\n D,\n cu_seqlens_q, cu_seqlens_k, nid_batch, nid_start,\n stride_qz, stride_qh, stride_qk,\n stride_kz, stride_kh, stride_kk,\n stride_vz, stride_vh, stride_vk,\n stride_doz, stride_doh, stride_dok,\n stride_dkz, stride_dkh, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk,\n stride_bw,\n Z, H, M, N,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n NUM_BUCKETS: tl.constexpr,\n MAX_DISTANCE: tl.constexpr,\n):\n # Backward kernel for key-value updates with bias\n # Arguments:\n # - Q, K, V: query, key, and value tensors\n # - BW: bias weights\n # - DO: output gradient tensor\n # - DK, DV, DB: gradients for keys, values, and bias\n # - L, D: tensors for intermediate computations\n # - cu_seqlens_q, cu_seqlens_k: cumulative sequence lengths\n # - nid_batch, nid_start: batch and start indices\n # - Various stride values\n # - BLOCK_M, BLOCK_DMODEL, BLOCK_N: block sizes\n # - CAUSAL, HAS_BIAS, NUM_BUCKETS, MAX_DISTANCE: compile-time constants\n\n@triton.jit\ndef _bwd_q_kernel_with_bias_calculation(\n Q, K, V, BW, sm_scale, DO,\n DQ,\n L,\n D,\n cu_seqlens_q, cu_seqlens_k, mid_batch, mid_start,\n stride_qz, stride_qh, stride_qk,\n stride_kz, stride_kh, stride_kk,\n stride_vz, stride_vh, stride_vk,\n stride_doz, stride_doh, stride_dok,\n stride_dqz, stride_dqh, stride_dqk,\n stride_bw,\n Z, H, M, N,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr, HAS_BIAS: tl.constexpr,\n NUM_BUCKETS: tl.constexpr,\n MAX_DISTANCE: tl.constexpr,\n):\n # Backward kernel for query updates with bias\n # Arguments:\n # - Q, K, V: query, key, and value tensors\n # - BW: bias weights\n # - DO: output gradient tensor\n # - DQ: gradient for queries\n # - L, D: tensors for intermediate computations\n # - cu_seqlens_q, cu_seqlens_k: cumulative sequence lengths\n # - mid_batch, mid_start: batch and start indices\n # - Various stride values\n # - BLOCK_M, BLOCK_DMODEL, BLOCK_N: block sizes\n # - CAUSAL, HAS_BIAS, NUM_BUCKETS, MAX_DISTANCE: compile-time constants\n", - "description_1": "Use triton language to define forward and backward kernels for computing attention with bias. Forward kernel processes query, key, value, and bias tensors, while backward kernels handle gradients for these tensors. Each kernel utilizes specific block sizes and handles bias computation and causal masking as necessary.", - "description_2": "Use triton language to implement attention mechanisms with bias for forward and backward passes, ensuring efficient computation with specified block sizes and handling of bias and causal settings.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for matrix multiplication\n@triton.jit\ndef matmul_kernel(A, B, C, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n pid = tl.program_id(0)\n # Compute the block row and column\n block_row = pid // (N // BLOCK_SIZE_N)\n block_col = pid % (N // BLOCK_SIZE_N)\n # Compute the start of the block\n a_start = block_row * BLOCK_SIZE_M * K\n b_start = block_col * BLOCK_SIZE_N\n c_start = block_row * BLOCK_SIZE_M * N + block_col * BLOCK_SIZE_N\n # Load A and B blocks\n a = tl.load(A + a_start + tl.arange(0, BLOCK_SIZE_M)[:, None] * K + tl.arange(0, BLOCK_SIZE_K)[None, :])\n b = tl.load(B + b_start + tl.arange(0, BLOCK_SIZE_K)[:, None] * N + tl.arange(0, BLOCK_SIZE_N)[None, :])\n # Compute the product\n c = tl.dot(a, b)\n # Store the result\n tl.store(C + c_start + tl.arange(0, BLOCK_SIZE_M)[:, None] * N + tl.arange(0, BLOCK_SIZE_N)[None, :], c)\n\n# Function to call the Triton kernel\ndef matmul(A, B, M, N, K):\n C = torch.empty((M, N), device='cuda', dtype=A.dtype)\n grid = (M // 128) * (N // 128)\n matmul_kernel[grid](A, B, C, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=32)\n return C\n", - "description_1": "Use triton language to implement a matrix multiplication kernel. The kernel takes 7 parameters: A, B, C (the matrices), M, N, K (the dimensions), and BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K (the block sizes). The kernel computes the product of matrices A and B and stores the result in C. The function matmul calls this kernel with the appropriate grid size.", - "description_2": "Use triton language to create a matrix multiplication kernel and a function to execute it, handling matrices A, B, and C with dimensions M, N, and K.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom . import custom_autotune\n\n# Triton kernel for matrix multiplication\n@custom_autotune.autotune(\n configs=[\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 256,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=2,\n num_warps=8),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 64,\n 'GROUP_SIZE_M': 8\n },\n num_stages=3,\n num_warps=8),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 32,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 128,\n 'GROUP_SIZE_M': 8\n },\n num_stages=2,\n num_warps=4),\n ],\n key=['M', 'N', 'K'],\n nearest_power_of_two=True,\n prune_configs_by={\n 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,\n 'perf_model': None,\n 'top_k': None,\n },\n)\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M,\n N, K, bits, maxq, stride_am, stride_ak, stride_bk,\n stride_bn, stride_cm, stride_cn, stride_scales,\n stride_zeros, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n )\n a_mask = (offs_am[:, None] < M)\n\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk +\n offs_bn[None, :] * stride_bn)\n g_ptrs = g_ptr + offs_k\n\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n scales = tl.load(scales_ptrs + g_idx[:, None] *\n stride_scales)\n zeros = tl.load(\n zeros_ptrs +\n g_idx[:, None] * stride_zeros)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(\n a_ptrs, mask=a_mask, other=0.)\n b = tl.load(b_ptrs)\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[\n None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@custom_autotune.autotune(\n configs=[\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 256,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 128,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 128,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 128,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 64,\n 'GROUP_SIZE_M': 8\n },\n num_stages=4,\n num_warps=4),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 32,\n 'BLOCK_SIZE_K': 128,\n 'GROUP_SIZE_M': 8\n },\n num_stages=2,\n num_warps=8),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 64,\n 'BLOCK_SIZE_N': 64,\n 'BLOCK_SIZE_K': 64,\n 'GROUP_SIZE_M': 8\n },\n num_stages=3,\n num_warps=8),\n triton.Config(\n {\n 'BLOCK_SIZE_M': 32,\n 'BLOCK_SIZE_N': 128,\n 'BLOCK_SIZE_K': 32,\n 'GROUP_SIZE_M': 8\n },\n num_stages=2,\n num_warps=4),\n ],\n key=['M', 'N', 'K'],\n nearest_power_of_two=True)\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits,\n maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,\n stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n )\n a_mask = (offs_am[:, None] < M)\n\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk +\n offs_n[None, :] * stride_bn)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n scales_ptrs = scales_ptr + offs_n[\n None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits\n ) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n scales = tl.load(scales_ptrs)\n zeros = tl.load(zeros_ptrs)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(\n a_ptrs, mask=a_mask, other=0.)\n b = tl.load(b_ptrs)\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[\n None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]),\n device=input.device,\n dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.\n cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )\n matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1],\n input.shape[1], bits, maxq, input.stride(0),\n input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0),\n output.stride(1), scales.stride(0),\n qzeros.stride(0))\n return output\n\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim),\n device=input.device,\n dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M'])\n * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )\n transpose_matmul_248_kernel[grid](input, qweight, output, scales,\n qzeros, g_idx, input.shape[0],\n qweight.shape[1], output_dim,\n bits, maxq, input.stride(0),\n input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0),\n output.stride(1), scales.stride(0),\n qzeros.stride(0))\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: matmul_248_kernel and transpose_matmul_248_kernel. The first kernel computes the product of two matrices A and B with shapes (M, K) and (K//8, N) respectively, resulting in matrix C of shape (M, N). The second kernel computes the transposed product with matrix A of shape (M, N) and matrix B of shape (K//8, N), resulting in matrix C of shape (M, K). Both kernels use a quantized representation of matrix B and involve scaling and zero-point adjustments during the computation.", - "description_2": "Use triton language to implement two kernels for quantized matrix multiplication: one for regular matrix multiplication and another for transposed multiplication. Both kernels account for scaling and zero-point adjustments.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef apply_clip_kernel(samples_ptr, min, max, output_ptr, n_audios, audio_len, BLOCK_SIZE: tl.constexpr):\n audio_idx = tl.program_id(0)\n if audio_idx >= n_audios:\n return\n for i in range(0, audio_len, BLOCK_SIZE):\n sample_idx = i + tl.arange(0, BLOCK_SIZE)\n mask = sample_idx < audio_len\n samples = tl.load(samples_ptr + audio_idx * audio_len + sample_idx, mask=mask)\n result = tl.where(samples > max, max, samples)\n result = tl.where(result < min, min, result)\n tl.store(output_ptr + audio_idx * audio_len + sample_idx, result, mask=mask)\n\ndef apply_clip(samples: torch.Tensor, min: float, max: float, inplace: bool = False):\n assert min < max\n assert samples.ndim == 2\n n_audios, audio_len = samples.shape\n grid = lambda _: (n_audios,)\n if inplace:\n apply_clip_kernel[grid](samples, min, max, samples, n_audios, audio_len)\n return samples\n else:\n copy = torch.empty_like(samples, dtype=samples.dtype)\n apply_clip_kernel[grid](samples, min, max, copy, n_audios, audio_len)\n return copy\n", - "description_1": "Use triton language to create a kernel `apply_clip_kernel` that clips the audio samples between a minimum and maximum value. The kernel takes 6 parameters: a pointer to the samples, minimum value, maximum value, a pointer for output, number of audios, and audio length. A `BLOCK_SIZE` constant is used for block-wise operations within the kernel. The kernel processes audio samples in chunks, loading, comparing, and storing them based on the clipping range. A function `apply_clip` is provided to set up the grid and execute the kernel, with an option to perform the operation in place or on a copy.", - "description_2": "Use triton language to implement a kernel that clips audio samples to a specified range, and use a wrapper function to execute this operation over multiple audio samples, either in place or on a duplicate.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport itertools\nimport math\n\n@triton.jit\ndef sinc_kernel(\n output_ptr,\n cutoffs_ptr,\n indices_ptr,\n num_taps,\n window_ptr,\n half_sample_rate,\n mode: tl.constexpr,\n BLOCK_SIZE: tl.constexpr):\n batch_idx = tl.program_id(1)\n pos = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = pos < num_taps\n\n cutoff_val = tl.load(cutoffs_ptr + batch_idx) / half_sample_rate\n index_val = tl.load(indices_ptr + pos, mask=mask)\n window_val = tl.load(window_ptr + pos, mask=mask)\n\n x = index_val * math.pi * cutoff_val\n sinc_val = tl.where(index_val == 0, 1., tl.sin(x) / x)\n windowed_sinc = sinc_val * window_val\n\n # Normalize each filter by the sum of its windowed sinc values\n normalized_sinc = windowed_sinc / tl.sum(windowed_sinc, axis=0)\n if mode == \"high\":\n center_idx = num_taps // 2\n adjusted_val = tl.where(pos == center_idx, 1.0 - normalized_sinc, -normalized_sinc)\n\n tl.store(output_ptr + batch_idx * num_taps + pos, adjusted_val, mask=mask)\n elif mode == \"low\":\n tl.store(output_ptr + batch_idx * num_taps + pos, normalized_sinc, mask=mask)\n else:\n raise ValueError(f\"Unknown mode: {mode}\")\n\ndef create_filters(filter_output, cutoff_freqs, time, window, sample_rate, num_taps, mode):\n grid_size = (1, len(cutoff_freqs))\n\n sinc_kernel[grid_size](\n filter_output,\n cutoff_freqs,\n time,\n num_taps,\n window,\n 0.5 * sample_rate,\n mode,\n triton.next_power_of_2(num_taps)\n )\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps)\n for (block_size, num_warps) in\n itertools.product([32, 64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32])\n ],\n key=['length', 'kernel_size', 'stride', 'n_frames']\n)\n@triton.jit\ndef unfold_kernel(input_ptr, output_ptr, length, kernel_size, stride, n_frames, BLOCK_SIZE: tl.constexpr):\n # Compute indices\n batch_idx = tl.program_id(0)\n\n # Global frame index\n frame_idx = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n\n # Bounds check for the frame index\n mask = frame_idx < n_frames\n\n # Calculate position in input for each thread\n input_pos = frame_idx * stride\n\n # Each thread processes one frame if within bounds\n for i in range(kernel_size):\n in_bounds = mask & ((input_pos + i) < length)\n\n # Use tl.where to handle in-bounds and out-of-bounds cases\n val = tl.where(in_bounds, tl.load(input_ptr + batch_idx * length + input_pos + i, mask=in_bounds), 0)\n\n out_idx = batch_idx * n_frames * kernel_size + frame_idx * kernel_size + i\n tl.store(output_ptr + out_idx, val, mask=in_bounds)\n\ndef unfold_triton(input, kernel_size, stride):\n assert input.ndim >= 2, \"Input tensor must be at least 2D\"\n length = input.shape[-1]\n n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1\n\n # Prepare output tensor\n output_shape = list(input.shape)[:-1] + [n_frames, kernel_size]\n output = torch.empty(output_shape, device=input.device, dtype=input.dtype)\n\n # Grid dimensions\n grid = lambda META: (\n input.shape[0],\n triton.cdiv(n_frames, META['BLOCK_SIZE']) + (n_frames % META['BLOCK_SIZE'] != 0)\n )\n\n # Launch kernel\n unfold_kernel[grid](input, output, length, kernel_size, stride, n_frames)\n\n return output\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=num_warps)\n for (num_warps) in [1, 2, 4, 8, 16, 32]\n ],\n key=['num_batches', 'num_frames', 'fft_size']\n)\n@triton.jit\ndef complex_mul_conjugate_kernel(\n a_real_ptr,\n b_real_ptr,\n a_imag_ptr,\n b_imag_ptr,\n output1_ptr,\n output2_ptr,\n num_batches,\n num_frames,\n fft_size,\n BLOCK_SIZE: tl.constexpr):\n # Compute indices for batch and fft\n batch_idx = tl.program_id(0)\n\n # Ensure we don't go out of bounds for batch index\n if batch_idx >= num_batches:\n return\n\n fft_idx = tl.arange(0, BLOCK_SIZE)\n fft_mask = fft_idx < fft_size\n\n batch_by_fft = batch_idx * fft_size\n\n b_real_val = tl.load(b_real_ptr + batch_by_fft + fft_idx, mask=fft_mask)\n b_imag_val = tl.load(b_imag_ptr + batch_by_fft + fft_idx, mask=fft_mask)\n\n for frame_idx in range(num_frames):\n global_idx = num_frames * batch_by_fft + frame_idx * fft_size + fft_idx\n\n a_real_val = tl.load(a_real_ptr + global_idx, mask=fft_mask)\n a_imag_val = tl.load(a_imag_ptr + global_idx, mask=fft_mask)\n\n result1 = a_real_val * b_real_val + a_imag_val * b_imag_val\n result2 = a_imag_val * b_real_val - a_real_val * b_imag_val\n\n tl.store(output1_ptr + global_idx, result1, mask=fft_mask)\n tl.store(output2_ptr + global_idx, result2, mask=fft_mask)\n\ndef complex_mul_conjugate_triton(a_real, b_real, a_imag, b_imag):\n assert a_real.shape[-1] == b_real.shape[-1] # Ensure last dimensions match for multiplication\n\n num_batches, num_frames, fft_size = a_real.shape\n\n # Output tensor\n output1 = torch.empty_like(a_real)\n output2 = torch.empty_like(a_real)\n\n # Define grid size for the kernel launch\n grid_size = (num_batches,)\n\n # Launch the kernel\n\n complex_mul_conjugate_kernel[grid_size](\n a_real,\n b_real,\n a_imag,\n b_imag,\n output1,\n output2,\n num_batches,\n num_frames,\n fft_size,\n triton.next_power_of_2(fft_size)\n )\n\n return output1, output2\n", - "description_1": "Use triton language to implement three kernels: sinc_kernel, unfold_kernel, and complex_mul_conjugate_kernel. The sinc_kernel computes a windowed sinc filter for each batch, taking 8 parameters: output_ptr, cutoffs_ptr, indices_ptr, num_taps, window_ptr, half_sample_rate, mode, and BLOCK_SIZE. The unfold_kernel extracts frames from an input tensor, taking 6 parameters: input_ptr, output_ptr, length, kernel_size, stride, and n_frames. The complex_mul_conjugate_kernel performs complex multiplication with conjugation, taking 9 parameters: a_real_ptr, b_real_ptr, a_imag_ptr, b_imag_ptr, output1_ptr, output2_ptr, num_batches, num_frames, and fft_size.", - "description_2": "Use triton language to create a sinc filter kernel, an unfold operation kernel, and a complex multiplication with conjugation kernel, each with specific parameters for their respective operations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef apply_gain_kernel(samples_ptr, amplitude_ratios_ptr, output_ptr, n_audios, audio_len, BLOCK_SIZE: tl.constexpr):\n # Get the index of the current audio sample\n audio_idx = tl.program_id(0)\n\n # Check if the audio index is within the number of audios\n if audio_idx >= n_audios:\n return\n\n # Load the gain for the current audio\n gain = tl.load(amplitude_ratios_ptr + audio_idx)\n\n # Iterate over the audio samples in blocks\n for i in range(0, audio_len, BLOCK_SIZE):\n sample_idx = i + tl.arange(0, BLOCK_SIZE)\n mask = sample_idx < audio_len\n # Load the samples with masking\n samples = tl.load(samples_ptr + audio_idx * audio_len + sample_idx, mask=mask)\n # Apply the gain\n result = samples * gain\n # Store the result\n tl.store(output_ptr + audio_idx * audio_len + sample_idx, result, mask=mask)\n\ndef apply_gain(samples: torch.Tensor, amplitude_ratios: torch.Tensor, inplace: bool = False):\n # Ensure the input tensors have the correct dimensions\n assert samples.ndim == 2 and amplitude_ratios.ndim == 1\n n_audios, audio_len = samples.shape\n\n # Define the grid size for the kernel launch\n grid = lambda _: (n_audios,)\n\n # Apply the gain kernel in-place or to a new tensor\n if inplace:\n apply_gain_kernel[grid](samples, amplitude_ratios, samples, n_audios, audio_len)\n return samples\n else:\n copy = torch.empty_like(samples, device='cuda', dtype=samples.dtype)\n apply_gain_kernel[grid](samples, amplitude_ratios, copy, n_audios, audio_len)\n return copy\n", - "description_1": "Use triton language to implement a kernel that applies a gain to audio samples. The kernel function 'apply_gain_kernel' takes 6 parameters: samples_ptr (pointer to audio samples), amplitude_ratios_ptr (pointer to gain values), output_ptr (pointer to output buffer), n_audios (number of audio samples), audio_len (length of each audio sample), and BLOCK_SIZE (block size for processing). The function iterates over audio samples in blocks, applies the gain, and stores the result. The 'apply_gain' function is a wrapper that prepares the input data and calls the kernel, with an option to perform the operation in-place.", - "description_2": "Use triton language to create a kernel that multiplies audio samples by a gain factor, iterating over samples in blocks and storing the results, with a wrapper function to handle input preparation and kernel invocation.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef rms_kernel(audios, audios_real_lens, audios_max_len, batch_idx, BLOCK_SIZE_RMS: tl.constexpr):\n audios_real_lens_vals = tl.load(audios_real_lens + batch_idx)\n\n _mean = tl.zeros([BLOCK_SIZE_RMS], dtype=tl.float32)\n for offset in range(0, audios_max_len, BLOCK_SIZE_RMS):\n audios_block_ptr = offset + tl.arange(0, BLOCK_SIZE_RMS)\n audios_mask = audios_block_ptr < audios_real_lens_vals\n\n audios_vals = tl.load(audios + batch_idx * audios_max_len + audios_block_ptr, mask=audios_mask)\n audios_partial_sum_sq = tl.where(audios_mask, tl.math.pow(audios_vals, 2.0), 0)\n _mean += audios_partial_sum_sq\n\n audios_global_sum_sq = tl.sum(_mean, axis=0)\n return tl.sqrt(audios_global_sum_sq / audios_real_lens_vals)\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_SUM': block_size_sum}, num_warps=num_warps)\n for (block_size_sum, num_warps) in\n itertools.product(\n [512, 1024],\n [2, 4, 8, 16]\n )\n ],\n key=['clean_audio_max_len', 'noisy_audio_max_len']\n)\n@triton.jit\ndef sum_with_snr_kernel(\n clean_audio, clean_audio_real_lens, clean_audio_max_len, desired_rms,\n noisy_audio_ptr, noisy_audio_real_lens, noisy_audio_max_len,\n output_ptr, BLOCK_SIZE_SUM: tl.constexpr, BLOCK_SIZE_RMS: tl.constexpr):\n batch_idx = tl.program_id(0)\n\n # RMS clean\n clean_audio_real_lens_val = tl.load(clean_audio_real_lens + batch_idx)\n clean_audio_rms = rms_kernel(clean_audio, clean_audio_real_lens, clean_audio_max_len, batch_idx, BLOCK_SIZE_RMS)\n\n # RMS noisy\n noisy_audio_real_lens_val = tl.load(noisy_audio_real_lens + batch_idx)\n\n noisy_audio_rms = rms_kernel(noisy_audio_ptr, noisy_audio_real_lens, noisy_audio_max_len, batch_idx, BLOCK_SIZE_RMS)\n\n # Desired RMS for noisy scale\n desired_rms_val = tl.load(desired_rms + batch_idx)\n relative_rms = clean_audio_rms / tl.math.pow(10.0, desired_rms_val / 20.0)\n\n for offset in range(0, clean_audio_max_len, BLOCK_SIZE_SUM):\n clean_audio_block_ptr = offset + tl.arange(0, BLOCK_SIZE_SUM)\n clean_audio_mask = clean_audio_block_ptr < clean_audio_real_lens_val\n clean_audio_vals = tl.load(\n clean_audio + batch_idx * clean_audio_max_len + clean_audio_block_ptr,\n mask=clean_audio_mask\n )\n\n offset_over_max = offset % noisy_audio_real_lens_val\n\n offset_adjusted = offset_over_max - tl.math.min(\n offset_over_max,\n tl.math.max(0, (offset_over_max + BLOCK_SIZE_SUM) - noisy_audio_real_lens_val)\n )\n\n noisy_audio_block_ptr = offset_adjusted + tl.arange(0, BLOCK_SIZE_SUM)\n\n noisy_audio_val = tl.load(\n noisy_audio_ptr + batch_idx * noisy_audio_max_len + noisy_audio_block_ptr,\n mask=noisy_audio_block_ptr < noisy_audio_real_lens_val\n )\n\n tl.store(\n output_ptr + batch_idx * clean_audio_max_len + clean_audio_block_ptr,\n clean_audio_vals + noisy_audio_val * (relative_rms / noisy_audio_rms),\n mask=clean_audio_mask\n )\n\ndef sum_with_snr_triton(samples: torch.Tensor, samples_lens: torch.Tensor, samples_noise, samples_noise_lens: torch.Tensor, snrs):\n assert samples.is_contiguous() and samples_noise.is_contiguous(), \"Samples must be contiguous\"\n\n B, T = samples.shape\n output = torch.empty_like(samples, device=samples.device, dtype=samples.dtype)\n\n grid = lambda opt: (B,)\n\n sum_with_snr_kernel[grid](\n samples, samples_lens, T, snrs,\n samples_noise, samples_noise_lens, samples_noise.shape[1],\n output, BLOCK_SIZE_RMS=max(1024, triton.next_power_of_2(max(T, samples_noise.shape[1]) // 1024)))\n\n return output\n", - "description_1": "Use triton language to implement two kernels: rms_kernel and sum_with_snr_kernel. The rms_kernel computes the root mean square (RMS) of audio signals. It takes 5 parameters: audios (audio data), audios_real_lens (real lengths of audio), audios_max_len (maximum length of audio), batch_idx (batch index), and BLOCK_SIZE_RMS (block size for RMS computation). The sum_with_snr_kernel adjusts the signal-to-noise ratio (SNR) of audio signals. It takes 9 parameters: clean_audio (clean audio data), clean_audio_real_lens (real lengths of clean audio), clean_audio_max_len (maximum length of clean audio), desired_rms (desired RMS values), noisy_audio_ptr (noisy audio data), noisy_audio_real_lens (real lengths of noisy audio), noisy_audio_max_len (maximum length of noisy audio), output_ptr (output data pointer), BLOCK_SIZE_SUM (block size for sum computation), and BLOCK_SIZE_RMS (block size for RMS computation). The function sum_with_snr_triton is a wrapper that prepares the data and calls the sum_with_snr_kernel.", - "description_2": "Use triton language to compute the RMS of audio signals and adjust their SNR using two kernels: rms_kernel and sum_with_snr_kernel. The rms_kernel calculates the RMS for given audio data, while the sum_with_snr_kernel modifies the audio data to achieve a desired SNR.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef kernel_function(input_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):\n # Define the index for the current thread\n idx = tl.arange(0, BLOCK_SIZE) + tl.program_id(0) * BLOCK_SIZE\n\n # Load data from input pointer\n input_data = tl.load(input_ptr + idx)\n\n # Perform some computation (e.g., element-wise addition)\n result = input_data + 1.0\n\n # Store the result back to the output pointer\n tl.store(output_ptr + idx, result)\n\ndef call_kernel(input_tensor, output_tensor):\n # Define the block size\n BLOCK_SIZE = 1024\n\n # Launch the kernel\n grid = lambda meta: (input_tensor.numel() + BLOCK_SIZE - 1) // BLOCK_SIZE\n kernel_function[grid](input_tensor, output_tensor, BLOCK_SIZE)\n\n# Example usage\ninput_tensor = torch.randn(1024, device='cuda')\noutput_tensor = torch.empty_like(input_tensor)\ncall_kernel(input_tensor, output_tensor)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition on an input tensor. The kernel is decorated with @triton.jit and takes three parameters: input_ptr, output_ptr, and BLOCK_SIZE. The kernel computes the index for each thread, loads data from the input pointer, performs addition, and stores the result in the output pointer. A separate function, call_kernel, is used to launch the kernel with a specified block size and grid configuration.", - "description_2": "Use triton language to create a kernel for element-wise addition on a tensor, and a function to launch this kernel with specified block size and grid.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom .triton_utils import config_of, signature_to_meta\nfrom ..utils import ceildiv, Placeholder\nfrom ..virtualized import V\nfrom .. import metrics\nfrom .common import IndentedBuffer\nfrom .triton import gen_common_triton_imports\nfrom .triton import TritonKernel\n\nclass ForeachKernel:\n def __init__(self):\n self.blocking_2d = False\n self.block_size_1d = 1024\n self.block_size_2d = 32\n self.num_warps = 8\n self.sub_kernels = []\n self.iter_vars_count = itertools.count()\n self.x_block_count = 0\n self.y_block_count = 0\n\n def get_block_size(self):\n return self.block_size_2d if self.blocking_2d else self.block_size_1d\n\n def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):\n sub_kernel = TritonKernel(\n *groups,\n index_dtype=index_dtype,\n mutations=mutations,\n pid_cache={\n \"tl.program_id(0)\": \"xpid_offset\",\n \"tl.program_id(1)\": \"ypid\",\n },\n reduction_hint=reduction_hint,\n )\n if self.blocking_2d:\n assert len(groups) == 3\n\n self.blocking_2d |= groups[1] != 1 and len(groups) == 3\n metrics.generated_kernel_count -= 1\n sub_kernel.args = self.args\n sub_kernel.iter_vars_count = self.iter_vars_count\n sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids\n self.sub_kernels.append(sub_kernel)\n return sub_kernel\n\n def jit_lines(self):\n can_use_32bit = all(k.index_dtype == \"tl.int32\" for k in self.sub_kernels)\n size_dtype = \"tl.int32\" if can_use_32bit else \"tl.int64\"\n _, _, signature = self.args.python_argdefs()\n triton_meta = {\n \"signature\": signature_to_meta(signature, size_dtype=size_dtype),\n \"device\": V.graph.scheduler.current_device.index,\n \"device_type\": V.graph.scheduler.current_device.type,\n \"constants\": {},\n }\n triton_meta[\"configs\"] = [config_of(signature)]\n inductor_meta = {\n \"kernel_name\": str(Placeholder.DESCRIPTIVE_NAME),\n \"backend_hash\": torch.utils._triton.triton_hash_with_backend(),\n }\n return f\"\"\"\n @triton_heuristics.foreach(\n num_warps={self.num_warps},\n triton_meta={triton_meta!r},\n inductor_meta={inductor_meta!r},\n )\n @triton.jit\n \"\"\"\n\n def grid(self):\n return (\n self.x_block_count,\n ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)\n if self.blocking_2d\n else 1,\n 1,\n )\n\n def codegen_kernel(self, name=None):\n code = IndentedBuffer()\n\n code.splice(gen_common_triton_imports())\n argdefs, _, _ = self.args.python_argdefs()\n code.splice(self.jit_lines())\n code.writeline(\n f\"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):\"\n )\n\n with code.indent():\n code.splice(\"xpid = tl.program_id(0)\")\n if self.blocking_2d:\n code.splice(\"ypid = tl.program_id(1)\")\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_2d}\")\n code.splice(f\"YBLOCK: tl.constexpr = {self.block_size_2d}\")\n else:\n code.splice(f\"XBLOCK: tl.constexpr = {self.block_size_1d}\")\n\n for sub_kernel in self.sub_kernels:\n assert len(sub_kernel.numels) <= 3\n numel_ind = 0 if not self.blocking_2d else 1\n self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))\n with code.indent():\n if self.blocking_2d:\n code.splice(f\"ynumel = {sub_kernel.numels[0]}\")\n code.splice(f\"xnumel = {sub_kernel.numels[1]}\")\n else:\n code.splice(f\"xnumel = {sub_kernel.numels[0]}\")\n\n sub_kernel.codegen_body()\n code.splice(sub_kernel.body)\n\n code.splice(\"else:\")\n with code.indent():\n code.splice(\"pass\")\n\n return code.getvalue()\n\n def call_kernel(self, code, name: str):\n _, call_args, _ = self.args.python_argdefs()\n for i in range(len(call_args)):\n if V.graph.is_unspec_arg(call_args[i]):\n call_args[i] = call_args[i] + \".item()\"\n if V.graph.cpp_wrapper:\n V.graph.wrapper_code.generate_kernel_call(\n name,\n call_args,\n device_index=V.graph.scheduler.current_device.index,\n grid=self.grid(),\n )\n else:\n call_args_str = \", \".join(call_args)\n stream_name = code.write_get_raw_stream(\n V.graph.scheduler.current_device.index\n )\n code.writeline(\n f\"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})\"\n )\n", - "description_1": "Use triton language to define and manage a triton kernel with configurable parameters such as block size, warps, and grid configuration. It supports 2D blocking and generates kernel code dynamically.", - "description_2": "Use triton language to create and invoke dynamic triton kernels with support for custom block size, grid configuration, and argument management.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = tl.program_id(0)\n block_start = pid * 1024\n offsets = block_start + tl.arange(0, 1024)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_add_kernel(X, Y, Z, N):\n grid = lambda meta: (triton.cdiv(N, 1024),)\n add_kernel[grid](X, Y, Z, N)\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = torch.empty(1024, device='cuda')\nN = X.numel()\ncall_add_kernel(X, Y, Z, N)\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that takes four parameters: X, Y, Z, and N. X, Y, and Z are pointers to the input and output tensors, and N is the number of elements. The kernel adds corresponding elements of X and Y and stores the result in Z. The kernel is launched with a grid size calculated based on N.", - "description_2": "Use triton language to create a kernel that performs element-wise addition of two input tensors and stores the result in an output tensor, with the number of elements specified as a parameter.", - "difficulty": 2 - }, - { - "code": "import functools\nimport itertools\nimport sympy\nimport torch\nimport triton\n\nfrom .utils import V\nfrom .codegen.triton import gen_common_triton_imports, texpr\nfrom .codegen.triton_utils import config_of, signature_to_meta\n\nclass TritonKernel:\n def __init__(\n self,\n kernel_name,\n input_nodes,\n output_node,\n defines,\n num_stages,\n num_warps,\n grid_fn,\n meta,\n call_sizes,\n use_jit=True,\n prefix_args=0,\n suffix_args=0,\n epilogue_fn=None,\n *,\n index_dtype,\n ):\n self.input_nodes = input_nodes\n self.output_node = output_node\n self.named_input_nodes = {}\n self.defines = defines\n self.kernel_name = kernel_name\n self.template_mask = None\n self.use_jit = use_jit\n self.num_stages = num_stages\n self.num_warps = num_warps\n self.grid_fn = grid_fn\n self.meta = meta\n self.call_sizes = call_sizes\n self.prefix_args = prefix_args\n self.suffix_args = suffix_args\n self.epilogue_fn = epilogue_fn\n self.render_hooks = dict()\n self.triton_meta = None\n\n def jit_lines(self):\n if self.use_jit:\n return \"@triton.jit\"\n # Additional code omitted for brevity\n\n def def_kernel(self, *argnames):\n # Additional code omitted for brevity\n def hook():\n # python_argdefs() cannot be run until after the rest of the template lazily adds more args\n arg_defs, *_ = self.args.python_argdefs()\n code = IndentedBuffer()\n code.splice(gen_common_triton_imports())\n code.splice(self.jit_lines())\n code.writeline(f\"def {self.kernel_name}({', '.join(arg_defs)}):\")\n with code.indent():\n code.splice(self.defines)\n code.splice(renames.getvalue())\n return code.getvalue()\n\n self.render_hooks[\"\"] = hook\n return \"\"\n\n def call_kernel(self, name: str, node: Optional = None):\n wrapper = V.graph.wrapper_code\n _, call_args, _ = self.args.python_argdefs()\n call_args = [str(a) for a in call_args]\n\n for i in range(len(call_args)):\n if V.graph.is_unspec_arg(call_args[i]):\n call_args[i] = call_args[i] + \".item()\"\n if isinstance(call_args[i], sympy.Symbol):\n call_args[i] = texpr(call_args[i])\n\n if V.graph.cpp_wrapper:\n grid_args = [V.graph.sizevars.simplify(s) for s in self.call_sizes] + [self.meta]\n grid = self.grid_fn(*grid_args)\n\n wrapper.generate_kernel_call(\n name,\n call_args,\n device_index=V.graph.scheduler.current_device.index,\n grid=grid,\n triton_meta=self.triton_meta,\n )\n else:\n stream_name = wrapper.write_get_raw_stream(V.graph.scheduler.current_device.index)\n\n wrapper.add_import_once(f\"import {self.grid_fn.__module__}\")\n meta = wrapper.add_meta_once(self.meta)\n\n grid_call = [\n texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes\n ] + [meta]\n grid_call = f\"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})\"\n wrapper.writeline(f\"{name}.run({', '.join(call_args)}, grid={grid_call}, stream={stream_name})\")\n\n# Assume a function call_kernel exists to call the triton kernel\ndef call_kernel_function(input_nodes, output_node, grid_fn, meta, call_sizes):\n kernel = TritonKernel(\n kernel_name=\"example_kernel\",\n input_nodes=input_nodes,\n output_node=output_node,\n defines=\"\",\n num_stages=1,\n num_warps=1,\n grid_fn=grid_fn,\n meta=meta,\n call_sizes=call_sizes,\n index_dtype=\"tl.int32\"\n )\n kernel.call_kernel(\"example_kernel\")\n", - "description_1": "Use triton language to define a kernel with input nodes, output node, kernel name, number of stages and warps, grid function, metadata, and call sizes. Apply Triton JIT compilation and execute the kernel using the defined configuration.", - "description_2": "Use triton language to configure a kernel with specific input/output and launch it on a GPU.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n init,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``\n DTYPE_PACK: Unsigned type twice the width of block_value\n\n NOTE: This function is limited to values which are 32-bits or less.\n \"\"\"\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n # tl.atomic_load\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(\n scratch_base, block_value, index, combine_fn, init\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block, must be 64-bits wide\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n \"\"\"\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n\n return exclusive_prefix\n\n@triton.jit\ndef frexp(x):\n # TODO(isuruf): use inline_asm_elementwise here\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n", - "description_1": "Use triton language to implement various mathematical and reduction operations, including tensor promotion, floating-point checks, product accumulation, minimum and maximum calculations with and without indices, Welford reduction, random integer generation, and exclusive scan operations. Each function is decorated with @triton.jit and operates on tensors using Triton's language constructs.", - "description_2": "Use triton language to create kernels for mathematical operations and reductions, including min/max, product, and exclusive scans, with support for floating-point checks and random number generation.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport math\nfrom torch.utils._triton import has_triton\n\nif has_triton():\n @triton.jit\n def _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n ):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n def sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n ):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\n def _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None\n ):\n f_name = \"_scaled_dot_product_attention\"\n check(\n not is_causal,\n f\"{f_name}(): is_causal == True is not supported.\"\n )\n check(\n attn_mask is not None,\n f\"{f_name}(): attn_mask == None is not supported.\"\n )\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\"\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\"\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\n", - "description_1": "Use triton language to implement a sampled matrix multiplication kernel and a scaled dot product attention function. The kernel function '_sampled_addmm_kernel' takes 28 parameters: alpha, beta, IS_BETA_ZERO, BLOCKSIZE_ROW, BLOCKSIZE_COL, k, TILE_K, values_ptr, values_batch_stride, values_nnz_stride, values_row_block_stride, values_col_block_stride, crow_indices_ptr, crow_indices_batch_stride, crow_indices_stride, col_indices_ptr, col_indices_batch_stride, col_indices_stride, mat1_ptr, mat1_batch_stride, mat1_tiled_row_stride, mat1_tiled_col_stride, mat1_row_block_stride, mat1_col_block_stride, mat2_ptr, mat2_batch_stride, mat2_tiled_row_stride, mat2_tiled_col_stride, mat2_row_block_stride, mat2_col_block_stride, acc_dtype, allow_tf32. The function 'sampled_addmm' calls this kernel and takes 8 parameters: input, mat1, mat2, beta, alpha, out, skip_checks, max_grid. The function '_scaled_dot_product_attention' performs scaled dot product attention using the sampled_addmm function and takes 7 parameters: query, key, value, attn_mask, dropout_p, is_causal, scale.", - "description_2": "Use triton language to create a kernel for sampled matrix multiplication and a function for scaled dot product attention. The kernel '_sampled_addmm_kernel' is designed to handle sparse matrix operations efficiently, while the function 'sampled_addmm' manages the setup and execution of this kernel. The '_scaled_dot_product_attention' function leverages 'sampled_addmm' to compute attention scores, apply scaling, and perform dropout, followed by a matrix multiplication with the value tensor.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\nfrom triton.language import load, store\n\n# Kernel to add two arrays element-wise\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Kernel to add two arrays element-wise with an optional parameter\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# 2D Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=4, num_warps=4\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n# Kernel to multiply an array by 2\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# In-place kernel to multiply an array by 2\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n# Kernel with indirection and activation\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n elif ACTIVATION == \"add_kernel\":\n add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n\n# Kernel to add two arrays element-wise with import\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to define multiple kernels for element-wise operations on arrays, including addition, multiplication, and conditional operations. Each kernel is parameterized by pointers to input and output arrays, the number of elements, and block sizes. Some kernels are autotuned for performance.", - "description_2": "Use triton language to create kernels for element-wise addition and multiplication of arrays, with optional parameters and autotuning for performance.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 256,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 256,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=5,\n num_warps=2,\n ),\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef linear_kernel_4bit_weight(\n a_ptr, b_ptr, c_ptr, bscales_ptr, bzeros_ptr,\n M, N, K,\n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr\n):\n # Map program ids `pid` to the block of C it should compute.\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n # Create pointers for the first blocks of A and B.\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n a_mask = offs_am[:, None] < M\n b_mask = offs_bn[None, :] < N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (\n (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn\n )\n\n bscales_ptrs = bscales_ptr + offs_bn[None, :]\n bzeros_ptrs = bzeros_ptr + offs_bn[None, :]\n\n scale = tl.load(bscales_ptrs)\n zero = tl.load(bzeros_ptrs)\n # Iterate to compute a block of the C matrix\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, K, BLOCK_SIZE_K):\n b12 = tl.load(b_ptrs, mask=b_mask)\n a = tl.load(a_ptrs, mask=a_mask).to(tl.float32)\n b = (\n ((b12.to(tl.uint8) >> ((offs_k[:, None] % 2) * 4)) & 0xF).to(tl.float32)\n - zero\n ) * scale\n accumulator += tl.dot(a, b)\n\n # Advance the ptrs to the next K block\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk\n c = accumulator\n\n # Write back the block of the output matrix C\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\ndef qlinear_4bit_weight(inp, weight, scales, zeros):\n weight = weight.t().contiguous()\n c_shape = inp.shape[:-1] + weight.shape[-1:]\n inp = inp.reshape(-1, inp.shape[-1]).contiguous()\n # we pad the input to amortize triton compilation cost better\n PAD_TO = 256\n if inp.shape[0] % PAD_TO != 0:\n c_crop = inp.shape[0]\n new_inp_shape0 = inp.shape[0] + PAD_TO - inp.shape[0] % PAD_TO\n inp2 = inp.new_empty((new_inp_shape0, inp.shape[1]))\n inp2[: inp.shape[0]] = inp\n inp2[inp.shape[0] :].zero_()\n inp = inp2\n else:\n c_crop = None\n\n assert inp.shape[1] == weight.shape[0] * 2, \"incompatible dimensions\"\n\n assert scales.shape == (weight.shape[1], 1)\n assert zeros.shape == (weight.shape[1], 1)\n scales = scales.contiguous()\n zeros = zeros.contiguous()\n K, N = weight.shape\n M, K = inp.shape\n assert (\n K % 32 == 0\n ), \"We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K\"\n # allocates output\n c = torch.empty((M, N), device=inp.device, dtype=inp.dtype)\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n linear_kernel_4bit_weight[grid](\n inp,\n weight,\n c,\n scales,\n zeros,\n M,\n N,\n K,\n inp.stride(0),\n inp.stride(1),\n weight.stride(0),\n weight.stride(1),\n c.stride(0),\n c.stride(1),\n )\n return c[:c_crop].reshape(c_shape)\n", - "description_1": "Use triton language to implement a 4-bit quantized linear kernel for matrix multiplication. The kernel 'linear_kernel_4bit_weight' takes 17 parameters: pointers to matrices a, b, and c, pointers to scaling factors and zero points, dimensions M, N, K, strides for matrix a, b, and c, and block size and group size meta-parameters. It computes a block of the C matrix resulting from multiplying A and B, where B is transposed, and applies scale and zero-point dequantization to B. The 'qlinear_4bit_weight' function wraps this kernel for higher-level usage, preparing inputs, outputs, and launching the kernel with an appropriate grid.", - "description_2": "Use triton language to create a quantized matrix multiplication kernel with 4-bit weights and a function to manage input/output preparation and kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]\n Pi = exp(xi) / sum(exp(xi))\n CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]\n = -y [ x - log[sum(exp(x))] ]\n = y * (log[sum(exp(x))] - x)\n If y == 0: CE_i = 0\n If y == 1: CE_i = logsumexp - x\n\n logsumexp is also stable\n Take y = log[sum(exp(x))]\n exp(y) = sum(exp(x))\n exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x\n exp(y) = exp(c)*sum(exp(x - c))\n y = log(exp(c)*sum(exp(x - c)))\n y = c + log[sum(exp(x - c))]\n This means we can set c = max(x) to make sure\n exp(x - c) always is exp(x - max(x)).\n This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.\n \"\"\"\n row_idx = tl.program_id(0)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx\n labels_ptr += row_idx\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = logsumexp - x\n else:\n loss = 0.0\n tl.store(logsumexp_ptr, logsumexp)\n tl.store(loss_ptr, loss)\npass\n\n\n@triton.jit\ndef _chunked_cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n N_CHUNKS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n 256K vocab divided in 4 chunks\n\n |-65536-| |-65536-| |-65536-| |-65536-|\n |-------| |-------| |-------| |-------|\n |-------| |-------| |-------| |-------|\n\n If y == 0: CE_i = 0\n If y == 1: CE_i = logsumexp - x\n\n Notice we can do logsumexp for each chunk and then\n logsumexp[chunk_sum(logsumexp)] == logsumexp\n\n chunk_sum = log[chunk_sum(logsumexp)]\n = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]\n = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]\n = log[sum(exp(a)) + ... + sum(exp(z))]\n = logsumexp(x)\n\n This means we can perform a logsumexp for each chunk, then do a\n final logsumexp reduction!\n\n Ie do: logsumexp(chunked_logsumexp) - x\n \"\"\"\n row_idx = tl.program_id(0)\n chunk_idx = tl.program_id(1)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n labels_ptr += row_idx\n\n col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if chunk_idx == 0:\n # logsumexp(chunked_logsumexp) - x\n # Do the -x separately\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = -1.0 * x\n else:\n loss = 0.0\n tl.store(loss_ptr, loss)\n pass\n tl.store(logsumexp_ptr, logsumexp)\npass\n\n\n@triton.jit\ndef _cross_entropy_backward(\n logits_ptr, logits_row_stride,\n dloss_ptr, dloss_row_stride,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n CE_i = -y log(P) = y * (log[sum(exp(x))] - x)\n dC/dx = d/dx (y * log[sum(exp(x))] - x * y)\n\n From https://en.wikipedia.org/wiki/LogSumExp\n d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)\n\n dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)\n dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick\n dC/dx = y * exp[x - logsumexp] - d/dx (x * y)\n\n If y == 0: dC/dx = 0\n If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1\n If y == 1 and x != label: dC/dx = exp[x - logsumexp]\n \"\"\"\n row_idx = tl.program_id(0)\n block_idx = tl.program_id(1)\n\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n dloss_ptr += row_idx * dloss_row_stride\n col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n if label_idx != -100:\n dloss = tl.load(dloss_ptr)\n else:\n dloss = 0.0\n\n x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n logsumexp = tl.load(logsumexp_ptr + row_idx)\n y = tl.exp(x - logsumexp)\n y = tl.where(\n col_offsets == label_idx,\n y - 1.0, # exp(x - logsumexp) - 1\n y, # exp(x - logsumexp)\n )\n\n # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.\n tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)\npass\n\n\ndef _cross_entropy_forward_impl(logits, labels):\n n_rows, vocab_size = logits.shape\n\n div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n n_chunks = div + (mod != 0)\n losses = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n if n_chunks == 1:\n # For small vocabs <= 65336 like Llama, Mistral\n BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n logsumexp = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n _cross_entropy_forward[(n_rows,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n else:\n # For large vocabs > 65336 like Gemma 256K\n logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda\")\n\n _chunked_cross_entropy_forward[(n_rows, n_chunks,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n N_CHUNKS = n_chunks,\n BLOCK_SIZE = MAX_FUSED_SIZE,\n num_warps = 32,\n )\n # logsumexp(chunked_logsumexp) - x\n # Do the -x separately\n logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum\n losses += logsumexp\n losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!\n\n return losses, logsumexp\n\n\ndef _cross_entropy_backward_impl(dlosses, logits, logsumexp, labels):\n n_rows, vocab_size = logits.shape\n\n BLOCK_SIZE = 4096\n div, mod = divmod(vocab_size, BLOCK_SIZE)\n n_blocks = div + (mod != 0)\n\n _cross_entropy_backward[(n_rows, n_blocks,)](\n logits, logits.stride(0),\n dlosses, dlosses.stride(0),\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = 8,\n )\n return logits\n", - "description_1": "Use triton language to implement cross entropy forward and backward kernels. The forward kernel (_cross_entropy_forward) computes the cross entropy loss and logsumexp for each row of logits, given the logits, labels, and other parameters. The chunked version (_chunked_cross_entropy_forward) handles large vocabularies by dividing the computation into chunks. The backward kernel (_cross_entropy_backward) computes the gradient of the cross entropy loss with respect to the logits. The forward and backward implementations (_cross_entropy_forward_impl and _cross_entropy_backward_impl) handle the logic for choosing the appropriate kernel and managing the data.", - "description_2": "Use triton language to create kernels for computing cross entropy loss and its gradient, handling both small and large vocabularies efficiently.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\nROPE_GROUP_SIZE = 4\n\n@triton.jit\ndef _rope_embedding(\n Q, Q_row_stride,\n cos, cos_row_stride,\n sin, sin_row_stride,\n seqlen,\n head_dim : tl.constexpr,\n n_heads : tl.constexpr,\n BACKWARD_PASS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n ROPE_GROUP_SIZE : tl.constexpr = 4,\n):\n \"\"\"\n Calculates the RoPE Embedding quickly\n RoPE is Q * cos + rotate_half(Q) * sin\n See our blog post for more info\n \"\"\"\n row_position = tl.program_id(0)\n group_head_position = tl.program_id(1)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n half_head_dim = head_dim // 2\n mask = col_offsets < half_head_dim\n\n sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n\n if BACKWARD_PASS:\n # See our blog post for more info.\n sin1 = -sin1\n\n head_start = group_head_position * ROPE_GROUP_SIZE\n head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)\n\n for k in range(head_start, head_end):\n offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets\n offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim\n\n Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n\n tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n\ndef _rope_embedding_forward_impl(Q, cos, sin):\n Q = Q.transpose(1, 2).clone()\n cos, sin = cos.squeeze(), sin.squeeze()\n batch, seq_len, n_heads, head_dim = Q.shape\n Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n n_rows, n_cols = Q.shape\n assert(seq_len <= cos.shape[0])\n\n BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n\n div, mod = divmod(n_heads, ROPE_GROUP_SIZE)\n n_groups = div + (mod != 0)\n\n _rope_embedding[(n_rows, n_groups, )](\n Q, Q.stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len,\n head_dim, n_heads,\n BACKWARD_PASS = False,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n Q = Q.view(batch, seq_len, n_heads, head_dim)\n Q = Q.transpose(1, 2)\n return Q, cos, sin, n_groups, BLOCK_SIZE, num_warps\n\ndef _rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps):\n dY = dY.transpose(1, 2)\n batch, seq_len, n_heads, head_dim = dY.shape\n dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n n_rows, n_cols = dY.shape\n\n _rope_embedding[(n_rows, n_groups, )](\n dY, dY .stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len, head_dim, n_heads,\n BACKWARD_PASS = True,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n dY = dY.view(batch, seq_len, n_heads, head_dim)\n dY = dY.transpose(1, 2)\n return dY\n", - "description_1": "Use triton language to implement a RoPE embedding kernel that computes the rotary position embedding for input tensor Q using cosine and sine values. The kernel takes 11 parameters: Q, Q_row_stride, cos, cos_row_stride, sin, sin_row_stride, seqlen, head_dim, n_heads, BACKWARD_PASS, BLOCK_SIZE, and an optional ROPE_GROUP_SIZE. The forward and backward implementations reshape and transpose the input tensors, calculate the number of groups, and launch the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a kernel for rotary position embedding with forward and backward implementations, handling input reshaping and kernel launching.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel to compute f = e * sigmoid(e) and h = f * g\n@triton.jit\ndef _fg_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n f_row = e_row * tl.sigmoid(e_row)\n f_row = f_row.to(g_row.dtype)\n h_row = f_row * g_row\n\n tl.store(h + offsets, h_row, mask=mask)\n\n# Function to launch the _fg_kernel\ndef swiglu_fg_kernel(e, g):\n batch, seq_len, hd = e.shape\n n_elements = e.numel()\n h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE=1024)\n return h\n\n# Kernel to compute derivatives for backpropagation\n@triton.jit\ndef _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n se_row = tl.sigmoid(e_row)\n f_row = se_row * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))\n de_row = de_row.to(DW_row.dtype)\n\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\n\n# Function to launch the _DWf_DW_dfg_kernel\ndef swiglu_DWf_DW_dfg_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n", - "description_1": "Use triton language to implement two kernels: one for computing element-wise operations involving sigmoid and multiplication, and another for computing derivatives for backpropagation. The first kernel (_fg_kernel) takes 5 parameters: e (input tensor), g (input tensor), h (output tensor), n_elements (number of elements to process), and BLOCK_SIZE (block size for parallel execution). The second kernel (_DWf_DW_dfg_kernel) takes the same number of parameters but operates on DW (input tensor for derivatives), e, g, n_elements, and BLOCK_SIZE. Both kernels are launched using their respective wrapper functions swiglu_fg_kernel and swiglu_DWf_DW_dfg_kernel.", - "description_2": "Use triton language to create kernels for element-wise operations with sigmoid and multiplication, and for computing derivatives in backpropagation, each with 5 parameters including input tensors and block size.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n# This is a matmul kernel based on triton.ops.matmul\n# It is modified to support rowwise quantized input and columnwise quantized weight\n# It's purpose is fused matmul then dequantize\n# It does support bias.\n\ndef init_to_zero(name):\n return lambda nargs: nargs[name].zero_()\n\ndef get_configs_io_bound():\n configs = []\n for num_stages in [2, 3, 4, 5, 6]:\n for block_m in [16, 32]:\n for block_k in [32, 64]:\n for block_n in [32, 64, 128, 256]:\n num_warps = 2 if block_n <= 64 else 4\n configs.append(\n triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},\n num_stages=num_stages, num_warps=num_warps))\n # split_k\n for split_k in [2, 4, 8, 16]:\n configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},\n num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))\n return configs\n\n@triton.autotune(\n configs=[\n # basic configs for compute-bound matmuls\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n *get_configs_io_bound(),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr\n ):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n w_factor = tl.load(state_w_ptr + rbn)[None, :]\n x_factor = tl.load(state_x_ptr + ram)[:, None]\n\n # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n\n acc = (w_factor * (x_factor * (acc * divfactor)))\n acc = acc.to(C.dtype.element_ty)\n\n if has_bias:\n bias = tl.load(bias + rn).to(C.dtype.element_ty)\n acc = acc + bias[None, :]\n\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):\n divfactor = 1. / (127. * 127.)\n\n has_bias = 0 if bias is None else 1\n\n device = a.device\n # handle non-contiguous inputs if necessary\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n # allocates output\n c = torch.empty((M, N), device=device, dtype=torch.float16)\n # accumulator types\n ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32\n # launch int8_matmul_rowwise_dequantize kernel\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n GROUP_M=8, ACC_TYPE=ACC_TYPE)\n return c\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_stages=1, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=4, num_warps=8),\n triton.Config({}, num_stages=8, num_warps=8),\n triton.Config({}, num_stages=1),\n triton.Config({}, num_stages=2),\n triton.Config({}, num_stages=4),\n triton.Config({}, num_stages=8),\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=['n_elements']\n)\n@triton.jit\ndef _quantize_rowwise(\n x_ptr,\n output_ptr,\n output_maxs,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n arange = tl.arange(0, P2)\n offsets = block_start + arange\n row_mask = arange < BLOCK_SIZE\n x = tl.load(x_ptr + offsets, mask=row_mask)\n\n abs_x = tl.abs(x)\n max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)\n output = tl.math.llrint(127. * (x / max_val))\n tl.store(output_ptr + offsets, output, mask=row_mask)\n tl.store(output_maxs + pid, max_val)\n\ndef quantize_rowwise(x: torch.Tensor):\n output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)\n output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (x.shape[0],)\n _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)\n return output, output_maxs\n\ndef matmul(a, b, state_x=None, state_w=None, bias=None):\n if state_x is None:\n a, state_x = quantize_rowwise(a)\n if state_w is None:\n b, state_w = quantize_rowwise(b)\n return int8_matmul_rowwise_dequantize(a, b, state_x, state_w, None)\n", - "description_1": "Use triton language to implement a fused matrix multiplication and dequantization kernel for int8 rowwise quantized inputs and columnwise quantized weights. The kernel supports optional bias addition. The main kernel function '_int8_matmul_rowwise_dequantize' takes 22 parameters: 3 input matrices (A, B, C), a bias vector, two state pointers (state_x_ptr, state_w_ptr), 3 dimensions (M, N, K), a division factor, a bias flag, 6 stride values, and 7 compile-time constants. The auxiliary function 'int8_matmul_rowwise_dequantize' prepares inputs, checks constraints, allocates output, and launches the kernel. Additionally, a rowwise quantization kernel '_quantize_rowwise' is implemented, which takes 6 parameters: input tensor pointer, output tensor pointer, output max values, number of elements, and 2 compile-time constants. The function 'quantize_rowwise' prepares inputs and launches the quantization kernel.", - "description_2": "Use triton language to create a fused int8 matrix multiplication and dequantization kernel with optional bias, and a rowwise quantization kernel. Implement functions to prepare inputs, check constraints, allocate outputs, and launch these kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fwd_sequential_scan_complex(\n v_real, # Real part of input tensor\n v_imag, # Imaginary part of input tensor\n decay_real, # Real part of decay factor\n decay_imag, # Imaginary part of decay factor\n hidden_real, # Real part of hidden state\n hidden_imag, # Imaginary part of hidden state\n B, # Batch size\n L, # Sequence length\n C, # Hidden dimension size\n BLOCK_M: tl.constexpr, # Block size in the M dimension\n):\n offset_b = tl.program_id(0)\n if offset_b >= B:\n return\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + offset_n * BLOCK_M\n h_real = tl.zeros([BLOCK_M,], dtype=tl.float32)\n h_imag = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for _ in range(L):\n x_real = tl.load(v_real + ptr).to(tl.float32)\n x_imag = tl.load(v_imag + ptr).to(tl.float32)\n f_real = tl.load(decay_real + ptr).to(tl.float32)\n f_imag = tl.load(decay_imag + ptr).to(tl.float32)\n h_real_new = h_real * f_real - h_imag * f_imag + x_real\n h_imag_new = h_real * f_imag + h_imag * f_real + x_imag\n tl.store(hidden_real + ptr, h_real_new.to(hidden_real.dtype.element_ty))\n tl.store(hidden_imag + ptr, h_imag_new.to(hidden_imag.dtype.element_ty))\n h_real = h_real_new\n h_imag = h_imag_new\n ptr += C\n\n@triton.jit\ndef bwd_sequential_scan_complex(\n grad_output_real, # Real part of the gradient of output\n grad_output_imag, # Imaginary part of the gradient of output\n v_real, # Real part of input tensor\n v_imag, # Imaginary part of input tensor\n f_real, # Real part of decay factor\n f_imag, # Imaginary part of decay factor\n hidden_real, # Real part of hidden state\n hidden_imag, # Imaginary part of hidden state\n B, # Batch size\n L, # Sequence length\n C, # Hidden dimension size\n BLOCK_M: tl.constexpr, # Block size in the M dimension\n):\n offset_b = tl.program_id(0)\n if offset_b >= B:\n return\n offset_n = tl.program_id(1)\n ptr = tl.arange(0, BLOCK_M) + offset_b * L * C + (L-1) * C + offset_n * BLOCK_M\n grad_h_real = tl.zeros([BLOCK_M,], dtype=tl.float32)\n grad_h_imag = tl.zeros([BLOCK_M,], dtype=tl.float32)\n\n for time_step in range(L-1, -1, -1):\n grad_real = tl.load(grad_output_real + ptr).to(tl.float32)\n grad_imag = tl.load(grad_output_imag + ptr).to(tl.float32)\n grad_h_real += grad_real\n grad_h_imag += grad_imag\n decay_real = tl.load(f_real + ptr).to(tl.float32)\n decay_imag = tl.load(f_imag + ptr).to(tl.float32)\n h_real = tl.load(hidden_real + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32)\n h_imag = tl.load(hidden_imag + ptr - C, mask= ptr >= (offset_b * L * C + C), other=0.0).to(tl.float32)\n grad_f_real = (grad_h_real * h_real + grad_h_imag * h_imag)\n grad_f_imag = (grad_h_imag * h_real - grad_h_real * h_imag)\n tl.store(f_real + ptr, grad_f_real.to(f_real.dtype.element_ty))\n tl.store(f_imag + ptr, grad_f_imag.to(f_real.dtype.element_ty))\n tl.store(v_real + ptr, grad_h_real.to(v_real.dtype.element_ty))\n tl.store(v_imag + ptr, grad_h_imag.to(v_real.dtype.element_ty))\n grad_h_real_new = grad_h_real * decay_real + grad_h_imag * decay_imag\n grad_h_imag_new = grad_h_imag * decay_real - grad_h_real * decay_imag\n grad_h_real = grad_h_real_new\n grad_h_imag = grad_h_imag_new\n ptr -= C\n\nclass TritonSequentialScan_Complex(Function):\n @staticmethod\n @torch.cuda.amp.custom_fwd\n def forward(ctx, v_real, v_imag, f_real, f_imag):\n B, L, C = v_real.shape\n num_warps = 8\n assert C % 256 == 0, 'Hidden dimension must be multiple of 256'\n v_real = v_real.contiguous()\n v_imag = v_imag.contiguous()\n f_real = f_real.contiguous()\n f_imag = f_imag.contiguous()\n hidden_real = torch.zeros_like(v_real).contiguous()\n hidden_imag = torch.zeros_like(v_imag).contiguous()\n fwd_sequential_scan_complex[(B, int(C/256))](\n v_real,\n v_imag,\n f_real,\n f_imag,\n hidden_real,\n hidden_imag,\n B,\n L,\n C,\n BLOCK_M=256,\n num_warps=num_warps\n )\n ctx.save_for_backward(v_real, v_imag, f_real, f_imag, hidden_real, hidden_imag)\n return hidden_real, hidden_imag\n\n @staticmethod\n @torch.cuda.amp.custom_bwd\n def backward(ctx, grad_output_real, grad_output_imag):\n v_real, v_imag, f_real, f_imag, hidden_real, hidden_imag = ctx.saved_tensors\n B, L, C = v_real.shape\n num_warps = 8\n bwd_sequential_scan_complex[(B, int(C/256))](\n grad_output_real,\n grad_output_imag,\n v_real,\n v_imag,\n f_real,\n f_imag,\n hidden_real,\n hidden_imag,\n B,\n L,\n C,\n BLOCK_M=256,\n num_warps=num_warps\n )\n return v_real, v_imag, f_real, f_imag\n\ncomplex_scan = TritonSequentialScan_Complex.apply\n", - "description_1": "Use triton language to implement forward and backward pass kernels for a sequential scan on complex-valued data. The forward kernel computes new hidden states based on input complex vectors and decay factors, iterating over the sequence length. The backward kernel computes gradients with respect to inputs and decay factors, iterating in reverse. Both kernels are executed with blocks of threads determined by batch size, sequence length, and hidden dimension.", - "description_2": "Use triton language to create forward and backward kernels for sequentially scanning complex data, updating hidden states, and computing gradients.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_copy_kv_index_to_req(\n req_to_token_indexs, b_req_idx, b_seq_len, memindex,\n stride_req_to_token_b, stride_req_to_token_s\n):\n cur_index = tl.program_id(0)\n cur_req_idx = tl.load(b_req_idx + cur_index)\n cur_token_index = tl.load(memindex + cur_index)\n cur_seq_len = tl.load(b_seq_len + cur_index)\n dest_offset = req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (cur_seq_len - 1) * stride_req_to_token_s\n tl.store(dest_offset, cur_token_index)\n return\n\n@torch.inference_mode()\ndef copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex):\n \"\"\"\n Copy indices of newly allocated K/V slots to req_to_token_indexs, will be\n invoked in the decoding stage.\n \"\"\"\n seq_len = b_seq_len.shape[0]\n assert b_seq_len.shape[0] == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_copy_kv_index_to_req[grid](\n req_to_token_indexs, b_req_idx, b_seq_len, memindex,\n req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_copy_kv_index_to_req' that copies token indices from 'memindex' to 'req_to_token_indexs'. The function takes six parameters: 'req_to_token_indexs' (destination for token indices), 'b_req_idx' (batch request indices), 'b_seq_len' (batch sequence lengths), 'memindex' (memory indices), 'stride_req_to_token_b' and 'stride_req_to_token_s' (strides for accessing 'req_to_token_indexs'). The kernel uses 'tl.program_id' to parallelize over the first dimension, loading values from global memory using 'tl.load', computing destination offsets, and storing results using 'tl.store'. Additionally, a wrapper function 'copy_kv_index_to_req' is provided to set up the execution configuration for the kernel, including the grid size and number of warps.", - "description_2": "Use triton language to create a kernel that copies indices based on input strides and uses a wrapper to configure its execution parameters.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel for copying data based on destination index\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n K, Dest_loc,\n Out,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n\n k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)\n tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)\n return\n\n# Function to invoke the kernel for copying data\n@torch.inference_mode()\ndef destindex_copy_kv(K, DestLoc, Out):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n K, DestLoc, Out,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n# Kernel for copying and quantizing data based on destination index\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K, Dest_loc, Out, Out_scale,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n stride_os_bs, stride_os_h, stride_os_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], \n mask=offs_h[:, None] < head_num, other=0.0)\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)[:, None]\n q_src_data = (src_data / data_scale).to(tl.int8)\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]\n tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num)\n tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num)\n\n# Function to invoke the kernel for copying and quantizing data\n@torch.inference_mode()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K, DestLoc, Out, Out_scale,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two kernels: one for copying data from a source tensor to a destination tensor based on a destination index, and another for copying and quantizing data. The first kernel (_fwd_kernel_destindex_copy_kv) takes 10 parameters: source tensor K, destination index Dest_loc, output tensor Out, strides for K and Out, head_num, and block sizes BLOCK_DMODEL and BLOCK_HEAD. The second kernel (_fwd_kernel_destindex_copy_quantize_kv) takes 13 parameters: source tensor K, destination index Dest_loc, output tensor Out, output scale tensor Out_scale, strides for K, Out, and Out_scale, head_num, and block sizes BLOCK_DMODEL and BLOCK_HEAD. Both kernels are invoked by their respective functions destindex_copy_kv and destindex_copy_quantize_kv, which set up the grid and block sizes and call the kernels with the appropriate parameters.", - "description_2": "Use triton language to create kernels for data manipulation: one for copying data based on an index and another for copying with quantization. Implement functions to set up and invoke these kernels with necessary parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef __fwd_kernel_pre_copy_and_register_kv(\n kv_range_begin, # [batch_size, ]\n kv_range_end, # [batch_size, ]\n new_kv_cache_len, # [batch_size, ]\n batch_size,\n kv_cache_index_begin, # [batch_size, ]\n kv_cache_index_end, # [batch_size, ]\n kv_first_token_global_idx, # [batch_size, ]\n num_logical_sp_peers,\n BLOCK_SIZE: tl.constexpr\n):\n offs = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n cur_kv_cache_index_begin = tl.load(kv_cache_index_begin + offs, mask=offs < batch_size)\n cur_kv_cache_index_end = tl.load(kv_cache_index_end + offs, mask=offs < batch_size)\n cur_kv_first_token_global_idx = tl.load(kv_first_token_global_idx + offs, mask=offs < batch_size)\n\n cur_kv_range_begin = tl.cdiv(cur_kv_cache_index_begin - cur_kv_first_token_global_idx, num_logical_sp_peers)\n cur_kv_range_end = tl.cdiv(cur_kv_cache_index_end - cur_kv_first_token_global_idx, num_logical_sp_peers)\n cur_new_kv_cache_len = cur_kv_range_end - cur_kv_range_begin\n \n tl.store(kv_range_begin + offs, cur_kv_range_begin, mask=offs < batch_size)\n tl.store(kv_range_end + offs, cur_kv_range_end, mask=offs < batch_size)\n tl.store(new_kv_cache_len + offs, cur_new_kv_cache_len, mask=offs < batch_size)\n\n@torch.inference_mode()\ndef pre_copy_and_register_kv(\n kv_cache_index_begin: torch.Tensor, # [batch_size, ]\n kv_cache_index_end: torch.Tensor, # [batch_size, ]\n kv_first_token_global_idx: torch.Tensor, # [batch_size, ]\n num_logical_sp_peers: int\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n batch_size = kv_cache_index_begin.shape[0]\n kv_range_begin = torch.empty_like(kv_cache_index_begin)\n kv_range_end = torch.empty_like(kv_cache_index_begin)\n new_kv_cache_len = torch.empty_like(kv_cache_index_begin)\n\n BLOCK_SIZE = 64\n grid = ((batch_size+BLOCK_SIZE-1)//BLOCK_SIZE, )\n __fwd_kernel_pre_copy_and_register_kv[grid](\n kv_range_begin, kv_range_end, new_kv_cache_len,\n batch_size,\n kv_cache_index_begin, kv_cache_index_end, kv_first_token_global_idx, num_logical_sp_peers,\n BLOCK_SIZE\n )\n\n return (kv_range_begin, kv_range_end, new_kv_cache_len)\n\n@triton.jit\ndef __fwd_kernel_destindex_copy_and_register_kv(\n new_kv_cache_len, # [batch_size,]\n kv_range_begin, # [batch_size,]\n kv_range_end, # [batch_size,]\n new_kv_cache_len_sum, # [batch_size,]\n mem_index, # [alloc_token_num,]\n kv_b_start_loc, # [batch_size,]\n kv, # [max_token_num, num_head, head_dim]\n total_kv_cache, # [_, num_head, head_dim]\n b_req_idx, # [batch_size,]\n cur_kv_cache_index, # [batch_size,]\n req_to_token_indexs, # [max_request_num, max_token_num]\n stride_kv_bs, stride_kv_h, stride_kv_d,\n stride_total_kv_cache_bs, stride_total_kv_cache_h, stride_total_kv_cache_d,\n stride_req_to_tokens_b, stride_req_to_tokens_s,\n num_used_mem_index,\n head_num,\n should_register_kv: tl.constexpr,\n BLOCK_HEAD: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n loop_start = tl.program_id(1) * BLOCK_N\n\n cur_new_kv_cache_len = tl.load(new_kv_cache_len + cur_batch)\n cur_kv_range_begin = tl.load(kv_range_begin + cur_batch)\n cur_kv_range_end = tl.load(kv_range_end + cur_batch)\n\n if (cur_new_kv_cache_len <= 0 or loop_start >= cur_new_kv_cache_len) or (cur_kv_range_begin < 0 or cur_kv_range_end < 0):\n return\n\n cur_mem_index_start = tl.load(new_kv_cache_len_sum + cur_batch - 1, mask=cur_batch>0, other=0) + num_used_mem_index\n cur_mem_index_ptr = mem_index + cur_mem_index_start\n\n cur_kv_b_start_loc = tl.load(kv_b_start_loc + cur_batch)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n stride_kv_bs = stride_kv_bs.to(tl.int64)\n stride_total_kv_cache_bs = stride_total_kv_cache_bs.to(tl.int64)\n cur_kv_ptrs = kv + (cur_kv_b_start_loc + cur_kv_range_begin) * stride_kv_bs + offs_h[:, None] * stride_kv_h + offs_d[None, :] * stride_kv_d\n\n total_kv_cache_ptrs = total_kv_cache + offs_h[:, None] * stride_total_kv_cache_h + offs_d[None, :] * stride_total_kv_cache_d\n\n if should_register_kv:\n cur_b_req_idx = tl.load(b_req_idx + cur_batch)\n cur_kv_cache_index_start = tl.load(cur_kv_cache_index + cur_batch)\n req_to_token_indexs_ptr = req_to_token_indexs + cur_b_req_idx * stride_req_to_tokens_b + cur_kv_cache_index_start * stride_req_to_tokens_s\n\n loop_end = tl.where(loop_start + BLOCK_N < cur_new_kv_cache_len, loop_start + BLOCK_N, cur_new_kv_cache_len)\n for start_n in range(loop_start, loop_end):\n cur_kv = tl.load(cur_kv_ptrs + start_n * stride_kv_bs, mask=offs_h[:, None] < head_num, other=0.0)\n cur_mem_index = tl.load(cur_mem_index_ptr + start_n)\n \n tl.store(total_kv_cache_ptrs + cur_mem_index * stride_total_kv_cache_bs, cur_kv, mask=offs_h[:, None] < head_num)\n if should_register_kv:\n tl.store(req_to_token_indexs_ptr + start_n * stride_req_to_tokens_s, cur_mem_index)\n\n@torch.inference_mode()\ndef destindex_copy_and_register_kv(\n batch_size: int,\n new_kv_cache_len: torch.Tensor,\n kv_range_begin: torch.Tensor,\n kv_range_end: torch.Tensor,\n new_kv_cache_len_sum: torch.Tensor,\n kv: torch.Tensor,\n kv_b_start_loc: torch.Tensor,\n cur_kv_cache_index: torch.Tensor,\n mem_index: torch.Tensor,\n total_kv_cache: torch.Tensor,\n req_to_token_indexs: torch.Tensor,\n b_req_idx: torch.Tensor,\n num_used_mem_index: int,\n max_len_in_batch: int,\n should_register_kv: bool\n): \n assert new_kv_cache_len.shape[0] == kv_range_begin.shape[0] == kv_range_end.shape[0] == new_kv_cache_len_sum.shape[0] == kv_b_start_loc.shape[0] == b_req_idx.shape[0] == cur_kv_cache_index.shape[0] == batch_size\n assert kv.shape[1] == total_kv_cache.shape[1] and kv.shape[2] == total_kv_cache.shape[2]\n head_num = kv.shape[1]\n head_dim = kv.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n\n BLOCK_N = 256\n\n grid = (batch_size, (max_len_in_batch + BLOCK_N - 1) // BLOCK_N)\n __fwd_kernel_destindex_copy_and_register_kv[grid](\n new_kv_cache_len,\n kv_range_begin,\n kv_range_end,\n new_kv_cache_len_sum,\n mem_index,\n kv_b_start_loc,\n kv,\n total_kv_cache,\n b_req_idx,\n cur_kv_cache_index,\n req_to_token_indexs,\n kv.stride(0), kv.stride(1), kv.stride(2),\n total_kv_cache.stride(0), total_kv_cache.stride(1), total_kv_cache.stride(2),\n req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),\n num_used_mem_index,\n head_num,\n should_register_kv,\n BLOCK_HEAD,\n BLOCK_DMODEL=head_dim,\n BLOCK_N=BLOCK_N,\n )\n", - "description_1": "Use triton language to implement two kernels: one for calculating key-value range and cache length, and another for copying and registering key-value pairs. The first kernel takes 8 parameters including tensors for kv_range_begin, kv_range_end, new_kv_cache_len, and constants like batch_size and BLOCK_SIZE. The second kernel takes 21 parameters including tensors for new_kv_cache_len, kv_range_begin, kv_range_end, and constants like BLOCK_HEAD, BLOCK_DMODEL, and BLOCK_N.", - "description_2": "Use triton language to implement kernels for key-value range calculation and key-value pair registration with parameters for tensor operations and block configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit()\ndef _fwd_sender_kernel(\n\tsend_buf,\t# [num_layers, num_tokens_sum, 2*head_num, head_dim]\n\tstride_buf_layer,\t# Not marked as tl.constexpr to avoid re-compile\n\tstride_buf_token,\t# Not marked as tl.constexpr to avoid re-compile\n\tstride_buf_head: tl.constexpr,\n\tstride_buf_headdim: tl.constexpr,\n\treq_to_token_indexes,\t# [max_request_num, max_sequence_length]\n\tstride_rtti_reqid: tl.constexpr,\n\tstride_rtti_tokenidx: tl.constexpr,\n\tmem_state,\t\t# [memory_manager_size]\n\n\tkv_cache,\t\t# [num_layers, kvcache_size, 2*head_num, head_dim]\n\tstride_kvc_layer,\t# Not marked as tl.constexpr to be converted to tl.int64\n\tstride_kvc_token: tl.constexpr,\n\tstride_kvc_head: tl.constexpr,\n\tstride_kvc_headdim: tl.constexpr,\n\n\trequest_ids,\t# [batch_size]\n\tnum_tokens,\t\t# [batch_size]\n\tnum_tokens_cumsum,\t# [batch_size]\n\tb_seq_len,\t\t# [batch_size]\n\n\tnum_layers: tl.constexpr,\n\tkvcache_size: tl.constexpr,\n\tnum_heads: tl.constexpr,\n\thead_dim: tl.constexpr\n):\n\t\"\"\"\n\tThe Triton kernel for the sender during decoding stage migration\n\n\tThis kernel performs the following jobs:\n\t- Set mem_state of migrated tokens to 0\n\t- Gather K/Vs from kv_buffer to send_buf\n\n\tgrid: (batch, migrating_token_idx, 2*num_heads)\n\t\"\"\"\n\tbatch_idx = tl.program_id(0)\n\tmigrating_token_idx = tl.program_id(1)\n\tcur_head = tl.program_id(2)\n\n\tcur_num_tokens = tl.load(num_tokens + batch_idx)\n\tif migrating_token_idx >= cur_num_tokens:\n\t\treturn\n\t\n\tcur_b_seq_len = tl.load(b_seq_len + batch_idx)\n\tcur_token_idx_in_req = cur_b_seq_len - cur_num_tokens + migrating_token_idx\n\n\tcur_request_id = tl.load(request_ids + batch_idx)\n\tcur_token_idx_in_kv_cache = tl.load(req_to_token_indexes + cur_request_id*stride_rtti_reqid + cur_token_idx_in_req*stride_rtti_tokenidx)\n\tcur_token_idx_in_send_buf = tl.load(num_tokens_cumsum + batch_idx - 1, mask=batch_idx>0, other=0) + migrating_token_idx\n\tcur_token_idx_in_kv_cache = cur_token_idx_in_kv_cache.to(tl.int64)\n\tcur_token_idx_in_send_buf = cur_token_idx_in_send_buf.to(tl.int64)\n\n\ttl.store(mem_state + cur_token_idx_in_kv_cache, 0)\n\n\tstride_kvc_layer = stride_kvc_layer.to(tl.int64)\n\tstride_buf_layer = stride_buf_layer.to(tl.int64)\n\tkvc_ptrs = kv_cache + cur_token_idx_in_kv_cache*stride_kvc_token + cur_head*stride_kvc_head + tl.arange(0, head_dim)*stride_kvc_headdim\n\tsend_buf_ptrs = send_buf + cur_token_idx_in_send_buf*stride_buf_token + cur_head*stride_buf_head + tl.arange(0, head_dim)*stride_buf_headdim\n\tfor layer in tl.static_range(num_layers):\n\t\t# Need a for-loop here since tl.arange() only accepts power-of-two\n\t\tcur_kvc_ptrs = kvc_ptrs + layer*stride_kvc_layer\n\t\tcur_send_buf_ptrs = send_buf_ptrs + layer*stride_buf_layer\n\t\ttl.store(cur_send_buf_ptrs, tl.load(cur_kvc_ptrs))\n\ndef decoding_stage_migration_sender_kernel(\n\tsend_buf: torch.Tensor,\n\treq_to_token_indexes: torch.Tensor,\n\tmem_state: torch.Tensor,\n\tkv_cache: torch.Tensor,\n\trequest_ids: torch.Tensor,\n\tnum_tokens: torch.Tensor,\n\tb_seq_len: torch.Tensor\n):\n\tnum_tokens_max = torch.max(num_tokens)\n\tnum_tokens_cumsum = torch.cumsum(num_tokens, dim=0)\n\tbatch_size = request_ids.shape[0]\n\tnum_layers = kv_cache.shape[0]\n\tkvcache_size = kv_cache.shape[1]\n\tnum_heads = kv_cache.shape[2] // 2\n\thead_dim = kv_cache.shape[3]\n\n\t_fwd_sender_kernel[(batch_size, num_tokens_max, 2*num_heads)](\n\t\tsend_buf, send_buf.stride(0), send_buf.stride(1), send_buf.stride(2), send_buf.stride(3),\n\t\treq_to_token_indexes, req_to_token_indexes.stride(0), req_to_token_indexes.stride(1),\n\t\tmem_state,\n\t\tkv_cache, kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),\n\t\trequest_ids, num_tokens, num_tokens_cumsum, b_seq_len,\n\t\tnum_layers, kvcache_size, num_heads, head_dim\n\t)\n\n@triton.jit()\ndef _fwd_receiver_kernel(\n\trecv_buf,\t# [num_layers, num_tokens_sum, 2*head_num, head_dim]\n\tstride_buf_layer,\t# Not marked as tl.constexpr to avoid re-compile\n\tstride_buf_token,\t# Not marked as tl.constexpr to avoid re-compile\n\tstride_buf_head: tl.constexpr,\n\tstride_buf_headdim: tl.constexpr,\n\treq_to_token_indexes,\t# [max_request_num, max_sequence_length]\n\tstride_rtti_reqid: tl.constexpr,\n\tstride_rtti_tokenidx: tl.constexpr,\n\tkv_cache,\t\t# [num_layers, kvcache_size, 2*head_num, head_dim]\n\tstride_kvc_layer,\t# Not marked as tl.constexpr to be converted to tl.int64\n\tstride_kvc_token: tl.constexpr,\n\tstride_kvc_head: tl.constexpr,\n\tstride_kvc_headdim: tl.constexpr,\n\talloc_mem,\t\t# [num_tokens_sum]\n\n\trequest_ids,\t# [batch_size]\n\tnum_tokens,\t\t# [batch_size]\n\tnum_tokens_cumsum,\t# [batch_size]\n\tb_seq_len,\t\t# [batch_size]\n\n\tnum_layers: tl.constexpr,\n\tkvcache_size: tl.constexpr,\n\tnum_heads: tl.constexpr,\n\thead_dim: tl.constexpr\n):\n\t\"\"\"\n\tThe Triton kernel for the receiver during decoding stage migration\n\n\tThis kernel performs the following jobs:\n\t- Save recv_buf to kv_cache\n\t- Modify req_to_token_indexes\n\n\tgrid: (batch, migrating_token_idx, 2*num_heads)\n\t\"\"\"\n\tbatch_idx = tl.program_id(0)\n\tmigrating_token_idx = tl.program_id(1)\n\tcur_head = tl.program_id(2)\n\n\tcur_num_tokens = tl.load(num_tokens + batch_idx)\n\tif migrating_token_idx >= cur_num_tokens:\n\t\treturn\n\t\n\tcur_b_seq_len = tl.load(b_seq_len + batch_idx)\n\tcur_token_idx_in_recv_buf = tl.load(num_tokens_cumsum + batch_idx - 1, mask=batch_idx>0, other=0) + migrating_token_idx\n\tcur_token_idx_in_kv_cache = tl.load(alloc_mem + cur_token_idx_in_recv_buf)\n\tcur_token_idx_in_recv_buf = cur_token_idx_in_recv_buf.to(tl.int64)\n\tcur_token_idx_in_kv_cache = cur_token_idx_in_kv_cache.to(tl.int64)\n\n\tcur_request_id = tl.load(request_ids + batch_idx)\n\tcur_token_idx_in_req = cur_b_seq_len + migrating_token_idx\n\ttl.store(req_to_token_indexes + cur_request_id*stride_rtti_reqid + cur_token_idx_in_req*stride_rtti_tokenidx, cur_token_idx_in_kv_cache)\n\n\tstride_kvc_layer = stride_kvc_layer.to(tl.int64)\n\tstride_buf_layer = stride_buf_layer.to(tl.int64)\n\tkvc_ptrs = kv_cache + cur_token_idx_in_kv_cache*stride_kvc_token + cur_head*stride_kvc_head + tl.arange(0, head_dim)*stride_kvc_headdim\n\trecv_buf_ptrs = recv_buf + cur_token_idx_in_recv_buf*stride_buf_token + cur_head*stride_buf_head + tl.arange(0, head_dim)*stride_buf_headdim\n\tfor layer in range(num_layers):\n\t\tcur_kvc_ptrs = kvc_ptrs + layer*stride_kvc_layer\n\t\tcur_recv_buf_ptrs = recv_buf_ptrs + layer*stride_buf_layer\n\t\ttl.store(cur_kvc_ptrs, tl.load(cur_recv_buf_ptrs))\n\ndef decoding_stage_migration_receiver_kernel(\n\trecv_buf: torch.Tensor,\n\treq_to_token_indexes: torch.Tensor,\n\tkv_cache: torch.Tensor,\n\talloc_mem: torch.Tensor,\n\trequest_ids: torch.Tensor,\n\tnum_tokens: torch.Tensor,\n\tb_seq_len: torch.Tensor\n):\n\tnum_tokens_max = torch.max(num_tokens)\n\tnum_tokens_cumsum = torch.cumsum(num_tokens, dim=0)\n\tbatch_size = request_ids.shape[0]\n\tnum_layers = kv_cache.shape[0]\n\tkvcache_size = kv_cache.shape[1]\n\tnum_heads = kv_cache.shape[2] // 2\n\thead_dim = kv_cache.shape[3]\n\n\t_fwd_receiver_kernel[(batch_size, num_tokens_max, 2*num_heads)](\n\t\trecv_buf, recv_buf.stride(0), recv_buf.stride(1), recv_buf.stride(2), recv_buf.stride(3),\n\t\treq_to_token_indexes, req_to_token_indexes.stride(0), req_to_token_indexes.stride(1),\n\t\tkv_cache, kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3),\n\t\talloc_mem,\n\t\trequest_ids, num_tokens, num_tokens_cumsum, b_seq_len,\n\t\tnum_layers, kvcache_size, num_heads, head_dim\n\t)\n", - "description_1": "Use triton language to implement two kernels for decoding stage migration. The first kernel, _fwd_sender_kernel, takes 18 parameters: send_buf, stride_buf_layer, stride_buf_token, stride_buf_head, stride_buf_headdim, req_to_token_indexes, stride_rtti_reqid, stride_rtti_tokenidx, mem_state, kv_cache, stride_kvc_layer, stride_kvc_token, stride_kvc_head, stride_kvc_headdim, request_ids, num_tokens, num_tokens_cumsum, b_seq_len, and constants num_layers, kvcache_size, num_heads, head_dim. It sets mem_state of migrated tokens to 0 and gathers K/Vs from kv_buffer to send_buf. The second kernel, _fwd_receiver_kernel, takes 18 parameters: recv_buf, stride_buf_layer, stride_buf_token, stride_buf_head, stride_buf_headdim, req_to_token_indexes, stride_rtti_reqid, stride_rtti_tokenidx, kv_cache, stride_kvc_layer, stride_kvc_token, stride_kvc_head, stride_kvc_headdim, alloc_mem, request_ids, num_tokens, num_tokens_cumsum, b_seq_len, and constants num_layers, kvcache_size, num_heads, head_dim. It saves recv_buf to kv_cache and modifies req_to_token_indexes.", - "description_2": "Use triton language to create kernels for migrating tokens during decoding. The sender kernel manages memory state and gathers data, while the receiver kernel saves data and updates indexes.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage1(\n Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, #[batch, head, seq_block_num]\n stride_req_to_tokens_b, stride_req_to_tokens_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od,\n stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es,\n gqa_group_size,\n BLOCK_SEQ: tl.constexpr, \n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n seq_start_block = tl.program_id(2)\n cur_kv_head = cur_head // gqa_group_size\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_start_index = seq_start_block * BLOCK_SEQ\n cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n \n block_n_size = tl.cdiv(\n tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index),\n BLOCK_N\n )\n \n offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n \n q = tl.load(Q + off_q)\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, block_n_size, 1):\n offs_n_new = start_n * BLOCK_N + offs_n\n k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0).to(tl.int64)\n off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float(\"-inf\"))\n v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n \n cur_max_logic = tl.max(att_value, axis=0)\n new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n exp_logic = tl.exp(att_value - new_max_logic)\n logic_scale = tl.exp(max_logic - new_max_logic)\n acc *= logic_scale\n acc += tl.sum(exp_logic[:, None] * v, axis=0)\n\n sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)\n max_logic = new_max_logic\n \n need_store = tl.where(block_n_size == 0, 0, 1)\n for _ in range(0, need_store, 1):\n off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d\n off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block\n tl.store(Mid_O + off_mid_o, acc / sum_exp)\n tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_flash_decode_stage1' with 24 parameters. The function performs a forward pass for flash attention decoding. It computes attention scores and updates intermediate outputs 'Mid_O' and 'Mid_O_LogExpSum' based on input tensors 'Q', 'K', 'V', and other parameters like 'sm_scale', 'Req_to_tokens', 'B_req_idx', 'B_Seqlen', and various strides. The function uses block sizes defined by 'BLOCK_SEQ', 'BLOCK_DMODEL', and 'BLOCK_N'.", - "description_2": "Use triton language to implement a kernel for flash attention decoding with 24 parameters, computing attention scores and updating outputs based on input tensors and block sizes.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, # [batch, head, seq_block_num]\n O, # [batch, head, head_dim]\n out_logexpsum, # [batch, head]\n stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od,\n stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es,\n stride_obs, stride_oh, stride_od,\n stride_out_logexpsum_b, stride_out_logexpsum_h,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = float(\"-1e20\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n for block_seq_n in range(0, block_n_size, 1):\n tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)\n tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n new_max_logic = tl.maximum(tlogic, max_logic)\n \n old_scale = tl.exp(max_logic - new_max_logic)\n acc *= old_scale\n exp_logic = tl.exp(tlogic - new_max_logic)\n acc += exp_logic * tv\n sum_exp = sum_exp * old_scale + exp_logic\n max_logic = new_max_logic\n \n if block_n_size > 0:\n # Here we check whether block_n_size is 0 in order to avoid \"div by zero\" error\n tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)\n tl.store(out_logexpsum + cur_batch * stride_out_logexpsum_b + cur_head * stride_out_logexpsum_h, max_logic + tl.log(sum_exp))\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_flash_decode_stage2' that performs a sequence of operations on input tensors. The function takes 17 parameters: B_Seqlen (tensor), Mid_O (tensor), Mid_O_LogExpSum (tensor), O (tensor), out_logexpsum (tensor), stride_mid_ob (int), stride_mid_oh (int), stride_mid_os (int), stride_mid_od (int), stride_mid_o_eb (int), stride_mid_o_eh (int), stride_mid_o_es (int), stride_obs (int), stride_oh (int), stride_od (int), stride_out_logexpsum_b (int), stride_out_logexpsum_h (int), BLOCK_SEQ (constexpr), and BLOCK_DMODEL (constexpr). The kernel computes a weighted sum of input blocks and stores the result in the output tensor O and out_logexpsum.", - "description_2": "Use triton language to create a kernel that processes input tensors to compute a weighted sum and store results in output tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n q_b_start_loc, q_b_seqlen, q_first_token_global_idx,\n kv_b_start_loc, kv_b_seqlen, kv_first_token_global_idx,\n logical_sp_peers_num: tl.constexpr,\n Out, m, l,\n stride_qbs: tl.constexpr, stride_qh: tl.constexpr, stride_qd: tl.constexpr,\n stride_kbs: tl.constexpr, stride_kh: tl.constexpr, stride_kd: tl.constexpr,\n stride_vbs: tl.constexpr, stride_vh: tl.constexpr, stride_vd: tl.constexpr,\n stride_obs: tl.constexpr, stride_oh: tl.constexpr, stride_od: tl.constexpr,\n stride_mbs: tl.constexpr, stride_mh: tl.constexpr,\n stride_lbs: tl.constexpr, stride_lh: tl.constexpr,\n kv_group_num: tl.constexpr,\n BLOCK_M: tl.constexpr, DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n blockIdz = tl.program_id(2)\n \n cur_kv_head = cur_head // kv_group_num\n\n cur_q_seq_len = tl.load(q_b_seqlen + cur_batch)\n if blockIdz*BLOCK_M >= cur_q_seq_len:\n return\n \n cur_kv_seq_len = tl.load(kv_b_seqlen + cur_batch)\n cur_q_start_index = tl.load(q_b_start_loc + cur_batch)\n cur_kv_start_index = tl.load(kv_b_start_loc + cur_batch)\n cur_q_first_token_global_idx = tl.load(q_first_token_global_idx + cur_batch)\n cur_kv_first_token_global_idx = tl.load(kv_first_token_global_idx + cur_batch)\n\n Q += cur_q_start_index*stride_qbs + cur_head*stride_qh\n K += cur_kv_start_index*stride_kbs + cur_kv_head*stride_kh\n V += cur_kv_start_index*stride_vbs + cur_kv_head*stride_vh\n Out += cur_q_start_index*stride_obs + cur_head*stride_oh\n m += cur_q_start_index*stride_mbs + cur_head*stride_mh\n l += cur_q_start_index*stride_lbs + cur_head*stride_lh\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, DMODEL)\n offs_m = blockIdz*BLOCK_M + tl.arange(0, BLOCK_M)\n multed_offs_m = offs_m * logical_sp_peers_num\n multed_offs_n = offs_n * logical_sp_peers_num\n\n q_ptrs = Q + offs_m[:, None] * stride_qbs + offs_d[None, :] * stride_qd\n q = tl.load(q_ptrs, mask=offs_m[:, None] < cur_q_seq_len, other=0.0, cache_modifier=\".cg\")\n k_ptrs = K + offs_n[None, :] * stride_kbs + offs_d[:, None] * stride_kd\n v_ptrs = V + offs_n[:, None] * stride_vbs + offs_d[None, :] * stride_vd\n\n m_i = tl.full([BLOCK_M], value=-float(\"inf\"), dtype=tl.float32)\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, DMODEL], dtype=tl.float32)\n\n loop_range = tl.minimum(\n tl.cdiv(\n cur_q_first_token_global_idx + ((blockIdz+1)*BLOCK_M-1)*logical_sp_peers_num+1 - cur_kv_first_token_global_idx,\n logical_sp_peers_num\n ),\n cur_kv_seq_len\n ) \n loop_range = tl.maximum(loop_range, 0)\n\n loop1_end = tl.maximum(loop_range-BLOCK_N*tl.cdiv(BLOCK_M, BLOCK_N), 0)\n for start_n in range(0, loop1_end, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n\n k = tl.load(k_ptrs + start_n*stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_kv_seq_len, other=0.0, cache_modifier=\".cg\")\n qk = tl.dot(q, k, out_dtype=tl.float32)\n k = None\n v = tl.load(v_ptrs + start_n*stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_kv_seq_len, other=0.0, cache_modifier=\".cg\")\n\n m_i_new = tl.maximum(m_i, tl.max(qk, 1)*sm_scale)\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk*sm_scale - m_i_new[:, None])\n acc *= alpha[:, None]\n acc += tl.dot(p.to(tl.float16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n for start_n in range(loop1_end, loop_range, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n\n k = tl.load(k_ptrs + start_n*stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_kv_seq_len, other=0.0, cache_modifier=\".cg\")\n qk = tl.dot(q, k, out_dtype=tl.float32)\n k = None\n v = tl.load(v_ptrs + start_n*stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_kv_seq_len, other=0.0, cache_modifier=\".cg\")\n\n qk = tl.where(\n ((cur_q_first_token_global_idx + multed_offs_m[:, None]) >= \\\n (cur_kv_first_token_global_idx + start_n*logical_sp_peers_num + multed_offs_n[None, :])) & \\\n ((start_n + offs_n[None, :]) < cur_kv_seq_len),\n qk, float(\"-1e20\")\n )\n\n m_i_new = tl.maximum(m_i, tl.max(qk, 1)*sm_scale)\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk*sm_scale - m_i_new[:, None])\n acc *= alpha[:, None]\n acc += tl.dot(p.to(tl.float16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n \n out_ptrs = Out + offs_m[:, None] * stride_obs + offs_d[None, :] * stride_od\n m_ptrs = m + offs_m * stride_mbs\n l_ptrs = l + offs_m * stride_lbs\n\n old_out = tl.load(out_ptrs, mask=offs_m[:, None] < cur_q_seq_len)\n m_i_old = tl.load(m_ptrs, mask=offs_m < cur_q_seq_len, other=float(\"-inf\"))\n l_i_old = tl.load(l_ptrs, mask=offs_m < cur_q_seq_len, other=0.)\n\n m_i_new = tl.maximum(m_i, m_i_old)\n l_i_new = l_i_old*tl.math.exp2(m_i_old-m_i_new) + l_i*tl.math.exp2(m_i-m_i_new)\n out = (\n old_out * (l_i_old*tl.math.exp2(m_i_old-m_i_new))[:, None] + \n acc.to(tl.float16) * tl.math.exp2(m_i-m_i_new)[:, None]\n ) / l_i_new[:, None]\n\n tl.store(out_ptrs, out, mask=offs_m[:, None] < cur_q_seq_len, cache_modifier=\".cg\")\n tl.store(m_ptrs, m_i_new, mask=offs_m < cur_q_seq_len, cache_modifier=\".cg\")\n tl.store(l_ptrs, l_i_new, mask=offs_m < cur_q_seq_len, cache_modifier=\".cg\")\n\n@torch.inference_mode()\ndef context_attention_fwd(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor,\n q_b_start_loc: torch.Tensor, q_b_seq_len: torch.Tensor, q_first_token_global_idx: torch.Tensor,\n kv_b_start_loc: torch.Tensor, kv_b_seq_len: torch.Tensor, kv_first_token_global_idx: torch.Tensor,\n logical_sp_peers_num: int, max_q_b_seq_len: int,\n m: torch.Tensor, l: torch.Tensor\n):\n BLOCK_M = 128 if not TESLA and not RTX4090 else 64\n BLOCK_N = 128 if not TESLA and not RTX4090 else 64\n \n if BLOCK_M//2 >= max(max_q_b_seq_len, 16):\n BLOCK_M = triton.next_power_of_2(max(max_q_b_seq_len, 16))\n if BLOCK_N//2 >= max(max_q_b_seq_len, 16):\n BLOCK_N = triton.next_power_of_2(max(max_q_b_seq_len, 16))\n \n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n sm_scale *= 1.442695040888963\n batch_size = q_b_seq_len.shape[0]\n num_q_heads = q.shape[1]\n num_kv_heads = k.shape[1]\n kv_group_num = num_q_heads // num_kv_heads\n \n grid = (batch_size, num_q_heads, triton.cdiv(max_q_b_seq_len, BLOCK_M))\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n q_b_start_loc, q_b_seq_len, q_first_token_global_idx,\n kv_b_start_loc, kv_b_seq_len, kv_first_token_global_idx,\n logical_sp_peers_num,\n o,\n m, l,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n m.stride(0), m.stride(1),\n l.stride(0), l.stride(1),\n kv_group_num=kv_group_num,\n DMODEL=Lk,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=3,\n )\n", - "description_1": "Use triton language to implement a forward kernel for attention computation. The kernel (_fwd_kernel) takes 28 parameters: Q, K, V (input tensors), sm_scale (scale factor), q_b_start_loc, q_b_seqlen, q_first_token_global_idx, kv_b_start_loc, kv_b_seqlen, kv_first_token_global_idx (batch-related indices and lengths), logical_sp_peers_num (constant expression), Out (output tensor), m, l (temporary buffers), and several stride and block size parameters. The kernel computes attention scores and updates the output tensor. The context_attention_fwd function wraps this kernel, taking 14 parameters: q, k, v, o (input and output tensors), q_b_start_loc, q_b_seq_len, q_first_token_global_idx, kv_b_start_loc, kv_b_seq_len, kv_first_token_global_idx (batch-related indices and lengths), logical_sp_peers_num, max_q_b_seq_len (maximum sequence length), m, l (temporary buffers). It sets up the grid and block sizes and calls the kernel.", - "description_2": "Use triton language to create a kernel for computing attention scores between query and key-value pairs, and a wrapper function to manage input/output tensors and launch the kernel.", - "difficulty": 5 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _rotary_kernel(\n Q,\n K,\n Cos,\n Sin,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_cosbs,\n stride_cosd,\n stride_sinbs,\n stride_sind,\n max_total_len,\n HEAD_Q,\n HEAD_K,\n BLOCK_HEAD: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)\n dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL)\n\n off_q0 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range0[None, None, :] * stride_qd\n )\n off_q1 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range1[None, None, :] * stride_qd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n q0 = tl.load(\n Q + off_q0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n q1 = tl.load(\n Q + off_q1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos - q1 * sin\n out1 = q0 * sin + q1 * cos\n\n tl.store(\n Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n tl.store(\n Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n\n off_k0 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range0[None, None, :] * stride_kd\n )\n off_k1 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range1[None, None, :] * stride_kd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n k0 = tl.load(\n K + off_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n k1 = tl.load(\n K + off_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out_k0 = k0 * cos - k1 * sin\n out_k1 = k0 * sin + k1 * cos\n\n tl.store(\n K + off_k0,\n out_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n tl.store(\n K + off_k1,\n out_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n return\n\n@torch.inference_mode()\ndef rotary_emb_fwd(q, k, cos, sin):\n total_len = q.shape[0]\n head_num_q, head_num_k = q.shape[1], k.shape[1]\n head_dim = q.shape[2]\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f\"k shape {k.shape} cos shape {cos.shape}\"\n\n BLOCK_SEQ = 16\n BLOCK_HEAD = 4\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n\n grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n _rotary_kernel[grid](\n q,\n k,\n cos,\n sin,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n cos.stride(0),\n cos.stride(1),\n sin.stride(0),\n sin.stride(1),\n total_len,\n head_num_q,\n head_num_k,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=head_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a rotary embedding kernel function (_rotary_kernel) that takes 20 parameters: Q, K, Cos, Sin (tensors), stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_cosbs, stride_cosd, stride_sinbs, stride_sind (strides for accessing elements in tensors), max_total_len (maximum sequence length), HEAD_Q, HEAD_K (head dimensions), BLOCK_HEAD, BLOCK_SEQ, BLOCK_DMODEL (block sizes as compile-time constants). The kernel performs element-wise operations on Q and K using Cos and Sin, storing the results back into Q and K. The rotary_emb_fwd function is a wrapper that prepares the input data and launches the kernel with appropriate grid and block configurations.", - "description_2": "Use triton language to create a rotary embedding kernel that processes input tensors Q and K with cosine and sine transformations, and a wrapper function to execute this kernel with specified grid and block settings.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _silu_and_mul_kernel(\n input_ptr,\n stride_input_m,\n stride_input_n,\n stride_output_m,\n stride_output_n,\n size_m,\n size_n,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n tid = tl.program_id(0)\n m_offsets = (tid * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)\n\n pid = tl.program_id(1)\n n_offsets = (pid * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)\n\n up_offsets = m_offsets[:, None] * stride_input_m + (n_offsets[None, :] + size_n) * stride_input_n\n gate_offsets = m_offsets[:, None] * stride_input_m + n_offsets[None, :] * stride_input_n\n res_offsets = m_offsets[:, None] * stride_output_m + n_offsets[None, :] * stride_output_n\n\n up = tl.load(\n input_ptr + up_offsets,\n mask=((n_offsets < size_n)[None, :]) & ((m_offsets < size_m)[:, None]),\n other=0.0,\n )\n gate = tl.load(\n input_ptr + gate_offsets,\n mask=((n_offsets < size_n)[None, :]) & ((m_offsets < size_m)[:, None]),\n other=0.0,\n ).to(tl.float32)\n\n gate = gate / (1 + tl.exp(-gate))\n gate = gate.to(tl.float16)\n\n tl.store(\n input_ptr + res_offsets,\n up * gate,\n mask=((n_offsets < size_n)[None, :]) & ((m_offsets < size_m)[:, None]),\n )\n\n\ndef silu_and_mul_fwd(input):\n stride_input_m = input.stride(0)\n stride_input_n = input.stride(1)\n stride_output_m = input.stride(0)\n stride_output_n = input.stride(1)\n size_m = input.shape[0]\n size_n = input.shape[-1] // 2\n BLOCK_M = 128\n BLOCK_N = 128\n grid = (\n triton.cdiv(size_m, BLOCK_M),\n triton.cdiv(size_n, BLOCK_N),\n )\n _silu_and_mul_kernel[grid](\n input,\n stride_input_m,\n stride_input_n,\n stride_output_m,\n stride_output_n,\n size_m,\n size_n,\n BLOCK_M,\n BLOCK_N,\n )\n return input[:, 0 : (input.shape[-1] // 2)]\n", - "description_1": "Use triton language to implement a kernel function '_silu_and_mul_kernel' that performs element-wise SiLU (Sigmoid Linear Unit) activation followed by multiplication on a 2D input tensor. The kernel takes 8 parameters: input_ptr (pointer to input data), stride_input_m (stride for input rows), stride_input_n (stride for input columns), stride_output_m (stride for output rows), stride_output_n (stride for output columns), size_m (number of rows), size_n (number of columns), and two block sizes BLOCK_M and BLOCK_N for tiling. The function 'silu_and_mul_fwd' prepares the input tensor and launches the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a kernel that applies SiLU activation and multiplication on a 2D tensor, and a function to launch this kernel with specified grid and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef matrix_mult(x, y, B):\n return tl.dot(x, y) if B >= 16 else tl.sum(x[:, :, None] * y, 1)\n\n@triton.jit\ndef sign(x):\n return (x > 0).to(tl.float32) - (x < 0).to(tl.float32)\n\n@triton.jit\ndef scan_add_op(x1, x2):\n return x1 + x2\n\n@triton.jit\ndef mlstm_matmul_kernel(Q, K, V, F, I, M, B, H, NH: tl.constexpr, S: tl.constexpr, D: tl.constexpr, SB: tl.constexpr):\n bh_id = tl.program_id(0)\n sb_id = tl.program_id(1)\n\n batch_id = bh_id // NH\n head_id = bh_id % NH\n\n batch_offset_q = batch_id * NH * S * D + head_id * S * D\n batch_offset_f = batch_id * NH * S + head_id * S\n offset_q = tl.arange(0, SB) + sb_id * SB\n offset_k = tl.arange(0, SB) + sb_id * SB\n d_range = tl.arange(0, D)\n\n q_range = batch_offset_q + offset_q[:, None] * D + d_range[None, :]\n q_mask = (offset_q[:, None] < S) & (d_range[None, :] < D)\n q = tl.load(Q + q_range, q_mask)\n f = tl.load(F + batch_offset_f + offset_q, offset_q < S)\n f = tl.cumsum(tl.log(tl.sigmoid(f)))\n\n c_acc = tl.zeros((SB, D), dtype=tl.float32)\n b_acc = tl.zeros((SB,), dtype=tl.float32)\n m_acc = tl.zeros((SB,), dtype=tl.float32) - float(\"inf\")\n for j in range(sb_id, -1, -1):\n kv_range = batch_offset_q + offset_k[:, None] * D + d_range[None, :]\n kv_mask = (offset_k[:, None] < S) & (d_range[None, :] < D)\n k = tl.load(K + kv_range, kv_mask) / tl.sqrt(tl.full((1,), D, dtype=tl.float32))\n v = tl.load(V + kv_range, kv_mask)\n f_next = tl.load(F + batch_offset_f + offset_k, offset_k < S)\n i = tl.load(I + batch_offset_f + offset_k, offset_k < S)\n\n f_next = tl.log(tl.sigmoid(f_next))\n if j == sb_id:\n f_next = tl.cumsum(f_next)\n d = f[:, None] - f_next[None, :] + i[None, :]\n mask = offset_q[:, None] >= offset_k[None, :]\n d = tl.where(mask, d, -float(\"inf\"))\n else:\n f += tl.sum(f_next)\n f_next = tl.cumsum(f_next)\n d = f[:, None] - f_next[None, :] + i[None, :]\n\n m = tl.maximum(tl.max(d, 1), m_acc)\n d = tl.exp(d - m[:, None])\n\n c = matrix_mult(q, tl.trans(k), SB) * d\n b_acc = b_acc * tl.exp(m_acc - m) + tl.sum(c, 1)\n c = matrix_mult(c, v, SB)\n c_acc = c_acc * tl.exp(m_acc - m)[:, None] + c\n\n m_acc = m\n offset_k -= SB\n\n n = tl.maximum(tl.abs(b_acc), tl.exp(-m_acc)) + 1e-6\n h = c_acc / n[:, None]\n\n tl.store(H + q_range, h, q_mask)\n tl.store(B + batch_offset_f + offset_q, b_acc, offset_q < S)\n tl.store(M + batch_offset_f + offset_q, m_acc, offset_q < S)\n\ndef mlstm_matmul(q, k, v, f, i, SB=16, num_warps=8):\n B, NH, S, D = q.shape\n h = torch.zeros((B, NH, S, D), device=q.device)\n m = torch.zeros((B, NH, S), device=q.device)\n b = torch.zeros((B, NH, S), device=q.device)\n\n grid = (B * NH, triton.cdiv(S, SB))\n mlstm_matmul_kernel[grid](q, k, v, f, i, m, b, h, NH, S, D, SB, num_warps=num_warps)\n return h\n\n@triton.jit\ndef mlstm_matmul_kernel_backward_db(dH, Q, K, V, F, I, M, B, dB,\n NH: tl.constexpr,\n S: tl.constexpr,\n D: tl.constexpr,\n SB: tl.constexpr):\n bh_id = tl.program_id(0)\n sb_id = tl.program_id(1)\n\n batch_id = bh_id // NH\n head_id = bh_id % NH\n\n batch_offset_dh = batch_id * NH * S * D + head_id * S * D\n batch_offset_f = batch_id * NH * S + head_id * S\n offset_dh = tl.arange(0, SB) + sb_id * SB\n offset_vk = tl.arange(0, SB) + sb_id * SB\n d_range = tl.arange(0, D)\n\n dh_range = batch_offset_dh + offset_dh[:, None] * D + d_range[None, :]\n dh_mask = (offset_dh[:, None] < S) & (d_range[None, :] < D)\n dh = tl.load(dH + dh_range, dh_mask)\n q = tl.load(Q + dh_range, dh_mask)\n m = tl.load(M + batch_offset_f + offset_dh, offset_dh < S)\n f = tl.load(F + batch_offset_f + offset_dh, offset_dh < S)\n f = tl.cumsum(tl.log(tl.sigmoid(f)))\n scale = tl.sqrt(tl.full((1,), D, dtype=tl.float32))\n\n dn_acc = tl.zeros((SB,), dtype=tl.float32)\n for j in range(sb_id, -1, -1):\n vk_range = batch_offset_dh + offset_vk[:, None] * D + d_range[None, :]\n vk_mask = (offset_vk[:, None] < S) & (d_range[None, :] < D)\n v = tl.load(V + vk_range, vk_mask)\n f_next = tl.load(F + batch_offset_f + offset_vk, offset_vk < S)\n i = tl.load(I + batch_offset_f + offset_vk, offset_vk < S)\n\n f_next = tl.log(tl.sigmoid(f_next))\n if j == sb_id:\n f_next = tl.cumsum(f_next)\n d = f[:, None] - f_next[None, :] + i[None, :]\n mask = offset_dh[:, None] >= offset_vk[None, :]\n d = tl.where(mask, d, -float('inf'))\n else:\n f += tl.sum(f_next)\n f_next = tl.cumsum(f_next)\n d = f[:, None] - f_next[None, :] + i[None, :]\n\n d = tl.exp(d - m[:, None])\n dc = matrix_mult(dh, tl.trans(v), SB)\n\n k = tl.load(K + vk_range, vk_mask) / scale\n c_tilde = matrix_mult(q, tl.trans(k), SB) * d\n dn_acc += tl.sum(c_tilde * dc, 1)\n\n offset_vk -= SB\n\n b = tl.load(B + batch_offset_f + offset_dh, offset_dh < S)\n n = tl.maximum(tl.abs(b), tl.exp(-m)) + 1e-6\n dn = -dn_acc * (1 / tl.exp(tl.log(n) * 2.0))\n db = sign(b) * dn * tl.where(tl.abs(b) > tl.exp(-m), 1.0, 0.0)\n tl.store(dB + batch_offset_f + offset_dh, db, offset_dh < S)\n\n@triton.jit\ndef mlstm_matmul_kernel_backward(dH, dB, Q, K, V, dQ, dK, dV, F, dF, I, dI, M, B,\n NH: tl.constexpr,\n S: tl.constexpr,\n D: tl.constexpr,\n SB: tl.constexpr):\n bh_id = tl.program_id(0)\n sb_id = tl.program_id(1)\n\n batch_id = bh_id // NH\n head_id = bh_id % NH\n\n batch_offset_dh = batch_id * NH * S * D + head_id * S * D\n batch_offset_f = batch_id * NH * S + head_id * S\n offset_dh = tl.arange(0, SB) + sb_id * SB\n offset_vk = tl.arange(0, SB) + sb_id * SB\n d_range = tl.arange(0, D)\n\n dh_range = batch_offset_dh + offset_dh[:, None] * D + d_range[None, :]\n dh_mask = (offset_dh[:, None] < S) & (d_range[None, :] < D)\n dh = tl.load(dH + dh_range, dh_mask)\n m = tl.load(M + batch_offset_f + offset_dh, offset_dh < S)\n b = tl.load(B + batch_offset_f + offset_dh, offset_dh < S)\n f = tl.load(F + batch_offset_f + offset_dh, offset_dh < S)\n db = tl.load(dB + batch_offset_f + offset_dh, offset_dh < S)\n\n q = tl.load(Q + dh_range, dh_mask)\n scale = tl.sqrt(tl.full((1,), D, dtype=tl.float32))\n n = tl.maximum(tl.abs(b), tl.exp(-m)) + 1e-6\n f = tl.cumsum(tl.log(tl.sigmoid(f)))\n f_low = f\n\n df_acc = tl.zeros((SB,), dtype=tl.float32)\n dq_acc = tl.zeros((SB, D), dtype=tl.float32)\n for j in range(sb_id, -1, -1):\n vk_range = batch_offset_dh + offset_vk[:, None] * D + d_range[None, :]\n vk_mask = (offset_vk[:, None] < S) & (d_range[None, :] < D)\n f_next = tl.load(F + batch_offset_f + offset_vk, offset_vk < S)\n i = tl.load(I + batch_offset_f + offset_vk, offset_vk < S)\n\n f_next = tl.log(tl.sigmoid(f_next))\n if j == sb_id:\n f_next = tl.cumsum(f_next)\n d = f[:, None] - f_next[None, :] + i[None, :]\n mask = offset_dh[:, None] >= offset_vk[None, :]\n d = tl.where(mask, d, -float('inf'))\n else:\n f += tl.sum(f_next)\n f_next = tl.cumsum(f_next)\n d = f[:, None] - f_next[None, :] + i[None, :]\n\n d = tl.exp(d - m[:, None])\n v = tl.load(V + vk_range, vk_mask)\n dc_tilde = matrix_mult(dh, tl.trans(v), SB) * (1 / n)[:, None] + db[:, None]\n\n k = tl.load(K + vk_range, vk_mask) / scale\n dq_acc += matrix_mult(dc_tilde * d, k, SB)\n c_tilde = matrix_mult(q, tl.trans(k), SB) * d\n df_acc += tl.sum(c_tilde * dc_tilde, 1)\n\n offset_vk -= SB\n\n tl.store(dQ + dh_range, dq_acc, dh_mask)\n\n offset_q = tl.arange(0, SB) + sb_id * SB\n f = tl.zeros((1,), dtype=tl.float32)\n\n v = tl.load(V + dh_range, dh_mask)\n k = tl.load(K + dh_range, dh_mask)\n i = tl.load(I + batch_offset_f + offset_dh, offset_dh < S)\n\n dk_acc = tl.zeros((SB, D), dtype=tl.float32)\n dv_acc = tl.zeros((SB, D), dtype=tl.float32)\n di_acc = tl.zeros((SB,), dtype=tl.float32)\n for j in range(sb_id, tl.cdiv(S, SB)):\n q_range = batch_offset_dh + offset_q[:, None] * D + d_range[None, :]\n q_mask = (offset_q[:, None] < S) & (d_range[None, :] < D)\n f_next = tl.load(F + batch_offset_f + offset_q, offset_q < S)\n\n f_next = tl.log(tl.sigmoid(f_next))\n f_next_sum = tl.sum(f_next)\n f_next = f + tl.cumsum(f_next)\n d = f_next[None, :] - f_low[:, None] + i[:, None]\n f += f_next_sum\n\n if j == sb_id:\n mask = offset_dh[:, None] <= offset_q[None, :]\n d = tl.where(mask, d, -float('inf'))\n\n dh = tl.load(dH + q_range, q_mask)\n m = tl.load(M + batch_offset_f + offset_q, offset_q < S)\n b = tl.load(B + batch_offset_f + offset_q, offset_q < S)\n db = tl.load(dB + batch_offset_f + offset_q, offset_q < S)\n\n d = tl.exp(d - m[None, :])\n n = tl.maximum(tl.abs(b), tl.exp(-m)) + 1e-6\n dc_tilde_T = matrix_mult(v, tl.trans(dh), SB) * (1 / n)[None, :] + db[None, :]\n\n q = tl.load(Q + q_range, q_mask) / scale\n dk_acc += matrix_mult(dc_tilde_T * d, q, SB)\n\n c_tilde_T = matrix_mult(k, tl.trans(q), SB) * d\n dv_acc += matrix_mult(c_tilde_T / n[None, :], dh, SB)\n di_acc += tl.sum(c_tilde_T * dc_tilde_T, 1)\n\n offset_q += SB\n\n tl.store(dK + dh_range, dk_acc, dh_mask)\n tl.store(dV + dh_range, dv_acc, dh_mask)\n tl.store(dI + batch_offset_f + offset_dh, di_acc, offset_dh < S)\n tl.store(dF + batch_offset_f + offset_dh + 1, di_acc - df_acc, (offset_dh + 1) < S)\n\n@triton.jit\ndef mlstm_matmul_kernel_df(dF, F, NH: tl.constexpr, S: tl.constexpr):\n bh_id = tl.program_id(0)\n batch_id = bh_id // NH\n head_id = bh_id % NH\n\n batch_offset_f = batch_id * NH * S + head_id * S\n offset_f = tl.arange(0, S)\n\n df = tl.load(dF + batch_offset_f + offset_f, offset_f < S)\n df = tl.associative_scan(df, 0, scan_add_op)\n\n f = tl.load(F + batch_offset_f + offset_f, offset_f < S)\n df = tl.sigmoid(-f) * df\n tl.store(dF + batch_offset_f + offset_f, df, offset_f < S)\n\nclass Triton_mLSTM(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, f, i, SB=16, num_warps=8):\n B, NH, S, D = q.shape\n h = torch.zeros((B, NH, S, D), device=q.device)\n m = torch.zeros((B, NH, S), device=q.device)\n b = torch.zeros((B, NH, S), device=q.device)\n\n grid = (B * NH, triton.cdiv(S, SB))\n mlstm_matmul_kernel[grid](q, k, v, f, i, m, b, h, NH, S, D, SB, num_warps=num_warps)\n ctx.save_for_backward(q, k, v, f, i, m, b)\n ctx.sb = SB\n return h\n\n @staticmethod\n def backward(ctx, dh):\n assert dh.is_contiguous()\n q, k, v, f, i, m, b = ctx.saved_tensors\n SB = ctx.sb\n\n dq = torch.zeros_like(q)\n dk = torch.zeros_like(k)\n dv = torch.zeros_like(v)\n df = torch.zeros_like(f)\n di = torch.zeros_like(i)\n db = torch.zeros_like(b)\n\n B, NH, S, D = q.shape\n\n batches = B * NH\n grid = (batches, triton.cdiv(S, SB))\n num_warps = 8\n mlstm_matmul_kernel_backward_db[grid](dh, q, k, v, f, i, m, b, db, NH, S, D, SB, num_warps=num_warps)\n mlstm_matmul_kernel_backward[grid](dh, db, q, k, v, dq, dk, dv, f, df, i, di, m, b, NH, S, D, SB, num_warps=num_warps)\n mlstm_matmul_kernel_df[(batches,)](df, f, NH, S, num_warps=num_warps)\n\n return dq, dk, dv, df, di, None, None\n\nif __name__ == '__main__':\n BATCH = 1\n HEADS = 4\n S = 2048\n D = 64\n SB = 32\n NUM_WARPS = 4\n\n q = torch.randn((BATCH, HEADS, S, D), device=DEVICE, dtype=torch.float32, requires_grad=True)\n k = torch.randn((BATCH, HEADS, S, D), device=DEVICE, dtype=torch.float32, requires_grad=True)\n v = torch.randn((BATCH, HEADS, S, D), device=DEVICE, dtype=torch.float32, requires_grad=True)\n f = torch.randn((BATCH, HEADS, S), device=DEVICE, dtype=torch.float32, requires_grad=True)\n i = torch.randn((BATCH, HEADS, S), device=DEVICE, dtype=torch.float32, requires_grad=True)\n dh = torch.randn((BATCH, HEADS, S, D), device=DEVICE, dtype=torch.float32)\n\n h_triton = Triton_mLSTM.apply(q, k, v, f, i, SB, NUM_WARPS)\n", - "description_1": "Use triton language to define and execute multiple kernels for matrix operations, sign calculation, and a specific multi-head LSTM-like attention mechanism. This includes a kernel for forward computation (mlstm_matmul_kernel) taking 12 arguments where the main matrices involved are Q, K, V for queries, keys, and values. A backward computation is facilitated with separate kernels (mlstm_matmul_kernel_backward_db, mlstm_matmul_kernel_backward, and mlstm_matmul_kernel_df) managing gradient calculations with 9 to 13 arguments each, handling gradients for matrices such as dQ, dK, dV and additional computations for gradients of biases and transformations.", - "description_2": "Use triton language to implement and manage complex multi-head matrix operations, including both forward and backward passes for a custom LSTM-like attention mechanism using matrix multiplication and gradient backpropagation across multiple Triton kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom vllm.model_executor.layers.ops.sample import _uniform_to_exponential\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n # This kernel function converts uniform distribution values to exponential.\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test function to validate conversion from uniform to exponential.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n # Launches the triton kernel with a grid size of 1.\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel function _uniform_to_exponential_kernel with 3 parameters: input (the source tensor), output (the target tensor for storing results), and n (an integer constant representing the number of elements to process). This kernel function converts uniform distribution values to exponential distribution using a custom function _uniform_to_exponential. A test function, test_uniform_to_exponential, launches the kernel to validate the conversion by checking for finite and positive results.", - "description_2": "Use triton language to create a kernel that transforms uniform values to exponential distribution, and validate the transformation using a test function.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Triton kernel for forward attention computation\n\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n # # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Triton kernel for forward attention with Alibi\n\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // num_queries_per_kv\n\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n # # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = 0\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n # load alibi\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n # -- update output accumulator --\n # scale p\n # scale acc\n acc_scale = alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n # init alibi\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = cur_batch_ctx_len\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, allow_tf32=False)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n # load alibi\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n # -- update output accumulator --\n # scale p\n # scale acc\n acc_scale = alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n acc = acc / l_i[:, None]\n\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n alibi_slopes=None):\n # Wrapper function for forward attention computation\n\n cap = torch.cuda.get_device_capability()\n BLOCK = 128 if cap[0] >= 8 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 8 if Lk <= 64 else 8\n if alibi_slopes is not None:\n _fwd_kernel_alibi[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n alibi_slopes,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4\n ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3), #[num_blocks, num_kv_heads, head_size, block_size]\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3), #[num_blocks, num_kv_heads, head_size, block_size]\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement multiple forward attention kernels. Each kernel has its own logic for attention calculation, supporting different scenarios like using Alibi. The kernels are optimized for execution on a grid of blocks. Each function has a variety of parameters: Q, K, V for query, key, and value matrices; K_cache, V_cache for cached keys and values; B_Loc, B_Start_Loc, B_Seqlen, B_Ctxlen for batching and sequencing information; strides for memory layout; num_queries_per_kv for queries per key-value head count; BLOCK_M, BLOCK_DMODEL, BLOCK_N for block dimensions.", - "description_2": "Implement attention forward pass kernels using Triton for high-performance GPU execution. Support advanced use cases such as Alibi by handling queries, keys, values, and their caches, adjusting for batching and different head configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n N,\n K,\n EM,\n num_valid_tokens,\n stride_am,\n stride_ak,\n stride_be,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n compute_type: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(compute_type)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any]) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel performs matrix multiplication between input tokens and expert matrices, using top-k routing weights. It handles token sorting and padding to ensure block size alignment for efficient computation. The kernel is invoked with a grid configuration that determines the block sizes and other meta-parameters.", - "description_2": "Use triton language to create a kernel for efficient matrix multiplication in a Mixture of Experts model, utilizing top-k routing and block size alignment.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output\n tensor. The random numbers in a row generated using the seed for that row.\n\n Args:\n out_ptr: The output tensor.\n seed_ptr: The per-row seeds to use for random number generation.\n out_row_stride: The stride between rows of the output tensor.\n out_3d_stride: The stride between 3D slices of the output tensor.\n seed_row_stride: The stride between rows of the seed tensor.\n n_rows: The number of rows in the output tensor.\n n_3d: The size of second dimension of the output tensor,\n if output tensor is 3D.\n n_cols: The number of columns in the output tensor.\n n_slices: The number of philox outputs to use.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a kernel that generates random float32 numbers in [0, 1) using a per-row seed. The kernel takes an output tensor, a seed tensor, strides for rows and 3D slices, the number of rows, columns, and 3D size, and constants for block size and number of slices. The wrapper function `seeded_uniform` initializes parameters, calculates block sizes, and calls the kernel.", - "description_2": "Use triton language to create a random number generator kernel using per-row seeds, along with a Python wrapper for parameter setup and execution.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS = 1e-6\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n col_offsets = tl.arange(0, block_size)\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n if modify_greedy_probs:\n if not uses_random_sampling:\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n\ndef _sample(probs: torch.Tensor,\n logprobs: torch.Tensor,\n sample_indices: torch.Tensor,\n output_samples: torch.Tensor,\n output_logprobs: torch.Tensor,\n output_modified_probs: torch.Tensor,\n seeds: torch.Tensor,\n uniform_noise: torch.Tensor,\n *,\n modify_greedy_probs: bool = False,\n save_logprobs: bool = True,\n save_modified_probs: bool = False) -> torch.Tensor:\n n_samples = sample_indices.shape[0]\n n_cols = probs.shape[1]\n n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1\n block_size = triton.next_power_of_2(n_cols)\n num_warps = 4\n if block_size >= 8192:\n num_warps = 32\n elif block_size >= 4096:\n num_warps = 16\n elif block_size >= 2048:\n num_warps = 8\n _sample_triton[(n_samples, n_best)](\n sample_indices,\n output_samples,\n output_logprobs,\n output_modified_probs,\n probs,\n logprobs,\n seeds,\n uniform_noise,\n output_samples.stride(0),\n probs.stride(0),\n uniform_noise.stride(0),\n uniform_noise.stride(1) if n_best > 1 else 1,\n n_samples,\n n_cols,\n n_best,\n num_warps=num_warps,\n block_size=block_size,\n modify_greedy_probs=modify_greedy_probs,\n save_logprobs=save_logprobs,\n save_modified_probs=save_modified_probs,\n )\n return output_samples, output_logprobs, output_modified_probs\n", - "description_1": "Use triton language to implement two kernels: `_uniform_to_exponential` and `_sample_triton`. `_uniform_to_exponential` takes one parameter `uniform_noise` and converts uniform samples to exponential samples using a clamp and logarithm. `_sample_triton` involves sampling tokens given several parameters: `sample_indices_ptr`, `output_ptr`, `output_logprobs_ptr`, `output_modified_probs_ptr`, `probs_ptr`, `logprobs_ptr`, `seeds_ptr`, `uniform_noise_ptr`, and various strides and control flags. It reads and processes probability rows and applies random sampling with Gumbel noise if needed.", - "description_2": "Use triton language to create a kernel that samples from probability distributions using Gumbel noise. Implement another kernel to convert uniform random noise to exponential noise.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom .triton_utils.kernels import silu\n\n@triton.jit\ndef quant_fused_matmul_248_kernel(\n a_ptr,\n c_ptr,\n b1_ptr,\n scales1_ptr,\n zeros1_ptr,\n g1_ptr,\n b2_ptr,\n scales2_ptr,\n zeros2_ptr,\n g2_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = zeros1 + 1\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = zeros2 + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\nclass FusedLlamaMLPForQuantizedModel:\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size,)\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n quant_fused_matmul_248_kernel[grid](\n x,\n c,\n self.gate_proj.qweight,\n self.gate_proj.scales,\n self.gate_proj.qzeros,\n self.gate_proj.g_idx,\n self.up_proj.qweight,\n self.up_proj.scales,\n self.up_proj.qzeros,\n self.up_proj.g_idx,\n M,\n N,\n K,\n self.bits,\n self.maxq,\n x.stride(0),\n x.stride(1),\n self.gate_proj.qweight.stride(0),\n self.gate_proj.qweight.stride(1),\n c.stride(0),\n c.stride(1),\n self.gate_proj.scales.stride(0),\n self.gate_proj.qzeros.stride(0),\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a kernel called quant_fused_matmul_248_kernel for a fused matrix multiplication operation. The kernel has 22 parameters, including pointers to input and output matrices, dimensions, bitwidth and some constants. It computes the product of two transformed matrices A and B1, applies the SiLU activation function, and multiplies with another product of matrices A and B2.", - "description_2": "Use triton language to implement a fused MLP model with a method that prepares inputs and launches the quant_fused_matmul_248_kernel, passing 22 arguments including matrix pointers, dimensions, and constants for the computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef dequant_kernel_248(\n g_idx_ptr,\n scales_ptr,\n qweight_ptr,\n qzeros_ptr,\n out_ptr,\n numels,\n maxq: tl.constexpr,\n bits: tl.constexpr,\n outfeatures: tl.constexpr,\n num_groups: tl.constexpr,\n X_BLOCK: tl.constexpr,\n):\n # Block indexing\n xoffset = tl.program_id(0) * X_BLOCK\n x_index = xoffset + tl.arange(0, X_BLOCK)\n xmask = x_index < numels\n row_idx = x_index // outfeatures\n col_idx = x_index % outfeatures\n\n elements_per_feature: tl.constexpr = 32 // bits\n\n # Load parameters\n g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy=\"evict_last\")\n qweights = tl.load(\n qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),\n None,\n )\n\n wf_weights = (row_idx % elements_per_feature) * bits\n\n wf_zeros = (col_idx % elements_per_feature) * bits\n\n tmp1 = g_idx + num_groups\n tmp2 = g_idx < 0\n tl.device_assert(g_idx >= 0, \"index out of bounds: 0 <= tmp0 < 0\")\n groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx\n\n scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(\n tl.float32\n )\n\n # Unpack weights\n weights = qweights >> wf_weights # bit shift qweight\n\n weights = weights & maxq\n\n # Unpack zeros\n qzero_ncols: tl.constexpr = outfeatures // elements_per_feature\n qzeros = tl.load(\n qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),\n None,\n eviction_policy=\"evict_last\",\n )\n zeros = qzeros >> wf_zeros\n zeros = zeros & maxq\n\n # Dequantize\n zeros = zeros + 1\n weights = weights - zeros\n weights = weights.to(tl.float32)\n weights = scales * weights\n\n tl.store(out_ptr + (x_index), weights, mask=xmask)\n\n\ndef dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):\n \"\"\"\n Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8\n \"\"\"\n\n num_groups = scales.shape[0]\n outfeatures = scales.shape[1]\n infeatures = g_idx.shape[0]\n\n out = torch.empty((infeatures, outfeatures), device=\"cuda\", dtype=torch.float16)\n numels = out.numel()\n maxq = 2**bits - 1 if maxq is None else maxq\n grid = lambda meta: (triton.cdiv(numels, meta[\"X_BLOCK\"]),) # noqa: E731\n\n dequant_kernel_248[grid](\n g_idx,\n scales,\n qweight,\n qzeros,\n out,\n numels,\n maxq=maxq,\n bits=bits,\n outfeatures=outfeatures,\n num_groups=num_groups,\n )\n return out\n", - "description_1": "Use triton language to implement a dequantization kernel (dequant_kernel_248) that processes quantized weights, scales, and zero points to produce dequantized weights. The kernel takes 11 parameters: pointers to group indices, scales, quantized weights, zero points, and output, along with the number of elements, maximum quantization value, bit width, number of output features, number of groups, and block size. The dequant248 function launches this kernel with 7 parameters: quantized weights, scales, zero points, group indices, bit width, and optionally maximum quantization value, to produce a dequantized output tensor.", - "description_2": "Use triton language to create a kernel for dequantizing weights from quantized format using scales and zero points, and a function to launch this kernel with appropriate parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef quant_matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_quant_matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)\n grid = lambda META: ( # noqa: E731\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n quant_matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales.to(input.dtype),\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n input.shape[1],\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n\n\ndef transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)\n grid = lambda META: ( # noqa: E731\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(output_dim, META[\"BLOCK_SIZE_K\"]),\n )\n transpose_quant_matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales.to(input.dtype),\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n output_dim,\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n", - "description_1": "Use triton language to implement a quantized matrix multiplication kernel and its transpose variant. Both kernels perform matrix multiplications with matrices of specified dimensions and types (float16, int32) and apply quantization during computation. The implementation involves setting up data pointers, loading necessary values like scales and zeros, computing matrix products while handling bit-shifting, and storing results. The auxiliary Python functions serve as wrappers to facilitate the grid launch of these kernels.", - "description_2": "Use triton language to create a matrix multiplication kernel that includes quantization and a transposed version. The kernels handle matrix dimensions, apply scale and zero point adjustments, and manage data indexing for matrix product computation.", - "difficulty": 3 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Bias,\n Out,\n Lse,\n TMP,\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_bb,\n stride_bh,\n stride_bm,\n stride_ob,\n stride_oh,\n stride_om,\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n headdim,\n CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = (\n Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n )\n k_ptrs = (\n K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n )\n v_ptrs = (\n V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n )\n if BIAS_TYPE == \"vector\":\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == \"matrix\":\n b_ptrs = (\n Bias\n + off_b * stride_bb\n + off_h * stride_bh\n + (offs_m[:, None] * stride_bm + offs_n[None, :])\n )\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(\n q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0\n )\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0,\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != \"none\":\n if BIAS_TYPE == \"vector\":\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(\n b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0\n ).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == \"matrix\":\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(\n b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((start_n + offs_n)[None, :] < seqlen_k),\n other=0.0,\n ).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0,\n )\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = (\n Out\n + off_b * stride_ob\n + off_h * stride_oh\n + (offs_m[:, None] * stride_om + offs_d[None, :])\n )\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(\n out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)\n )\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = \"none\"\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = \"vector\"\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = \"matrix\"\n else:\n raise RuntimeError(\n \"Last 2 dimensions of bias must be (1, seqlen_k)\" \" or (seqlen_q, seqlen_k)\"\n )\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q,\n k,\n v,\n bias,\n o,\n lse,\n tmp,\n softmax_scale,\n q.stride(0),\n q.stride(2),\n q.stride(1),\n k.stride(0),\n k.stride(2),\n k.stride(1),\n v.stride(0),\n v.stride(2),\n v.stride(1),\n *bias_strides,\n o.stride(0),\n o.stride(2),\n o.stride(1),\n nheads,\n seqlen_q,\n seqlen_k,\n seqlen_q_rounded,\n d,\n seqlen_q // 32,\n seqlen_k // 32,\n bias_type,\n causal,\n BLOCK_HEADDIM,\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, softmax_scale\n", - "description_1": "Use triton language to implement a forward kernel (_fwd_kernel) for the FlashAttention mechanism. This function accepts 36 parameters: Q, K, V (tensors for query, key, and value respectively), Bias (optional tensor for bias), Out (output tensor), Lse (log-sum-exp tensor), TMP (temporary buffer), softmax_scale (scalar for softmax scaling), and multiple stride values for navigating tensors in memory. Additionally, the function receives dimensions for number of heads, sequence lengths, and head dimensions as well as cache keys for sequence lengths, and several constexpr parameters for controlling behavior like bias type and causal masking. The kernel is responsible for loading slices of Q, K, and V, computing attention weights, applying softmax, and updating the output and log-sum-exp tensors.", - "description_2": "Use triton language to implement the FlashAttention forward pass function (_flash_attn_forward) which calls a Triton kernel to perform matrix multiplications and apply attention mechanism efficiently on GPU. This function takes 6 parameters: q, k, v (query, key, value tensors), bias (optional), causal (boolean for causal masking), and softmax_scale (scaling factor for softmax). It asserts constraints on input shapes, types, and device compatibility, prepares the necessary output buffers, sets the kernel execution grid configuration, and invokes the Triton kernel with appropriate meta-parameters.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, M_in, Lse_in, O_in,\n Lse, M_out, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == 'vector':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == 'matrix':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lin_ptrs = Lse_in + off_hb * seqlen_q_rounded + offs_m\n acc_o_ptrs = O_in + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n lse_i = tl.load(lin_ptrs)\n m_ptrs = M_in + off_hb * seqlen_q_rounded + offs_m\n m_i = tl.load(m_ptrs)\n acc_o = tl.load(acc_o_ptrs)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0)\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != 'none':\n if BIAS_TYPE == 'vector':\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == 'matrix':\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((start_n + offs_n)[None, :] < seqlen_k),\n other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n m_ptrs = M_out + off_hb * seqlen_q_rounded + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o,\n mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n,\n Q, K, V, Bias,\n DO, DQ, DK, DV,\n LSE, D,\n softmax_scale,\n stride_qm, stride_kn, stride_vn, stride_bm,\n stride_dom, stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim,\n ATOMIC_ADD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M\n offs_qm = begin_m + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n offs_m = tl.arange(0, BLOCK_M)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])\n do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])\n dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])\n if BIAS_TYPE == 'vector':\n b_ptrs = Bias + offs_n\n elif BIAS_TYPE == 'matrix':\n b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])\n dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)\n dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)\n if begin_m >= seqlen_q:\n dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])\n dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])\n _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,\n EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)\n return\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n else:\n k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)\n v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)\n else:\n k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n num_block_m = tl.cdiv(seqlen_q, BLOCK_M)\n for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):\n start_m = tl.multiple_of(start_m, BLOCK_M)\n offs_m_curr = start_m + offs_m\n if EVEN_M & EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)\n & (offs_d[None, :] < headdim), other=0.0)\n qk = tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk = tl.where(offs_n[None, :] < seqlen_k, qk, float(\"-inf\"))\n if IS_CAUSAL:\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n if BIAS_TYPE != 'none':\n tl.debug_barrier()\n if BIAS_TYPE == 'vector':\n if EVEN_N:\n bias = tl.load(b_ptrs).to(tl.float32)\n else:\n bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == 'matrix':\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs).to(tl.float32)\n else:\n bias = tl.load(b_ptrs,\n mask=(offs_m_curr[:, None] < seqlen_q)\n & (offs_n[None, :] < seqlen_k),\n other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n if not (EVEN_M & EVEN_HEADDIM):\n tl.debug_barrier()\n lse_i = tl.load(LSE + offs_m_curr)\n if BIAS_TYPE == 'none':\n p = tl.exp(qk * softmax_scale - lse_i[:, None])\n else:\n p = tl.exp(qk - lse_i[:, None])\n if EVEN_M & EVEN_HEADDIM:\n do = tl.load(do_ptrs)\n else:\n do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)\n & (offs_d[None, :] < headdim), other=0.0)\n dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n if not (EVEN_M & EVEN_HEADDIM):\n tl.debug_barrier()\n dp = tl.dot(do, v, trans_b=True)\n if not EVEN_HEADDIM:\n tl.debug_barrier()\n Di = tl.load(D + offs_m_curr)\n ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)\n dk += tl.dot(ds, q, trans_a=True)\n if not (EVEN_M & EVEN_HEADDIM):\n tl.debug_barrier()\n if not ATOMIC_ADD:\n if EVEN_M & EVEN_HEADDIM:\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds, k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n else:\n if EVEN_HEADDIM:\n dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,\n eviction_policy=\"evict_last\")\n dq += tl.dot(ds, k)\n tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,\n eviction_policy=\"evict_last\")\n else:\n dq = tl.load(dq_ptrs,\n mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0, eviction_policy=\"evict_last\")\n dq += tl.dot(ds, k)\n tl.store(dq_ptrs, dq,\n mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n eviction_policy=\"evict_last\")\n else:\n dq = tl.dot(ds, k)\n if EVEN_M & EVEN_HEADDIM:\n tl.atomic_add(dq_ptrs, dq)\n else:\n if EVEN_HEADDIM:\n tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)\n else:\n tl.atomic_add(dq_ptrs, dq,\n mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n dq_ptrs += BLOCK_M * stride_dqm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_dom\n if BIAS_TYPE == 'matrix':\n b_ptrs += BLOCK_M * stride_bm\n dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])\n dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])\n _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,\n EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),\n ],\n key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, Bias,\n DO, DQ, DK, DV,\n LSE, D,\n softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_dob, stride_doh, stride_dom,\n stride_dqb, stride_dqh, stride_dqm,\n stride_dkb, stride_dkh, stride_dkn,\n stride_dvb, stride_dvh, stride_dvn,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n SEQUENCE_PARALLEL: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n Q += off_b * stride_qb + off_h * stride_qh\n K += off_b * stride_kb + off_h * stride_kh\n V += off_b * stride_vb + off_h * stride_vh\n DO += off_b * stride_dob + off_h * stride_doh\n DQ += off_b * stride_dqb + off_h * stride_dqh\n DK += off_b * stride_dkb + off_h * stride_dkh\n DV += off_b * stride_dvb + off_h * stride_dvh\n if BIAS_TYPE != 'none':\n Bias += off_b * stride_bb + off_h * stride_bh\n D += off_hb * seqlen_q_rounded\n LSE += off_hb * seqlen_q_rounded\n if not SEQUENCE_PARALLEL:\n num_block_n = tl.cdiv(seqlen_k, BLOCK_N)\n for start_n in range(0, num_block_n):\n _bwd_kernel_one_col_block(\n start_n,\n Q, K, V, Bias,\n DO, DQ, DK, DV,\n LSE, D,\n softmax_scale,\n stride_qm, stride_kn, stride_vn, stride_bm,\n stride_dom, stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim,\n ATOMIC_ADD=False,\n BIAS_TYPE=BIAS_TYPE,\n IS_CAUSAL=IS_CAUSAL,\n BLOCK_HEADDIM=BLOCK_HEADDIM,\n EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N\n )\n else:\n start_n = tl.program_id(0)\n _bwd_kernel_one_col_block(\n start_n,\n Q, K, V, Bias,\n DO, DQ, DK, DV,\n LSE, D,\n softmax_scale,\n stride_qm, stride_kn, stride_vn, stride_bm,\n stride_dom, stride_dqm, stride_dkn, stride_dvn,\n seqlen_q, seqlen_k, headdim,\n ATOMIC_ADD=True,\n BIAS_TYPE=BIAS_TYPE,\n IS_CAUSAL=IS_CAUSAL,\n BLOCK_HEADDIM=BLOCK_HEADDIM,\n EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N\n )\n\n\ndef _flash_attn_forward(q, k, v, prev_m, prev_lse, prev_o, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert prev_m.shape == (batch, nheads, seqlen_k)\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n bias = bias.transpose(0,1).contiguous()\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'\n ' or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n m = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q, k, v, bias, o, prev_m, prev_lse, prev_o,\n lse, m, tmp,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n o.stride(0), o.stride(2), o.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32, \n bias_type, causal, BLOCK_HEADDIM,\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, m, softmax_scale\n\n\ndef _flash_attn_backward(do, q, k, v, delta, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):\n if do.stride(-1) != 1:\n do = do.contiguous()\n batch, seqlen_q, nheads, d = q.shape\n assert do.shape == (batch, seqlen_q, nheads, d) , f'do shape is {do.shape} and q shape is {q.shape}'\n assert k.shape == (batch, seqlen_q, nheads, d), f'k shape is {k.shape} and q shape is {q.shape}'\n assert v.shape == (batch, seqlen_q, nheads, d), f'v shape is {v.shape} and q shape is {q.shape}'\n _, seqlen_k, _, _ = k.shape\n assert d <= 128\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n assert lse.shape == (batch, nheads, seqlen_q_rounded), f\"lse shape is {lse.shape}\"\n assert delta.shape == (batch, nheads, seqlen_q_rounded), f\"delta shape is {delta.shape}\"\n assert q.stride(-1) == k.stride(-1) == v.stride(-1) == 1\n assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n dq_accum = torch.empty_like(q, dtype=torch.float32)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n bias = bias.transpose(0,1).contiguous()\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n assert bias.stride(-1) == 1\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'\n ' or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n grid = lambda META: (triton.cdiv(seqlen_k, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1,\n batch * nheads)\n _bwd_kernel[grid](\n q, k, v, bias,\n do, dq_accum, dk, dv,\n lse, delta,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n do.stride(0), do.stride(2), do.stride(1),\n dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),\n dk.stride(0), dk.stride(2), dk.stride(1),\n dv.stride(0), dv.stride(2), dv.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n )\n dq.copy_(dq_accum)\n", - "description_1": "Use triton language to implement a FlashAttention mechanism with forward and backward kernels, including support for causal and non-causal attention, self-attention, cross-attention, and optional attention bias. The forward kernel (_fwd_kernel) computes the output and intermediate states (like the LSE and m) for given Q, K, V matrices and biases. The backward kernel (_bwd_kernel) calculates the gradients for Q, K, V, incorporating optional attention bias and handling parallel execution across sequence dimensions if needed.", - "description_2": "Use triton language to implement a FlashAttention mechanism, with forward and backward kernel functions. The forward kernel (_fwd_kernel) processes Q, K, V matrices and outputs with optional bias, while the backward kernel (_bwd_kernel) computes gradients for these matrices considering parallel computation.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton_pre_mlir as triton\nimport triton_pre_mlir.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb,\n stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k,\n seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation for forward pass of FlashAttention\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom,\n nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,\n):\n # Triton kernel for preprocessing in backward pass\n\n@triton.jit\ndef _bwd_store_dk_dv(\n dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n):\n # Triton kernel for storing gradients of K and V\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn,\n stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q,\n seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,\n EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel for backward pass processing of one column block\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False},\n num_warps=8,\n num_stages=1,\n pre_hook=init_to_zero(\"DQ\"),\n ),\n triton.Config(\n {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True},\n num_warps=8,\n num_stages=1,\n pre_hook=init_to_zero(\"DQ\"),\n ),\n ],\n key=[\n \"CACHE_KEY_SEQLEN_Q\",\n \"CACHE_KEY_SEQLEN_K\",\n \"BIAS_TYPE\",\n \"IS_CAUSAL\",\n \"BLOCK_HEADDIM\",\n ],\n)\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb,\n stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh,\n stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q,\n CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr,\n EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n # Triton kernel implementation for backward pass of FlashAttention\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n # Function to call the forward Triton kernel\n\ndef _flash_attn_backward(\n do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None\n):\n # Function to call the backward Triton kernel\n", - "description_1": "Use triton language to implement a FlashAttention mechanism with both forward and backward passes. The forward pass computes the attention output given query, key, value, and optional bias tensors, supporting both causal and non-causal attention. The backward pass computes gradients for query, key, and value tensors. The implementation supports head dimensions up to 128 and uses triton's kernel functions for efficient computation.", - "description_2": "Use triton language to create a FlashAttention mechanism with forward and backward passes, supporting causal and non-causal attention, and head dimensions up to 128.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport triton._C.libtriton as libtriton\nfrom deepspeed.accelerator import get_accelerator\n\n@triton.jit\ndef _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc,\n stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta):\n # Triton kernel for sparse matrix multiplication\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8)\n offpb = 0\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8)\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8)\n offpa = 0\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8)\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n\n\nclass _sparse_matmul(torch.autograd.Function):\n\n sdd_cache = dict()\n dsd_cache = dict()\n dds_cache = dict()\n locks = dict()\n\n @staticmethod\n def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time):\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n AS0 = a.size(0)\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]\n if a_inner != b_inner:\n raise ValueError(f\"Size of tensor A along the {a_dim} dim ({a_inner}) must match size \"\n f\"of tensor B along the {b_dim} dim ({b_inner})\")\n if a_inner % 16 != 0:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n\n batch_size = a.size(0)\n a_outer = a.size(3 if trans_a else 2)\n dtype = a.dtype\n is_16_multiple = a_inner % 16 == 0\n is_32_multiple = a_inner % 32 == 0\n is_64_multiple = a_inner % 64 == 0\n if not is_16_multiple:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n device = a.device\n total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])\n c = torch.empty((batch_size, total_width, block, block), dtype=dtype, device=a.device)\n for lut, width, pack in zip(luts, widths, packs):\n F32TK = [8, 16]\n F16TK = [16]\n F16TK += [32] if is_32_multiple else []\n F16TK += [64] if is_64_multiple else []\n TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]\n num_lock = 1\n meta = {\n 'TM': block * pack,\n 'TN': block * pack,\n 'BLOCK': block,\n 'TK': TK[0],\n 'TZ': 1,\n 'SDD': True,\n 'DSD': False,\n 'DDS': False\n }\n locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)\n max_width = 49152\n total = 0 if bench else None\n for off_width in range(0, width, max_width):\n grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]\n _kernel[grid](a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(0),\n c.stride(2),\n c.stride(3),\n a_outer,\n a_outer,\n a_inner,\n off_width,\n lut,\n locks,\n num_lock,\n num_warps=4,\n **meta)\n return c\n\n\nclass MatMul:\n def __call__(self, a, b):\n c_lut, c_num_locks, c_width, c_packs,\\\n da_lut, da_num_locks, da_width, da_packs,\\\n db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)\n time_c = [None]\n time_da = [None]\n time_db = [None]\n\n original_dims = max(a.ndim, b.ndim)\n a, b = self._validate_inputs(a, b)\n\n a = MatMul._pad_shape(a, self.mode == 'dsd')\n b = MatMul._pad_shape(b, self.mode == 'dds')\n\n c = _sparse_matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,\n c_num_locks, c_width, c_packs, self.bench, time_c, da_lut, da_num_locks, da_width,\n da_packs, self.bench, time_da, db_lut, db_num_locks, db_width, db_packs, self.bench,\n time_db)\n\n dims_to_trim = c.ndim - original_dims\n for _ in range(dims_to_trim):\n c = c.squeeze(0)\n\n self.time_c = time_c[0]\n self.time_da = time_da[0]\n self.time_db = time_db[0]\n return c\n", - "description_1": "Use triton language to implement a sparse matrix multiplication kernel for block-sparse matrices. The kernel should handle different modes of sparsity (sparse = dense x dense, dense = sparse x dense, dense = dense x sparse) and efficiently perform matrix multiplications using a look-up table (LUT) for sparsity patterns. The function accepts parameters related to matrix dimensions, strides, LUTs, locks, and meta information for the computation.", - "description_2": "Use triton language to define a sparse matrix multiplication kernel with parameters for matrix sizes, strides, and block configurations, utilizing LUT for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm,\n stride_zattnm, **meta):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n columnid = tl.load(LUT + offset + rbmn * 4 + 1)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # apply scale\n if meta['APPLY_SCALE']:\n x = x * scale\n # apply RPE\n if meta['APPLY_RPE']:\n prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn\n rpe = tl.load(prpe, mask=check, other=0)\n x = x + rpe\n # apply key-padding mask\n if meta['APPLY_KP_MASK']:\n pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn\n kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))\n if meta['KP_MASK_MUL']:\n kp_m = tl.where(kp_m == 0, -float('inf'), 0.)\n x = x + kp_m\n # apply attention mask\n if meta['APPLY_ATTN_MASK']:\n pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn\n attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))\n if meta['ATTN_MASK_MUL']:\n attn_m = tl.where(attn_m == 0, -float('inf'), 0.)\n x = x + attn_m\n # computation\n x = tl.softmax(x)\n tl.store(px, x, mask=check)\n\n@triton.jit\ndef _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n # compute fused softmax backward\n x = tl.load(X, mask=check, other=0)\n dx = tl.load(DX, mask=check, other=0)\n x = x.to(tl.float32)\n dx = dx.to(tl.float32)\n y = x * (dx - tl.sum(x * dx, 0)) * scale\n tl.store(DX, y, mask=check)\n\nclass _sparse_softmax(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,\n num_blocks, maxlut, bench, time):\n\n apply_scale = False if scale == 1.0 else True\n\n # handle None rpe\n if rpe is None:\n apply_rpe = False\n stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0\n rpe = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_rpe = True\n stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)\n\n # handle None key_padding_mask\n if key_padding_mask is None:\n apply_kp_mask = False\n stride_zkpm = 0\n key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_kp_mask = True\n stride_zkpm = key_padding_mask.stride(0)\n\n # handle None attention_mask\n if attn_mask is None:\n apply_attn_mask = False\n stride_zattnm = 0\n attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_attn_mask = True\n stride_zattnm = attn_mask.stride(0)\n\n # run kernel\n M = x.shape[0]\n meta = {\n 'BLOCK': block,\n 'APPLY_SCALE': apply_scale,\n 'APPLY_RPE': apply_rpe,\n 'APPLY_KP_MASK': apply_kp_mask,\n 'APPLY_ATTN_MASK': apply_attn_mask,\n 'KP_MASK_MUL': kp_mask_mode == 'mul',\n 'ATTN_MASK_MUL': attn_mask_mode == 'mul',\n }\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\\\n stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)\n\n # save to context\n ctx.mark_dirty(x)\n ctx.save_for_backward(x, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.apply_scale = apply_scale\n ctx.apply_rpe = apply_rpe\n ctx.apply_kp_mask = apply_kp_mask\n ctx.apply_attn_mask = apply_attn_mask\n ctx.kp_mask_mode = kp_mask_mode\n ctx.attn_mask_mode = attn_mask_mode\n return x\n\n @staticmethod\n def backward(ctx, dx):\n\n # retrieve from context\n x, lut = ctx.saved_tensors\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n", - "description_1": "Use triton language to implement a block-sparse softmax operation with two kernels: _forward and _backward. The _forward kernel takes 13 parameters: X (input tensor), scale (scaling factor), LUT (look-up table), RPE (relative position embedding), KP_M (key padding mask), ATTN_M (attention mask), sizemax (maximum size), stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm (stride values for various tensors). The _backward kernel takes 7 parameters: X (input tensor), scale (scaling factor), DX (gradient tensor), LUT (look-up table), sizemax (maximum size), stride_zx, stride_zdx (stride values for input and gradient tensors). The _sparse_softmax class uses these kernels to perform forward and backward passes of the softmax operation on block-sparse matrices, applying optional scaling, relative position embedding, key padding mask, and attention mask.", - "description_2": "Use triton language to create a block-sparse softmax operation with forward and backward kernels, handling optional scaling, relative position embedding, key padding mask, and attention mask.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n TMP,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, N_CTX, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\nclass triton_flash_attn(torch.nn.Module):\n\n def __init__(self, ):\n super(triton_flash_attn, self).__init__()\n\n def forward(self, q, k, v, sm_scale, block_128=True):\n BLOCK = 128 if block_128 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n tmp,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n k.shape[0],\n k.shape[1],\n k.shape[2],\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=1,\n )\n return o\n", - "description_1": "Use triton language to implement a forward kernel for a flash attention mechanism. The kernel takes 25 parameters: Q, K, V (input matrices), sm_scale (a scaling factor), TMP (temporary storage), Out (output matrix), 16 stride parameters for indexing, Z, H, N_CTX (context size), and three block sizes (BLOCK_M, BLOCK_DMODEL, BLOCK_N). The kernel computes scaled dot-product attention using a loop over the context size, updating accumulators and storing results in the output matrix.", - "description_2": "Use triton language to create a flash attention module in PyTorch. The module's forward method takes 5 parameters: q, k, v (input matrices), sm_scale (a scaling factor), and block_128 (a boolean to determine block size). It prepares the grid and temporary storage, calculates the number of warps, and calls the _fwd_kernel with appropriate parameters to compute the attention output.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _kernel(A,\n B,\n C,\n stride_za,\n stride_ha,\n stride_ma,\n stride_ka,\n stride_zb,\n stride_hb,\n stride_kb,\n stride_nb,\n stride_zc,\n stride_hc,\n stride_mc,\n stride_nc,\n DS0,\n DS1,\n SDD_K,\n SDD_off_width,\n lut,\n locks,\n nlocks,\n **meta):\n # Triton kernel for block sparse matrix multiplication\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n # Prologue\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n # output offset\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n # dense input offset\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8) # compiler hint\n offpb = 0\n # sparse input offset\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8) # compiler hint\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n # output offset\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n # dense input offset\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8) # compiler hint\n offpa = 0\n # sparse input offset\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8) # compiler hint\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n\n # initialize a, b pointers\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n\n # Inner Loop\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n # pre-fetch\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n # write-back directly\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n # accumulate partial results using spin-locks\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(\n 1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n\n\ndef _sdd_matmul(a,\n b,\n trans_a,\n trans_b,\n trans_c,\n spdims,\n block,\n luts,\n num_locks,\n widths,\n packs,\n bench,\n time):\n # Parameters\n AS0 = a.size(0)\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]\n if a_inner != b_inner:\n raise ValueError(\"Size of tensor A and B must match\")\n if a_inner % 16 != 0:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n\n batch_size = a.size(0)\n a_outer = a.size(3 if trans_a else 2)\n dtype = a.dtype\n is_32_multiple = a_inner % 32 == 0\n is_64_multiple = a_inner % 64 == 0\n if not is_32_multiple:\n raise ValueError('Reduction size for SDD must be a multiple of 32')\n device = a.device\n\n # Create kernel\n total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])\n c = torch.empty((batch_size,\n total_width,\n block,\n block),\n dtype=dtype,\n device=a.device)\n for lut, width, pack in zip(luts, widths, packs):\n F16TK = [16]\n F16TK += [32] if is_32_multiple else []\n F16TK += [64] if is_64_multiple else []\n TK = F16TK\n num_lock = 1\n meta = {\n 'TM': block * pack,\n 'TN': block * pack,\n 'BLOCK': block,\n 'TK': TK[0],\n 'TZ': 1,\n 'SDD': True,\n 'DSD': False,\n 'DDS': False\n }\n # Create output\n locks = torch.zeros(2 * width * AS0 * num_lock, dtype=torch.int32, device=a.device)\n # Maximum grid size is 65535\n max_width = 49152\n for off_width in range(0, width, max_width):\n grid = lambda meta: [\n meta['TZ'],\n min(max_width, width - off_width),\n batch_size\n ]\n _kernel[grid](a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(0),\n c.stride(2),\n c.stride(3),\n a_outer,\n a_outer,\n a_inner,\n off_width,\n lut,\n locks,\n num_lock,\n num_warps=4,\n **meta)\n return c\n", - "description_1": "Use triton language to implement a block sparse matrix multiplication kernel (_kernel) and a wrapper function (_sdd_matmul) to handle the calling logic. The _kernel has 22 main parameters (A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks) and uses meta information for configuration. The _sdd_matmul function prepares inputs and launches the _kernel with appropriate grid dimensions and parameters.", - "description_2": "Use triton language to create a matrix multiplication kernel with parameters for input matrices, strides, and a lookup table. Implement a function to configure and invoke the kernel with appropriate settings for block sparsity.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\ndef num_warps(n):\n if n < 512:\n return 4\n if n < 2048:\n return 8\n return 16\n\n@triton.heuristics({\n 'num_warps': lambda *args,\n **meta: num_warps(args[6] * meta['BLOCK'])\n})\n@triton.heuristics({\n 'TN': lambda *args,\n **meta: next_power_of_2(args[6] * meta['BLOCK'])\n})\n@triton.jit\ndef _forward(X,\n scale,\n LUT,\n RPE,\n KP_M,\n ATTN_M,\n sizemax,\n stride_zx,\n stride_zrpe,\n stride_hrpe,\n stride_srpe,\n stride_zkpm,\n stride_zattnm,\n **meta):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n columnid = tl.load(LUT + offset + rbmn * 4 + 1)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # apply scale\n if meta['APPLY_SCALE']:\n x = x * scale\n # apply RPE\n if meta['APPLY_RPE']:\n prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn\n rpe = tl.load(prpe, mask=check, other=0)\n x = x + rpe\n # apply key-padding mask\n if meta['APPLY_KP_MASK']:\n pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn\n kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))\n if meta['KP_MASK_MUL']:\n kp_m = tl.where(kp_m == 0, -float('inf'), 0.)\n x = x + kp_m\n # apply attention mask\n if meta['APPLY_ATTN_MASK']:\n pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn\n attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))\n if meta['ATTN_MASK_MUL']:\n attn_m = tl.where(attn_m == 0, -float('inf'), 0.)\n x = x + attn_m\n # computation\n x = tl.softmax(x)\n tl.store(px, x, mask=check)\n\n@triton.heuristics({\n 'num_warps': lambda *args,\n **meta: num_warps(args[4] * meta['BLOCK'])\n})\n@triton.heuristics({\n 'TN': lambda *args,\n **meta: next_power_of_2(args[4]) * meta['BLOCK']\n})\n@triton.jit\ndef _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n # compute fused softmax backward\n x = tl.load(X, mask=check, other=0)\n dx = tl.load(DX, mask=check, other=0)\n x = x.to(tl.float32)\n dx = dx.to(tl.float32)\n y = x * (dx - tl.sum(x * dx, 0)) * scale\n tl.store(DX, y, mask=check)\n\nclass _sparse_softmax(torch.autograd.Function):\n\n bwd_kernels = dict()\n\n @staticmethod\n def make_lut(layout, block, device):\n _empty = torch.tensor([], dtype=torch.int64, device=layout.device)\n sizes = _empty.clone()\n # sizes along rows\n for h in range(layout.shape[0]):\n sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))\n # offsets in block format\n offsets = torch.zeros_like(sizes)\n offsets[1:] = torch.cumsum(sizes[:-1], dim=0)\n # block indices\n idx = torch.arange(layout.sum())\n head = layout.nonzero()[:, 0]\n rows = layout.nonzero()[:, 1]\n columns = layout.nonzero()[:, 2]\n core = torch.stack((idx, columns, rows, head), dim=1).view(-1)\n # construct look-up table\n offsets = offsets * 4 + 2 * sizes.numel()\n header = torch.stack((sizes, offsets), dim=1).view(-1)\n lut = torch.cat((header, core)).type(torch.int32).to(device)\n return lut, int(sizes.max())\n\n @staticmethod\n def forward(ctx,\n x,\n scale,\n rpe,\n key_padding_mask,\n attn_mask,\n kp_mask_mode,\n attn_mask_mode,\n spdims,\n block,\n lut,\n num_blocks,\n maxlut,\n bench,\n time):\n\n apply_scale = False if scale == 1.0 else True\n\n # handle None rpe\n if rpe is None:\n apply_rpe = False\n stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0\n rpe = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_rpe = True\n stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)\n\n # handle None key_padding_mask\n if key_padding_mask is None:\n apply_kp_mask = False\n stride_zkpm = 0\n key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_kp_mask = True\n stride_zkpm = key_padding_mask.stride(0)\n\n # handle None attention_mask\n if attn_mask is None:\n apply_attn_mask = False\n stride_zattnm = 0\n attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_attn_mask = True\n stride_zattnm = attn_mask.stride(0)\n\n # run kernel\n M = x.shape[0]\n meta = {\n 'BLOCK': block,\n 'APPLY_SCALE': apply_scale,\n 'APPLY_RPE': apply_rpe,\n 'APPLY_KP_MASK': apply_kp_mask,\n 'APPLY_ATTN_MASK': apply_attn_mask,\n 'KP_MASK_MUL': kp_mask_mode == 'mul',\n 'ATTN_MASK_MUL': attn_mask_mode == 'mul',\n }\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\\\n stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)\n\n # save to context\n ctx.mark_dirty(x)\n ctx.save_for_backward(x, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.apply_scale = apply_scale\n ctx.apply_rpe = apply_rpe\n ctx.apply_kp_mask = apply_kp_mask\n ctx.apply_attn_mask = apply_attn_mask\n ctx.kp_mask_mode = kp_mask_mode\n ctx.attn_mask_mode = attn_mask_mode\n return x\n\n @staticmethod\n def backward(ctx, dx):\n\n # retrieve from context\n x, lut = ctx.saved_tensors\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x,\n ctx.scale,\n dx,\n lut,\n ctx.maxlut,\n x.stride(0),\n dx.stride(0),\n BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\nclass Softmax:\n \"\"\"Block-Sparse Softmax class; this class computes softmax on a block sparse matrix. It is also able to apply either/all of the following masks:\n - relative position embedding\n - key padding mask\n - attention mask\n\n For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509\n \"\"\"\n def sparse_softmax(*args, **kwargs):\n return _sparse_softmax.apply(*args, **kwargs)\n\n def make_lut(self, device):\n \"\"\"Generates the sparsity layout used in block-sparse softmax\n \"\"\"\n key = (device, )\n if key not in self.lut_cache:\n self.lut_cache[key] = _sparse_softmax.make_lut(self.layout,\n self.block,\n device)\n return self.lut_cache[key]\n\n def __init__(self, layout, block, bench=False):\n \"\"\"Initialize the Block-Sparse Softmax class.\n\n Arguments:\n layout: required: sparsity layout tensor\n block: required: an integer determining the block size.\n bench: optional: set if you want to do benchmarking\n \"\"\"\n\n self.num_blocks = layout.sum().item()\n self.spdims = layout.shape\n self.layout = layout\n self.block = block\n self.bench = bench\n self.lut_cache = dict()\n\n def __call__(self,\n x,\n scale=1.,\n rpe=None,\n key_padding_mask=None,\n attn_mask=None,\n key_padding_mask_mode='add',\n attn_mask_mode='add'):\n \"\"\"Applies softmax on a Block-Sparse input tensor.\n\n For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509\n\n Arguments:\n x: required: a block-sparse tensor that softmax is applied on it; computation will be in place and result will be returned in the same tensor\n scale: optional: a float value; x values will be multiplied by this value before normalization. Default value is 1.0.\n rpe: optional: a tensor same dimension as x that is used as relative position embedding\n key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)\n attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported\n key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied\n attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied\n\n Return:\n x: a block-sparse tensor contains normalized input x using softmax; and masks applied if given\n \"\"\"\n\n time_y = [None]\n if rpe is not None and rpe.dtype != x.dtype:\n raise ValueError('relative position embedding must be %s' % x.dtype)\n if attn_mask is not None and attn_mask.dtype != x.dtype:\n raise ValueError('Attention mask must be %s' % x.dtype)\n if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:\n raise ValueError('Key padding mask must be %s' % x.dtype)\n lut, maxlut = self.make_lut(x.device)\n x = Softmax.sparse_softmax(x,\n scale,\n rpe,\n key_padding_mask,\n attn_mask,\n key_padding_mask_mode,\n attn_mask_mode,\n self.spdims,\n self.block,\n lut,\n self.num_blocks,\n maxlut,\n self.bench,\n time_y)\n self.time_y = time_y[0]\n return x\n", - "description_1": "Use triton language to implement block-sparse softmax with optional scaling, relative position embedding, key padding mask, and attention mask. The forward kernel (_forward) takes 13 parameters: X (input tensor), scale (scaling factor), LUT (look-up table), RPE (relative position embedding), KP_M (key padding mask), ATTN_M (attention mask), sizemax (maximum size), stride_zx (stride for X), stride_zrpe (stride for RPE), stride_hrpe (stride for head in RPE), stride_srpe (stride for sequence in RPE), stride_zkpm (stride for key padding mask), and stride_zattnm (stride for attention mask). The backward kernel (_backward) takes 7 parameters: X (input tensor), scale (scaling factor), DX (gradient tensor), LUT (look-up table), sizemax (maximum size), stride_zx (stride for X), and stride_zdx (stride for DX). The Softmax class provides a method to apply the block-sparse softmax using these kernels.", - "description_2": "Use triton language to create a block-sparse softmax operator with support for scaling, relative position embedding, key padding mask, and attention mask. Implement forward and backward kernels to handle the computation and gradient propagation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n TMP,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, N_CTX, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\nclass triton_flash_attn(torch.nn.Module):\n def __init__(self, ):\n super(triton_flash_attn, self).__init__()\n\n def forward(self, q, k, v, sm_scale, block_128=True):\n BLOCK = 128 if block_128 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1],\n q.shape[2]),\n device=q.device,\n dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n tmp,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n k.shape[0],\n k.shape[1],\n k.shape[2],\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=1,\n )\n return o\n", - "description_1": "Use triton language to implement a forward kernel for flash attention. The kernel takes 24 parameters: Q, K, V (input matrices), sm_scale (scale factor), TMP (temporary storage), Out (output matrix), 16 stride parameters for indexing, Z, H, N_CTX (context size), and 3 block size constants. The kernel computes scaled dot-product attention using a loop over the context size, updating accumulators for the output matrix.", - "description_2": "Use triton language to create a PyTorch module for flash attention. The module has a forward method with 5 parameters: q, k, v (input matrices), sm_scale (scale factor), and block_128 (block size flag). It sets up grid and temporary storage, then calls the triton kernel with appropriate parameters to compute the attention output.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Kernel function for matrix multiplication\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr\n):\n # Triton kernel code for matrix multiplication\n pass\n\n# Function to call the Triton kernel\ndef call_matmul_kernel(a, b, c, M, N, K):\n # Launch the Triton kernel\n matmul_kernel[(M, N)](a, b, c, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=32)\n\n# Example usage\na = torch.randn(128, 128, device='cuda')\nb = torch.randn(128, 128, device='cuda')\nc = torch.empty(128, 128, device='cuda')\ncall_matmul_kernel(a, b, c, 128, 128, 128)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with parameters for input matrices, output matrix, and dimensions M, N, K. The kernel uses block sizes for M, N, and K as compile-time constants.", - "description_2": "Use triton language to create a matrix multiplication kernel and a function to call this kernel with specified block sizes and dimensions.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef rotate_half_kernel(\n qk_seq_ptr,\n position_ids_ptr,\n qk_seq_stride,\n position_ids_batch_stride,\n seq_len,\n HEAD_DIM: tl.constexpr,\n BLOCK_HEIGHT: tl.constexpr,\n BLOCK_WIDTH: tl.constexpr,\n INV_BASE: tl.constexpr,\n):\n # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.\n # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.\n\n HALF_HEAD: tl.constexpr = HEAD_DIM // 2\n STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH\n\n batch_seq = tl.program_id(axis=0)\n row_blk_x_col_blk = tl.program_id(axis=1)\n\n row_blk = row_blk_x_col_blk // STEPS_PER_ROW\n row = row_blk * BLOCK_HEIGHT\n if BLOCK_WIDTH < HALF_HEAD:\n col_blk = row_blk_x_col_blk % STEPS_PER_ROW\n col = col_blk * BLOCK_WIDTH\n else:\n col: tl.constexpr = 0\n\n # A block will never cross a sequence boundary, which simplifies things a lot.\n batch = batch_seq // seq_len\n seq = batch_seq % seq_len\n position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)\n # As sometimes happens, just calculating this on the fly is faster than loading it from memory.\n # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.\n freq = (\n tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE)\n * position_id\n )\n cos = tl.cos(freq).to(tl.float32)\n sin = tl.sin(freq).to(tl.float32)\n\n col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)\n embed_offsets = (row * HEAD_DIM + col) + col_offsets\n x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets\n\n for k in range(0, BLOCK_HEIGHT):\n x = tl.load(x_ptrs).to(tl.float32)\n y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)\n out_x = x * cos - y * sin\n tl.store(x_ptrs, out_x)\n out_y = x * sin + y * cos\n tl.store(x_ptrs + HALF_HEAD, out_y)\n x_ptrs += HEAD_DIM\n\n\ndef triton_rotate_half_(qk, position_ids, config=None):\n batch_size, seq_len, qandk, num_heads, head_dim = qk.shape\n\n # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.\n config = config or {\n \"BLOCK_HEIGHT\": 1,\n \"BLOCK_WIDTH\": min(128, head_dim // 2),\n \"num_warps\": 1,\n }\n config[\"BLOCK_HEIGHT\"] = min(config[\"BLOCK_HEIGHT\"], 2 * num_heads)\n\n assert qk.stride(3) == head_dim\n assert qk.stride(4) == 1\n assert position_ids.shape == (batch_size, seq_len)\n assert (\n position_ids.stride(1) == 1\n ), \"position_ids must be contiguous in the last dimension\"\n assert (2 * num_heads) % config[\n \"BLOCK_HEIGHT\"\n ] == 0, f'number of rows not evenly divisible by {config[\"BLOCK_HEIGHT\"]}'\n assert (head_dim // 2) % config[\n \"BLOCK_WIDTH\"\n ] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config[\"BLOCK_WIDTH\"]}'\n\n qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)\n grid = (\n qk_by_seq.shape[0],\n (2 * num_heads // config[\"BLOCK_HEIGHT\"])\n * (head_dim // 2 // config[\"BLOCK_WIDTH\"]),\n )\n\n # Must be the same as the theta of the frequencies used to train the model.\n BASE = 10000.0\n\n rotate_half_kernel[grid](\n qk_by_seq,\n position_ids,\n qk_by_seq.stride(0),\n position_ids.stride(0),\n seq_len,\n HEAD_DIM=head_dim,\n BLOCK_HEIGHT=config[\"BLOCK_HEIGHT\"],\n BLOCK_WIDTH=config[\"BLOCK_WIDTH\"],\n INV_BASE=-2.0 * math.log(BASE) / head_dim,\n num_warps=config[\"num_warps\"],\n )\n", - "description_1": "Use triton language to implement a kernel function 'rotate_half_kernel' that performs in-place rotation of half of the head dimension of a query-key sequence tensor based on position ids. The kernel takes 9 parameters: qk_seq_ptr (pointer to the query-key sequence), position_ids_ptr (pointer to position ids), qk_seq_stride (stride of the query-key sequence), position_ids_batch_stride (stride of position ids), seq_len (sequence length), HEAD_DIM (head dimension), BLOCK_HEIGHT (block height), BLOCK_WIDTH (block width), and INV_BASE (inverse base for frequency calculation). The function 'triton_rotate_half_' is a wrapper that prepares the input data and configuration for the kernel execution, taking 3 parameters: qk (query-key tensor), position_ids (position ids tensor), and config (optional configuration dictionary).", - "description_2": "Use triton language to create a kernel that rotates half of the head dimension of a tensor based on position ids, with a wrapper function to configure and execute the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fusedmatmul_248_kernel(\n a_ptr,\n c_ptr,\n b1_ptr,\n scales1_ptr,\n zeros1_ptr,\n g1_ptr,\n b2_ptr,\n scales2_ptr,\n zeros2_ptr,\n g2_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n b2_ptrs = b2_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(\n scales1_ptrs + g1_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(\n zeros1_ptrs + g1_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = zeros1 + 1\n\n zeros2 = tl.load(\n zeros2_ptrs + g2_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = zeros2 + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\nclass QuantLlamaMLP(nn.Module):\n def __init__(\n self,\n gate_proj,\n down_proj,\n up_proj,\n ):\n super().__init__()\n self.register_buffer(\"gate_proj_qweight\", gate_proj.qweight)\n self.register_buffer(\"gate_proj_scales\", gate_proj.scales)\n self.register_buffer(\"gate_proj_qzeros\", gate_proj.qzeros)\n self.register_buffer(\"gate_proj_g_idx\", gate_proj.g_idx)\n self.register_buffer(\"up_proj_qweight\", up_proj.qweight)\n self.register_buffer(\"up_proj_scales\", up_proj.scales)\n self.register_buffer(\"up_proj_qzeros\", up_proj.qzeros)\n self.register_buffer(\"up_proj_g_idx\", up_proj.g_idx)\n\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.down_proj = down_proj\n\n def forward(self, x):\n return self.down_proj(self.triton_llama_mlp(x))\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size,)\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n fusedmatmul_248_kernel[grid](\n x,\n c,\n self.gate_proj_qweight,\n self.gate_proj_scales,\n self.gate_proj_qzeros,\n self.gate_proj_g_idx,\n self.up_proj_qweight,\n self.up_proj_scales,\n self.up_proj_qzeros,\n self.up_proj_g_idx,\n M,\n N,\n K,\n self.bits,\n self.maxq,\n x.stride(0),\n x.stride(1),\n self.gate_proj_qweight.stride(0),\n self.gate_proj_qweight.stride(1),\n c.stride(0),\n c.stride(1),\n self.gate_proj_scales.stride(0),\n self.gate_proj_qzeros.stride(0),\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a fused matrix multiplication kernel that computes C = silu(A * B1) * (A * B2) where A is a float16 matrix of shape (M, K), B1 and B2 are int32 matrices of shape (K//8, N). The kernel takes 24 parameters including pointers to input matrices, scales, zeros, group indices, dimensions M, N, K, bit width, max quantization value, and strides for memory access. The kernel is optimized for specific block sizes and group sizes.", - "description_2": "Use triton language to implement a SiLU activation function and a fused matrix multiplication kernel for quantized weights, optimized for specific block sizes and group sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(\n scales_ptrs + g_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs + g_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk\n + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = (\n zeros_ptr\n + (offs_n[None, :] // infearure_per_bits)\n + g_idx[:, None] * stride_zeros\n )\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty(\n (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n input.shape[1],\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty(\n (input.shape[0], output_dim), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(output_dim, META[\"BLOCK_SIZE_K\"]),\n )\n transpose_matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n output_dim,\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n", - "description_1": "Use triton language to define a matrix multiplication kernel for quantized matrices with kernel function `matmul_248_kernel` taking 18 tensor arguments and 4 block size constants to compute the multiplication of A (M, K) with B (K//8, N) and outputting C (M, N) while applying scale and zero-point adjustments.", - "description_2": "Use triton language to define a transpose matrix multiplication kernel with kernel function `transpose_matmul_248_kernel` taking 18 tensor arguments and 4 block size constants to compute the transpose multiplication of A (M, N) with B (K//8, N) and outputting C (M, K) with scale and zero-point corrections.", - "difficulty": 4 - }, - { - "code": "import triton\nimport math\n\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\ndef autotune(\n configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False\n):\n def decorator(fn):\n return Autotuner(\n fn,\n fn.arg_names,\n configs,\n key,\n reset_to_zero,\n prune_configs_by,\n nearest_power_of_two,\n )\n return decorator\n\n@autotune(configs=[\n triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),\n triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),\n ],\n key=['x_size']\n)\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n", - "description_1": "Use triton language to define a kernel function with two parameters: x_ptr (pointer to data) and x_size (size of data). The kernel uses a meta-parameter BLOCK_SIZE to determine the block size for processing. The kernel is decorated with an autotuner that evaluates different configurations based on changes in x_size.", - "description_2": "Use triton language to create a kernel with autotuning capabilities, adjusting block size based on input size.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef rotate_half_kernel(\n qk_seq_ptr,\n position_ids_ptr,\n qk_seq_stride,\n position_ids_batch_stride,\n seq_len,\n HEAD_DIM: tl.constexpr,\n BLOCK_HEIGHT: tl.constexpr,\n BLOCK_WIDTH: tl.constexpr,\n INV_BASE: tl.constexpr,\n):\n # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.\n # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.\n\n HALF_HEAD: tl.constexpr = HEAD_DIM // 2\n STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH\n\n batch_seq = tl.program_id(axis=0)\n row_blk_x_col_blk = tl.program_id(axis=1)\n\n row_blk = row_blk_x_col_blk // STEPS_PER_ROW\n row = row_blk * BLOCK_HEIGHT\n if BLOCK_WIDTH < HALF_HEAD:\n col_blk = row_blk_x_col_blk % STEPS_PER_ROW\n col = col_blk * BLOCK_WIDTH\n else:\n col: tl.constexpr = 0\n\n # A block will never cross a sequence boundary, which simplifies things a lot.\n batch = batch_seq // seq_len\n seq = batch_seq % seq_len\n position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)\n # As sometimes happens, just calculating this on the fly is faster than loading it from memory.\n # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.\n freq = (\n tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE)\n * position_id\n )\n cos = tl.cos(freq).to(tl.float32)\n sin = tl.sin(freq).to(tl.float32)\n\n col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)\n embed_offsets = (row * HEAD_DIM + col) + col_offsets\n x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets\n\n for k in range(0, BLOCK_HEIGHT):\n x = tl.load(x_ptrs).to(tl.float32)\n y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)\n out_x = x * cos - y * sin\n tl.store(x_ptrs, out_x)\n out_y = x * sin + y * cos\n tl.store(x_ptrs + HALF_HEAD, out_y)\n x_ptrs += HEAD_DIM\n\n\ndef triton_rotate_half_(qk, position_ids, config=None):\n batch_size, seq_len, qandk, num_heads, head_dim = qk.shape\n\n config = config or {\n \"BLOCK_HEIGHT\": 1,\n \"BLOCK_WIDTH\": min(128, head_dim // 2),\n \"num_warps\": 1,\n }\n config[\"BLOCK_HEIGHT\"] = min(config[\"BLOCK_HEIGHT\"], 2 * num_heads)\n\n assert qk.stride(3) == head_dim\n assert qk.stride(4) == 1\n assert position_ids.shape == (batch_size, seq_len)\n assert (\n position_ids.stride(1) == 1\n ), \"position_ids must be contiguous in the last dimension\"\n assert (2 * num_heads) % config[\n \"BLOCK_HEIGHT\"\n ] == 0, f'number of rows not evenly divisible by {config[\"BLOCK_HEIGHT\"]}'\n assert (head_dim // 2) % config[\n \"BLOCK_WIDTH\"\n ] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config[\"BLOCK_WIDTH\"]}'\n\n qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)\n grid = (\n qk_by_seq.shape[0],\n (2 * num_heads // config[\"BLOCK_HEIGHT\"])\n * (head_dim // 2 // config[\"BLOCK_WIDTH\"]),\n )\n\n BASE = 10000.0\n\n rotate_half_kernel[grid](\n qk_by_seq,\n position_ids,\n qk_by_seq.stride(0),\n position_ids.stride(0),\n seq_len,\n HEAD_DIM=head_dim,\n BLOCK_HEIGHT=config[\"BLOCK_HEIGHT\"],\n BLOCK_WIDTH=config[\"BLOCK_WIDTH\"],\n INV_BASE=-2.0 * math.log(BASE) / head_dim,\n num_warps=config[\"num_warps\"],\n )\n", - "description_1": "Use triton language to implement a kernel 'rotate_half_kernel' with parameters: qk_seq_ptr, position_ids_ptr, qk_seq_stride, position_ids_batch_stride, seq_len, HEAD_DIM, BLOCK_HEIGHT, BLOCK_WIDTH, INV_BASE. This kernel modifies the input sequence data by performing a rotation transformation using cosine and sine computations. Another function 'triton_rotate_half_' with parameters qk, position_ids, config is implemented to configure and launch this kernel.", - "description_2": "Use triton language to create a kernel that applies rotation on sequence data with configurable block dimensions and then implement a wrapper function to set up and execute this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fusedmatmul_248_kernel(\n a_ptr,\n c_ptr,\n b1_ptr,\n scales1_ptr,\n zeros1_ptr,\n g1_ptr,\n b2_ptr,\n scales2_ptr,\n zeros2_ptr,\n g2_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n b2_ptrs = b2_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(\n scales1_ptrs + g1_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(\n zeros1_ptrs + g1_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = zeros1 + 1\n\n zeros2 = tl.load(\n zeros2_ptrs + g2_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = zeros2 + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\nclass QuantLlamaMLP(nn.Module):\n def __init__(\n self,\n gate_proj,\n down_proj,\n up_proj,\n ):\n super().__init__()\n self.register_buffer(\"gate_proj_qweight\", gate_proj.qweight)\n self.register_buffer(\"gate_proj_scales\", gate_proj.scales)\n self.register_buffer(\"gate_proj_qzeros\", gate_proj.qzeros)\n self.register_buffer(\"gate_proj_g_idx\", gate_proj.g_idx)\n self.register_buffer(\"up_proj_qweight\", up_proj.qweight)\n self.register_buffer(\"up_proj_scales\", up_proj.scales)\n self.register_buffer(\"up_proj_qzeros\", up_proj.qzeros)\n self.register_buffer(\"up_proj_g_idx\", up_proj.g_idx)\n\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.down_proj = down_proj\n\n def forward(self, x):\n return self.down_proj(self.triton_llama_mlp(x))\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size,)\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n fusedmatmul_248_kernel[grid](\n x,\n c,\n self.gate_proj_qweight,\n self.gate_proj_scales,\n self.gate_proj_qzeros,\n self.gate_proj_g_idx,\n self.up_proj_qweight,\n self.up_proj_scales,\n self.up_proj_qzeros,\n self.up_proj_g_idx,\n M,\n N,\n K,\n self.bits,\n self.maxq,\n x.stride(0),\n x.stride(1),\n self.gate_proj_qweight.stride(0),\n self.gate_proj_qweight.stride(1),\n c.stride(0),\n c.stride(1),\n self.gate_proj_scales.stride(0),\n self.gate_proj_qzeros.stride(0),\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a fused matrix multiplication kernel that computes C = silu(A * B1) * (A * B2) with inputs A, B1, B2, scales, and zeros, and a helper function silu. The kernel takes 24 parameters including pointers to input and output matrices, dimensions, bit width, and strides.", - "description_2": "Use triton language to implement a fused matrix multiplication kernel with silu activation and a helper function, taking 24 parameters including matrix pointers and dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(\n scales_ptrs + g_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs + g_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk\n + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = (\n zeros_ptr\n + (offs_n[None, :] // infearure_per_bits)\n + g_idx[:, None] * stride_zeros\n )\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty(\n (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n input.shape[1],\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty(\n (input.shape[0], output_dim), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(output_dim, META[\"BLOCK_SIZE_K\"]),\n )\n transpose_matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n output_dim,\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: 'matmul_248_kernel' and 'transpose_matmul_248_kernel'. The first kernel computes C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel computes C = A x B where A is a float16 matrix of shape (M, N), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, K). Both kernels use additional parameters for scaling and zero-point adjustments, and they are optimized for specific block sizes and group sizes.", - "description_2": "Use triton language to create optimized matrix multiplication kernels for quantized matrices, handling scaling and zero-point adjustments, with specific block and group sizes.", - "difficulty": 4 - }, - { - "code": "import triton\n\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n\ndef autotune(\n configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False\n):\n def decorator(fn):\n return Autotuner(\n fn,\n fn.arg_names,\n configs,\n key,\n reset_to_zero,\n prune_configs_by,\n nearest_power_of_two,\n )\n return decorator\n", - "description_1": "Use triton language to define a kernel function with 2 parameters. The kernel uses a BLOCK_SIZE defined in META to perform operations on x_ptr of size x_size. Utilize the autotune function to optimize kernel execution using a decorator with parameters for configurations, keys, and pruning logic.", - "description_2": "Use triton language to implement a kernel function for block processing and a decorator for autotuning configurations to optimize execution.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef rotate_half_kernel(\n qk_seq_ptr,\n position_ids_ptr,\n qk_seq_stride,\n position_ids_batch_stride,\n seq_len,\n HEAD_DIM: tl.constexpr,\n BLOCK_HEIGHT: tl.constexpr,\n BLOCK_WIDTH: tl.constexpr,\n INV_BASE: tl.constexpr,\n):\n # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.\n # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.\n\n HALF_HEAD: tl.constexpr = HEAD_DIM // 2\n STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH\n\n batch_seq = tl.program_id(axis=0)\n row_blk_x_col_blk = tl.program_id(axis=1)\n\n row_blk = row_blk_x_col_blk // STEPS_PER_ROW\n row = row_blk * BLOCK_HEIGHT\n if BLOCK_WIDTH < HALF_HEAD:\n col_blk = row_blk_x_col_blk % STEPS_PER_ROW\n col = col_blk * BLOCK_WIDTH\n else:\n col: tl.constexpr = 0\n\n # A block will never cross a sequence boundary, which simplifies things a lot.\n batch = batch_seq // seq_len\n seq = batch_seq % seq_len\n position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)\n # As sometimes happens, just calculating this on the fly is faster than loading it from memory.\n # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.\n freq = (\n tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE)\n * position_id\n )\n cos = tl.cos(freq).to(tl.float32)\n sin = tl.sin(freq).to(tl.float32)\n\n col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)\n embed_offsets = (row * HEAD_DIM + col) + col_offsets\n x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets\n\n for k in range(0, BLOCK_HEIGHT):\n x = tl.load(x_ptrs).to(tl.float32)\n y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)\n out_x = x * cos - y * sin\n tl.store(x_ptrs, out_x)\n out_y = x * sin + y * cos\n tl.store(x_ptrs + HALF_HEAD, out_y)\n x_ptrs += HEAD_DIM\n\n\ndef triton_rotate_half_(qk, position_ids, config=None):\n batch_size, seq_len, qandk, num_heads, head_dim = qk.shape\n\n # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.\n config = config or {\n \"BLOCK_HEIGHT\": 1,\n \"BLOCK_WIDTH\": min(128, head_dim // 2),\n \"num_warps\": 1,\n }\n config[\"BLOCK_HEIGHT\"] = min(config[\"BLOCK_HEIGHT\"], 2 * num_heads)\n\n assert qk.stride(3) == head_dim\n assert qk.stride(4) == 1\n assert position_ids.shape == (batch_size, seq_len)\n assert (\n position_ids.stride(1) == 1\n ), \"position_ids must be contiguous in the last dimension\"\n assert (2 * num_heads) % config[\n \"BLOCK_HEIGHT\"\n ] == 0, f'number of rows not evenly divisible by {config[\"BLOCK_HEIGHT\"]}'\n assert (head_dim // 2) % config[\n \"BLOCK_WIDTH\"\n ] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config[\"BLOCK_WIDTH\"]}'\n\n qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)\n grid = (\n qk_by_seq.shape[0],\n (2 * num_heads // config[\"BLOCK_HEIGHT\"])\n * (head_dim // 2 // config[\"BLOCK_WIDTH\"]),\n )\n\n # Must be the same as the theta of the frequencies used to train the model.\n BASE = 10000.0\n\n rotate_half_kernel[grid](\n qk_by_seq,\n position_ids,\n qk_by_seq.stride(0),\n position_ids.stride(0),\n seq_len,\n HEAD_DIM=head_dim,\n BLOCK_HEIGHT=config[\"BLOCK_HEIGHT\"],\n BLOCK_WIDTH=config[\"BLOCK_WIDTH\"],\n INV_BASE=-2.0 * math.log(BASE) / head_dim,\n num_warps=config[\"num_warps\"],\n )\n", - "description_1": "Use triton language to implement a kernel function 'rotate_half_kernel' that performs in-place rotation of half of the head dimension of a query-key sequence tensor based on position ids. The kernel is configured with parameters like head dimension, block height, block width, and inverse base for frequency calculation. The function 'triton_rotate_half_' sets up the grid and configuration for the kernel execution.", - "description_2": "Use triton language to create a kernel for rotating half of the head dimension of a tensor in-place, using position ids and frequency calculations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef fusedmatmul_248_kernel(\n a_ptr,\n c_ptr,\n b1_ptr,\n scales1_ptr,\n zeros1_ptr,\n g1_ptr,\n b2_ptr,\n scales2_ptr,\n zeros2_ptr,\n g2_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n b2_ptrs = b2_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(\n scales1_ptrs + g1_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(\n zeros1_ptrs + g1_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = zeros1 + 1\n\n zeros2 = tl.load(\n zeros2_ptrs + g2_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = zeros2 + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\nclass QuantLlamaMLP(nn.Module):\n def __init__(\n self,\n gate_proj,\n down_proj,\n up_proj,\n ):\n super().__init__()\n self.register_buffer(\"gate_proj_qweight\", gate_proj.qweight)\n self.register_buffer(\"gate_proj_scales\", gate_proj.scales)\n self.register_buffer(\"gate_proj_qzeros\", gate_proj.qzeros)\n self.register_buffer(\"gate_proj_g_idx\", gate_proj.g_idx)\n self.register_buffer(\"up_proj_qweight\", up_proj.qweight)\n self.register_buffer(\"up_proj_scales\", up_proj.scales)\n self.register_buffer(\"up_proj_qzeros\", up_proj.qzeros)\n self.register_buffer(\"up_proj_g_idx\", up_proj.g_idx)\n\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.down_proj = down_proj\n\n def forward(self, x):\n return self.down_proj(self.triton_llama_mlp(x))\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size,)\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n fusedmatmul_248_kernel[grid](\n x,\n c,\n self.gate_proj_qweight,\n self.gate_proj_scales,\n self.gate_proj_qzeros,\n self.gate_proj_g_idx,\n self.up_proj_qweight,\n self.up_proj_scales,\n self.up_proj_qzeros,\n self.up_proj_g_idx,\n M,\n N,\n K,\n self.bits,\n self.maxq,\n x.stride(0),\n x.stride(1),\n self.gate_proj_qweight.stride(0),\n self.gate_proj_qweight.stride(1),\n c.stride(0),\n c.stride(1),\n self.gate_proj_scales.stride(0),\n self.gate_proj_qzeros.stride(0),\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a fused matrix multiplication and activation kernel, performing operations on quantized weights. The kernel, 'fusedmatmul_248_kernel', accepts 26 parameters: pointers to input matrices, scaling factors, zero offsets, dimensions (M, N, K), quantization bits, maximum quantization level, and strides for accessing data in memory. The kernel computes a result matrix C = silu(A * B1) * (A * B2) with efficient memory access and parallel computation techniques. 'silu' is a helper kernel implementing the SiLU activation function.", - "description_2": "Use triton language to create a kernel that efficiently computes fused matrix multiplication with quantized weights and applies the SiLU activation, using specific data pointers and memory strides.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(\n scales_ptrs + g_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs + g_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk\n + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = (\n zeros_ptr\n + (offs_n[None, :] // infearure_per_bits)\n + g_idx[:, None] * stride_zeros\n )\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty(\n (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n input.shape[1],\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty(\n (input.shape[0], output_dim), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(output_dim, META[\"BLOCK_SIZE_K\"]),\n )\n transpose_matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n output_dim,\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: 'matmul_248_kernel' and 'transpose_matmul_248_kernel'. The first kernel computes C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel computes C = A x B where A is a float16 matrix of shape (M, N), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, K). Both kernels use additional parameters for scaling and zero-point adjustments, and they are optimized for specific block sizes and group sizes.", - "description_2": "Use triton language to create optimized matrix multiplication kernels for quantized matrices, handling scaling and zero-point adjustments, with specific block and group sizes for performance tuning.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel to copy key-value index to request\n@triton.jit\ndef _fwd_kernel_copy_kv_index_to_req(\n req_to_token_indexs, b_req_idx, b_seq_len, memindex,\n stride_req_to_token_b, stride_req_to_token_s\n):\n # Get the current program index\n cur_index = tl.program_id(0)\n # Load the current request index, token index, and sequence length\n cur_req_idx = tl.load(b_req_idx + cur_index)\n cur_token_index = tl.load(memindex + cur_index)\n cur_seq_len = tl.load(b_seq_len + cur_index)\n # Calculate the destination offset\n dest_offset = req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (cur_seq_len - 1) * stride_req_to_token_s\n # Store the token index at the calculated offset\n tl.store(dest_offset, cur_token_index)\n return\n\n# Function to invoke the Triton kernel\n@torch.no_grad()\ndef copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex):\n # Get the sequence length\n seq_len = b_seq_len.shape[0]\n # Ensure the shapes of the inputs are consistent\n assert b_seq_len.shape[0] == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]\n # Define the grid size for the Triton kernel\n grid = (seq_len,)\n num_warps = 1\n\n # Launch the Triton kernel\n _fwd_kernel_copy_kv_index_to_req[grid](\n req_to_token_indexs, b_req_idx, b_seq_len, memindex,\n req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel that copies key-value indices to a request tensor. The kernel takes six parameters: req_to_token_indexs (destination tensor), b_req_idx (request indices), b_seq_len (sequence lengths), memindex (memory indices), stride_req_to_token_b (stride for batch dimension), and stride_req_to_token_s (stride for sequence dimension). The kernel calculates the destination offset and stores the token index at this offset. The function copy_kv_index_to_req sets up the grid and launches the kernel.", - "description_2": "Use triton language to create a kernel for copying indices with parameters for destination tensor, request indices, sequence lengths, memory indices, and strides. Implement a function to configure and launch this kernel.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), \n ],\n key=['M', 'N', 'K', 'NO_GROUPS'],\n)\n@triton.jit\ndef matmul4_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales_g, stride_scales_n,\n stride_zeros_g, stride_zeros_n,\n groupsize, NO_GROUPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n bits = 4\n infearure_per_bits = 8\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) \n a_mask = (offs_am[:, None] < M)\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) \n scales_ptrs = scales_ptr + offs_bn * stride_scales_n \n zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) \n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n if NO_GROUPS:\n scales = tl.load(scales_ptrs) \n zeros = tl.load(zeros_ptrs) \n zeros = (zeros >> zeros_shifter) & 0xF \n zeros = zeros * scales\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n a = tl.load(a_ptrs, mask=a_mask, other=0.) \n b = tl.load(b_ptrs) \n if not NO_GROUPS:\n g_id = k // (groupsize // BLOCK_SIZE_K)\n ptr = scales_ptrs + g_id * stride_scales_g\n scales = tl.load(ptr) \n ptr = zeros_ptrs + g_id * stride_zeros_g \n zeros = tl.load(ptr) \n zeros = (zeros >> zeros_shifter) & 0xF \n zeros = (zeros) * scales \n b = (b >> shifter[:, None]) & 0xF \n b = b * scales[None, :] - zeros[None, :] \n accumulator += tl.dot(a, b.to(a.dtype))\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk \n c = accumulator.to(c_ptr.dtype.element_ty) \n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size, output=None) -> torch.FloatTensor:\n assert x.shape[-1] == (qweight.shape[0] * 8), \"A must be a multiple of 8 in the last dimension\"\n assert x.is_contiguous(), \"A must be contiguous\"\n\n M, K = x.shape\n N = qweight.shape[1]\n\n if output is None:\n inplace = False\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n else:\n inplace = True\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n matmul4_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size, group_size == K,\n )\n if not inplace:\n return output\n\n\n@triton.autotune(\n configs=[\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n bs_ptr, bzp_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n group_size,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr\n ):\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \\\n + offs_bn[None, :] * stride_bsn\n bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \\\n + (offs_bn[None, :] // 8) * stride_bzpn\n b_shift_bits = (offs_k[:, None] % 8) * 4\n bzp_shift_bits = (offs_bn[None, :] % 8) * 4\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n bs = tl.load(bs_ptrs)\n bzp = tl.load(bzp_ptrs)\n int_b = (b >> b_shift_bits) & 0xF\n int_bzp = (bzp >> bzp_shift_bits) & 0xF\n b = ((int_b - int_bzp) * bs).to(a.dtype)\n accumulator += tl.dot(a, b.to(a.dtype))\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) \n c = accumulator.to(c_ptr.dtype.element_ty)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\n\ndef matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor:\n assert x.is_contiguous(), \"A must be contiguous\"\n assert qweight.is_contiguous(), \"B must be contiguous\" \n M, K = x.shape\n N = scales.shape[1]\n if output is None:\n output = torch.zeros((M, N), device=x.device, dtype=x.dtype) \n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n )\n return output\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n ],\n key=['K', 'N'],\n)\n@triton.jit\ndef dequantize_kernel(\n b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr,\n K, N, group_size,\n stride_bk, stride_bn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n stride_fpbk, stride_fpbn,\n BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n):\n k_block_idx = tl.program_id(axis=0)\n n_block_idx = tl.program_id(axis=1)\n offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n fpb_offs = offs_k[:, None] * stride_fpbk + offs_n[None, :] * stride_fpbn\n b_offs = (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn\n bzp_offs = (offs_k[:, None] // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn\n bs_offs = (offs_k[:, None] // group_size) * stride_bsk + offs_n[None, :] * stride_bsn\n n_mask = offs_n[None, :] < N\n k_mask = offs_k[:, None] < K\n mask = n_mask & k_mask\n int32_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0)\n zp_b = tl.load(b_zp_ptr + bzp_offs, mask=mask, other=0.0)\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=mask, other=0.0)\n b_shift = (offs_k[:, None] % 8) * 4\n bzp_shift = (offs_n[None, :] % 8) * 4\n fp_weight = (((int32_b >> b_shift) & 0xF) - ((zp_b >> bzp_shift) & 0xF)) * scale_b\n tl.store(fpb_ptr + fpb_offs, fp_weight, mask=mask)\n\n\ndef dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size):\n Kw, N = b.shape\n K = Kw * 8\n fp_b = torch.ones((K, N), device=device, dtype=dtype)\n grid = lambda META: (\n triton.cdiv(K, META['BLOCK_SIZE_K']),\n triton.cdiv(N, META['BLOCK_SIZE_N']), \n )\n dequantize_kernel[grid](\n b, b_scale, b_zero_point, fp_b,\n K, N, group_size,\n b.stride(0), b.stride(1),\n b_scale.stride(0), b_scale.stride(1),\n b_zero_point.stride(0), b_zero_point.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n return fp_b\n\n\ndef matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, group_size=128, out=None):\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n Kw, N = b.shape\n if out is None:\n out = torch.empty((M, N), device=a.device, dtype=a.dtype)\n fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, group_size)\n torch.mm(a, fp_b, out=out)\n fp_b = None\n return out\n", - "description_1": "Use triton language to implement three different matrix multiplication and dequantization operations for int4 quantized matrices: 1) matmul4_kernel handles int4 quantized B and scales/zeros arrays for dequantization, assumes M, N, K are multiples of respective block sizes, requires scale and zero pointers, computes C = A x B using quantized values; 2) matmul_kernel splits K dimension across multiple blocks, uses scales and zero points for dequantization, handles int4 quantized B; 3) dequantize_kernel dequantizes int4 weight matrices by converting them into full precision for further computation, each kernel assumes specific block sizes and parameters as specified in function signatures.", - "description_2": "Use triton language to implement matrix multiplication and dequantization of int4 quantized matrices by handling different block configurations and optimizing kernel execution with specific constants and constraints.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef dequantize_kernel(\n # Pointers to matrices\n b_ptr, b_scale_ptr, fpb_ptr,\n # Matrix dimensions\n K, N,\n stride_bk, stride_bn,\n stride_fpbk, stride_fpbn,\n # Meta-parameters\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n k_block_idx = tl.program_id(axis=0)\n n_block_idx = tl.program_id(axis=1)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + \\\n (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_bn\n fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + \\\n (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_fpbn\n bs_offs = n_block_idx * BLOCK_SIZE_N + offs_n[None, :]\n n_mask = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] < N\n mask = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None] < K) & n_mask\n int_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0)\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=n_mask, other=0.0)\n tl.store(fpb_ptr + fpb_offs, int_b * scale_b, mask=mask)\n\ndef matmul_dequantize_int8(a, b, b_scale, out=None):\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n if out == None:\n # Allocates output.\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n else:\n c = out\n fp_b = torch.empty((K, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n dequantize_kernel[grid](\n b, b_scale, fp_b,\n K, N,\n b.stride(0), b.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n torch.mm(a, fp_b, out=c)\n return c\n", - "description_1": "Use triton language to implement a kernel function 'dequantize_kernel' that dequantizes an int8 matrix 'b' using a scale matrix 'b_scale' and stores the result in 'fpb'. The kernel takes 10 parameters: pointers to matrices (b_ptr, b_scale_ptr, fpb_ptr), matrix dimensions (K, N), strides for matrices (stride_bk, stride_bn, stride_fpbk, stride_fpbn), and block sizes (BLOCK_SIZE_N, BLOCK_SIZE_K). The function 'matmul_dequantize_int8' calls this kernel to perform matrix multiplication with dequantization, taking 4 parameters: matrices 'a', 'b', 'b_scale', and an optional output matrix 'out'.", - "description_2": "Use triton language to create a dequantization kernel for int8 matrices and a function to perform matrix multiplication with dequantization.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for copying values with destination index\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n K, Dest_loc,\n Out,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n\n k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)\n tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)\n return\n\n# Python wrapper for the Triton kernel\n@torch.no_grad()\ndef destindex_copy_kv(K, DestLoc, Out):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n K, DestLoc, Out,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n# Triton kernel for copying and quantizing values with destination index\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K, Dest_loc, Out, Out_scale,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n stride_os_bs, stride_os_h, stride_os_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], \n mask=offs_h[:, None] < head_num, other=0.0)\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.).to(Out_scale.dtype.element_ty)[:, None]\n q_src_data = (src_data / data_scale).to(tl.int8)\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]\n tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num)\n tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num)\n\n# Python wrapper for the Triton kernel\n@torch.no_grad()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K, DestLoc, Out, Out_scale,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement kernels for copying values based on destination indices and for quantizing values. The first kernel `_fwd_kernel_destindex_copy_kv` takes 10 parameters (three tensors, three strides for source, three strides for destination, the head number, and two constexprs for block sizes). It copies values from source `K` to destination `Out` using indices from `Dest_loc`. The second kernel `_fwd_kernel_destindex_copy_quantize_kv` extends this functionality by quantizing the values before copying. It takes 13 parameters, including an additional output scale tensor and its strides, performs quantization, and stores both the quantized data and its scale.", - "description_2": "Use triton language to develop a copy kernel with index-based addressing for tensor manipulation. Implement a quantization kernel that scales and converts data to int8 format before storing.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Prompt_ids, \n Text_weight_embs,\n Img_embs,\n Out,\n Img_token_lens,\n Img_start_token_ids,\n Img_start_locs,\n stride_text_emb_s, stride_text_emb_d, # text_stride\n stride_img_emb_s, stride_img_emb_d, # img_stride\n stride_out_s, stride_out_d,\n tp_text_start_token_id,\n tp_text_end_token_id,\n hidden_size,\n BLOCK_HIDDEN_DIM: tl.constexpr\n ):\n seq_index = tl.program_id(0).to(tl.int64)\n img_handle_id = tl.program_id(1)\n\n token_id = tl.load(Prompt_ids + seq_index)\n off_d = tl.arange(0, BLOCK_HIDDEN_DIM)\n \n # load store text emb\n for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1):\n load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0)\n tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)\n \n img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)\n img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)\n img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)\n # load store img emb\n for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1):\n load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0)\n tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)\n return\n\n@torch.no_grad()\ndef multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor, \n img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor, \n tp_text_start_token_id,\n tp_text_end_token_id):\n total_len = prompt_ids.shape[0]\n BLOCK = triton.next_power_of_2(out.shape[1])\n grid = (total_len, len(img_token_lens) + 1)\n num_warps = 1\n _fwd_kernel[grid](\n prompt_ids,\n text_weight_embs,\n img_embs,\n out,\n img_token_lens,\n img_start_token_ids,\n img_start_locs,\n text_weight_embs.stride(0), text_weight_embs.stride(1),\n img_embs.stride(0), img_embs.stride(1),\n out.stride(0), out.stride(1),\n tp_text_start_token_id,\n tp_text_end_token_id,\n hidden_size=out.shape[1],\n BLOCK_HIDDEN_DIM=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel' that processes prompt IDs, text and image embeddings, and stores the results in an output tensor. The kernel takes 16 parameters including input tensors, strides, token IDs, hidden size, and a block dimension. The function 'multimodal_emb' is a wrapper that sets up the grid and block dimensions and calls the kernel with appropriate arguments.", - "description_2": "Use triton language to create a kernel for embedding processing with 16 parameters, and a wrapper function to configure and invoke the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=4),\n triton.Config({}, num_stages=2, num_warps=2),\n triton.Config({}, num_stages=2, num_warps=1),\n ],\n key=['K'],\n)\n@triton.jit\ndef quantize_int8_perrow_kernel(\n fpa_ptr, a_ptr, as_ptr,\n M, K, \n stride_fpam, stride_fpak,\n stride_am, stride_ak,\n stride_asm,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n\n fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n a_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n a_max = tl.maximum(a_max, tl.max(tl.abs(fpa), axis=1))\n fpa_ptrs += BLOCK_SIZE_K * stride_fpak\n a_scale = (a_max / 127.)\n fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n inta = (fpa / a_scale[:, None]).to(tl.int8)\n tl.store(a_ptrs, inta, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K)\n fpa_ptrs += BLOCK_SIZE_K * stride_fpak\n a_ptrs += BLOCK_SIZE_K * stride_ak\n as_offs = pid_m * BLOCK_SIZE_M * stride_asm + tl.arange(0, BLOCK_SIZE_M)\n tl.store(as_ptr + as_offs, a_scale)\n\ndef quantize_int8_perrow(fpa):\n a = torch.empty(fpa.shape, device=fpa.device, dtype=torch.int8)\n a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=fpa.dtype)\n M, K = fpa.shape\n BLOCK_SIZE_M = 1\n BLOCK_SIZE_K = triton.next_power_of_2(K)\n grid = (M // BLOCK_SIZE_M,)\n quantize_int8_perrow_kernel[grid](\n fpa, a, a_scale,\n M, K,\n fpa.stride(0), fpa.stride(1),\n a.stride(0), a.stride(1),\n a_scale.stride(0),\n BLOCK_SIZE_M, BLOCK_SIZE_K,\n )\n return a, a_scale\n\n@triton.autotune(\n configs=[\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n # Additional configs...\n ],\n key=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, as_ptr, b_ptr, bs_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_asm,\n stride_bk, stride_bn,\n stride_bsn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, \n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n as_ptrs = as_ptr + offs_am * stride_asm\n bs_ptrs = bs_ptr + offs_bn * stride_bsn\n a_scale = tl.load(as_ptrs, mask=offs_am < M, other=0.0)\n b_scale = tl.load(bs_ptrs, mask=offs_bn < N, other=0.0)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk\n \n c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(c_ptr.dtype.element_ty)\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\ndef matmul_quantize_int8(fpa, b, b_scale, out=None):\n a, a_scale = quantize_int8_perrow(fpa)\n return matmul_int8(a, a_scale, b, b_scale, out)\n\ndef matmul_int8(a, a_scale, b, b_scale, out=None):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n if out == None:\n c = torch.zeros((M, N), device=a.device, dtype=torch.float16)\n else:\n c = out.fill_(0.)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n a, a_scale, b, b_scale, c,\n M, N, K,\n a.stride(0), a.stride(1),\n a_scale.stride(0),\n b.stride(0), b.stride(1),\n b_scale.stride(0),\n c.stride(0), c.stride(1),\n )\n return c\n", - "description_1": "Use triton language to create two kernels: 'quantize_int8_perrow_kernel' and 'matmul_kernel'. The first kernel quantizes a floating-point matrix 'fpa' per row into an int8 matrix 'a' and computes scale factors 'a_scale'. It accepts 10 primary arguments: fpa_ptr, a_ptr, as_ptr, M, K, stride_fpam, stride_fpak, stride_am, stride_ak, stride_asm, and 2 meta-parameters: BLOCK_SIZE_M, BLOCK_SIZE_K. The second kernel performs matrix multiplication on quantized matrices, taking 16 primary arguments: a_ptr, as_ptr, b_ptr, bs_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_asm, stride_bk, stride_bn, stride_bsn, stride_cm, stride_cn, and 5 meta-parameters: BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, SPLIT_K.", - "description_2": "Use triton language to quantize a matrix into int8 per row and perform matrix multiplication on quantized matrices.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel to copy key-value indices to request\n@triton.jit\ndef _fwd_kernel_copy_kv_index_to_req(\n req_to_token_indexs, # Pointer to the output tensor\n b_req_idx, # Pointer to the batch request indices\n b_split_seq_len, # Pointer to the split sequence lengths\n cumsum_split_seq_len, # Pointer to the cumulative sum of split sequence lengths\n b_seq_len, # Pointer to the batch sequence lengths\n memindex, # Pointer to the memory index\n stride_req_to_token_b,# Stride for the batch dimension in the output tensor\n stride_req_to_token_s,# Stride for the sequence dimension in the output tensor\n BLOCK_M: tl.constexpr # Block size for the M dimension\n):\n cur_index = tl.program_id(0)\n cur_req_idx = tl.load(b_req_idx + cur_index)\n q_split_len = tl.load(b_split_seq_len + cur_index)\n q_mem_end = tl.load(cumsum_split_seq_len + cur_index)\n q_mem_start = q_mem_end - q_split_len\n\n store_end = tl.load(b_seq_len + cur_index)\n store_start = store_end - q_split_len\n\n off_m = tl.arange(0, BLOCK_M)\n for block_start in range(0, q_split_len, BLOCK_M):\n read_index = tl.load(\n memindex + q_mem_start + block_start + off_m, mask=q_mem_start + block_start + off_m < q_mem_end, other=0\n )\n tl.store(\n req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (block_start + store_start + off_m),\n read_index,\n mask=block_start + store_start + off_m < store_end,\n )\n return\n\n# Function to invoke the Triton kernel\n@torch.no_grad()\ndef splitfuse_copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_ready_cache_len, b_seq_len, memindex):\n batch_size = b_seq_len.shape[0]\n grid = (batch_size,)\n num_warps = 1\n b_split_seq_len = b_seq_len - b_ready_cache_len\n cumsum_split_seq_len = torch.cumsum(b_split_seq_len, dim=0)\n _fwd_kernel_copy_kv_index_to_req[grid](\n req_to_token_indexs,\n b_req_idx,\n b_split_seq_len,\n cumsum_split_seq_len,\n b_seq_len,\n memindex,\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n BLOCK_M=32,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel that copies key-value indices to a request tensor. The kernel takes 8 parameters: pointers to the output tensor, batch request indices, split sequence lengths, cumulative sum of split sequence lengths, batch sequence lengths, memory index, and strides for the output tensor. It uses a block size for the M dimension to perform the copy operation efficiently. The kernel is invoked by a function that calculates the grid size and prepares the input parameters.", - "description_2": "Use triton language to create a kernel for copying indices with efficient memory access patterns, and a function to set up and launch this kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Alibi,\n B_Start_Loc,\n B_Seqlen,\n Out,\n Req_to_tokens,\n B_req_idx,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n b_ready_cache_len,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n ready_cache_len = tl.load(b_ready_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - ready_cache_len\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_m = tl.load(Alibi + cur_head)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + ready_cache_len, cur_batch_seq_len + ready_cache_len)\n\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),\n mask=(start_n + offs_n) < block_end_loc,\n other=0,\n )\n off_k = kv_loc[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd\n k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n alibi_loc = ready_cache_len + offs_m[:, None] - (start_n + offs_n[None, :])\n qk -= alibi_loc * alibi_m\n\n qk = tl.where((offs_m[:, None] + ready_cache_len) >= (start_n + offs_n[None, :]), qk, -10000000.0)\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n off_v = kv_loc[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd\n v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n\n@torch.no_grad()\ndef context_attention_fwd(\n q, k, v, o, b_req_idx, alibi, b_start_loc, b_seq_len, b_ready_cache_len, max_input_len, req_to_token_indexs\n):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq ** 0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n alibi,\n b_start_loc,\n b_seq_len,\n o,\n req_to_token_indexs,\n b_req_idx,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n b_ready_cache_len,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for context attention. The kernel function '_fwd_kernel' takes 27 parameters: Q, K, V (query, key, value tensors), sm_scale (scale for softmax), Alibi (alibi tensor), B_Start_Loc, B_Seqlen, Out (output tensor), Req_to_tokens, B_req_idx, and various stride parameters for Q, K, V, Out, and Req_to_tokens. It also takes b_ready_cache_len and three block size constants (BLOCK_M, BLOCK_DMODEL, BLOCK_N). The function computes the attention scores and updates the output tensor 'Out'. The 'context_attention_fwd' function is a wrapper that sets up the grid and block sizes, and calls the '_fwd_kernel' with the appropriate parameters.", - "description_2": "Use triton language to create a context attention forward kernel that computes attention scores and updates an output tensor using query, key, and value tensors, along with additional parameters for scaling and indexing.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n # Write output\n tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)\n\ndef layernorm_forward(x, weight, bias, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n return y\n", - "description_1": "Use triton language to implement a fused layer normalization kernel. The kernel '_layer_norm_fwd_fused' takes 8 parameters: X (input tensor), Y (output tensor), W (weights), B (biases), stride (stride for row access), N (number of columns), eps (epsilon for numerical stability), and BLOCK_SIZE (block size for computation). The function 'layernorm_forward' prepares the input and output tensors, calculates the block size and number of warps, and enqueues the kernel for execution.", - "description_2": "Use triton language to create a fused layer normalization operation with input, output, weights, biases, and block size parameters. Implement a function to prepare data and execute the kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_token_att1(\n Q, K, sm_scale, Alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,\n Att_Out,\n stride_req_to_tokens_b, stride_req_to_tokens_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n att_stride_h, att_stride_bs,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_id = tl.load(B_req_idx + cur_batch)\n\n cur_batch_start_index = 0\n cur_batch_end_index = cur_batch_seq_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n alibi_m = tl.load(Alibi + cur_head)\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_id + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)\n off_k = k_loc[:, None] * stride_kbs + cur_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n att_value -= alibi_m * (cur_batch_seq_len - 1 - offs_n)\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n@torch.no_grad()\ndef token_att_fwd(q, k, att_out, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch):\n BLOCK = 32\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n num_warps = 2\n\n _fwd_kernel_token_att1[grid](\n q, k, sm_scale, alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, \n att_out,\n Req_to_tokens.stride(0), Req_to_tokens.stride(1),\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n att_out.stride(0), att_out.stride(1),\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for token attention. The kernel function '_fwd_kernel_token_att1' takes 18 parameters: Q, K, sm_scale, Alibi, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, stride_req_to_tokens_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, att_stride_h, att_stride_bs, BLOCK_DMODEL, and BLOCK_N. It computes the attention values for a batch of queries and keys, applying scaling and alibi adjustments, and stores the results in Att_Out. The function 'token_att_fwd' is a wrapper that sets up the grid and block dimensions, calculates the scaling factor, and calls the kernel function with the appropriate parameters.", - "description_2": "Use triton language to create a token attention forward kernel that computes scaled dot-product attention with alibi adjustments for a batch of queries and keys, and stores the results in the output tensor.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob,\n V,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_ph,\n stride_pbs,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_start_index = 0\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n * stride_pbs, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(\n Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0.0,\n )\n v_value = tl.load(\n V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(Out.dtype.element_ty)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n\n@torch.no_grad()\ndef token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen):\n BLOCK = 128\n batch, head = B_req_idx.shape[0], v.shape[1]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n _fwd_kernel_token_att2[grid](\n prob,\n v,\n out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n prob.stride(0),\n prob.stride(1),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n out.stride(0),\n out.stride(1),\n out.stride(2),\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel function `_fwd_kernel_token_att2` with 19 parameters. The parameters include input matrices, their strides, and constant block dimensions to perform a specific tensor operation. The kernel is launched in the `token_att_fwd2` function that takes 7 parameters including the input matrices and indices to set up grid dimensions and execute the kernel with specific block and warp settings.", - "description_2": "Use triton language to define a kernel function for tensor operations with specific stride and block parameters and execute it via a higher-level Python function using PyTorch to set execution configuration.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, Alibi, B_Loc, B_Seqlen, max_input_len,\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_b_loc_b, stride_b_loc_s,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n off_k = cur_head * stride_kh + offs_d[None, :] * stride_kd\n off_v = cur_head * stride_vh + offs_d[None, :] * stride_vd\n off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s\n\n q = tl.load(Q + off_q)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n m_i = -float(\"inf\")\n l_i = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_m = tl.load(Alibi + cur_head)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0)\n k = tl.load(k_ptrs + k_index[:, None] * stride_kbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n qk = tl.zeros([BLOCK_N,], dtype=tl.float32)\n qk += tl.sum(q[None, :] * k, 1)\n qk *= sm_scale\n\n alibi_loc = cur_batch_seq_len - 1 - (start_n + offs_n)\n qk -= alibi_loc * alibi_m\n\n qk = tl.where(cur_batch_seq_len > (start_n + offs_n), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 0)\n p = tl.exp(qk - m_ij)\n l_ij = tl.sum(p, 0)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale\n # update acc\n v_index = k_index\n v = tl.load(v_ptrs + v_index[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n # print(p)\n acc += tl.sum(p[:, None] * v, 0)\n\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n", - "description_1": "Use triton language to implement a forward kernel for attention mechanism. The kernel takes 20 parameters: Q, K, V (query, key, value tensors), sm_scale (scale for softmax), Alibi (alibi tensor), B_Loc, B_Seqlen (location and sequence length tensors), max_input_len (maximum input length), Out (output tensor), and various stride parameters for memory access. BLOCK_DMODEL and BLOCK_N are compile-time constants defining block sizes. The kernel computes scaled dot-product attention with alibi adjustment and stores the result in the output tensor.", - "description_2": "Use triton language to create a kernel for computing scaled dot-product attention with alibi adjustment, using 20 input parameters including tensors and stride information, and store the result in an output tensor.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _rotary_kernel(\n Q,\n K,\n Cos,\n Sin,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_cosbs,\n stride_cosd,\n stride_sinbs,\n stride_sind,\n max_total_len,\n HEAD_Q,\n HEAD_K,\n BLOCK_HEAD: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2\n dim_range1 = dim_range0 + 1\n\n off_q0 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range0[None, None, :] * stride_qd\n )\n off_q1 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range1[None, None, :] * stride_qd\n )\n\n cos_range = tl.arange(0, BLOCK_DMODEL // 2)\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd\n\n q0 = tl.load(\n Q + off_q0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n q1 = tl.load(\n Q + off_q1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos - q1 * sin\n out1 = q0 * sin + q1 * cos\n\n tl.store(\n Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n tl.store(\n Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n\n off_k0 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range0[None, None, :] * stride_kd\n )\n off_k1 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range1[None, None, :] * stride_kd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd\n\n k0 = tl.load(\n K + off_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n k1 = tl.load(\n K + off_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out_k0 = k0 * cos - k1 * sin\n out_k1 = k0 * sin + k1 * cos\n\n tl.store(\n K + off_k0,\n out_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n tl.store(\n K + off_k1,\n out_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n return\n\n@torch.no_grad()\ndef rotary_emb_fwd(q, k, cos, sin):\n total_len = q.shape[0]\n head_num_q, head_num_k = q.shape[1], k.shape[1]\n head_dim = q.shape[2] // 2\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f\"k shape {k.shape} cos shape {cos.shape}\"\n\n BLOCK_SEQ = 16\n BLOCK_HEAD = 4\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n\n grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n _rotary_kernel[grid](\n q,\n k,\n cos,\n sin,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n cos.stride(0),\n cos.stride(1),\n sin.stride(0),\n sin.stride(1),\n total_len,\n head_num_q,\n head_num_k,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=head_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a rotary kernel function that performs element-wise operations on input tensors Q and K using cosine and sine values. The kernel function takes 19 parameters: Q, K, Cos, Sin, and various strides and constants for indexing and computation. The rotary_emb_fwd function calls this kernel with 4 input tensors (q, k, cos, sin) and calculates grid dimensions based on input shapes.", - "description_2": "Use triton language to create a kernel for element-wise tensor operations with cosine and sine, and a wrapper function to set up and call this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef ceil_div(a, b):\n return (a + b - 1) // b\n\n\n@triton.jit\ndef moe_align_block_size_stage1(\n topk_ids_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n total_tokens_post_pad_ptr,\n tokens_cnts_ptr,\n cumsum_ptr,\n num_experts: tl.constexpr,\n block_size: tl.constexpr,\n numel: tl.constexpr,\n tokens_per_thread: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n\n start_idx = pid * tokens_per_thread\n\n off_c = (pid + 1) * num_experts\n\n for i in range(tokens_per_thread):\n if start_idx + i < numel:\n idx = tl.load(topk_ids_ptr + start_idx + i)\n token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)\n tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)\n\n\n@triton.jit\ndef moe_align_block_size_stage2(\n topk_ids_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n total_tokens_post_pad_ptr,\n tokens_cnts_ptr,\n cumsum_ptr,\n num_experts: tl.constexpr,\n block_size: tl.constexpr,\n numel: tl.constexpr,\n tokens_per_thread: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n\n last_cnt = 0\n for i in range(1, num_experts + 1):\n token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)\n last_cnt = last_cnt + token_cnt\n tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)\n\n\n@triton.jit\ndef moe_align_block_size_stage3(\n topk_ids_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n total_tokens_post_pad_ptr,\n tokens_cnts_ptr,\n cumsum_ptr,\n num_experts: tl.constexpr,\n block_size: tl.constexpr,\n numel: tl.constexpr,\n tokens_per_thread: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n last_cumsum = 0\n off_cnt = num_experts * num_experts\n for i in range(1, num_experts + 1):\n token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)\n last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size\n tl.store(cumsum_ptr + i, last_cumsum)\n tl.store(total_tokens_post_pad_ptr, last_cumsum)\n\n\n@triton.jit\ndef moe_align_block_size_stage4(\n topk_ids_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n total_tokens_post_pad_ptr,\n tokens_cnts_ptr,\n cumsum_ptr,\n num_experts: tl.constexpr,\n block_size: tl.constexpr,\n numel: tl.constexpr,\n tokens_per_thread: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n start_idx = tl.load(cumsum_ptr + pid)\n end_idx = tl.load(cumsum_ptr + pid + 1)\n\n for i in range(start_idx, end_idx, block_size):\n tl.store(expert_ids_ptr + i // block_size, pid)\n\n start_idx = pid * tokens_per_thread\n off_t = pid * num_experts\n\n for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):\n expert_id = tl.load(topk_ids_ptr + i)\n token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)\n rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)\n tl.store(sorted_token_ids_ptr + rank_post_pad, i)\n tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)\n\n\n@torch.no_grad()\ndef moe_align_block_size(\n topk_ids: torch.Tensor,\n num_experts: int,\n block_size: int,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_pad: torch.Tensor,\n) -> None:\n numel = topk_ids.numel()\n grid = (num_experts,)\n tokens_cnts = torch.zeros((num_experts + 1, num_experts), dtype=torch.int32, device=\"cuda\")\n cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=\"cuda\")\n tokens_per_thread = ceil_div(numel, num_experts)\n\n moe_align_block_size_stage1[grid](\n topk_ids,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_pad,\n tokens_cnts,\n cumsum,\n num_experts,\n block_size,\n numel,\n tokens_per_thread,\n BLOCK_SIZE=num_experts,\n )\n moe_align_block_size_stage2[grid](\n topk_ids,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_pad,\n tokens_cnts,\n cumsum,\n num_experts,\n block_size,\n numel,\n tokens_per_thread,\n BLOCK_SIZE=num_experts,\n )\n moe_align_block_size_stage3[(1,)](\n topk_ids,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_pad,\n tokens_cnts,\n cumsum,\n num_experts,\n block_size,\n numel,\n tokens_per_thread,\n BLOCK_SIZE=num_experts,\n )\n moe_align_block_size_stage4[grid](\n topk_ids,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_pad,\n tokens_cnts,\n cumsum,\n num_experts,\n block_size,\n numel,\n tokens_per_thread,\n BLOCK_SIZE=num_experts,\n )\n", - "description_1": "Use triton language to define four kernels: moe_align_block_size_stage1, moe_align_block_size_stage2, moe_align_block_size_stage3, moe_align_block_size_stage4. Each kernel takes multiple pointers (such as topk_ids_ptr, sorted_token_ids_ptr, etc.) and several triton.constexpr parameters (like num_experts, block_size, etc.). These kernels handle different stages of aligning tokens to blocks with the goal of padding tokens efficiently. The moe_align_block_size function is a wrapper that orchestrates these kernels' execution on a CUDA device, by calculating necessary parameters and passing tensors and constants to each stage.", - "description_2": "Use triton language to create a kernel-based approach that takes input tensors and constant parameters, processes the data in multiple stages using GPU parallelization, and outputs the aligned tokens and necessary metadata in the provided tensors.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd\n\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n a_scale_ptr,\n b_scale_ptr,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n N,\n K,\n EM,\n num_valid_tokens,\n stride_am,\n stride_ak,\n stride_be,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n compute_type: tl.constexpr,\n use_fp8: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n if use_fp8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n if use_fp8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n if use_fp8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(\n A: torch.Tensor,\n B: torch.Tensor,\n C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor,\n topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool,\n top_k: int,\n config: Dict[str, Any],\n compute_type: tl.dtype,\n use_fp8: bool,\n) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n if not use_fp8:\n assert A_scale is None\n assert B_scale is None\n else:\n assert B_scale is not None\n\n grid = lambda META: (\n triton.cdiv(sorted_token_ids.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(B.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8=use_fp8,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused computation kernel for a Mixture of Experts (MOE) with token and expert matrices. The kernel function `fused_moe_kernel` requires 21 parameters: a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn. It utilizes the specified BLOCK_SIZE and GROUP_SIZE constants for efficient matrix multiplication. The `invoke_fused_moe_kernel` function encapsulates calling the kernel with 15 parameters: A, B, C, A_scale, B_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, mul_routed_weight, top_k, config, compute_type, use_fp8.", - "description_2": "Use triton language to create a MOE kernel handling token-expert interaction with customizable block sizes, utilizing both float32 and optionally float8 formats; employ the kernel to execute the MOE operation on given matrices using a predefined configuration and expert-topK information.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q_nope,\n Q_rope,\n KV_nope,\n KV_rope,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n Out,\n Req_to_tokens,\n B_req_idx,\n stride_q_bs,\n stride_q_h,\n stride_q_d,\n stride_q_rope_bs,\n stride_q_rope_h,\n stride_q_rope_d,\n stride_kv_bs,\n stride_kv_h,\n stride_kv_d,\n stride_kv_rope_bs,\n stride_kv_rope_h,\n stride_kv_rope_d,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n b_prompt_cache_len,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_ROPE_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = 0\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs\n + cur_head * stride_q_h\n + offs_d[None, :] * stride_q_d\n )\n off_q_rope = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs\n + cur_head * stride_q_rope_h\n + offs_rope_d[None, :] * stride_q_rope_d\n )\n\n q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)\n\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),\n mask=(start_n + offs_n) < block_end_loc,\n other=0,\n )\n off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d\n off_kv_rope = (\n kv_loc[None, :] * stride_kv_rope_bs\n + cur_kv_head * stride_kv_rope_h\n + offs_rope_d[:, None] * stride_kv_rope_d\n )\n kv = tl.load(KV_nope + off_kv, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)\n kv_rope = tl.load(KV_rope + off_kv_rope, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, kv)\n qk += tl.dot(q_rope, kv_rope)\n\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float(\"-100000000.0\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0)\n acc = acc * acc_scale[:, None]\n v = tl.trans(kv)\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n@torch.no_grad()\ndef context_attention_fwd(\n q_nope,\n q_rope,\n kv_nope,\n kv_rope,\n o,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n b_prompt_cache_len,\n max_input_len,\n req_to_token_indexs,\n softmax_scale,\n):\n BLOCK = 128 if not TESLA else 64\n q_nope_dim = q_nope.shape[-1]\n q_rope_dim = q_rope.shape[-1]\n assert q_nope_dim == kv_nope.shape[-1]\n assert q_rope_dim == kv_rope.shape[-1]\n assert q_nope_dim in {16, 32, 64, 128, 256, 512}\n assert q_rope_dim in {16, 32, 64, 128, 256}\n\n if q_nope_dim >= 512:\n BLOCK = 64 if not TESLA else 32\n else:\n BLOCK = 128 if not TESLA else 64\n\n sm_scale = softmax_scale\n batch, head = b_seq_len.shape[0], q_nope.shape[1]\n kv_group_num = q_nope.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n num_warps = 4 if q_nope_dim <= 64 else 8\n\n _fwd_kernel[grid](\n q_nope,\n q_rope,\n kv_nope,\n kv_rope,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n req_to_token_indexs,\n b_req_idx,\n q_nope.stride(0),\n q_nope.stride(1),\n q_nope.stride(2),\n q_rope.stride(0),\n q_rope.stride(1),\n q_rope.stride(2),\n kv_nope.stride(0),\n kv_nope.stride(1),\n kv_nope.stride(2),\n kv_rope.stride(0),\n kv_rope.stride(1),\n kv_rope.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n kv_group_num=kv_group_num,\n b_prompt_cache_len=b_prompt_cache_len,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=q_nope_dim,\n BLOCK_ROPE_DMODEL=q_rope_dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n@triton.jit\ndef _fwd_kernel_no_prompt_cache(\n Q_nope,\n Q_rope,\n KV_nope,\n KV_rope,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n Out,\n stride_q_bs,\n stride_q_h,\n stride_q_d,\n stride_q_rope_bs,\n stride_q_rope_h,\n stride_q_rope_d,\n stride_kv_bs,\n stride_kv_h,\n stride_kv_d,\n stride_kv_rope_bs,\n stride_kv_rope_h,\n stride_kv_rope_d,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_ROPE_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n cur_kv_head = 0\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs\n + cur_head * stride_q_h\n + offs_d[None, :] * stride_q_d\n )\n off_rope_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs\n + cur_head * stride_q_rope_h\n + offs_rope_d[None, :] * stride_q_rope_d\n )\n off_kv = offs_n[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d\n off_rope_kv = (\n offs_n[None, :] * stride_kv_rope_bs + cur_kv_head * stride_kv_rope_h + offs_rope_d[:, None] * stride_kv_rope_d\n )\n\n q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n q_rope = tl.load(Q_rope + off_rope_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n kv_ptrs = KV_nope + off_kv\n kv_rope_ptrs = KV_rope + off_rope_kv\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n kv = tl.load(\n kv_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kv_bs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,\n other=0.0,\n )\n kv_rope = tl.load(\n kv_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kv_rope_bs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,\n other=0.0,\n )\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, kv)\n qk += tl.dot(q_rope, kv_rope)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n v = tl.trans(kv)\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n@torch.no_grad()\ndef context_attention_fwd_no_prompt_cache(\n q_nope, q_rope, kv_nope, kv_rope, o, b_start_loc, b_seq_len, max_input_len, softmax_scale\n):\n q_nope_dim = q_nope.shape[-1]\n q_rope_dim = q_rope.shape[-1]\n assert q_nope_dim == kv_nope.shape[-1]\n assert q_rope_dim == kv_rope.shape[-1]\n assert q_nope_dim in {16, 32, 64, 128, 256, 512}\n assert q_rope_dim in {16, 32, 64, 128, 256}\n\n if q_nope_dim >= 512:\n BLOCK = 64 if not TESLA else 32\n else:\n BLOCK = 128 if not TESLA else 64\n\n sm_scale = softmax_scale\n batch, head = b_seq_len.shape[0], q_nope.shape[1]\n kv_group_num = q_nope.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 4 if q_nope_dim <= 64 else 8\n _fwd_kernel_no_prompt_cache[grid](\n q_nope,\n q_rope,\n kv_nope,\n kv_rope,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n q_nope.stride(0),\n q_nope.stride(1),\n q_nope.stride(2),\n q_rope.stride(0),\n q_rope.stride(1),\n q_rope.stride(2),\n kv_nope.stride(0),\n kv_nope.stride(1),\n kv_nope.stride(2),\n kv_rope.stride(0),\n kv_rope.stride(1),\n kv_rope.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=q_nope_dim,\n BLOCK_ROPE_DMODEL=q_rope_dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two forward kernels for context attention, one with prompt cache and one without. The kernels perform matrix multiplications and softmax operations on input tensors Q_nope, Q_rope, KV_nope, and KV_rope, with scaling by sm_scale. The results are stored in the output tensor Out. The kernels are called by context_attention_fwd and context_attention_fwd_no_prompt_cache functions, which set up the grid and block dimensions based on input tensor shapes and device properties.", - "description_2": "Use triton language to implement forward kernels for context attention with and without prompt cache, performing matrix multiplications and softmax operations on input tensors, and storing results in output tensor.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n KV_nope,\n KV_rope,\n Dest_loc,\n O_nope,\n O_rope,\n stride_kv_nope_bs,\n stride_kv_nope_h,\n stride_kv_nope_d,\n stride_kv_rope_bs,\n stride_kv_rope_h,\n stride_kv_rope_d,\n stride_o_nope_bs,\n stride_o_nope_h,\n stride_o_nope_d,\n stride_o_rope_bs,\n stride_o_rope_h,\n stride_o_rope_d,\n kv_nope_head_num,\n kv_rope_head_num,\n BLOCK_DMODEL_NOPE: tl.constexpr,\n BLOCK_DMODEL_ROPE: tl.constexpr,\n BLOCK_HEAD: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)\n offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n kv_nope_ptrs = (\n KV_nope\n + cur_index * stride_kv_nope_bs\n + stride_kv_nope_h * offs_h[:, None]\n + stride_kv_nope_d * offs_d_nope[None, :]\n )\n kv_rope_ptrs = (\n KV_rope\n + cur_index * stride_kv_rope_bs\n + stride_kv_rope_h * offs_h[:, None]\n + stride_kv_rope_d * offs_d_rope[None, :]\n )\n\n o_nope_ptrs = (\n O_nope\n + dest_index * stride_o_nope_bs\n + stride_o_nope_h * offs_h[:, None]\n + stride_o_nope_d * offs_d_nope[None, :]\n )\n o_rope_ptrs = (\n O_rope\n + dest_index * stride_o_rope_bs\n + stride_o_rope_h * offs_h[:, None]\n + stride_o_rope_d * offs_d_rope[None, :]\n )\n\n kv_nope = tl.load(kv_nope_ptrs, mask=offs_h[:, None] < kv_nope_head_num, other=0.0)\n kv_rope = tl.load(kv_rope_ptrs, mask=offs_h[:, None] < kv_rope_head_num, other=0.0)\n\n tl.store(o_nope_ptrs, kv_nope, mask=offs_h[:, None] < kv_nope_head_num)\n tl.store(o_rope_ptrs, kv_rope, mask=offs_h[:, None] < kv_rope_head_num)\n return\n\n@torch.no_grad()\ndef destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope):\n seq_len = DestLoc.shape[0]\n kv_nope_head_num = KV_nope.shape[1]\n kv_rope_head_num = KV_rope.shape[1]\n\n kv_nope_head_dim = KV_nope.shape[2]\n kv_rope_head_dim = KV_rope.shape[2]\n\n assert KV_nope.shape[1] == O_nope.shape[1]\n assert KV_nope.shape[2] == O_nope.shape[2]\n assert KV_rope.shape[1] == O_rope.shape[1]\n assert KV_rope.shape[2] == O_rope.shape[2]\n\n assert _is_power_of_two(kv_nope_head_dim) and _is_power_of_two(kv_rope_head_dim)\n\n BLOCK_HEAD = triton.next_power_of_2(max(kv_nope_head_num, kv_rope_head_num))\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n KV_nope,\n KV_rope,\n DestLoc,\n O_nope,\n O_rope,\n KV_nope.stride(0),\n KV_nope.stride(1),\n KV_nope.stride(2),\n KV_rope.stride(0),\n KV_rope.stride(1),\n KV_rope.stride(2),\n O_nope.stride(0),\n O_nope.stride(1),\n O_nope.stride(2),\n O_rope.stride(0),\n O_rope.stride(1),\n O_rope.stride(2),\n kv_nope_head_num,\n kv_rope_head_num,\n BLOCK_DMODEL_NOPE=kv_nope_head_dim,\n BLOCK_DMODEL_ROPE=kv_rope_head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_destindex_copy_kv' and a calling function 'destindex_copy_kv'. The kernel function takes 19 arguments: two 3D tensors 'KV_nope' and 'KV_rope' representing input data, a 1D tensor 'Dest_loc' representing destination indices, two 3D tensors 'O_nope' and 'O_rope' for output, and several strides, head counts, and block sizes as input. The purpose of this function is to copy and transform data from input tensors 'KV_nope' and 'KV_rope' to output tensors 'O_nope' and 'O_rope' based on the destination indices, with strides controlling the data reading and writing process. The calling function 'destindex_copy_kv' uses torch to handle input parameters and configure the grid and launch parameters for the Triton kernel call.", - "description_2": "Use triton language to implement a kernel that reads data from input tensors with specified strides and writes it to output tensors based on a given index map, optimizing for parallel execution using triton.jit.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage1(\n Q_nope,\n Q_rope,\n KV_nope,\n KV_rope,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_q_bs,\n stride_q_h,\n stride_q_d,\n stride_q_rope_bs,\n stride_q_rope_h,\n stride_q_rope_d,\n stride_kv_bs,\n stride_kv_h,\n stride_kv_d,\n stride_kv_rope_bs,\n stride_kv_rope_h,\n stride_kv_rope_d,\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_ROPE_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n seq_start_block = tl.program_id(2)\n cur_kv_head = 0\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_start_index = seq_start_block * BLOCK_SEQ\n cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)\n\n off_q = cur_batch * stride_q_bs + cur_head * stride_q_h + offs_d\n off_q_rope = cur_batch * stride_q_rope_bs + cur_head * stride_q_rope_h + offs_rope_d\n\n block_n_size = (\n tl.where(\n cur_batch_end_index - cur_batch_start_index <= 0,\n 0,\n cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,\n )\n // BLOCK_N\n )\n\n offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n\n q = tl.load(Q_nope + off_q)\n q_rope = tl.load(Q_rope + off_q_rope)\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, block_n_size, 1):\n offs_n_new = start_n * BLOCK_N + offs_n\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n off_kv = kv_loc[:, None] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[None, :]\n off_kv_rope = kv_loc[:, None] * stride_kv_rope_bs + cur_kv_head * stride_kv_rope_h + offs_rope_d[None, :]\n kv = tl.load(KV_nope + off_kv, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n kv_rope = tl.load(KV_rope + off_kv_rope, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * kv, 1)\n att_value += tl.sum(q_rope[None, :] * kv_rope, 1)\n att_value *= sm_scale\n att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float(\"-inf\"))\n v = kv\n\n cur_max_logic = tl.max(att_value, axis=0)\n new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n exp_logic = tl.exp(att_value - new_max_logic)\n logic_scale = tl.exp(max_logic - new_max_logic)\n acc *= logic_scale\n acc += tl.sum(exp_logic[:, None] * v, axis=0)\n\n sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)\n max_logic = new_max_logic\n\n need_store = tl.where(block_n_size == 0, 0, 1)\n for _ in range(0, need_store, 1):\n off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d\n off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block\n tl.store(Mid_O + off_mid_o, acc / sum_exp)\n tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))\n return\n\n\n@torch.no_grad()\ndef flash_decode_stage1(\n q_nope,\n q_rope,\n kv_nope,\n kv_rope,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n max_len_in_batch,\n mid_out,\n mid_out_logsumexp,\n block_seq,\n qk_nope_head_dim,\n softmax_scale,\n):\n BLOCK_SEQ = block_seq\n BLOCK_N = 16\n assert BLOCK_SEQ % BLOCK_N == 0\n # shape constraints\n q_nope_dim = q_nope.shape[-1]\n q_rope_dim = q_rope.shape[-1]\n assert q_nope_dim == kv_nope.shape[-1]\n assert q_rope_dim == kv_rope.shape[-1]\n assert q_nope_dim in {16, 32, 64, 128, 256, 512}\n assert q_rope_dim in {16, 32, 64, 128, 256}\n\n sm_scale = softmax_scale # 计算scale系数\n batch, head_num = B_req_idx.shape[0], q_nope.shape[1]\n grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))\n\n _fwd_kernel_flash_decode_stage1[grid](\n q_nope,\n q_rope,\n kv_nope,\n kv_rope,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n mid_out,\n mid_out_logsumexp,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n q_nope.stride(0),\n q_nope.stride(1),\n q_nope.stride(2),\n q_rope.stride(0),\n q_rope.stride(1),\n q_rope.stride(2),\n kv_nope.stride(0),\n kv_nope.stride(1),\n kv_nope.stride(2),\n kv_rope.stride(0),\n kv_rope.stride(1),\n kv_rope.stride(2),\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logsumexp.stride(0),\n mid_out_logsumexp.stride(1),\n mid_out_logsumexp.stride(2),\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=q_nope_dim,\n BLOCK_ROPE_DMODEL=q_rope_dim,\n BLOCK_N=BLOCK_N,\n num_warps=1,\n num_stages=2,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_flash_decode_stage1' with 30 tensor parameters and 4 constexpr parameters, which performs a series of tensor operations including loading, arithmetic operations, and storing results. The function is called by 'flash_decode_stage1', which prepares the grid and block dimensions and passes the necessary parameters to the kernel.", - "description_2": "Use triton language to create a kernel for tensor operations with 30 tensor parameters and 4 constexpr parameters, and a wrapper function to set up and call this kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, # [batch, head, seq_block_num]\n output, # [batch, head, head_dim]\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n stride_obs,\n stride_oh,\n stride_od,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n for block_seq_n in range(0, block_n_size, 1):\n tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)\n tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n new_max_logic = tl.maximum(tlogic, max_logic)\n\n old_scale = tl.exp(max_logic - new_max_logic)\n acc *= old_scale\n exp_logic = tl.exp(tlogic - new_max_logic)\n acc += exp_logic * tv\n sum_exp = sum_exp * old_scale + exp_logic\n max_logic = new_max_logic\n\n tl.store(output + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)\n return\n\n@torch.no_grad()\ndef flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, output, block_seq):\n Lk = mid_out.shape[-1]\n assert Lk in {16, 32, 64, 128, 256, 512}\n batch, head_num = mid_out.shape[0], mid_out.shape[1]\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n mid_out,\n mid_out_logexpsum,\n output,\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logexpsum.stride(0),\n mid_out_logexpsum.stride(1),\n mid_out_logexpsum.stride(2),\n output.stride(0),\n output.stride(1),\n output.stride(2),\n BLOCK_SEQ=block_seq,\n BLOCK_DMODEL=Lk,\n num_warps=4,\n num_stages=2,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_flash_decode_stage2' with 16 parameters for decoding operations. The kernel processes input tensors 'Mid_O' and 'Mid_O_LogExpSum' based on batch and head dimensions, computes scaled values, and stores the result in 'output'. The function 'flash_decode_stage2' is a wrapper that sets up the grid and calls the kernel with 13 parameters, including tensor strides and block sizes.", - "description_2": "Use triton language to create a kernel for decoding operations with input tensors, compute scaled values, and store results. Implement a wrapper function to configure and invoke the kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n@triton.jit\ndef gelu(x):\n \"\"\"\n GeLU_ activation - Gaussian error linear unit\n\n .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf\n \"\"\"\n return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))\n\n@triton.jit\ndef _gelu_and_mul_kernel(\n input_ptr,\n stride_input_m,\n stride_input_n,\n stride_output_m,\n stride_output_n,\n size_m,\n size_n,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n tid = tl.program_id(0)\n input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)\n output_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)\n\n pid = tl.program_id(1)\n input_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)\n output_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)\n\n up_offsets = input_m_offsets[:, None] * stride_input_m + (input_n_offsets[None, :] + size_n) * stride_input_n\n gate_offsets = input_m_offsets[:, None] * stride_input_m + input_n_offsets[None, :] * stride_input_n\n res_offsets = output_m_offsets[:, None] * stride_output_m + output_n_offsets[None, :] * stride_output_n\n\n up = tl.load(\n input_ptr + up_offsets,\n mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],\n other=0.0,\n )\n gate = tl.load(\n input_ptr + gate_offsets,\n mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],\n other=0.0,\n ).to(tl.float32)\n\n gate = gelu(gate)\n gate = gate.to(input_ptr.dtype.element_ty)\n\n tl.store(\n input_ptr + res_offsets,\n up * gate,\n mask=(output_n_offsets < size_n)[None, :] * (output_m_offsets < size_m)[:, None],\n )\n\n\ndef gelu_and_mul_fwd(input):\n stride_input_m = input.stride(0)\n stride_input_n = input.stride(1)\n stride_output_m = input.stride(0)\n stride_output_n = input.stride(1)\n size_m = input.shape[0]\n size_n = input.shape[-1] // 2\n BLOCK_M = 128\n BLOCK_N = 128\n grid = (\n triton.cdiv(size_m, BLOCK_M),\n triton.cdiv(size_n, BLOCK_N),\n )\n _gelu_and_mul_kernel[grid](\n input,\n stride_input_m,\n stride_input_n,\n stride_output_m,\n stride_output_n,\n size_m,\n size_n,\n BLOCK_M,\n BLOCK_N,\n )\n return input[:, 0 : (input.shape[-1] // 2)]\n", - "description_1": "Use triton language to implement a GeLU activation function and a kernel that applies GeLU and element-wise multiplication on a 2D input tensor. The kernel function '_gelu_and_mul_kernel' takes 10 parameters: input_ptr (pointer to input data), stride_input_m (stride for input rows), stride_input_n (stride for input columns), stride_output_m (stride for output rows), stride_output_n (stride for output columns), size_m (number of rows), size_n (number of columns divided by 2), BLOCK_M (block size for rows), and BLOCK_N (block size for columns). The function 'gelu_and_mul_fwd' prepares the input tensor and calls the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a kernel that performs GeLU activation and element-wise multiplication on a 2D tensor, with parameters for input/output strides, sizes, and block dimensions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Out,\n B_Start_Loc,\n B_Seqlen,\n Req_to_tokens,\n B_req_idx,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n b_prompt_cache_len,\n H: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n cur_bh = tl.program_id(1)\n cur_batch = cur_bh // H\n cur_head = cur_bh % H\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = block_start_loc + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)\n\n # causal mask\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),\n mask=(start_n + offs_n) < block_end_loc,\n other=0,\n )\n off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)\n qk = tl.dot(q, k)\n\n mask = offs_m[:, None] + prompt_cache_len >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk * sm_scale, -1.0e8)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n\n # -- update m_i and l_i\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n # -- update output accumulator --\n acc = acc * alpha[:, None]\n # update acc\n off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0)\n p = p.to(v.dtype)\n acc = tl.dot(p, v, acc)\n # update m_i and l_i\n m_i = m_ij\n\n acc = acc / l_i[:, None]\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n\n@torch.no_grad()\ndef context_attention_fwd(\n q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs\n):\n BLOCK_M = 128 if not TESLA else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128, 256}\n\n sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = lambda meta: (triton.cdiv(max_input_len, meta[\"BLOCK_M\"]), batch * head, 1)\n\n BLOCK_N = BLOCK_M\n num_warps = 4 if Lk <= 64 else 8\n num_stages = 1\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n o,\n b_start_loc,\n b_seq_len,\n req_to_token_indexs,\n b_req_idx,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n kv_group_num=kv_group_num,\n b_prompt_cache_len=b_prompt_cache_len,\n H=head,\n BLOCK_DMODEL=Lk,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n\n@triton.jit\ndef _fwd_kernel_no_prompt_cache(\n Q,\n K,\n V,\n sm_scale,\n Out,\n B_Start_Loc,\n B_Seqlen,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n H,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n cur_bh = tl.program_id(1)\n cur_batch = cur_bh // H\n cur_head = cur_bh % H\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = block_start_loc + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum(block_start_loc + BLOCK_M, cur_batch_seq_len)\n\n # causal mask\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(\n k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < block_end_loc,\n other=0,\n )\n qk = tl.dot(q, k)\n\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk * sm_scale, -1.0e8)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n\n # -- update m_i and l_i\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n # -- update output accumulator --\n acc = acc * alpha[:, None]\n # update acc\n v = tl.load(\n v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < block_end_loc,\n other=0.0,\n )\n p = p.to(v.dtype)\n acc = tl.dot(p, v, acc)\n # update m_i\n m_i = m_ij\n\n acc = acc / l_i[:, None]\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n\n@torch.no_grad()\ndef context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n BLOCK_M = 128 if not TESLA else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128, 256}\n\n sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (triton.cdiv(max_input_len, BLOCK_M), batch * head, 1)\n BLOCK_N = BLOCK_M\n num_warps = 4 if Lk <= 64 else 8\n num_stages = 1\n\n _fwd_kernel_no_prompt_cache[grid](\n q,\n k,\n v,\n sm_scale,\n o,\n b_start_loc,\n b_seq_len,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n kv_group_num=kv_group_num,\n H=head,\n BLOCK_DMODEL=Lk,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n return\n", - "description_1": "Use triton language to implement two kernels for fused attention: `_fwd_kernel` and `_fwd_kernel_no_prompt_cache`. Each kernel is invoked by its corresponding wrapper function `context_attention_fwd` or `context_attention_fwd_no_prompt_cache`. The kernels process input queries (Q), keys (K), values (V), and produce an output tensor (Out). Inputs include scaling factors, batch and sequence information, strides for indexing, and other parameters for managing attention mechanisms. The kernels handle masking, accumulation, and normalization to perform attention computation efficiently.", - "description_2": "Use triton language to implement fused attention kernels that perform multi-head scaled dot-product attention with optional prompt caching, using block-wise operations and ensuring efficient computation with specified constraints.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage1(\n Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, #[batch, head, seq_block_num]\n stride_req_to_tokens_b, stride_req_to_tokens_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od,\n stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es,\n gqa_group_size,\n BLOCK_SEQ: tl.constexpr, \n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n seq_start_block = tl.program_id(2)\n cur_kv_head = cur_head // gqa_group_size\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_start_index = seq_start_block * BLOCK_SEQ\n cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n \n block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N\n \n offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n \n q = tl.load(Q + off_q)\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, block_n_size, 1):\n offs_n_new = start_n * BLOCK_N + offs_n\n k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)\n off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float(\"-inf\"))\n v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n \n cur_max_logic = tl.max(att_value, axis=0)\n new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n exp_logic = tl.exp(att_value - new_max_logic)\n logic_scale = tl.exp(max_logic - new_max_logic)\n acc *= logic_scale\n acc += tl.sum(exp_logic[:, None] * v, axis=0)\n\n sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)\n max_logic = new_max_logic\n \n need_store = tl.where(block_n_size == 0, 0, 1)\n for _ in range(0, need_store, 1):\n off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d\n off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block\n tl.store(Mid_O + off_mid_o, acc / sum_exp)\n tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))\n return\n\n\n@torch.no_grad()\ndef flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq):\n BLOCK_SEQ = block_seq\n BLOCK_N = 16\n assert BLOCK_SEQ % BLOCK_N == 0\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))\n gqa_group_size = q.shape[1] // k.shape[1]\n \n _fwd_kernel_flash_decode_stage1[grid](\n q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen,\n mid_out,\n mid_out_logsumexp,\n Req_to_tokens.stride(0), Req_to_tokens.stride(1),\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3),\n mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2),\n gqa_group_size,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK_N,\n num_warps=1,\n num_stages=2,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_flash_decode_stage1' with 24 parameters for performing a flash attention-like operation. The kernel computes scaled dot-product attention over a sequence of blocks, handling multiple heads and batches. It uses triton's parallel programming model to efficiently load and process data in blocks, applying softmax scaling and storing results. The function 'flash_decode_stage1' is a wrapper that sets up the grid and block sizes, and calls the kernel with appropriate parameters.", - "description_2": "Use triton language to implement a flash attention-like operation with a kernel function that computes scaled dot-product attention over sequence blocks, handling multiple heads and batches efficiently.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Kernel for forward pass of GQA decode attention mechanism\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_seqlen, Out,\n stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od,\n stride_req_to_tokens_b, stride_req_to_tokens_s, kv_group_num,\n Q_HEAD_NUM: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_kv_head = tl.program_id(1)\n cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_seq_len = tl.load(B_seqlen + cur_batch)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_q_head_range = cur_kv_head * kv_group_num + cur_q_head_offs\n off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :]\n off_k = cur_kv_head * stride_kh + offs_d[:, None]\n off_v = cur_kv_head * stride_vh + offs_d[None, :]\n q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * kv_group_num, other=0.0)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n m_i = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([Q_HEAD_NUM], dtype=tl.float32)\n acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n kv_loc = tl.load(\n Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0,\n )\n k = tl.load(\n k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0\n )\n qk = tl.zeros([Q_HEAD_NUM, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(cur_batch_seq_len - 1 >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(\n v_ptrs + kv_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = cur_batch * stride_obs + cur_q_head_range[:, None] * stride_oh + offs_d[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * kv_group_num)\n return\n\n# Wrapper for calling the kernel\n@torch.no_grad()\ndef gqa_decode_attention_fwd(q, k, v, o, req_to_tokens, b_req_idx, b_seq_len):\n BLOCK = 32\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lq ** 0.5)\n batch = b_req_idx.shape[0]\n kv_group_num = q.shape[1] // k.shape[1]\n kv_head_num = k.shape[1]\n grid = (batch, kv_head_num)\n num_warps = 4\n _fwd_kernel[grid](\n q, k, v, sm_scale, req_to_tokens, b_req_idx, b_seq_len, o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n req_to_tokens.stride(0), req_to_tokens.stride(1),\n kv_group_num=kv_group_num,\n Q_HEAD_NUM=max(16, triton.next_power_of_2(kv_group_num)),\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=2,\n )\n return\n", - "description_1": "Use triton language to implement a kernel (_fwd_kernel) for the forward pass of a generalized query attention (GQA) mechanism. The kernel accepts 24 parameters: query (Q), key (K), value (V) matrices, scaling factor (sm_scale), requested tokens matrix (Req_to_tokens), batch request index (B_req_idx), batch sequence length (B_seqlen), output matrix (Out), strides for query, key, value and output matrices (stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od), strides for requested tokens matrix (stride_req_to_tokens_b, stride_req_to_tokens_s), and a constant expression for the number of query heads (Q_HEAD_NUM), block size for the model dimension (BLOCK_DMODEL), and block size for sequence length (BLOCK_N). The wrapper function (gqa_decode_attention_fwd) initiates the kernel by accepting 7 parameters: query (q), key (k), value (v), output (o), requested tokens matrix (req_to_tokens), batch request index (b_req_idx), and batch sequence length (b_seq_len). It calculates the scale factor, derives the batch size and group numbers, configures the execution grid and number of warps, and dispatches the kernel with the appropriate parameters.", - "description_2": "Use triton language to implement a generalized query attention kernel for forward pass with query, key, value, and scaling matrices. Dispatch kernel execution using torch.no_grad() wrapper.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage1(\n Q,\n K,\n V,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n gqa_group_size,\n Q_HEAD_NUM: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_kv_head = tl.program_id(1)\n seq_start_block = tl.program_id(2)\n\n cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)\n cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_start_index = seq_start_block * BLOCK_SEQ\n cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)\n\n off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :]\n\n block_n_size = (\n tl.where(\n cur_batch_end_index - cur_batch_start_index <= 0,\n 0,\n cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,\n )\n // BLOCK_N\n )\n\n offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n\n q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0)\n\n sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32)\n max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float(\"inf\")\n acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, block_n_size, 1):\n offs_n_new = start_n * BLOCK_N + offs_n\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]\n k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0)\n att_value = tl.dot(q, k)\n att_value *= sm_scale\n att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float(\"-inf\"))\n v = tl.load(\n V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :],\n mask=offs_n_new[:, None] < cur_batch_end_index,\n other=0.0,\n )\n\n cur_max_logic = tl.max(att_value, axis=1)\n new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n exp_logic = tl.exp(att_value - new_max_logic[:, None])\n logic_scale = tl.exp(max_logic - new_max_logic)\n acc *= logic_scale[:, None]\n acc += tl.dot(exp_logic.to(v.dtype), v)\n\n sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1)\n max_logic = new_max_logic\n\n need_store = tl.where(block_n_size == 0, 0, 1)\n for _ in range(0, need_store, 1):\n off_mid_o = (\n cur_batch * stride_mid_ob\n + cur_q_head_range[:, None] * stride_mid_oh\n + seq_start_block * stride_mid_os\n + offs_d[None, :]\n )\n off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block\n tl.store(\n Mid_O + off_mid_o,\n acc / sum_exp[:, None],\n mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size,\n )\n tl.store(\n Mid_O_LogExpSum + off_mid_o_logexpsum,\n max_logic + tl.log(sum_exp),\n mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size,\n )\n return\n\n@torch.no_grad()\ndef flash_decode_stage1(\n q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq\n):\n BLOCK_SEQ = block_seq\n BLOCK_N = 16\n assert BLOCK_SEQ % BLOCK_N == 0\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n batch, kv_head_num = B_req_idx.shape[0], k.shape[1]\n grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))\n gqa_group_size = q.shape[1] // k.shape[1]\n\n _fwd_kernel_flash_decode_stage1[grid](\n q,\n k,\n v,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n mid_out,\n mid_out_logsumexp,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logsumexp.stride(0),\n mid_out_logsumexp.stride(1),\n mid_out_logsumexp.stride(2),\n gqa_group_size,\n Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)),\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK_N,\n num_warps=2,\n num_stages=2,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_flash_decode_stage1' with 28 parameters for performing a forward pass of a flash attention mechanism. The kernel processes input tensors Q, K, V, and other parameters to compute attention outputs and log-sum-exp values, storing results in Mid_O and Mid_O_LogExpSum. The function 'flash_decode_stage1' with 10 parameters sets up the grid and block sizes, calculates the scaling factor, and calls the kernel with appropriate strides and constants.", - "description_2": "Use triton language to implement a flash attention forward pass kernel and its calling function, handling input tensors and computing attention outputs with log-sum-exp.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_int4_kv(\n K,\n Dest_loc,\n Out,\n Out_scale,\n stride_k_bs,\n stride_k_h,\n stride_k_g,\n stride_k_d,\n stride_o_bs,\n stride_o_h,\n stride_o_g,\n stride_o_d,\n stride_os_bs,\n stride_os_h,\n stride_os_g,\n group_size,\n BLOCK_GROUP_NUM: tl.constexpr,\n BLOCK_GROUP_DIM: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_g = tl.arange(0, BLOCK_GROUP_NUM)\n offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n src_data_0 = tl.load(\n K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2,\n mask=offs_g[:, None] < group_size,\n other=0.0,\n )\n src_data_1 = tl.load(\n K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1,\n mask=offs_g[:, None] < group_size,\n other=0.0,\n )\n\n abs_data_0 = tl.abs(src_data_0)\n abs_data_1 = tl.abs(src_data_1)\n\n data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty)\n q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8)\n q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0)\n q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0)\n\n q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8)\n q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1)\n q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1)\n\n low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF)\n high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4\n\n out_data = low_4 | high_4\n\n o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g\n tl.store(o_ptrs, out_data, mask=offs_g[:, None] < group_size)\n tl.store(os_ptrs, data_scale, mask=offs_g < group_size)\n return\n\n@torch.no_grad()\ndef destindex_copy_int4kv(K, DestLoc, Out, Out_scale):\n head_dim = K.shape[2]\n quant_group_dim = 8\n\n assert head_dim % quant_group_dim == 0, \"error head dim, can not been supported to copy quant kv\"\n\n group_size = head_dim // quant_group_dim\n group_dim = quant_group_dim\n\n K = K.view((K.shape[0], K.shape[1], group_size, group_dim))\n Out = Out.view(\n Out.shape[0], Out.shape[1], group_size, group_dim // 2\n )\n\n # _fwd_kernel_destindex_copy_quantize_int4_kv[grid](\n # K,\n # DestLoc,\n # Out,\n # Out_scale,\n # K.stride(0),\n # K.stride(1),\n # K.stride(2),\n # K.stride(3),\n # Out.stride(0),\n # Out.stride(1),\n # Out.stride(2),\n # Out.stride(3),\n # Out_scale.stride(0),\n # Out_scale.stride(1),\n # Out_scale.stride(2),\n # group_size,\n # BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),\n # BLOCK_GROUP_DIM=group_dim,\n # num_warps=num_warps,\n # num_stages=1,\n # )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_destindex_copy_quantize_int4_kv' with 17 parameters for quantizing and copying data from a source tensor 'K' to a destination tensor 'Out' using destination indices 'Dest_loc'. The function also computes a scale 'Out_scale' for the quantized data. The kernel uses block sizes defined by 'BLOCK_GROUP_NUM' and 'BLOCK_GROUP_DIM'. The function 'destindex_copy_int4kv' prepares the input tensors and calls the kernel with appropriate strides and grid dimensions.", - "description_2": "Use triton language to create a kernel for quantizing and copying data with destination indices, and a wrapper function to prepare inputs and invoke the kernel.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _rotary_kernel(\n Q, K, Cos, Sin,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_cosbs, stride_cosd,\n stride_sinbs, stride_sind,\n max_total_len, HEAD_Q, HEAD_K,\n BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)\n dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL)\n\n off_q0 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range0[None, None, :] * stride_qd\n )\n off_q1 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range1[None, None, :] * stride_qd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n q0 = tl.load(\n Q + off_q0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n q1 = tl.load(\n Q + off_q1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos - q1 * sin\n out1 = q0 * sin + q1 * cos\n\n tl.store(\n Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n tl.store(\n Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n\n off_k0 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range0[None, None, :] * stride_kd\n )\n off_k1 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range1[None, None, :] * stride_kd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n k0 = tl.load(\n K + off_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n k1 = tl.load(\n K + off_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out_k0 = k0 * cos - k1 * sin\n out_k1 = k0 * sin + k1 * cos\n\n tl.store(\n K + off_k0,\n out_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n tl.store(\n K + off_k1,\n out_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n return\n\n\n@torch.no_grad()\ndef rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.):\n total_len = q.shape[0]\n head_num_q, head_num_k = q.shape[1], k.shape[1]\n head_dim = int(q.shape[2] * partial_rotary_factor)\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f\"k shape {k.shape} cos shape {cos.shape}\"\n\n BLOCK_SEQ = 16\n BLOCK_HEAD = 4\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n\n grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n _rotary_kernel[grid](\n q,\n k,\n cos,\n sin,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n cos.stride(0),\n cos.stride(1),\n sin.stride(0),\n sin.stride(1),\n total_len,\n head_num_q,\n head_num_k,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=head_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_rotary_kernel' that takes 18 parameters including input matrices Q, K, Cos, Sin, strides, maximum context length and head dimensions, and performs a series of loads, computes, and stores with masking based on head and sequence ranges. Accompanying this kernel is a host function 'rotary_emb_fwd' that sets up parameters and grid dimensions to call the kernel.", - "description_2": "Use triton language to build a function '_rotary_kernel' for rotary embeddings, managing tensor operations via Triton kernels, and invoking it through 'rotary_emb_fwd' which handles parameter setup and execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _silu_and_mul_kernel(\n input_ptr,\n stride_input_m,\n stride_input_n,\n stride_output_m,\n stride_output_n,\n size_m,\n size_n,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n stride_input_m = stride_input_m.to(tl.int64)\n stride_output_m = stride_output_m.to(tl.int64)\n\n tid = tl.program_id(0)\n input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)\n output_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)\n\n pid = tl.program_id(1)\n input_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)\n output_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)\n\n up_offsets = input_m_offsets[:, None] * stride_input_m + (input_n_offsets[None, :] + size_n) * stride_input_n\n gate_offsets = input_m_offsets[:, None] * stride_input_m + input_n_offsets[None, :] * stride_input_n\n res_offsets = output_m_offsets[:, None] * stride_output_m + output_n_offsets[None, :] * stride_output_n\n\n up = tl.load(\n input_ptr + up_offsets,\n mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],\n other=0.0,\n )\n gate = tl.load(\n input_ptr + gate_offsets,\n mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],\n other=0.0,\n ).to(tl.float32)\n\n gate = gate / (1 + tl.exp(-gate))\n gate = gate.to(input_ptr.dtype.element_ty)\n\n tl.store(\n input_ptr + res_offsets,\n up * gate,\n mask=(output_n_offsets < size_n)[None, :] * (output_m_offsets < size_m)[:, None],\n )\n\n\ndef silu_and_mul_fwd(input):\n stride_input_m = input.stride(0)\n stride_input_n = input.stride(1)\n stride_output_m = input.stride(0)\n stride_output_n = input.stride(1)\n size_m = input.shape[0]\n size_n = input.shape[-1] // 2\n BLOCK_M = 128\n BLOCK_N = 128\n grid = (\n triton.cdiv(size_m, BLOCK_M),\n triton.cdiv(size_n, BLOCK_N),\n )\n _silu_and_mul_kernel[grid](\n input,\n stride_input_m,\n stride_input_n,\n stride_output_m,\n stride_output_n,\n size_m,\n size_n,\n BLOCK_M,\n BLOCK_N,\n )\n return input[:, 0 : (input.shape[-1] // 2)]\n", - "description_1": "Use triton language to implement a kernel function '_silu_and_mul_kernel' and a wrapper function 'silu_and_mul_fwd'. The kernel function takes 8 arguments: input_ptr (pointer to input tensor), stride_input_m and stride_input_n (strides for input tensor), stride_output_m and stride_output_n (strides for output tensor), size_m and size_n (dimensions for processing), and BLOCK_M and BLOCK_N (block size constants). It performs element-wise operations including silu activation and multiplication on a blocked grid. The wrapper function 'silu_and_mul_fwd' calculates strides and size parameters from the input tensor, configures the grid size, and invokes the kernel.", - "description_2": "Use triton language to create a block-based silu activation and multiplication kernel with associated wrapper function, designed for tensor operations using grid configuration.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_split_start_loc,\n B_split_ready_cache_len,\n B_seqlen,\n Out,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_q_split_start_loc = tl.load(B_split_start_loc + cur_batch)\n cur_batch_seq_start = tl.load(B_split_ready_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_seqlen + cur_batch)\n cur_batch_q_split_seq_len = cur_batch_seq_len - cur_batch_seq_start\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_q_split_start_loc + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :]\n off_k = cur_kv_head * stride_kh + offs_d[:, None]\n off_v = cur_kv_head * stride_vh + offs_d[None, :]\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_q_split_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(start_m * BLOCK_M < cur_batch_q_split_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (cur_batch_seq_start + (start_m + 1) * BLOCK_M), BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0,\n )\n k = tl.load(\n k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0\n )\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(cur_batch_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-100000000.0\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(\n v_ptrs + kv_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_q_split_start_loc + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_q_split_seq_len)\n return\n\n@torch.no_grad()\ndef splitfuse_context_attention_fwd(\n q,\n k,\n v,\n o,\n prefill_req_num,\n req_to_tokens,\n prefill_b_req_idx,\n prefill_b_split_start_loc,\n prefill_b_split_ready_cache_len,\n prefill_b_seq_len,\n prefill_max_split_seq_len_in_batch,\n):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数\n _, head = prefill_b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (prefill_req_num, head, triton.cdiv(prefill_max_split_seq_len_in_batch, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n req_to_tokens,\n prefill_b_req_idx,\n prefill_b_split_start_loc,\n prefill_b_split_ready_cache_len,\n prefill_b_seq_len,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_tokens.stride(0),\n req_to_tokens.stride(1),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n@triton.jit\ndef _fwd_kernel_int8(\n Q,\n K,\n K_scale,\n V,\n V_scale,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_split_start_loc,\n B_split_ready_cache_len,\n B_seqlen,\n Out,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_ksbs,\n stride_ksh,\n stride_ksd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_vsbs,\n stride_vsh,\n stride_vsd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_q_split_start_loc = tl.load(B_split_start_loc + cur_batch)\n cur_batch_seq_len = tl.load(B_seqlen + cur_batch)\n cur_batch_seq_start = tl.load(B_split_ready_cache_len + cur_batch)\n cur_batch_q_split_seq_len = cur_batch_seq_len - cur_batch_seq_start\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_q_split_start_loc + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :]\n off_k = cur_kv_head * stride_kh + offs_d[:, None]\n off_v = cur_kv_head * stride_vh + offs_d[None, :]\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_q_split_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n ks_ptrs = K_scale + cur_kv_head * stride_ksh\n vs_ptrs = V_scale + cur_kv_head * stride_vsh\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(start_m * BLOCK_M < cur_batch_q_split_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (cur_batch_seq_start + (start_m + 1) * BLOCK_M), BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0,\n )\n k = tl.load(\n k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0\n )\n k_scale = tl.load(\n ks_ptrs + kv_loc[None, :] * stride_ksbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0\n )\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, (k_scale * k))\n qk *= sm_scale\n qk = tl.where(cur_batch_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-100000000.0\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(\n v_ptrs + kv_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n v_scale = tl.load(\n vs_ptrs + kv_loc[:, None] * stride_vsbs, mask=(start_n + offs_n)[:, None] < cur_batch_seq_len, other=0.0\n )\n\n p = p.to(V.dtype.element_ty)\n acc += tl.dot(p, v.to(V.dtype.element_ty) * v_scale)\n\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_q_split_start_loc + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_q_split_seq_len)\n return\n\n@torch.no_grad()\ndef splitfuse_context_attention_fwd_int8kv(\n q,\n k,\n k_scale,\n v,\n v_scale,\n o,\n prefill_req_num,\n req_to_tokens,\n prefill_b_req_idx,\n prefill_b_split_start_loc,\n prefill_b_split_ready_cache_len,\n prefill_b_seq_len,\n prefill_max_split_seq_len_in_batch,\n):\n\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数\n _, head = prefill_b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (prefill_req_num, head, triton.cdiv(prefill_max_split_seq_len_in_batch, BLOCK))\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel_int8[grid](\n q,\n k,\n k_scale,\n v,\n v_scale,\n sm_scale,\n req_to_tokens,\n prefill_b_req_idx,\n prefill_b_split_start_loc,\n prefill_b_split_ready_cache_len,\n prefill_b_seq_len,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k_scale.stride(0),\n k_scale.stride(1),\n k_scale.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v_scale.stride(0),\n v_scale.stride(1),\n v_scale.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_tokens.stride(0),\n req_to_tokens.stride(1),\n kv_group_num=kv_group_num,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two forward kernels for context attention, one for standard precision and one for int8 precision. The kernels take in query, key, and value tensors, along with scaling factors and other parameters, to compute the attention output. The kernels are invoked by their respective wrapper functions which set up the grid and block dimensions.", - "description_2": "Use triton language to create forward kernels for context attention with standard and int8 precision, handling query, key, value tensors, and scaling.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel for token attention forward pass\n@triton.jit\ndef _fwd_kernel_token_att1(\n Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, Att_Out,\n stride_req_to_tokens_b, stride_req_to_tokens_s, stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd, att_stride_h, att_stride_bs, kv_group_num,\n BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n cur_batch_start_index = 0\n cur_batch_end_index = cur_batch_seq_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value = att_value.to(tl.float32)\n att_value *= sm_scale\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n# Wrapper function for the Triton kernel\n@torch.no_grad()\ndef token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch):\n BLOCK = 32\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128, 256}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))\n kv_group_num = q.shape[1] // k.shape[1]\n\n if kv_group_num == 1:\n num_warps = 4\n else:\n num_warps = 2\n\n _fwd_kernel_token_att1[grid](\n q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, att_out,\n Req_to_tokens.stride(0), Req_to_tokens.stride(1), q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2), att_out.stride(0), att_out.stride(1),\n kv_group_num=kv_group_num, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1,\n )\n return\n\n# Triton kernel for token attention forward pass with int8 inputs\n@triton.jit\ndef _fwd_kernel_token_att1_int8(\n Q, K, K_scale, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, Att_Out,\n stride_req_to_tokens_b, stride_req_to_tokens_s, stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd, stride_ksbs, stride_ksh, stride_ksd,\n att_stride_h, att_stride_bs, kv_group_num, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n cur_batch_start_index = 0\n cur_batch_end_index = cur_batch_seq_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n off_ks = k_loc[:, None] * stride_ksbs + cur_kv_head * stride_ksh\n k_scale = tl.load(K_scale + off_ks, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k * k_scale, 1)\n att_value *= sm_scale\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n# Wrapper function for the Triton int8 kernel\n@torch.no_grad()\ndef token_att_fwd_int8k(q, k, k_scale, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_input_len):\n BLOCK = 32\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))\n\n kv_group_num = q.shape[1] // k.shape[1]\n if kv_group_num == 1:\n num_warps = 4\n else:\n num_warps = 2\n\n _fwd_kernel_token_att1_int8[grid](\n q, k, k_scale, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, att_out,\n Req_to_tokens.stride(0), Req_to_tokens.stride(1), q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2), k_scale.stride(0), k_scale.stride(1), k_scale.stride(2),\n att_out.stride(0), att_out.stride(1), kv_group_num=kv_group_num, BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two kernels for token attention forward passes, one supporting int8 inputs, with corresponding wrapper functions to handle multi-dimensional data inputs (Q, K) and calculate attention output based on specific parameters, grid size, and block dimensions.", - "description_2": "Use triton language to create kernels for attention calculations and provide wrapper functions for handling grid and block settings, supporting both float32 and int8 input formats.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob,\n V,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_ph,\n stride_pbs,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = 0\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(\n Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0.0,\n )\n v_value = tl.load(\n V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(Out.dtype.element_ty)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n\n@torch.no_grad()\ndef token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen):\n BLOCK = 128\n batch, head = B_req_idx.shape[0], prob.shape[0]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n kv_group_num = prob.shape[0] // v.shape[1]\n\n _fwd_kernel_token_att2[grid](\n prob,\n v,\n out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n prob.stride(0),\n prob.stride(1),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n out.stride(0),\n out.stride(1),\n out.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n@triton.jit\ndef _fwd_kernel_token_att2_int8v(\n Prob,\n V,\n V_scale,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_ph,\n stride_pbs,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_vsbs,\n stride_vsh,\n stride_vsd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = 0\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n vs_offs = cur_kv_head * stride_vsh\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(\n Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0.0,\n )\n v_value = tl.load(\n V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n vs_value = tl.load(\n V_scale + vs_offs + v_loc[:, None] * stride_vsbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,\n other=0.0,\n )\n acc += tl.sum(p_value[:, None] * v_value * vs_value, 0)\n\n acc = acc.to(Out.dtype.element_ty)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n\n@torch.no_grad()\ndef token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch):\n if max_len_in_batch < 512:\n BLOCK = triton.next_power_of_2(max_len_in_batch)\n else:\n BLOCK = 512\n batch, head = B_req_idx.shape[0], prob.shape[0]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n kv_group_num = prob.shape[0] // v.shape[1]\n\n _fwd_kernel_token_att2_int8v[grid](\n prob,\n v,\n v_scale,\n out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n prob.stride(0),\n prob.stride(1),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v_scale.stride(0),\n v_scale.stride(1),\n v_scale.stride(2),\n out.stride(0),\n out.stride(1),\n out.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_fwd_kernel_token_att2' with 19 arguments for processing attention data. Also implement a wrapper 'token_att_fwd2' with 7 arguments to configure and launch the kernel. Another kernel '_fwd_kernel_token_att2_int8v' with 21 arguments processes attention with int8 input, and 'token_att_fwd2_int8v' as its launcher with 8 arguments.", - "description_2": "Use triton language to compute attention by multiplying probabilities with values, both for float and int8, with optional scaling.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel(\n Logics, V, Out,\n Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,\n stride_logic_h, stride_logic_bs,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_req_to_token_b, stride_req_to_token_s,\n other_kv_index,\n kv_group_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n v_ptrs = V + off_v\n\n e_max = float(\"-inf\")\n e_sum = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n v_index = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + \n (start_n + offs_n) * stride_req_to_token_s, \n mask=(start_n + offs_n) < cur_batch_seq_len, other=other_kv_index)\n\n qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, \n mask=start_n + offs_n < cur_batch_seq_len, other=float(\"-inf\"))\n \n n_e_max = tl.maximum(tl.max(qk, 0), e_max)\n old_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max)\n e_sum = e_sum * old_scale + tl.sum(p, 0)\n v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)\n acc = acc * old_scale + tl.sum(p[:, None] * v, 0)\n e_max = n_e_max\n\n acc = acc / e_sum\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n@torch.no_grad()\ndef token_softmax_reducev_fwd(logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, other_kv_index):\n BLOCK = 64\n batch, head = b_seq_len.shape[0], logics.shape[0]\n grid = (batch, head)\n kv_group_num = logics.shape[0] // v.shape[1]\n\n num_warps = 1\n _fwd_kernel[grid](\n logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len,\n logics.stride(0), logics.stride(1),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n req_to_tokens.stride(0), req_to_tokens.stride(1),\n other_kv_index,\n kv_group_num,\n BLOCK_DMODEL=v.shape[-1],\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=3\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel (_fwd_kernel) that performs a softmax reduction over a set of logic values and a value tensor. The kernel takes 20 parameters: Logics, V, Out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, stride_logic_h, stride_logic_bs, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_req_to_token_b, stride_req_to_token_s, other_kv_index, kv_group_num, and two constexpr parameters BLOCK_DMODEL and BLOCK_N. The kernel computes the softmax of the logic values, scales the value tensor accordingly, and stores the result in the output tensor. The function token_softmax_reducev_fwd is a wrapper that sets up the grid and block dimensions and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to create a kernel that computes the softmax of logic values and scales a value tensor, storing the result in an output tensor. The kernel is called by a wrapper function that sets up execution parameters.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_init_att_window_info(\n b_seq_len,\n b_att_seq_len,\n batch_size,\n sliding_window,\n BLOCK_SIZE: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n cur_start = cur_index * BLOCK_SIZE\n offsets = cur_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < batch_size\n\n cur_seq_len = tl.load(b_seq_len + offsets, mask=mask)\n b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window)\n\n tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask)\n return\n\n@torch.no_grad()\ndef init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window):\n # shape constraints\n assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0]\n\n BLOCK_SIZE = 32\n num_warps = 1\n grid = (triton.cdiv(batch_size, BLOCK_SIZE),)\n\n _fwd_kernel_init_att_window_info[grid](\n b_seq_len,\n b_att_seq_len,\n batch_size=batch_size,\n sliding_window=sliding_window,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a kernel that initializes attention window information. The kernel takes 5 parameters: b_seq_len (sequence lengths), b_att_seq_len (output buffer for attention sequence lengths), batch_size (number of sequences), sliding_window (maximum attention window size), and BLOCK_SIZE (block size for processing). The kernel calculates the attention sequence length for each sequence, ensuring it does not exceed the sliding window size, and stores the result in b_att_seq_len. The function init_att_window_info_fwd is a wrapper that sets up the grid and block size for the kernel execution and ensures the input shapes are consistent.", - "description_2": "Use triton language to create a kernel that computes attention window sizes for sequences, ensuring they do not exceed a given sliding window size, and stores the results.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_token_att1(\n Q,\n K,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n B_Att_Start_Loc,\n B_Att_Seqlen,\n Att_Out,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n att_stride_h,\n att_stride_bs,\n kv_group_num,\n sliding_window,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_d = tl.arange(0, BLOCK_DMODEL) # [D]\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch)\n\n # use new start index of k value\n cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0)\n cur_batch_end_index = cur_batch_seq_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]\n\n # use new value to decide block mask\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark\n offs_n_new = cur_batch_start_index + offs_n # the latest window of token\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n off_k = (\n k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd\n ) # [32, D], find token index\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]\n att_value = att_value.to(tl.float32)\n att_value *= sm_scale\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n\n@torch.no_grad()\ndef token_att_fwd(\n q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window\n):\n BLOCK = 32\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK))\n kv_group_num = q.shape[1] // k.shape[1]\n\n if kv_group_num == 1:\n num_warps = 4\n else:\n num_warps = 2\n\n _fwd_kernel_token_att1[grid](\n q,\n k,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n B_Att_Start_Loc,\n B_Att_Seqlen,\n att_out,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n att_out.stride(0),\n att_out.stride(1),\n kv_group_num=kv_group_num,\n sliding_window=sliding_window,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel for token attention. The kernel '_fwd_kernel_token_att1' requires parameters such as query matrix 'Q', key matrix 'K', softmax scaling factor 'sm_scale', a mapping 'Req_to_tokens', and various indexing and stride parameters for efficient loading and storing of attention values. The function 'token_att_fwd' is a wrapper that sets the execution grid and other kernel parameters, ensuring the shapes of Q and K are valid, and calling the Triton kernel to compute attention scores, storing them in 'att_out'.", - "description_2": "Use triton language to define a token attention kernel that computes attention values for each token within a specified sliding window, utilizing grid-based parallelization and vectorized operations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel(\n Logics,\n V,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n B_Att_Start_Loc,\n B_Att_Seqlen,\n stride_logic_h,\n stride_logic_bs,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_token_b,\n stride_req_to_token_s,\n other_kv_index,\n kv_group_num,\n sliding_window,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch)\n cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n v_ptrs = V + off_v\n\n e_max = float(\"-inf\")\n e_sum = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_att_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n v_index = tl.load(\n Req_to_tokens\n + cur_batch_req_idx * stride_req_to_token_b\n + (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s,\n mask=(start_n + offs_n) < cur_att_seq_len,\n other=other_kv_index,\n )\n\n qk = tl.load(\n Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,\n mask=(start_n + offs_n) < cur_att_seq_len,\n other=float(\"-inf\"),\n )\n\n n_e_max = tl.maximum(tl.max(qk, 0), e_max)\n old_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max)\n e_sum = e_sum * old_scale + tl.sum(p, 0)\n v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)\n acc = acc * old_scale + tl.sum(p[:, None] * v, 0)\n e_max = n_e_max\n\n acc = acc / e_sum\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n@torch.no_grad()\ndef token_softmax_reducev_fwd(\n logics,\n v,\n o,\n req_to_tokens,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n b_att_start_loc,\n b_att_seq_len,\n other_kv_index,\n sliding_window,\n):\n BLOCK = 64\n batch, head = b_seq_len.shape[0], logics.shape[0]\n grid = (batch, head)\n kv_group_num = logics.shape[0] // v.shape[1]\n\n num_warps = 1\n _fwd_kernel[grid](\n logics,\n v,\n o,\n req_to_tokens,\n b_req_idx,\n b_start_loc,\n b_seq_len,\n b_att_start_loc,\n b_att_seq_len,\n logics.stride(0),\n logics.stride(1),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_tokens.stride(0),\n req_to_tokens.stride(1),\n other_kv_index,\n kv_group_num,\n sliding_window,\n BLOCK_DMODEL=v.shape[-1],\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=3,\n )\n return\n", - "description_1": "Use triton language to implement a forward kernel (_fwd_kernel) that performs a softmax reduction over a sliding window of attention scores. The kernel takes 22 parameters: Logics, V, Out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, stride_logic_h, stride_logic_bs, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_req_to_token_b, stride_req_to_token_s, other_kv_index, kv_group_num, and sliding_window. It also uses two constexpr parameters: BLOCK_DMODEL and BLOCK_N. The kernel computes the softmax of the attention scores and applies it to the value vectors V, storing the result in Out. The function token_softmax_reducev_fwd is a wrapper that sets up the grid and block dimensions and calls the kernel with the appropriate parameters.", - "description_2": "Use triton language to implement a softmax reduction kernel for attention mechanisms, processing data in blocks and handling sliding windows. The kernel is invoked by a wrapper function that configures execution parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n B_Start_Loc,\n B_Seqlen, # B_LOC records the actual location of each batch input, B_SEQ_len records the actual length of the current input\n Out,\n Req_to_tokens,\n B_req_idx,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n b_prompt_cache_len,\n head_dim: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for forward pass with prompt cache\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n\n q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0)\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)\n\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),\n mask=(start_n + offs_n) < block_end_loc,\n other=0,\n )\n off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n k = tl.load(\n K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0\n )\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float(\"-100000000.0\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0)\n acc = acc * acc_scale[:, None]\n # update acc\n off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n v = tl.load(\n V + off_v, mask=((start_n + offs_n[:, None]) < block_end_loc) & (offs_d[None, :] < head_dim), other=0.0\n )\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim))\n return\n\n\n@torch.no_grad()\ndef context_attention_fwd(\n q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs\n):\n # Forward function for Triton kernel with prompt cache\n BLOCK = 128 if not TESLA else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n head_dim = Lq\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n\n sm_scale = 1.0 / (Lq ** 0.5) # calculate scale coefficient\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n req_to_token_indexs,\n b_req_idx,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n kv_group_num=kv_group_num,\n b_prompt_cache_len=b_prompt_cache_len,\n head_dim=head_dim,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n@triton.jit\ndef _fwd_kernel_no_prompt_cache(\n Q,\n K,\n V,\n sm_scale,\n B_Start_Loc,\n B_Seqlen, # B_LOC records the actual location of each batch input, B_SEQ_len records the actual length of the current input\n Out,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n head_dim,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for forward pass without prompt cache\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(\n k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (offs_d[:, None] < head_dim),\n other=0.0,\n )\n # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(\n v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (offs_d[None, :] < head_dim),\n other=0.0,\n )\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim))\n return\n\n\n@torch.no_grad()\ndef context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n # Forward function for Triton kernel without prompt cache\n BLOCK = 128 if not TESLA else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n head_dim = Lq\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n sm_scale = 1.0 / (Lq ** 0.5) # calculate scale coefficient\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel_no_prompt_cache[grid](\n q,\n k,\n v,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n kv_group_num=kv_group_num,\n head_dim=head_dim,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two attention forward kernels: one with prompt cache and one without, involving batch processing of queries, keys, and values, computing the scaled dot-product attention while utilizing memory optimally. The kernel accepts data and layout descriptors, applies a scaling factor, calculates attention scores, and updates accumulated attention values efficiently across multi-heads and batch dimensions.", - "description_2": "Implement Triton kernels for efficient batch processing of scaled dot-product attention with and without prompt cache across multi-heads, utilizing memory efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n K,\n Dest_loc,\n Out,\n stride_k_bs,\n stride_k_h,\n stride_k_d,\n stride_o_bs,\n stride_o_h,\n stride_o_d,\n head_num,\n head_dim,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n\n k = tl.load(k_ptrs, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), other=0.0)\n tl.store(o_ptrs, k, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim))\n return\n\n@torch.no_grad()\ndef destindex_copy_kv(K, DestLoc, Out):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n K,\n DestLoc,\n Out,\n K.stride(0),\n K.stride(1),\n K.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n head_num,\n head_dim,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K,\n Dest_loc,\n Out,\n Out_scale,\n stride_k_bs,\n stride_k_h,\n stride_k_d,\n stride_o_bs,\n stride_o_h,\n stride_o_d,\n stride_os_bs,\n stride_os_h,\n stride_os_d,\n head_num,\n head_dim,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n src_data = tl.load(\n K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],\n mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim),\n other=0.0,\n )\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None]\n q_src_data = (src_data / data_scale).to(tl.int8)\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]\n tl.store(o_ptrs, q_src_data, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim))\n tl.store(os_ptrs, data_scale, mask=(offs_h[:, None] < head_num))\n\n@torch.no_grad()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K,\n DestLoc,\n Out,\n Out_scale,\n K.stride(0),\n K.stride(1),\n K.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n Out_scale.stride(0),\n Out_scale.stride(1),\n Out_scale.stride(2),\n head_num,\n head_dim,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement two kernels: one for copying data from a source tensor to a destination tensor based on a destination index, and another for copying and quantizing data. The first kernel (_fwd_kernel_destindex_copy_kv) takes 12 parameters: source tensor K, destination index Dest_loc, output tensor Out, strides for K and Out, head number, head dimension, and block sizes. The second kernel (_fwd_kernel_destindex_copy_quantize_kv) takes 15 parameters: source tensor K, destination index Dest_loc, output tensor Out, output scale tensor Out_scale, strides for K, Out, and Out_scale, head number, head dimension, and block sizes. Both kernels use Triton's parallel programming model to perform operations across multiple program instances.", - "description_2": "Use triton language to create kernels for copying and quantizing data with specified block sizes and strides, utilizing parallel execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage1(\n Q,\n K,\n V,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n Mid_O,\n Mid_O_LogExpSum,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n gqa_group_size,\n head_dim,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n seq_start_block = tl.program_id(2)\n cur_kv_head = cur_head // gqa_group_size\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_batch_start_index = seq_start_block * BLOCK_SEQ\n cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d\n\n block_n_size = (\n tl.where(\n cur_batch_end_index - cur_batch_start_index <= 0,\n 0,\n cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,\n )\n // BLOCK_N\n )\n\n offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)\n\n q = tl.load(Q + off_q, mask=offs_d < head_dim, other=0.0)\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, block_n_size, 1):\n offs_n_new = start_n * BLOCK_N + offs_n\n k_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,\n mask=offs_n_new < cur_batch_end_index,\n other=0,\n )\n off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]\n k = tl.load(\n K + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0\n )\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float(\"-inf\"))\n v = tl.load(\n V + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0\n )\n\n cur_max_logic = tl.max(att_value, axis=0)\n new_max_logic = tl.maximum(cur_max_logic, max_logic)\n\n exp_logic = tl.exp(att_value - new_max_logic)\n logic_scale = tl.exp(max_logic - new_max_logic)\n acc *= logic_scale\n acc += tl.sum(exp_logic[:, None] * v, axis=0)\n\n sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)\n max_logic = new_max_logic\n\n need_store = tl.where(block_n_size == 0, 0, 1)\n for _ in range(0, need_store, 1):\n off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d\n off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block\n tl.store(Mid_O + off_mid_o, acc / sum_exp, mask=offs_d < head_dim)\n tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))\n return\n\n@torch.no_grad()\ndef flash_decode_stage1(\n q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq\n):\n BLOCK_SEQ = block_seq\n BLOCK_N = 16\n assert BLOCK_SEQ % BLOCK_N == 0\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n head_dim = Lq\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n sm_scale = 1.0 / (Lk ** 0.5)\n batch, head_num = B_req_idx.shape[0], q.shape[1]\n grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))\n gqa_group_size = q.shape[1] // k.shape[1]\n\n _fwd_kernel_flash_decode_stage1[grid](\n q,\n k,\n v,\n sm_scale,\n Req_to_tokens,\n B_req_idx,\n B_Seqlen,\n mid_out,\n mid_out_logsumexp,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logsumexp.stride(0),\n mid_out_logsumexp.stride(1),\n mid_out_logsumexp.stride(2),\n gqa_group_size,\n head_dim,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK_N,\n num_warps=1,\n num_stages=2,\n )\n return\n", - "description_1": "Use triton language to implement a flash attention decoding kernel for processing query, key, and value matrices in blocks. This kernel computes the scaled dot-product attention over multiple heads and batches. The kernel takes 31 parameters: 6 input tensors (Q, K, V, Req_to_tokens, B_req_idx, B_Seqlen), 2 output tensors (Mid_O, Mid_O_LogExpSum), 17 strides to handle data layout, group size for GQA, head dimension, and 3 block sizes (BLOCK_SEQ, BLOCK_DMODEL, BLOCK_N) specified as constant expressions.", - "description_2": "Use triton language to implement a decoding kernel for multi-head attention, processing query, key, and value matrices in parallel, with a total of 31 parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rotary_kernel(\n Q,\n K,\n Cos,\n Sin,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_cosbs,\n stride_cosd,\n stride_sinbs,\n stride_sind,\n max_total_len,\n HEAD_Q,\n HEAD_K, # N_CTX 代表要计算的上下文长度\n rot_dim,\n head_dim,\n BLOCK_HEAD: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n # dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)\n # dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL)\n dim_range1 = rot_dim + tl.arange(0, BLOCK_DMODEL)\n\n off_q0 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range0[None, None, :] * stride_qd\n )\n off_q1 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range1[None, None, :] * stride_qd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n q0 = tl.load(\n Q + off_q0,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_Q)\n & (dim_range0[None, None, :] < rot_dim),\n other=0.0,\n )\n q1 = tl.load(\n Q + off_q1,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_Q)\n & (dim_range1[None, None, :] < head_dim),\n other=0.0,\n )\n\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos - q1 * sin\n out1 = q0 * sin + q1 * cos\n\n tl.store(\n Q + off_q0,\n out0,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_Q)\n & (dim_range0[None, None, :] < rot_dim),\n )\n tl.store(\n Q + off_q1,\n out1,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_Q)\n & (dim_range1[None, None, :] < head_dim),\n )\n\n off_k0 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range0[None, None, :] * stride_kd\n )\n off_k1 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range1[None, None, :] * stride_kd\n )\n\n off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n\n k0 = tl.load(\n K + off_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_K)\n & (dim_range0[None, None, :] < rot_dim),\n other=0.0,\n )\n k1 = tl.load(\n K + off_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_K)\n & (dim_range1[None, None, :] < head_dim),\n other=0.0,\n )\n cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out_k0 = k0 * cos - k1 * sin\n out_k1 = k0 * sin + k1 * cos\n\n tl.store(\n K + off_k0,\n out_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_K)\n & (dim_range0[None, None, :] < rot_dim),\n )\n tl.store(\n K + off_k1,\n out_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len)\n & (cur_head_range[None, :, None] < HEAD_K)\n & (dim_range1[None, None, :] < head_dim),\n )\n return\n\n@torch.no_grad()\ndef rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0):\n total_len = q.shape[0]\n head_num_q, head_num_k = q.shape[1], k.shape[1]\n head_dim = int(q.shape[2] * partial_rotary_factor)\n rot_dim = head_dim // 2\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f\"k shape {k.shape} cos shape {cos.shape}\"\n\n BLOCK_SEQ = 16\n BLOCK_HEAD = 4\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n BLOCK_DMODEL = triton.next_power_of_2(rot_dim)\n grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n _rotary_kernel[grid](\n q,\n k,\n cos,\n sin,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n cos.stride(0),\n cos.stride(1),\n sin.stride(0),\n sin.stride(1),\n total_len,\n head_num_q,\n head_num_k,\n rot_dim,\n head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement a rotary positional embedding kernel. The kernel _rotary_kernel accepts 23 parameters: two tensors Q and K for input data, Cos and Sin for the trigonometric values, 8 strides for these tensors, max_total_len for limiting sequences, HEAD_Q and HEAD_K for the number of heads, rot_dim and head_dim for dimensions, and BLOCK_HEAD, BLOCK_SEQ, BLOCK_DMODEL for compile-time constants. It performs rotation of vectors by leveraging cosine and sine transformations and writes the output back to Q and K tensors with consideration of specified dimensions and heads. Additionally, a rotary_emb_fwd function invokes this kernel, adjusting block and grid sizes based on the input dimensions.", - "description_2": "Use triton language to build a kernel for rotary positional embeddings in transformers, which processes input tensors Q and K using cosine and sine operations and writes the modified tensors back, controlled by parameters for sequence length and heads.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = tl.program_id(0)\n block_start = pid * 1024\n offsets = block_start + tl.arange(0, 1024)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\ndef call_add_kernel(X, Y, Z, N):\n grid = lambda meta: (triton.cdiv(N, 1024),)\n add_kernel[grid](X, Y, Z, N)\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = torch.empty(1024, device='cuda')\nN = X.numel()\ncall_add_kernel(X, Y, Z, N)\n", - "description_1": "Use triton language to define a kernel 'add_kernel' that takes four parameters: X, Y, Z, and N. X, Y, and Z are pointers to the input and output tensors, and N is the number of elements. The kernel adds corresponding elements of X and Y and stores the result in Z. The kernel is launched with a grid size calculated based on N.", - "description_2": "Use triton language to create a kernel that performs element-wise addition of two input tensors and stores the result in an output tensor, with the kernel launch grid size determined by the number of elements.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel decorated with @triton.jit\n@triton.jit\ndef kernel_function(x_ptr, y_ptr, size, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < size\n x = tl.load(x_ptr + offsets, mask=mask)\n y = x * 2.0 # Example computation\n tl.store(y_ptr + offsets, y, mask=mask)\n\ndef call_kernel(x, y, block_size):\n # Get pointers to the data\n x_ptr = x.data_ptr()\n y_ptr = y.data_ptr()\n size = x.numel()\n # Launch the kernel\n grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']),)\n kernel_function[grid](x_ptr, y_ptr, size, BLOCK_SIZE=block_size)\n", - "description_1": "Use triton language to define a kernel function 'kernel_function' which takes pointers to input arrays 'x_ptr', 'y_ptr', a size 'size', and a block size 'BLOCK_SIZE'. The kernel computes element-wise multiplication of the input array by 2 and stores the result in the output array. It utilizes triton's parallel execution with a specified block size. The function 'call_kernel' sets up the data pointers and launches 'kernel_function' with appropriate grid configuration.", - "description_2": "Use triton language to create a kernel for element-wise multiplication by 2 of input array elements. The kernel should take pointers and size parameters, leveraging parallel computation with a defined block size. Implement a wrapper function to prepare and execute this kernel with triton's grid configuration.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK=block_size)\n\n# Example usage\nx = torch.tensor([1.0, 2.0, 3.0], device='cuda')\ny = torch.tensor([4.0, 5.0, 6.0], device='cuda')\nz = torch.empty_like(x)\ncall_example_kernel(x, y, z, block_size=128)\n", - "description_1": "Use triton language to define a kernel named 'example_kernel' with three parameters: X, Y, Z, and a block size constant. The kernel performs operations on these parameters. A function 'call_example_kernel' is defined to launch this kernel with specified block size and input tensors x, y, z.", - "description_2": "Use triton language to define a kernel with three tensor inputs and a block size, and a function to launch this kernel with specified inputs.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Triton kernel to promote a value to a tensor\n@triton.jit\ndef promote_to_tensor(x):\n return x + tl.zeros((1,), tl.int1)\n\n# Triton kernel to check if a tensor is floating point\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n# Triton kernel to perform element-wise minimum operation\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Triton kernel to perform element-wise maximum operation\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Triton kernel to compute minimum along a specific dimension\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n# Triton kernel to compute maximum along a specific dimension\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n# Triton kernel to perform reduction using Welford's algorithm\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n# Triton kernel to pack a value and a flag into a single tensor\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n# Triton kernel to unpack a value from a packed tensor\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n# Triton kernel to unpack a flag from a packed tensor\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n# Triton kernel to compute exclusive scan using decoupled lookback\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n init,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n exclusive_prefix = init\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n", - "description_1": "Use triton language to create kernels for: 1) Promoting a scalar to a tensor. 2) Checking if tensor elements are floating-point. 3) Element-wise minimum/maximum operations. 4) Minimum/maximum reduction along a dimension. 5) Welford reduction for mean and variance. 6) Packing and unpacking of values and flags. 7) Exclusive scan with decoupled lookback.", - "description_2": "Use triton language to develop kernels for element-wise operations, reductions, packing/unpacking, and exclusive scans with decoupled lookback.", - "difficulty": 4 - }, - { - "code": "import torch\nfrom torch.testing._internal.triton_utils import requires_cuda\nfrom torch.testing._internal.common_cuda import SM80OrLater\nfrom torch._inductor import config\nimport triton\nfrom torch.testing._internal.triton_utils import (\n add_kernel,\n add_kernel_2d_autotuned,\n add_kernel_autotuned,\n add_kernel_with_optional_param,\n)\n\n@requires_cuda\ndef test_triton_kernel(grid_type, num_dims, dynamic, autotune, device):\n class Model(torch.nn.Module):\n def __init__(self):\n super().__init__()\n\n def forward(self, x, y):\n output = torch.zeros_like(x)\n if autotune and num_dims == 2:\n x_elements = output.size()[0]\n y_elements = output.size()[1]\n else:\n n_elements = output.numel()\n\n if autotune and num_dims == 2:\n if grid_type == 1:\n grid = (x_elements, y_elements)\n elif grid_type == 2:\n grid = lambda meta: (\n triton.cdiv(x_elements, meta[\"BLOCK_SIZE_X\"]),\n triton.cdiv(y_elements, meta[\"BLOCK_SIZE_Y\"]),\n )\n else:\n def grid_fn(meta):\n return (\n triton.cdiv(x_elements, meta[\"BLOCK_SIZE_X\"]),\n triton.cdiv(y_elements, meta[\"BLOCK_SIZE_Y\"]),\n )\n grid = grid_fn\n else:\n if grid_type == 1:\n grid = (n_elements,)\n elif grid_type == 2:\n grid = lambda meta: (\n triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),\n )\n else:\n def grid_fn(meta):\n return (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n grid = grid_fn\n\n if autotune:\n if num_dims == 1:\n add_kernel_autotuned[grid](x, y, output, n_elements)\n else:\n add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements)\n else:\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n return output\n\n dims = [10] * num_dims\n x = torch.randn(*dims, device=device)\n y = torch.randn(*dims, device=device)\n dynamic_shapes = []\n if dynamic:\n dim0_x = torch.export.Dim(\"dim0_x\", min=1, max=10)\n dim0_y = torch.export.Dim(\"dim0_y\", min=1, max=10)\n dynamic_shapes = {\"x\": {0: dim0_x}, \"y\": {0: dim0_y}}\n example_inputs = (x, y)\n config.patch({\"profile_bandwidth\": \"1\", \"profile_bandwidth_regex\": \"\"})\n model = Model().to(device)\n with torch.no_grad():\n model(*example_inputs)\n\n# Kernel to test grid configuration with optional dynamic shapes\ndef test_triton_kernel_with_none_input(device):\n class Model(torch.nn.Module):\n def __init__(self):\n super().__init__()\n\n def forward(self, x, y):\n n_elements = x.size()[0]\n BLOCK_SIZE = 1024\n\n output_wo_y = torch.empty_like(x)\n output_with_y = torch.empty_like(x)\n\n wo_kernel = add_kernel_with_optional_param[(1,)](\n x,\n None,\n output_wo_y,\n n_elements,\n ARGS_PASSED=\"one\",\n BLOCK_SIZE=BLOCK_SIZE,\n )\n with_kernel = add_kernel_with_optional_param[(1,)](\n x,\n y,\n output_with_y,\n n_elements,\n ARGS_PASSED=\"two\",\n BLOCK_SIZE=BLOCK_SIZE,\n )\n\n return 2.71 * output_wo_y + 3.14 * output_with_y\n\n example_inputs = (\n torch.randn(1023, device=device),\n torch.randn(1023, device=device),\n )\n\n model = Model().to(device)\n with torch.no_grad():\n model(*example_inputs)\n\n# Kernel to test grid configuration with optional dynamic shapes\ndef test_triton_kernel_equal_to_1_arg(device):\n class Model(torch.nn.Module):\n def forward(self, x, y):\n out = torch.empty_like(x)\n n_elements = x.numel()\n add_kernel[(n_elements,)](x, y, out, n_elements, BLOCK_SIZE=16)\n return out\n\n example_inputs = (\n torch.randn(1, device=device),\n torch.randn(1, device=device),\n )\n\n model = Model().to(device)\n with torch.no_grad():\n model(*example_inputs)\n\n# Kernel to test pass kernel\n@triton.jit\ndef pass_kernel(x, num):\n pass\n\ndef test_triton_kernel_dynamic_shape_with_div(device):\n class Model(torch.nn.Module):\n def __init__(self):\n super().__init__()\n\n def forward(self, x):\n num = x.numel() // 4\n grid = lambda meta: (triton.cdiv(num, 16),)\n pass_kernel[grid](x, num)\n return x\n\n x = torch.randn(10, device=device)\n dim0_x = torch.export.Dim(\"dim0_x\", min=1, max=10)\n dynamic_shapes = {\"x\": {0: dim0_x}}\n model = Model().to(device)\n with torch.no_grad():\n model(x)\n\n# Kernel to test pass kernel\ndef test_triton_kernel_reinterpret_view(device):\n @triton.jit\n def pass_kernel(x, y):\n pass\n\n class Model(torch.nn.Module):\n def __init__(self):\n super().__init__()\n\n def forward(self, x):\n out = torch.zeros_like(x[:, 4:])\n add_kernel[(10,)](\n in_ptr0=x[:, 3:-1],\n in_ptr1=x[:, 4:],\n out_ptr=out,\n n_elements=160,\n BLOCK_SIZE=16,\n )\n return out\n\n example_inputs = (torch.randn(10, 20, device=device),)\n model = Model().to(device)\n with torch.no_grad():\n model(*example_inputs)\n\n# Kernel to test scaled dot product attention\n@requires_cuda\n@unittest.skipIf(not SM80OrLater, \"bfloat16 only supported in sm80+\")\ndef test_sdpa(device):\n class Model(torch.nn.Module):\n def __init__(self):\n super().__init__()\n\n def forward(self, q, k, v):\n return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]\n\n example_inputs = (\n torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=device),\n torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=device),\n torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=device),\n )\n model = Model().to(device)\n with torch.no_grad():\n model(*example_inputs)\n", - "description_1": "Use triton language to test various Triton kernel operations, including grid configuration, dynamic shape handling, kernel with optional and None inputs, reinterpret view, and scaled dot product attention. Each function demonstrates specific capabilities and configurations of Triton kernels, including handling dynamic shapes, optional inputs, and specific data types like bfloat16.", - "description_2": "Use triton language to test Triton kernels for dynamic shape handling and scaled dot product attention.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom torch._C import _cuda_getCurrentRawStream as get_cuda_stream\nfrom torch._inductor.utils import instance_descriptor\nfrom torch._inductor.triton_heuristics import CachingAutotuner, grid, HeuristicType\nfrom torch._dynamo.utils import same\nimport unittest\n\n# Triton kernel to perform element-wise addition of two vectors\ndef autotune(configs, meta):\n def decorator(fn):\n return CachingAutotuner(\n fn,\n triton_meta=meta,\n configs=configs,\n save_cache_hook=False,\n mutated_arg_names=[\"in_out_ptr0\"],\n heuristic_type=HeuristicType.POINTWISE,\n )\n return decorator\n\n@autotune(\n configs=[\n triton.Config({\"XBLOCK\": 1}),\n triton.Config({\"XBLOCK\": 2}),\n ],\n meta={\n \"signature\": {0: \"*fp32\", 1: \"*fp32\", 2: \"i32\"},\n \"device\": 0,\n \"configs\": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],\n \"constants\": {},\n },\n)\n@triton.jit\ndef kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * XBLOCK\n offsets = block_start + tl.arange(0, XBLOCK)\n mask = offsets < xnumel\n x = tl.load(in_out_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr0 + offsets, mask=mask)\n output = x + y\n tl.store(in_out_ptr0 + offsets, output, mask=mask)\n\ndef test_triton_kernel():\n xnumel = 384\n in0 = torch.rand((xnumel,), device=\"cuda\", dtype=torch.float32)\n inout1 = torch.rand((xnumel,), device=\"cuda\", dtype=torch.float32)\n inout2 = inout1.clone()\n\n stream0 = get_cuda_stream(0)\n kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)\n kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)\n\n assert same(inout1, inout2, tol=0.001, equal_nan=True), \"failed autotune with inplace kernel\"\n\nif __name__ == \"__main__\":\n unittest.main(argv=[''], exit=False)\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition of two vectors, leveraging CachingAutotuner for efficient execution. The kernel takes three parameters: the output/input vector, the input vector, and the number of elements. The XBLOCK parameter is used for parallel execution across blocks.", - "description_2": "Use triton language to create an autotuned kernel for in-place element-wise vector addition on CUDA devices.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Triton kernel for fused addition and summation\n@triton.jit\ndef triton_red_fused_add_sum_2(in_out_ptr0, in_ptr0, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 1024\n rnumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n _tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(in_ptr0 + (r1 + (2048 * x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)\n tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])\n tmp3 = _tmp2 + tmp1\n _tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)\n tmp2 = tl.sum(_tmp2, 1)[:, None]\n tmp4 = tl.load(in_out_ptr0 + (x0), xmask, eviction_policy='evict_last')\n tmp5 = tmp4 + tmp2\n tl.debug_barrier()\n tl.store(in_out_ptr0 + (x0), tmp5, xmask)\n", - "description_1": "Use triton language to implement a kernel function 'triton_red_fused_add_sum_2' that performs a fused addition and summation operation. The kernel takes six parameters: two pointers to input/output data, two integer values representing the number of elements in the x and reduction dimensions, and two compile-time constants for block sizes. The kernel iterates over the reduction dimension, loads data, performs element-wise addition, and stores the result back to the output pointer.", - "description_2": "Use triton language to create a kernel for fused addition and summation with parameters for data pointers, element counts, and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n# This Triton kernel demonstrates passing a kernel as a parameter\n@triton.jit\ndef pass_kernel(kernel):\n pass\n\n@requires_cuda\ndef test_triton_kernel_with_kernel_param():\n @torch.compile(backend=\"eager\")\n def f(x):\n grid = (x.numel(),)\n pass_kernel[grid](kernel=x)\n\n t1 = torch.rand(5, device=\"cuda\")\n f(t1)\n\n# Triton kernel demonstrating the use of an inner Triton function\n@requires_cuda\n@common_utils.parametrize(\"backend\", [\"eager\", \"aot_eager\", \"inductor\"])\ndef test_triton_kernel_inner_triton_function(backend):\n def f(x: torch.Tensor):\n @triton.jit\n def pow2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = x * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)\n return output\n\n t = torch.rand(5, device=\"cuda\")\n compiled_func = torch.compile(f, backend=backend, fullgraph=True)\n\n# Triton kernel with the kernel parameters demonstrating use of multiple operations and conditionals\n@requires_cuda\ndef test_triton_kernel_multi_kernel():\n @triton.jit\n def mul2_and_add_and_zero_negatives_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n ):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n if ACTIVATION == \"zero_negs\":\n output = tl.where(output >= 0, output, 0.0)\n tl.store(out_ptr + offsets, output, mask=mask)\n\n @torch.compile\n def call_triton(\n x: torch.Tensor,\n y: torch.Tensor,\n xi: torch.Tensor,\n yi: torch.Tensor,\n output: torch.Tensor,\n outputi: torch.Tensor,\n ):\n n_elements = output.numel()\n grid = (x.numel(),)\n mul2_and_add_and_zero_negatives_kernel[grid](\n x, y, output, n_elements, BLOCK_SIZE=16, ACTIVATION=\"zero_negs\"\n )\n mul2_and_add_and_zero_negatives_kernel[grid](\n xi, yi, outputi, n_elements, BLOCK_SIZE=16, ACTIVATION=None\n )\n return (output, outputi)\n\n t1 = torch.tensor(\n [-2.0, -1.0, 0.0, 1.0, 2.0], device=\"cuda\", requires_grad=False\n )\n t2 = torch.tensor(\n [-2.0, -1.0, 0.0, 1.0, 2.0], device=\"cuda\", requires_grad=False\n )\n float_result = 2 * t1 + 2 * t2\n float_result = float_result.where(float_result >= 0, 0.0)\n\n t1i = torch.randint(-2, 2, (5,), device=\"cuda\")\n t2i = torch.randint(-2, 2, (5,), device=\"cuda\")\n o = torch.zeros_like(t1, requires_grad=False)\n oi = torch.zeros_like(t1i)\n int_result = 2 * t1i + 2 * t2i\n\n (result, resulti) = call_triton(t1, t2, t1i, t2i, o, oi)\n assert torch.equal(float_result, result)\n assert torch.equal(int_result, resulti)\n", - "description_1": "Use triton language to create kernels such as pass_kernel which takes a kernel as an argument and doesn't perform operations on it. Use another kernel, pow2_kernel, within a callable Python function to square elements of a tensor using block and grid strategy in Triton. Additionally, define a more complex kernel, mul2_and_add_and_zero_negatives_kernel, to add elements from two input pointers, multiply them by two, and zero out negative results based on an activation condition. Ensure each kernel uses the grid strategy for execution and outputs are correctly stored in an output tensor.", - "description_2": "Use triton language to create a kernel that processes tensor elements in blocks to achieve operations like element-wise addition and conditionally zeroing negatives. Implement a nested Triton function for squaring tensor elements efficiently in blocks.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n\ndef _run_sampled_addmm_kernel(\n alpha, beta, is_beta_zero,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n):\n n_batches = values.size(0)\n n_block_rows = crow_indices.size(-1) - 1\n\n full_grid = (n_batches, n_block_rows)\n if max_grid is not None:\n grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))\n else:\n grid_blocks = None\n tensor_dims_map = {\n values: (0, None),\n crow_indices: (0, -1),\n col_indices: (0, None),\n mat1: (0, -4),\n mat2: (0, None),\n }\n if values.dtype in (torch.half, torch.bfloat16):\n acc_dtype = tl.float32\n allow_tf32 = True\n else:\n acc_dtype = tl.float64\n allow_tf32 = False\n\n def kernel(grid, *sliced_tensors):\n _sampled_addmm_kernel[grid](\n alpha, beta, is_beta_zero,\n *blocksize, k, tile_k,\n *ptr_stride_extractor(*sliced_tensors),\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n num_stages=1,\n num_warps=4\n )\n\n launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)\n\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n", - "description_1": "Use triton language to implement a sampled addmm kernel function for sparse matrix operations with block sizes, allowing beta and alpha scaling factors, and perform matrix multiplication and addition in a tile-based manner.", - "description_2": "Use triton language to create a block-based matrix multiplication kernel that scales and adds matrices efficiently.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\nfrom triton.language import load, store\n\n# Kernel to add two arrays element-wise\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Kernel to add two arrays element-wise with an optional parameter\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# 2D Autotuned kernel to add two arrays element-wise\n@triton.autotune(\n configs=[\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 128, \"BLOCK_SIZE_Y\": 128}, num_stages=4, num_warps=4\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=3, num_warps=8\n ),\n triton.Config(\n {\"BLOCK_SIZE_X\": 64, \"BLOCK_SIZE_Y\": 64}, num_stages=4, num_warps=4\n ),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_2d_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n x_elements,\n y_elements,\n BLOCK_SIZE_X: \"tl.constexpr\",\n BLOCK_SIZE_Y: \"tl.constexpr\",\n):\n xoffset = tl.program_id(0) * BLOCK_SIZE_X\n xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]\n xmask = xindex < x_elements\n yoffset = tl.program_id(1) * BLOCK_SIZE_Y\n yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]\n ymask = yindex < y_elements\n x1 = xindex\n y0 = yindex\n tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)\n tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)\n tmp2 = tmp0 + tmp1\n tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)\n\n# Kernel to multiply an array by 2\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# In-place kernel to multiply an array by 2\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n# Kernel with indirection and activation\n@triton.jit\ndef indirection_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ACTIVATION: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n if ACTIVATION == \"mul2_inplace_kernel\":\n mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n elif ACTIVATION == \"add_kernel\":\n add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n x = tl.load(in_ptr0 + offsets, mask=mask)\n tl.store(out_ptr + offsets, x, mask=mask)\n\n# Kernel to add two arrays element-wise with import\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to define multiple kernels for element-wise operations on arrays, including addition, multiplication, and conditional operations. Each kernel is parameterized by pointers to input and output arrays, the number of elements to process, and block size. Some kernels are autotuned for performance.", - "description_2": "Use triton language to create kernels for element-wise array operations with parameters for input/output pointers, element count, and block size, including autotuning.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef my_kernel(X, Y, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < X.shape[0]\n x = tl.load(X + offsets, mask=mask)\n y = x * 2\n tl.store(Y + offsets, y, mask=mask)\n\ndef call_my_kernel(X, Y):\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(X.shape[0], meta['BLOCK_SIZE']),)\n my_kernel[grid](X, Y, BLOCK_SIZE=BLOCK_SIZE)\n\n# Example usage\nX = torch.arange(0, 10240, dtype=torch.float32, device='cuda')\nY = torch.empty_like(X)\ncall_my_kernel(X, Y)\n", - "description_1": "Use triton language to define a kernel 'my_kernel' that takes two arguments: X and Y. The kernel multiplies each element of X by 2 and stores the result in Y. The kernel uses a block size of 1024 and handles out-of-bounds accesses with a mask. The kernel is launched with a grid size calculated based on the size of X.", - "description_2": "Use triton language to define a kernel that multiplies each element of an input tensor by 2 and stores the result in an output tensor, handling out-of-bounds accesses with a mask.", - "difficulty": 2 - }, - { - "code": "import triton\n\n# Triton kernel function\n@triton.jit\ndef example_kernel(X, Y, BLOCK_SIZE: tl.constexpr):\n # Kernel code here\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(X, Y):\n BLOCK_SIZE = 1024\n grid = (X.size // BLOCK_SIZE,)\n example_kernel[grid](X, Y, BLOCK_SIZE)\n", - "description_1": "Use triton language to define a kernel function 'example_kernel' that takes two arguments X and Y, and a BLOCK_SIZE as a constexpr. The kernel is executed over a grid determined by the size of X divided by BLOCK_SIZE. The function 'call_example_kernel' is used to invoke this kernel with a specified grid size.", - "description_2": "Use triton language to create a kernel with two input tensors and a block size, and execute it over a grid based on the input size.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef example_kernel(X, Y, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + offsets)\n y = x * 2\n tl.store(Y + offsets, y)\n\ndef call_kernel(X, Y, BLOCK_SIZE):\n grid = lambda meta: (X.size // meta['BLOCK_SIZE'],)\n example_kernel[(grid,)](X, Y, BLOCK_SIZE)\n\n# Example usage\nX = torch.arange(0, 1024, dtype=torch.float32, device='cuda')\nY = torch.empty_like(X)\ncall_kernel(X, Y, BLOCK_SIZE=128)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' that multiplies each element of input tensor X by 2 and stores the result in tensor Y. The kernel uses a block size specified by the BLOCK_SIZE parameter. The function 'call_kernel' sets up the grid and launches the kernel with the given tensors and block size.", - "description_2": "Use triton language to create a kernel that doubles the elements of a tensor and stores the result in another tensor, with a specified block size.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n\n# Example usage\nx = torch.tensor([1.0, 2.0, 3.0])\ny = torch.tensor([4.0, 5.0, 6.0])\nz = torch.empty_like(x)\ncall_example_kernel(x, y, z, block_size=1024)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' with 4 parameters: X, Y, Z, and BLOCK_SIZE. The kernel performs operations on input tensors X, Y, and Z with a specified block size. A function 'call_example_kernel' is used to invoke this kernel with PyTorch tensors and a block size.", - "description_2": "Use triton language to create a kernel that processes three input tensors with a specified block size, and provide a function to call this kernel with PyTorch tensors.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef promote_to_tensor(x):\n # Addition promotes to tensor for us\n return x + tl.zeros((1,), tl.int1)\n\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n # Consider NaNs as equal\n equal |= a_isnan and b_isnan\n\n # Prefer lowest index if values are equal\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n@triton.jit\ndef bucketize_binary_search(\n values, # 1D tensor\n offsets_ptr,\n indexing_dtype,\n right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]\n OFFSETS_SIZE: int,\n BLOCK_SHAPE, # tuple/list of block shape\n):\n \"\"\"\n See [Note: Inductor bucketize op]\n \"\"\"\n\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n # Workaround for triton bug, tensor.to doesn't unwrap constexpr values\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``\n DTYPE_PACK: Unsigned type twice the width of block_value\n\n NOTE: This function is limited to values which are 32-bits or less because\n we need to pack (value, flag) into a single unsigned int.\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n if index > 0:\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], DTYPE_VALUE)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n # tl.atomic_load\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):\n \"\"\"Compute exclusive scan of a scalar value between blocks\n\n Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back\n\n scratch_base: Pointer to scratch space in global memory\n block_value: Scalar value for this block, must be 64-bits wide\n index: Scalar index of this block relative to the current scan\n combine_fn: Function ``(value, value) -> value`` which is scanned over\n init: Scalar value equal to the identiy of combine_fn\n \"\"\"\n # Publish block sum so subsequent blocks don't get stuck waiting for us\n if index > 0:\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n # Calculate exclusive prefix scan\n exclusive_prefix = tl.zeros([], block_value.dtype)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n # Make inclusive block sum visible to other blocks\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n\n return exclusive_prefix\n\n@triton.jit\ndef frexp(x):\n # TODO(isuruf): use inline_asm_elementwise here\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n\n@triton.jit\ndef _compare_and_swap_with_index(\n x,\n idxs,\n valid_mask,\n flip,\n i: tl.constexpr,\n n_dims: tl.constexpr,\n stable: tl.constexpr,\n descending: tl.constexpr,\n):\n n_outer: tl.constexpr = x.numel >> n_dims\n shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]\n\n idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)\n\n y = tl.reshape(x, shape)\n iy = y.to(idtype, bitcast=True)\n # slice left/right with 'stride' 2**(n_dims - i - 1)\n right_mask = tl.arange(0, 2)[None, :, None].to(idtype)\n left_mask = (1 - right_mask).to(idtype)\n ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape)\n iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape)\n ileft = tl.reshape(ileft, x.shape)\n iright = tl.reshape(iright, x.shape)\n left = ileft.to(x.dtype, bitcast=True)\n right = iright.to(x.dtype, bitcast=True)\n\n # idx\n y_idx = tl.reshape(idxs, shape)\n left_idx = tl.broadcast_to(\n tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape\n )\n right_idx = tl.broadcast_to(\n tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape\n )\n left_idx = tl.reshape(left_idx, x.shape)\n right_idx = tl.reshape(right_idx, x.shape)\n\n # valid\n if valid_mask is None:\n left_valid_mask = tl.full(x.shape, True, tl.int1)\n right_valid_mask = tl.full(x.shape, True, tl.int1)\n else:\n y_valid_mask = tl.reshape(valid_mask, shape)\n left_valid_mask = tl.broadcast_to(\n tl.sum(y_valid_mask * left_mask.to(tl.int8), 1)[:, None, :], shape\n ).to(tl.int1)\n right_valid_mask = tl.broadcast_to(\n tl.sum(y_valid_mask * right_mask.to(tl.int8), 1)[:, None, :], shape\n ).to(tl.int1)\n left_valid_mask = tl.reshape(left_valid_mask, x.shape)\n right_valid_mask = tl.reshape(right_valid_mask, x.shape)\n\n # actual compare-and-swap\n ix = x.to(idtype, bitcast=True)\n\n if descending:\n cond = left < right\n else:\n cond = left > right\n\n if stable:\n # When stable sorting, tie break by index\n cond = cond | ((left == right) & (left_idx > right_idx))\n\n cond = (right_valid_mask > left_valid_mask) | (\n (right_valid_mask == left_valid_mask) & cond\n )\n cond = cond ^ flip\n ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))\n new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs))\n if valid_mask is None:\n new_valid_mask = tl.full(x.shape, True, tl.int1)\n else:\n new_valid_mask = valid_mask ^ tl.where(\n cond, left_valid_mask ^ right_valid_mask, tl.zeros_like(valid_mask)\n )\n\n return ret.to(x.dtype, bitcast=True), new_idxs, new_valid_mask\n\n@triton.jit\ndef _bitonic_merge_with_index(\n x,\n idxs,\n mask,\n stage: tl.constexpr,\n alternating: tl.constexpr,\n n_dims: tl.constexpr,\n stable: tl.constexpr,\n descending: tl.constexpr,\n):\n n_outer: tl.constexpr = x.numel >> n_dims\n tl.static_assert(stage <= n_dims)\n # flip denotes whether to re-arrange sub-sequences of elements in ascending or\n # descending order.\n # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage\n # if flip = 00110011... then all the elements will be re-arranged alternatingly (with\n # a stride of 2) at this stage\n if alternating:\n shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]\n flip = tl.reshape(\n tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape\n )\n else:\n flip = False\n # perform `stage` rounds of `compare-and-swap`\n next_mask = mask\n for i in tl.static_range(stage):\n x, idxs, next_mask = _compare_and_swap_with_index(\n x, idxs, mask, flip, i + (n_dims - stage), n_dims, stable, descending\n )\n if mask is not None:\n mask = next_mask\n return x, idxs, next_mask\n\n@triton.jit\ndef sort_with_index(\n x, # value\n idxs, # index\n mask, # mask if current value is valid (invalid values sort to the end)\n dim: tl.constexpr = None,\n stable: tl.constexpr = tl.constexpr(False),\n descending: tl.constexpr = tl.constexpr(False),\n):\n x, idxs = tl.broadcast(x, idxs)\n if mask is not None:\n x, mask = tl.broadcast(x, mask)\n # handle default dimension or check that it is the most minor dim\n _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim\n tl.static_assert(\n _dim == len(x.shape) - 1, \"only minor dimension is currently supported\"\n )\n # iteratively run bitonic merge-sort steps\n n_dims: tl.constexpr = _log2(x.shape[_dim])\n\n for i in tl.static_range(1, n_dims + 1):\n x, idxs, next_mask = _bitonic_merge_with_index(\n x,\n idxs,\n mask,\n i,\n alternating=i < n_dims,\n n_dims=n_dims,\n stable=stable,\n descending=descending,\n )\n if mask is not None:\n mask = next_mask\n return x, idxs\n", - "description_1": "Use triton language to implement various kernels for tensor operations such as product reduction, min/max operations with and without indices, Welford reduction, random number generation, bucketization via binary search, bitonic merge sorting with index, etc., leveraging Triton's parallel programming capabilities.", - "description_2": "Use triton language to define kernels for parallel tensor operations and sorting with advanced features like stable sorting and bitwise operations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = triton.program_id(0)\n block_size = 1024\n offset = pid * block_size + triton.arange(0, block_size)\n mask = offset < N\n x = triton.load(X + offset, mask=mask)\n y = triton.load(Y + offset, mask=mask)\n z = x + y\n triton.store(Z + offset, z, mask=mask)\n\n# Function to call the Triton kernel\ndef add(X, Y):\n assert X.shape == Y.shape\n Z = torch.empty_like(X)\n N = X.numel()\n grid = lambda meta: (triton.cdiv(N, meta['block_size']),)\n add_kernel[grid](X, Y, Z, N)\n return Z\n\n# Example usage\nX = torch.randn(1024, device='cuda')\nY = torch.randn(1024, device='cuda')\nZ = add(X, Y)\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel takes four arguments: X, Y, Z, and N. X and Y are input tensors, Z is the output tensor, and N is the number of elements. The kernel computes the sum of X and Y and stores the result in Z. The function 'add' calls this kernel, ensuring that the input tensors X and Y have the same shape, and returns the result tensor Z.", - "description_2": "Use triton language to implement an element-wise addition kernel with inputs X, Y, and output Z, ensuring the same shape for X and Y, and compute the sum.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom torch._C import _cuda_getCurrentRawStream as get_cuda_stream\nfrom torch._inductor.runtime.hints import DeviceProperties, HeuristicType, instance_descriptor\nfrom torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid\nfrom torch._dynamo.utils import same\nfrom torch._dynamo.testing import rand_strided\n\ndef autotune(configs, meta):\n def decorator(fn):\n return CachingAutotuner(\n fn,\n triton_meta=meta,\n configs=configs,\n save_cache_hook=False,\n mutated_arg_names=[\"in_out_ptr0\"],\n heuristic_type=HeuristicType.POINTWISE,\n )\n return decorator\n\n@autotune(\n configs=[\n triton.Config({\"XBLOCK\": 1}),\n triton.Config({\"XBLOCK\": 2}),\n ],\n meta={\n \"signature\": {0: \"*fp32\", 1: \"*fp32\", 2: \"i32\"},\n \"device\": DeviceProperties.create(torch.device(\"cuda\")),\n \"configs\": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],\n \"constants\": {},\n },\n)\n@triton.jit\ndef kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * XBLOCK\n offsets = block_start + tl.arange(0, XBLOCK)\n mask = offsets < xnumel\n x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0)\n y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0)\n output = x + y\n tl.store(in_out_ptr0 + offsets, output, mask=mask)\n\ndef run_kernel():\n xnumel = 384\n in0 = rand_strided((xnumel,), (1,), device=\"cuda\", dtype=torch.float32)\n inout1 = rand_strided((xnumel,), (1,), device=\"cuda\", dtype=torch.float32)\n inout2 = inout1.clone()\n\n stream0 = get_cuda_stream(0)\n kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)\n kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)\n\n assert same(inout1, inout2, tol=0.001, equal_nan=True), \"failed autotune with inplace kernel\"\n\nrun_kernel()\n", - "description_1": "Use triton language to define a kernel that performs element-wise addition on two input tensors. The kernel is autotuned with two configurations for optimal performance. The kernel takes three arguments: in_out_ptr0 (output tensor), in_ptr0 (input tensor), and xnumel (number of elements). The kernel uses a block size defined by XBLOCK to load, compute, and store the results.", - "description_2": "Use triton language to define and autotune a kernel for element-wise addition of two tensors on a CUDA device.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\n\n# Define a Triton kernel with @triton.jit decorator\n@triton.jit\ndef my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_my_kernel(x, y):\n # Function to invoke the Triton kernel\n my_kernel[(1,)](x_ptr=x, y_ptr=y, BLOCK_SIZE=1024)\n\n", - "description_1": "Use triton language to define a kernel 'my_kernel' that takes two pointers 'x_ptr' and 'y_ptr', and a block size 'BLOCK_SIZE'. It performs operations on these pointers within the kernel. The 'call_my_kernel' function is used to invoke this kernel with input tensors 'x' and 'y'.", - "description_2": "Use triton language to define a kernel with two pointers as input and a block size, then invoke this kernel from a Python function.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel function with 6 parameters: \n# in_out_ptr0 (output pointer), in_ptr0 (input pointer), \n# xnumel (total number of elements in x dimension), rnumel (total number of elements in reduction dimension),\n# XBLOCK and RBLOCK are compile-time constants for block sizes\n@triton.jit\ndef triton_red_fused_add_sum_2(in_out_ptr0, in_ptr0, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n xnumel = 1024\n rnumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n _tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(in_ptr0 + (r1 + (2048 * x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)\n tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])\n tmp3 = _tmp2 + tmp1\n _tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)\n tmp2 = tl.sum(_tmp2, 1)[:, None]\n tmp4 = tl.load(in_out_ptr0 + (x0), xmask, eviction_policy='evict_last')\n tmp5 = tmp4 + tmp2\n tl.debug_barrier()\n tl.store(in_out_ptr0 + (x0), tmp5, xmask)\n", - "description_1": "Use triton language to define a kernel 'triton_red_fused_add_sum_2' which performs a reduction operation. It takes 6 parameters: in_out_ptr0 (output pointer), in_ptr0 (input pointer), xnumel (number of elements in x dimension), rnumel (number of elements in reduction dimension), XBLOCK and RBLOCK (block size constants). The kernel performs element-wise addition and reduction with summation over a 2D grid defined by XBLOCK and RBLOCK.", - "description_2": "Use triton language to create a kernel for a fused addition and sum reduction operation across two dimensions, parametrized by pointers to input/output data, element counts, and block size constants.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to call the Triton kernel\ndef add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n assert x.is_cuda and y.is_cuda\n assert x.shape == y.shape\n n_elements = x.numel()\n output = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\n# Example usage\nx = torch.randn(1024, device='cuda')\ny = torch.randn(1024, device='cuda')\noutput = add(x, y)\n", - "description_1": "Use triton language to implement an element-wise addition kernel. The kernel takes two input pointers (x_ptr, y_ptr) and an output pointer (output_ptr), along with the number of elements (n_elements) to process. The kernel uses a block size (BLOCK_SIZE) to divide the work among threads. Each thread loads elements from the input pointers, performs addition, and stores the result in the output pointer. The function 'add' calls this kernel, ensuring the inputs are on CUDA and have the same shape, and returns the result.", - "description_2": "Use triton language to create a kernel for element-wise addition of two CUDA tensors, and a function to execute this kernel.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\nfrom torch._inductor.utils import get_triton_code\nfrom torch._C import FileCheck\nimport unittest\n\ndef mock_triton_hash_with_backend(*args, **kwargs):\n return \"\".join(random.choices(string.ascii_uppercase + string.digits, k=64))\n\ndef test_open_device_registration():\n device = torch.device(\"cpu\")\n x = torch.empty(2, 16).fill_(1).to(device)\n\n def foo(x):\n return torch.sin(x) + x.min()\n\n opt_fn = torch.compile(foo)\n\n with unittest.mock.patch(\n \"torch.utils._triton.triton_hash_with_backend\",\n new=mock_triton_hash_with_backend,\n ):\n code = get_triton_code(opt_fn, x)\n\n FileCheck().check(\"import triton\").check(\"@triton.jit\").check(\n \"tl_math.sin\"\n ).check(\"device_str='cpu'\").run(code)\n", - "description_1": "Use triton language to implement a kernel that computes the sine of a tensor and adds the minimum value of the tensor to each element. The kernel is invoked using a compiled function. The test checks if the generated code imports Triton, uses @triton.jit, and includes specific computations.", - "description_2": "Use triton language to create a kernel that computes the sine and minimum of a tensor. Invoke it with a compiled function and check the code generation for Triton specific syntax.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n# Define shared triton constants here.\nCONSTANT_C: tl.constexpr = 4\nSTRING_CONSTANT_C: tl.constexpr = \"CONSTANT_C\"\nBOOL_CONSTANT_C: tl.constexpr = True\n\n@triton.jit\ndef pass_kernel(kernel):\n pass\n\ndef f(x):\n grid = (x.numel(),)\n pass_kernel[grid](kernel=x)\n\n@triton.jit\ndef pow2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = x * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef f(x: torch.Tensor):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)\n return output\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n BLOCK_SIZE: \"tl.constexpr\",\n out_ptr,\n n_elements,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef f(x, y):\n out = torch.zeros_like(x)\n n_elements = x.numel()\n add_kernel[(n_elements,)](x, y, 4, out, n_elements)\n return out\n", - "description_1": "Use triton language to define and execute kernels for element-wise operations. The 'pass_kernel' takes a single tensor as input and does nothing. The 'pow2_kernel' takes two pointers, the number of elements, and a block size, computes the square of each element, and stores the result. The 'add_kernel' takes two input pointers, an output pointer, the number of elements, and a block size, adds corresponding elements from the input pointers, and stores the result in the output pointer.", - "description_2": "Use triton language to define kernels for squaring elements and adding elements from two tensors, and execute these kernels on GPU.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport math\nfrom typing import Optional, Tuple\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs\n + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :], other=0.0\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None], other=0.0\n )\n\n acc_block += tl.dot(mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype)\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\"\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape\n and out._nnz() == input_broadcasted._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\"\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha, beta, beta == 0.0,\n blocksize, k, tile_k,\n values, crow_indices, col_indices,\n mat1, mat2,\n max_grid\n )\n\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\ndef _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None\n):\n f_name = \"_scaled_dot_product_attention\"\n check(\n not is_causal,\n f\"{f_name}(): is_causal == True is not supported.\"\n )\n check(\n attn_mask is not None,\n f\"{f_name}(): attn_mask == None is not supported.\"\n )\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\"\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\"\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\n", - "description_1": "Use triton language to implement a sampled matrix multiplication kernel and a scaled dot product attention function. The sampled_addmm function takes 6 parameters: input (a sparse tensor in BSR format), mat1 (a dense tensor), mat2 (a dense tensor), and optional parameters beta, alpha, and out. It performs a matrix multiplication of mat1 and mat2, scaled by alpha, and adds it to the input scaled by beta. The _scaled_dot_product_attention function takes 7 parameters: query, key, value (all dense tensors), attn_mask (a sparse tensor in BSR format), dropout_p, is_causal, and scale. It computes the scaled dot product attention using the sampled_addmm function, applies a softmax, and optionally applies dropout.", - "description_2": "Use triton language to create a kernel for sampled matrix multiplication and a function for scaled dot product attention with dropout support.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\n\n# This kernel performs element-wise addition on two input arrays with optional parameters.\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Description for add_kernel_with_optional_param\n# Parameters:\n# in_ptr0, in_ptr1, out_ptr: pointers to input and output data\n# n_elements: total number of elements to process\n# ARGS_PASSED: a constexpr string that dictates if two input arrays should be used\n# BLOCK_SIZE: the size of the block\n\n# This kernel is an autotuned version for adding two arrays.\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Description for add_kernel_autotuned\n# Parameters:\n# in_ptr0, in_ptr1, out_ptr: pointers to input and output data\n# n_elements: total number of elements to process\n# BLOCK_SIZE: the size of the block\n\n# This kernel performs element-wise multiplication with a scaling factor.\n@triton.jit\ndef add_kernel_with_scaling(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n scaling_factor,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = (x + y) * scaling_factor\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Description for add_kernel_with_scaling\n# Parameters:\n# in_ptr0, in_ptr1, out_ptr: pointers to input and output data\n# n_elements: total number of elements to process\n# scaling_factor: factor to scale the result\n# BLOCK_SIZE: the size of the block\n\n# This kernel uses inline assembly to perform an operation on inputs.\n@triton.jit\ndef inline_asm_kernel(X, Y, Z, n: \"tl.constexpr\", BLOCK: \"tl.constexpr\"):\n x = tl.load(X + tl.arange(0, BLOCK))\n y = tl.load(Y + tl.arange(0, BLOCK))\n s = tl.full([BLOCK], n, tl.int32)\n z = tl.inline_asm_elementwise(\n \"shf.l.wrap.b32 $0, $1, $2, $3;\",\n \"=r,r, r, r\",\n [x, y, s],\n dtype=tl.int32,\n is_pure=True,\n pack=1,\n )\n tl.store(Z + tl.arange(0, BLOCK), z)\n\n# Description for inline_asm_kernel\n# Parameters:\n# X, Y, Z: pointers to input and output data\n# n: a constexpr integer used in inline assembly\n# BLOCK: block size\n", - "description_1": "Use triton language to implement various element-wise operations on input arrays, including addition with optional parameters, addition with scaling, and operations with inline assembly.", - "description_2": "Use triton language to optimize array operations with features like optional parameters, scaling, and inline assembly.", - "difficulty": 3 - }, - { - "code": "import triton\n\n@triton.jit\ndef example_kernel(x_ptr, y_ptr, N):\n # Triton kernel that adds two vectors x and y\n # x_ptr: pointer to the first input vector (float32)\n # y_ptr: pointer to the second input vector (float32)\n # N: size of the vectors (int32)\n i = tl.program_id(0)\n if i < N:\n x_ptr[i] += y_ptr[i]\n\ndef call_example_kernel(x, y, N):\n # Call the Triton kernel\n # x: first input vector (torch.Tensor)\n # y: second input vector (torch.Tensor)\n # N: size of the vectors (int)\n grid = (N,)\n example_kernel[grid](x, y, N)\n", - "description_1": "Use triton language to define a kernel that adds elements of two input vectors x and y of size N, and a function to launch this kernel with specified grid size.", - "description_2": "Use triton language to perform element-wise addition of two vectors using a kernel.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]\n Pi = exp(xi) / sum(exp(xi))\n CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]\n = -y [ x - log[sum(exp(x))] ]\n = y * (log[sum(exp(x))] - x)\n If y == 0: CE_i = 0\n If y == 1: CE_i = logsumexp - x\n\n logsumexp is also stable\n Take y = log[sum(exp(x))]\n exp(y) = sum(exp(x))\n exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x\n exp(y) = exp(c)*sum(exp(x - c))\n y = log(exp(c)*sum(exp(x - c)))\n y = c + log[sum(exp(x - c))]\n This means we can set c = max(x) to make sure\n exp(x - c) always is exp(x - max(x)).\n This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.\n \"\"\"\n row_idx = tl.program_id(0)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx\n labels_ptr += row_idx\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = logsumexp - x\n else:\n loss = 0.0\n tl.store(logsumexp_ptr, logsumexp)\n tl.store(loss_ptr, loss)\npass\n\n\n@triton.jit\ndef _chunked_cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n N_CHUNKS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n 256K vocab divided in 4 chunks\n\n |-65536-| |-65536-| |-65536-| |-65536-|\n |-------| |-------| |-------| |-------|\n |-------| |-------| |-------| |-------|\n\n If y == 0: CE_i = 0\n If y == 1: CE_i = logsumexp - x\n\n Notice we can do logsumexp for each chunk and then\n logsumexp[chunk_sum(logsumexp)] == logsumexp\n\n chunk_sum = log[chunk_sum(logsumexp)]\n = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]\n = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]\n = log[sum(exp(a)) + ... + sum(exp(z))]\n = logsumexp(x)\n\n This means we can perform a logsumexp for each chunk, then do a\n final logsumexp reduction!\n\n Ie do: logsumexp(chunked_logsumexp) - x\n \"\"\"\n row_idx = tl.program_id(0)\n chunk_idx = tl.program_id(1)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n labels_ptr += row_idx\n\n col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if chunk_idx == 0:\n # logsumexp(chunked_logsumexp) - x\n # Do the -x separately\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n loss = -1.0 * x\n else:\n loss = 0.0\n tl.store(loss_ptr, loss)\n pass\n tl.store(logsumexp_ptr, logsumexp)\npass\n\n\n@triton.jit\ndef _cross_entropy_backward(\n logits_ptr, logits_row_stride,\n dloss_ptr, dloss_row_stride,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n CE_i = -y log(P) = y * (log[sum(exp(x))] - x)\n dC/dx = d/dx (y * log[sum(exp(x))] - x * y)\n\n From https://en.wikipedia.org/wiki/LogSumExp\n d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)\n\n dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)\n dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick\n dC/dx = y * exp[x - logsumexp] - d/dx (x * y)\n\n If y == 0: dC/dx = 0\n If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1\n If y == 1 and x != label: dC/dx = exp[x - logsumexp]\n \"\"\"\n row_idx = tl.program_id(0)\n block_idx = tl.program_id(1)\n\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n dloss_ptr += row_idx * dloss_row_stride\n col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n if label_idx != -100:\n dloss = tl.load(dloss_ptr)\n else:\n dloss = 0.0\n\n x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\")).to(tl.float32)\n logsumexp = tl.load(logsumexp_ptr + row_idx)\n y = tl.exp(x - logsumexp)\n y = tl.where(\n col_offsets == label_idx,\n y - 1.0, # exp(x - logsumexp) - 1\n y, # exp(x - logsumexp)\n )\n\n # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.\n tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)\npass\n\n\ndef _cross_entropy_forward_impl(logits, labels):\n n_rows, vocab_size = logits.shape\n\n div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n n_chunks = div + (mod != 0)\n losses = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n if n_chunks == 1:\n # For small vocabs <= 65336 like Llama, Mistral\n BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n logsumexp = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n\n _cross_entropy_forward[(n_rows,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n else:\n # For large vocabs > 65336 like Gemma 256K\n logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda\")\n\n _chunked_cross_entropy_forward[(n_rows, n_chunks,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n N_CHUNKS = n_chunks,\n BLOCK_SIZE = MAX_FUSED_SIZE,\n num_warps = 32,\n )\n # logsumexp(chunked_logsumexp) - x\n # Do the -x separately\n logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum\n losses += logsumexp\n losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!\n\n return losses, logsumexp\n\n\ndef _cross_entropy_backward_impl(dlosses, logits, logsumexp, labels):\n n_rows, vocab_size = logits.shape\n\n BLOCK_SIZE = 4096\n div, mod = divmod(vocab_size, BLOCK_SIZE)\n n_blocks = div + (mod != 0)\n\n _cross_entropy_backward[(n_rows, n_blocks,)](\n logits, logits.stride(0),\n dlosses, dlosses.stride(0),\n logsumexp,\n labels,\n VOCAB_SIZE = vocab_size,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = 8,\n )\n return logits\n", - "description_1": "Use triton language to implement cross-entropy forward and backward kernels. The forward kernel computes the cross-entropy loss and logsumexp for given logits and labels, handling both small and large vocabulary sizes. The backward kernel computes the gradient of the cross-entropy loss with respect to the logits. The forward implementation function decides whether to use a single kernel or a chunked approach based on the vocabulary size, while the backward implementation function applies the backward kernel to compute gradients.", - "description_2": "Use triton language to implement cross-entropy loss computation and its gradient calculation for varying vocabulary sizes.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\nROPE_GROUP_SIZE = 4\n\n@triton.jit\ndef _rope_embedding(\n Q, Q_row_stride,\n cos, cos_row_stride,\n sin, sin_row_stride,\n seqlen,\n head_dim : tl.constexpr,\n n_heads : tl.constexpr,\n BACKWARD_PASS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n ROPE_GROUP_SIZE : tl.constexpr = 4,\n):\n \"\"\"\n Calculates the RoPE Embedding quickly\n RoPE is Q * cos + rotate_half(Q) * sin\n See our blog post for more info\n \"\"\"\n row_position = tl.program_id(0)\n group_head_position = tl.program_id(1)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n half_head_dim = head_dim // 2\n mask = col_offsets < half_head_dim\n\n sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n\n if BACKWARD_PASS:\n # See our blog post for more info.\n sin1 = -sin1\n\n head_start = group_head_position * ROPE_GROUP_SIZE\n head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)\n\n for k in range(head_start, head_end):\n offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets\n offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim\n\n Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n\n tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n\ndef _rope_embedding_forward_impl(Q, cos, sin):\n Q = Q.transpose(1, 2).clone()\n cos, sin = cos.squeeze(), sin.squeeze()\n batch, seq_len, n_heads, head_dim = Q.shape\n Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n n_rows, n_cols = Q.shape\n assert(seq_len <= cos.shape[0])\n\n BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n\n div, mod = divmod(n_heads, ROPE_GROUP_SIZE)\n n_groups = div + (mod != 0)\n\n _rope_embedding[(n_rows, n_groups, )](\n Q, Q.stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len,\n head_dim, n_heads,\n BACKWARD_PASS = False,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n Q = Q.view(batch, seq_len, n_heads, head_dim)\n Q = Q.transpose(1, 2)\n return Q, cos, sin, n_groups, BLOCK_SIZE, num_warps\n\ndef _rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps):\n dY = dY.transpose(1, 2)\n batch, seq_len, n_heads, head_dim = dY.shape\n dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n n_rows, n_cols = dY.shape\n\n _rope_embedding[(n_rows, n_groups, )](\n dY, dY .stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len, head_dim, n_heads,\n BACKWARD_PASS = True,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n dY = dY.view(batch, seq_len, n_heads, head_dim)\n dY = dY.transpose(1, 2)\n return dY\n", - "description_1": "Use triton language to implement a RoPE embedding kernel that computes the rotary position embedding for input tensor Q using cosine and sine values. The kernel is parameterized by sequence length, head dimension, number of heads, and block size. It supports both forward and backward passes.", - "description_2": "Use triton language to create a kernel for rotary position embedding with support for forward and backward computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for element-wise operations on tensors e and g\n@triton.jit\ndef _fg_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n # f = e * sigmoid(e)\n f_row = e_row * tl.sigmoid(e_row)\n f_row = f_row.to(g_row.dtype)\n # h = f * g\n h_row = f_row * g_row\n\n # Store h\n tl.store(h + offsets, h_row, mask=mask)\n\n# Function to launch the _fg_kernel\ndef swiglu_fg_kernel(e, g):\n batch, seq_len, hd = e.shape\n n_elements = e.numel()\n h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=\"cuda\")\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE=1024)\n return h\n\n# Triton kernel for computing derivatives\n@triton.jit\ndef _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr):\n block_idx = tl.program_id(0)\n offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n DW_row = tl.load(DW + offsets, mask=mask, other=0)\n e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n g_row = tl.load(g + offsets, mask=mask, other=0)\n\n se_row = tl.sigmoid(e_row)\n f_row = se_row * e_row\n f_row = f_row.to(DW_row.dtype)\n h_row = f_row * g_row\n df_row = DW_row * f_row\n dg_row = DW_row * g_row\n de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))\n de_row = de_row.to(DW_row.dtype)\n\n # Store derivatives in buffers\n tl.store(DW + offsets, h_row, mask=mask)\n tl.store(e + offsets, df_row, mask=mask)\n tl.store(g + offsets, de_row, mask=mask)\n\n# Function to launch the _DWf_DW_dfg_kernel\ndef swiglu_DWf_DW_dfg_kernel(DW, e, g):\n batch_seq_len, hd = e.shape\n n_elements = e.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE=1024)\n return DW, e, g\n", - "description_1": "Use triton language to implement two kernels: one for element-wise operations on tensors e and g, and another for computing derivatives. The first kernel (_fg_kernel) takes 5 parameters: e, g, h, n_elements, and BLOCK_SIZE. It computes f = e * sigmoid(e) and h = f * g, storing the result in h. The second kernel (_DWf_DW_dfg_kernel) takes 5 parameters: DW, e, g, n_elements, and BLOCK_SIZE. It computes derivatives df, dg, and de based on the input tensors and stores them in the respective buffers.", - "description_2": "Use triton language to create kernels for element-wise tensor operations and derivative computations, with parameters for input tensors, number of elements, and block size.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef triton_cross_scan_flex(\n x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)\n y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)\n x_layout: tl.constexpr,\n y_layout: tl.constexpr,\n operation: tl.constexpr,\n onebyone: tl.constexpr,\n scans: tl.constexpr,\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n pos_h = (i_h * BH + tl.arange(0, BH)[:, None])\n pos_w = (i_w * BW + tl.arange(0, BW)[None, :])\n neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])\n neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])\n if scans == 0:\n HWRoute0 = pos_h * DW + pos_w\n HWRoute1 = pos_w * DH + pos_h\n HWRoute2 = neg_h * DW + neg_w\n HWRoute3 = neg_w * DH + neg_h\n elif scans == 1:\n HWRoute0 = pos_h * DW + pos_w\n HWRoute1 = HWRoute0\n HWRoute2 = HWRoute0\n HWRoute3 = HWRoute0\n elif scans == 2:\n HWRoute0 = pos_h * DW + pos_w\n HWRoute1 = HWRoute0\n HWRoute2 = neg_h * DW + neg_w\n HWRoute3 = HWRoute2 \n\n _tmp1 = DC * DH * DW\n\n y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)\n if y_layout == 0:\n p_y1 = y_ptr_base + HWRoute0\n p_y2 = y_ptr_base + _tmp1 + HWRoute1\n p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2\n p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3\n else:\n p_y1 = y_ptr_base + HWRoute0 * 4 * DC\n p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC\n p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC\n p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC \n \n if onebyone == 0:\n x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)\n if x_layout == 0:\n p_x = x_ptr_base + HWRoute0\n else:\n p_x = x_ptr_base + HWRoute0 * DC\n\n if operation == 0:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n _x = tl.load(p_x + _idx_x, mask=_mask_hw)\n tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)\n tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)\n tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)\n tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)\n elif operation == 1:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)\n _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)\n _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)\n _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)\n tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)\n\n else:\n x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)\n if x_layout == 0:\n p_x1 = x_ptr_base + HWRoute0\n p_x2 = p_x1 + _tmp1\n p_x3 = p_x2 + _tmp1\n p_x4 = p_x3 + _tmp1 \n else:\n p_x1 = x_ptr_base + HWRoute0 * 4 * DC\n p_x2 = p_x1 + DC\n p_x3 = p_x2 + DC\n p_x4 = p_x3 + DC \n \n if operation == 0:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)\n else:\n for idxc in range(_for_C):\n _idx_x = idxc * DH * DW if x_layout == 0 else idxc\n _idx_y = idxc * DH * DW if y_layout == 0 else idxc\n tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)\n tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)\n tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)\n tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)\n\n\nclass CrossScanTritonF(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):\n if one_by_one:\n if in_channel_first:\n B, _, C, H, W = x.shape\n else:\n B, H, W, _, C = x.shape\n else:\n if in_channel_first:\n B, C, H, W = x.shape\n else:\n B, H, W, C = x.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = 1, 32, 32\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n \n ctx.in_channel_first = in_channel_first\n ctx.out_channel_first = out_channel_first\n ctx.one_by_one = one_by_one\n ctx.scans = scans\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n\n y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))\n triton_cross_scan_flex[(NH * NW, NC, B)](\n x.contiguous(), y, \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, \n BC, BH, BW, C, H, W, NH, NW\n )\n return y\n \n @staticmethod\n def backward(ctx, y: torch.Tensor):\n in_channel_first = ctx.in_channel_first\n out_channel_first = ctx.out_channel_first\n one_by_one = ctx.one_by_one\n scans = ctx.scans\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n if one_by_one:\n x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))\n else:\n x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))\n \n triton_cross_scan_flex[(NH * NW, NC, B)](\n x, y.contiguous(), \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,\n BC, BH, BW, C, H, W, NH, NW\n )\n return x, None, None, None, None\n\n\nclass CrossMergeTritonF(torch.autograd.Function):\n @staticmethod\n def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):\n if out_channel_first:\n B, _, C, H, W = y.shape\n else:\n B, H, W, _, C = y.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = 1, 32, 32\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.in_channel_first = in_channel_first\n ctx.out_channel_first = out_channel_first\n ctx.one_by_one = one_by_one\n ctx.scans = scans\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n if one_by_one:\n x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))\n else:\n x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))\n triton_cross_scan_flex[(NH * NW, NC, B)](\n x, y.contiguous(), \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,\n BC, BH, BW, C, H, W, NH, NW\n )\n return x\n \n @staticmethod\n def backward(ctx, x: torch.Tensor):\n in_channel_first = ctx.in_channel_first\n out_channel_first = ctx.out_channel_first\n one_by_one = ctx.one_by_one\n scans = ctx.scans\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))\n triton_cross_scan_flex[(NH * NW, NC, B)](\n x.contiguous(), y, \n (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,\n BC, BH, BW, C, H, W, NH, NW\n )\n return y, None, None, None, None, None\n\n\ndef cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):\n CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF\n with torch.cuda.device(x.device):\n return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)\n\n\ndef cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):\n CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF\n with torch.cuda.device(y.device):\n return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)\n", - "description_1": "Use triton language to implement a flexible cross scan and merge operation on tensors. The kernel function 'triton_cross_scan_flex' takes 14 parameters: two tensors (x and y), and 12 constexpr parameters that define the layout, operation type, and dimensions. The function performs different operations based on the 'operation' and 'scans' parameters, storing results in the output tensor y. The 'CrossScanTritonF' and 'CrossMergeTritonF' classes wrap this kernel for forward and backward passes, handling different tensor layouts and operations. The 'cross_scan_fn' and 'cross_merge_fn' functions are used to apply these operations, selecting between Triton and PyTorch implementations based on the environment.", - "description_2": "Use triton language to create a kernel for cross scan and merge operations on tensors, with support for different layouts and operations, and wrap it for use in PyTorch autograd.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Forward Triton kernel for Swish-Gated Linear Units (Swiglu)\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n# Function to invoke the forward kernel\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n# Backward Triton kernel for Swish-Gated Linear Units (Swiglu)\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row,\n stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n# Function to invoke the backward kernel\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](\n x, y, dout, out if recompute_output else None, dx, dy, x.stride(0), y.stride(0),\n dout.stride(0), out.stride(0) if recompute_output else 0, dx.stride(0),\n dy.stride(0), N\n )\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n", - "description_1": "Use triton language to implement forward and backward kernels for Swish-Gated Linear Units (Swiglu). The forward kernel takes 8 arguments: X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, and BLOCK_N. It computes the element-wise Swiglu forward operation and stores the result in OUT. The backward kernel takes 14 arguments: X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, and BLOCK_N, and an additional RECOMPUTE_OUTPUT flag. It computes the gradients of X and Y using the Swiglu backward operation.", - "description_2": "Use triton language to compute the Swiglu activation's forward pass with 8 parameters, and its backward pass with 14 parameters and a recompute flag, to efficiently perform deep learning operations on GPUs.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Z, # pointer to the other branch\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_z_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n HAS_Z: tl.constexpr,\n NORM_BEFORE_GATE: tl.constexpr,\n IS_RMS_NORM: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = tl.program_id(1)\n X += row * stride_x_row + group * N\n Y += row * stride_y_row + group * N\n if HAS_Z:\n Z += row * stride_z_row + group * N\n if not IS_RMS_NORM:\n Mean += group * M\n Rstd += group * M\n W += group * N\n if HAS_BIAS:\n B += group * N\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n if HAS_Z and not NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n x *= z * tl.sigmoid(z)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n if HAS_Z and NORM_BEFORE_GATE:\n z = tl.load(Z + cols, mask=mask).to(tl.float32)\n y *= z * tl.sigmoid(z)\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n M, N = x.shape\n if group_size is None:\n group_size = N\n assert N % group_size == 0\n ngroups = N // group_size\n assert x.stride(-1) == 1\n if z is not None:\n assert z.stride(-1) == 1\n assert z.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n # allocate output\n if out is not None:\n assert out.shape == x.shape\n else:\n out = torch.empty_like(x)\n assert out.stride(-1) == 1\n mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n if group_size > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_N // 256, 1), 8)\n grid = (M, ngroups)\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n M, group_size, eps,\n BLOCK_N=BLOCK_N,\n NORM_BEFORE_GATE=norm_before_gate,\n IS_RMS_NORM=is_rms_norm,\n num_warps=num_warps)\n return out, mean, rstd\n", - "description_1": "Use triton language to implement a layer normalization forward pass kernel with parameters for input, output, weights, biases, additional branch, mean, and reciprocal standard deviation. The kernel computes mean and variance, normalizes the input, applies a linear transformation, and optionally applies a gating mechanism. The forward function sets up the kernel execution with appropriate grid and block sizes.", - "description_2": "Use triton language to implement a layer normalization forward pass kernel and its corresponding Python function to execute the kernel with specified parameters and configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _selective_scan_update_kernel(\n # Pointers to matrices\n state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,\n # Matrix dimensions\n batch, nheads, dim, dstate, nheads_ngroups_ratio,\n # Strides\n stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n stride_x_batch, stride_x_head, stride_x_dim,\n stride_dt_batch, stride_dt_head, stride_dt_dim,\n stride_dt_bias_head, stride_dt_bias_dim,\n stride_A_head, stride_A_dim, stride_A_dstate,\n stride_B_batch, stride_B_group, stride_B_dstate,\n stride_C_batch, stride_C_group, stride_C_dstate,\n stride_D_head, stride_D_dim,\n stride_z_batch, stride_z_head, stride_z_dim,\n stride_out_batch, stride_out_head, stride_out_dim,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt) # scalar, not a matrix\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n if not TIE_HDIM:\n dB = B[None, :] * dt[:, None]\n else:\n dB = B * dt # vector of size (dstate,)\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n \"\"\"\n Argument:\n state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n x: (batch, dim) or (batch, nheads, dim)\n dt: (batch, dim) or (batch, nheads, dim)\n A: (dim, dstate) or (nheads, dim, dstate)\n B: (batch, dstate) or (batch, ngroups, dstate)\n C: (batch, dstate) or (batch, ngroups, dstate)\n D: (dim,) or (nheads, dim)\n z: (batch, dim) or (batch, nheads, dim)\n dt_bias: (dim,) or (nheads, dim)\n Return:\n out: (batch, dim) or (batch, nheads, dim)\n \"\"\"\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n # We don't want autotune since it will overwrite the state\n # We instead tune by hand.\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n else ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else\n ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state, x, dt, dt_bias, A, B, C, D, z, out,\n batch, nheads, dim, dstate, nheads // ngroups,\n state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n x.stride(0), x.stride(1), x.stride(2),\n dt.stride(0), dt.stride(1), dt.stride(2),\n *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0), A.stride(1), A.stride(2),\n B.stride(0), B.stride(1), B.stride(2),\n C.stride(0), C.stride(1), C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0], z_strides[1], z_strides[2],\n out.stride(0), out.stride(1), out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to define a kernel '_selective_scan_update_kernel' that performs a matrix update operation based on input state, x, dt, A, B, C, D, z, dt_bias, and a set of meta-parameters for matrix dimensions and strides. The kernel applies transformations involving dt, A, B, C, D, and z with optional conditions controlled by meta-parameters. Implement a function 'selective_state_update' that prepares the data and grid for the kernel launch, handling broadcasting and optional arguments, and invokes the kernel with the prepared parameters.", - "description_2": "Use triton language to implement a selective matrix update kernel that applies element-wise transformations using pointers and meta-parameters, and a wrapper function to configure and launch this kernel efficiently with given input tensors.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n int(seqlen), int(chunk_size), int(k), int(ngroups if has_groups else 1),\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n int(seqlen), int(chunk_size), int(k), int(ngroups if has_groups else 1),\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n", - "description_1": "Use triton language to implement two kernels: _bmm_chunk_fwd_kernel and _bmm_chunk_bwd_kernel. The _bmm_chunk_fwd_kernel performs a batched matrix multiplication with optional sequence index masking and causal masking. It takes 24 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. The _bmm_chunk_bwd_kernel computes the gradient of the batched matrix multiplication with respect to one of the input matrices. It takes 23 parameters: pointers to input matrices, matrix dimensions, strides, and meta-parameters for configuration. Both kernels are called by their respective wrapper functions _bmm_chunk_fwd and _bmm_chunk_bwd, which handle input preparation and kernel invocation.", - "description_2": "Use triton language to create two kernels for forward and backward batched matrix multiplication with optional sequence and causal masking, and implement wrapper functions to prepare inputs and invoke these kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n stride_D_head,\n IS_CAUSAL: tl.constexpr,\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_Z: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n # Triton kernel implementation\n pass\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = C.shape\n assert nheads % ngroups == 0\n assert C.shape == (batch, seqlen, ngroups, dstate)\n assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n if z is not None:\n assert z.shape == x.shape\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n if z is not None:\n out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n assert out_x.stride() == out.stride()\n else:\n out_x = None\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n if z is not None else (0, 0, 0, 0))\n _chunk_scan_fwd_kernel[grid](\n cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n int(chunk_size), int(headdim), int(dstate),\n int(batch), int(seqlen), int(nheads // ngroups),\n cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n D.stride(0) if D is not None else 0,\n True,\n D is not None,\n D.dim() == 2 if D is not None else True,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(int(dstate)), 16),\n HAS_Z=z is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n IS_TRITON_22=TRITON_22,\n )\n return out, out_x\n", - "description_1": "Use triton language to implement a forward kernel for chunked scan operations. The kernel takes pointers to input matrices and performs operations based on the provided dimensions, strides, and meta-parameters. The function _chunk_scan_fwd sets up the necessary parameters and calls the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to implement a forward kernel for chunked scan operations with input matrices and meta-parameters, and a function to set up and call this kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom torch import Tensor\n\n# Kernel for backward pass of chunk scan with chunk state\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n # Pointers to matrices\n x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n b_ptr, dstates_ptr,\n dx_ptr, ddt_ptr, dD_ptr,\n # Matrix dimensions\n chunk_size, hdim, dstate,\n batch, seqlen, nheads_ngroups_ratio,\n # Strides\n stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_D_head,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n # Meta-parameters\n HAS_D: tl.constexpr,\n D_HAS_HDIM: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n IS_TRITON_22: tl.constexpr,\n):\n pid_bc = tl.program_id(axis=1)\n pid_c = pid_bc // batch\n pid_b = pid_bc - pid_c * batch\n pid_h = tl.program_id(axis=2)\n num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n if not HAS_SEQ_IDX:\n scale = tl.exp(dA_cs_last - dA_cs_m)\n else:\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc = tl.dot(b, dstates) * scale[:, None]\n else:\n for k in range(0, dstate, BLOCK_SIZE_K):\n b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n dstates = dstates.to(b_ptr.dtype.element_ty)\n acc += tl.dot(b, dstates)\n b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n acc *= scale[:, None]\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n K_MAX = chunk_size_limit\n K_MIN = pid_m * BLOCK_SIZE_M\n cb_ptrs += K_MIN * stride_cb_csize_k\n dout_ptrs += K_MIN * stride_dout_seqlen\n dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n k = tl.multiple_of(k, BLOCK_SIZE_K)\n cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n mask = k + offs_k[None, :] >= offs_m[:, None]\n cb = tl.where(mask, cb, 0.0)\n cb = cb.to(dout_ptr.dtype.element_ty)\n acc += tl.dot(cb, dout)\n cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n dx = acc * dt_m[:, None]\n dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n if HAS_D:\n dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if D_HAS_HDIM:\n D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n else:\n D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n dx += dout_res * D\n tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n if HAS_D:\n dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n if D_HAS_HDIM:\n dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n dD = tl.sum(dout_res * x, axis=0)\n tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n else:\n dD = tl.sum(dout_res * x)\n tl.store(dD_ptr, dD)\n ddt = tl.sum(acc * x, axis=1)\n ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n# Function to call the kernel for backward pass\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n batch, seqlen, nheads, headdim = x.shape\n _, _, nchunks, chunk_size = dt.shape\n _, _, ngroups, dstate = B.shape\n assert nheads % ngroups == 0\n assert B.shape == (batch, seqlen, ngroups, dstate)\n assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n assert dt.shape == (batch, nheads, nchunks, chunk_size)\n assert dA_cumsum.shape == dt.shape\n assert dout.shape == x.shape\n assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if D is not None:\n assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n assert D.stride(-1) == 1\n BLOCK_SIZE_min = 32\n dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n else:\n dD = None\n dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n if D is not None else (0, 0, 0, 0, 0))\n if dx is None:\n dx = torch.empty_like(x)\n else:\n assert dx.shape == x.shape\n ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n batch * nchunks, nheads)\n with torch.cuda.device(x.device.index):\n _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n int(chunk_size), int(headdim), int(dstate),\n int(batch), int(seqlen), int(nheads // ngroups),\n x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n D.stride(0) if D is not None else 0,\n B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n D is not None,\n D.dim() == 2 if D is not None else True,\n HAS_SEQ_IDX=seq_idx is not None,\n BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n IS_TRITON_22=TRITON_22\n )\n if D is not None:\n BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n if D.dim() == 1:\n dD = rearrange(dD, \"h 1 -> h\")\n return dx, ddt.to(dtype=dt.dtype), dD\n", - "description_1": "Use triton language to implement a kernel and a Python function for the backward pass of a chunk scan operation with chunk state in a neural network. The kernel uses tensor pointers, matrix dimensions, strides, and various meta-parameters to perform computation on input data pointers and stores results in output data pointers. The function initializes and manages these data pointers and executes the kernel on a specific CUDA device, returning processed tensors.", - "description_2": "Implement a backward pass operation for chunk scan using Triton to manage tensor pointers, perform computations, and return results, ensuring compatibility with CUDA devices.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n HAS_INITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n if HAS_INITSTATES:\n initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n states_ptrs = states_ptr + offs_m * stride_states_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n if not HAS_INITSTATES:\n states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n else:\n initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(out_ptrs, states, mask=offs_m < dim)\n out_ptrs += stride_out_chunk\n seq_idx = 0\n for c in range(nchunks):\n new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n states = scale * states + new_states\n if c < nchunks - 1:\n tl.store(out_ptrs, states, mask=offs_m < dim)\n else:\n tl.store(final_states_ptrs, states, mask=offs_m < dim)\n states_ptrs += stride_states_chunk\n dA_cs_ptr += stride_dA_cs_chunk\n out_ptrs += stride_out_chunk\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE': 64}),\n triton.Config({'BLOCK_SIZE': 128}),\n triton.Config({'BLOCK_SIZE': 256}),\n triton.Config({'BLOCK_SIZE': 512}),\n triton.Config({'BLOCK_SIZE': 1024}),\n triton.Config({'BLOCK_SIZE': 2048}),\n ],\n key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n dim, nchunks, seqlen, chunk_size,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n CONVERT_STATES: tl.constexpr,\n HAS_DFINAL_STATES: tl.constexpr,\n HAS_DINITSTATES: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n pid_m = tl.program_id(axis=0)\n dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n if CONVERT_STATES:\n states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n if HAS_DFINAL_STATES:\n dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n if HAS_DINITSTATES:\n dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n if CONVERT_STATES:\n states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n if HAS_DFINAL_STATES:\n dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n else:\n dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n if HAS_SEQ_IDX:\n seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n dstates_ptrs -= stride_dstates_chunk\n for c in range(nchunks - 1):\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n seq_idx = seq_idx_new\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if CONVERT_STATES:\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n dout_ptrs -= stride_dout_chunk\n dstates_ptrs -= stride_dstates_chunk\n dA_cs_ptr -= stride_dA_cs_chunk\n ddA_cs_ptr -= stride_ddA_cs_chunk\n out_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n states_converted_ptrs -= stride_out_chunk\n if CONVERT_STATES:\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n if not HAS_DINITSTATES:\n tl.store(ddA_cs_ptr, 0.0)\n else:\n dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n scale = tl.exp(dA_cs)\n if HAS_SEQ_IDX:\n scale = tl.where(seq_idx == 0, scale, 0.0)\n out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n ddA = tl.sum(out * dstates) * scale\n tl.store(ddA_cs_ptr, ddA)\n dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n dstates = scale * dstates + dout\n tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n out_dtype=None):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n if initial_states is not None:\n assert initial_states.shape == (batch, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n out_dtype = states.dtype if out_dtype is None else out_dtype\n out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(states.device.index):\n _state_passing_fwd_kernel[grid](\n states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n int(dim), int(nchunks), int(seqlen if seq_idx is not None else 0), int(chunk_size if seq_idx is not None else 0),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n final_states.stride(0), final_states.stride(1), final_states.stride(2),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n if initial_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n HAS_INITSTATES=initial_states is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out, final_states\n\ndef _state_passing_bwd(\n states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,\n dstates_dtype=None, states_dtype=None, chunk_size=None\n):\n batch, nchunks, nheads, dim = states.shape\n assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n assert dout.shape == (batch, nchunks, nheads, dim)\n if seq_idx is not None:\n assert chunk_size is not None\n seqlen = seq_idx.shape[-1]\n assert seq_idx.shape == (batch, seqlen)\n dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n if states_dtype is not None and states_dtype != states.dtype:\n states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n assert states_converted.stride() == states.stride()\n else:\n states_converted = None\n if has_initial_states:\n dinitstates = torch.empty_like(dstates[:, 0])\n else:\n dinitstates = None\n if dfinal_states is not None:\n assert dfinal_states.shape == (batch, nheads, dim)\n BLOCK_SIZE_min = 64\n n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n dtype=torch.float32, device=dA_chunk_cumsum.device)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n with torch.cuda.device(dout.device.index):\n _state_passing_bwd_kernel[grid](\n dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n int(dim), int(nchunks), int(seqlen if seq_idx is not None else 0), int(chunk_size if seq_idx is not None else 0),\n dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n if dfinal_states is not None else (0, 0, 0)),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n if dinitstates is not None else (0, 0, 0)),\n CONVERT_STATES=states_converted is not None,\n HAS_DFINAL_STATES=dfinal_states is not None,\n HAS_DINITSTATES=dinitstates is not None,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n if states_dtype is not None and states_dtype == states.dtype:\n states_converted = states\n return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n", - "description_1": "Use triton language to implement two kernels: _state_passing_fwd_kernel and _state_passing_bwd_kernel. The forward kernel (_state_passing_fwd_kernel) takes 24 parameters including pointers to matrices, matrix dimensions, strides, and meta-parameters. It performs state passing with optional initial states and sequence indices, storing results in output and final states pointers. The backward kernel (_state_passing_bwd_kernel) takes 28 parameters, including pointers to matrices, matrix dimensions, strides, and meta-parameters. It computes gradients for state passing, handling optional final states, initial states, and sequence indices, storing results in gradient pointers.", - "description_2": "Use triton language to create forward and backward kernels for state passing operations, handling optional initial and final states, and sequence indices, with parameters for matrix pointers, dimensions, strides, and meta-parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# An OpenAI Triton kernel to both perform the scatter-add and counts of each index\n@triton.jit\ndef scatter_add_kernel(\n self_ptr,\n src_ptr, # Source array\n index_ptr, # Indices\n n_elements, # Number of elements in the source/indices array\n n_labels, # Number of labels (distinct indices)\n counts, # Output counts of each distinct index\n BLOCK_SIZE: tl.constexpr,\n BLOCK_SIZE_C: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n\n mask = offsets < n_elements\n\n # Load the source values and indices\n src = tl.load(src_ptr + offsets, mask=mask)\n indices = tl.load(index_ptr + offsets, mask=mask)\n\n # Iterate over n_labels\n for i in range(0, BLOCK_SIZE_C):\n idx = i + tl.program_id(1) * BLOCK_SIZE_C + 1\n if idx <= n_labels:\n l_mask = indices == idx\n # Perform the scatter-add operation\n tl.atomic_add(self_ptr + idx - 1, tl.sum(tl.where(l_mask, src, 0)))\n # Update count for idx\n tl.atomic_add(counts + idx - 1, tl.sum(tl.where(l_mask, 1, 0)))\n\n\ndef volume(d):\n return np.prod(d)\n\n\nclass UnownedMemory:\n def __init__(self, ptr, shape, dtype):\n mem = cp.cuda.UnownedMemory(ptr, volume(shape) * cp.dtype(dtype).itemsize, self)\n cupy_ptr = cp.cuda.MemoryPointer(mem, 0)\n self.d = cp.ndarray(shape, dtype=dtype, memptr=cupy_ptr)\n\n\nclass ScatterAddPlugin(\n trt.IPluginV3,\n trt.IPluginV3OneCore,\n trt.IPluginV3OneBuildV2,\n trt.IPluginV3OneRuntime,\n):\n def __init__(self, fc=None):\n trt.IPluginV3.__init__(self)\n trt.IPluginV3OneCore.__init__(self)\n trt.IPluginV3OneBuildV2.__init__(self)\n trt.IPluginV3OneRuntime.__init__(self)\n\n self.plugin_namespace = \"\"\n self.plugin_name = \"ScatterAddPlugin\"\n self.plugin_version = \"1\"\n self.num_outputs = 2\n\n def enqueue(self, input_desc, output_desc, inputs, outputs, workspace, stream):\n\n # No-copy operations to setup torch tensors over the I/O buffers\n inp_mem = UnownedMemory(\n inputs[0], input_desc[0].dims, trt.nptype(input_desc[0].type)\n )\n src_mem = UnownedMemory(\n inputs[1], input_desc[1].dims, trt.nptype(input_desc[1].type)\n )\n idx_mem = UnownedMemory(\n inputs[2], input_desc[2].dims, trt.nptype(input_desc[2].type)\n )\n counts_mem = UnownedMemory(\n outputs[1], output_desc[1].dims, trt.nptype(output_desc[1].type)\n )\n\n inp = torch.as_tensor(inp_mem.d, device=\"cuda\")\n src = torch.as_tensor(src_mem.d, device=\"cuda\")\n idx = torch.as_tensor(idx_mem.d, device=\"cuda\")\n counts = torch.as_tensor(counts_mem.d, device=\"cuda\")\n\n # Zero out the counts before passing to kernel\n counts.zero_()\n\n n_classes = inp.shape[0]\n n_elements = src.numel()\n\n # Block size definitions\n BLOCK_SIZE = 1024\n BLOCK_SIZE_C = 32\n\n # Calculate grid size\n grid_x = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE\n grid_y = (n_classes + BLOCK_SIZE_C - 1) // BLOCK_SIZE_C\n\n scatter_add_kernel[(grid_x, grid_y)](\n inp, src, idx, n_elements, n_classes, counts, BLOCK_SIZE, BLOCK_SIZE_C\n )\n", - "description_1": "Use triton language to implement a scatter-add operation with counting distinct indices. The kernel 'scatter_add_kernel' takes 7 parameters: self_ptr (output array), src_ptr (source array), index_ptr (indices array), n_elements (number of elements in source/indices array), n_labels (number of distinct indices), counts (output counts of each distinct index), and two block sizes (BLOCK_SIZE and BLOCK_SIZE_C). The kernel performs scatter-add and counts operations in parallel using Triton's atomic operations. The 'ScatterAddPlugin' class manages the execution of this kernel, setting up the necessary memory and launching the kernel with appropriate grid sizes.", - "description_2": "Use triton language to create a kernel for scatter-add operations with index counting, and manage its execution with a plugin class.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport numpy as np\nimport cupy as cp\nimport logging\n\nlogger = logging.getLogger(\"CircPadMultiTactic\")\n\n@triton.jit\ndef circ_pad(X,\n all_pads_0, all_pads_2, all_pads_4, all_pads_6,\n orig_dims_0, orig_dims_1, orig_dims_2, orig_dims_3,\n Y,\n Y_shape_1, Y_shape_2, Y_shape_3,\n X_len, Y_len, BLOCK_SIZE: tl.constexpr,):\n pid = tl.program_id(0)\n i = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n\n mask_y = i < Y_len\n\n i3 = i % Y_shape_3\n i2 = (i // Y_shape_3) % Y_shape_2\n i1 = (i // Y_shape_3 // Y_shape_2) % Y_shape_1\n i0 = i // Y_shape_3 // Y_shape_2 // Y_shape_1\n\n j0 = (i0 - all_pads_0 + orig_dims_0) % orig_dims_0\n j1 = (i1 - all_pads_2 + orig_dims_1) % orig_dims_1\n j2 = (i2 - all_pads_4 + orig_dims_2) % orig_dims_2\n j3 = (i3 - all_pads_6 + orig_dims_3) % orig_dims_3\n\n load_idx = orig_dims_3 * orig_dims_2 * orig_dims_1 * j0 + orig_dims_3 * orig_dims_2 * j1 + orig_dims_3 * j2 + j3\n mask_x = load_idx < X_len\n\n x = tl.load(X + load_idx, mask=mask_x)\n\n tl.store(Y + i, x, mask=mask_y)\n\ndef enqueue_tactic_TRITON(input_desc, output_desc, inputs, outputs, X_shape, pads):\n inp_dtype = trt.nptype(input_desc[0].type)\n\n a_mem = cp.cuda.UnownedMemory(\n inputs[0], volume(input_desc[0].dims) * cp.dtype(inp_dtype).itemsize, self\n )\n c_mem = cp.cuda.UnownedMemory(\n outputs[0],\n volume(output_desc[0].dims) * cp.dtype(inp_dtype).itemsize,\n self,\n )\n\n a_ptr = cp.cuda.MemoryPointer(a_mem, 0)\n c_ptr = cp.cuda.MemoryPointer(c_mem, 0)\n\n c_d = cp.ndarray((volume(output_desc[0].dims)), dtype=inp_dtype, memptr=c_ptr)\n\n a_d = cp.ndarray((volume(input_desc[0].dims)), dtype=inp_dtype, memptr=a_ptr)\n a_t = torch.as_tensor(a_d, device='cuda')\n c_t = torch.as_tensor(c_d, device='cuda')\n\n N = len(X_shape)\n all_pads = np.zeros((N * 2,), dtype=np.int32)\n orig_dims = np.array(X_shape, dtype=np.int32)\n out_dims = np.array(X_shape, dtype=np.int32)\n\n for i in range(np.size(pads) // 2):\n out_dims[N - i - 1] += pads[i * 2] + pads[i * 2 + 1]\n all_pads[N * 2 - 2 * i - 2] = pads[i * 2]\n all_pads[N * 2 - 2 * i - 1] = pads[i * 2 + 1]\n\n all_pads = all_pads.tolist()\n orig_dims = orig_dims.tolist()\n out_dims = out_dims.tolist()\n\n blockSize = 256\n numBlocks = tuple([int((np.prod(out_dims) + blockSize - 1) // blockSize)])\n\n circ_pad[numBlocks](a_t,\n all_pads[0], all_pads[2], all_pads[4], all_pads[6],\n orig_dims[0], orig_dims[1], orig_dims[2], orig_dims[3],\n c_t,\n out_dims[1], out_dims[2], out_dims[3],\n int(np.prod(orig_dims)), int(np.prod(out_dims)), BLOCK_SIZE=256\n )\n", - "description_1": "Use triton language to implement a circular padding kernel `circ_pad` for 4-dimensional tensors. The kernel takes 15 parameters: input tensor `X`, padding values for each dimension, original and output dimensions of the tensor, and other configuration like block size. It computes indices with modulo operations to perform the circular padding and stores the results in `Y`. The `enqueue_tactic_TRITON` function calls this kernel with appropriate configurations and tensor pointers to perform the padding on GPU.", - "description_2": "Use triton language to implement a kernel for circular padding of 4D tensors and a corresponding function to launch it with specific configurations for GPU execution.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nimport numpy as np\nimport cupy as cp\n\n@triton.jit\ndef circ_pad(\n X,\n all_pads_0,\n all_pads_2,\n all_pads_4,\n all_pads_6,\n orig_dims_0,\n orig_dims_1,\n orig_dims_2,\n orig_dims_3,\n Y,\n Y_shape_1,\n Y_shape_2,\n Y_shape_3,\n X_len,\n Y_len,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n i = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n\n mask_y = i < Y_len\n\n i3 = i % Y_shape_3\n i2 = (i // Y_shape_3) % Y_shape_2\n i1 = (i // Y_shape_3 // Y_shape_2) % Y_shape_1\n i0 = i // Y_shape_3 // Y_shape_2 // Y_shape_1\n\n j0 = (i0 - all_pads_0 + orig_dims_0) % orig_dims_0\n j1 = (i1 - all_pads_2 + orig_dims_1) % orig_dims_1\n j2 = (i2 - all_pads_4 + orig_dims_2) % orig_dims_2\n j3 = (i3 - all_pads_6 + orig_dims_3) % orig_dims_3\n\n load_idx = (\n orig_dims_3 * orig_dims_2 * orig_dims_1 * j0\n + orig_dims_3 * orig_dims_2 * j1\n + orig_dims_3 * j2\n + j3\n )\n mask_x = load_idx < X_len\n\n x = tl.load(X + load_idx, mask=mask_x)\n\n tl.store(Y + i, x, mask=mask_y)\n\ndef call_circ_pad_kernel(inputs, outputs, input_desc, output_desc, pads, X_shape):\n inp_dtype = trt.nptype(input_desc[0].type)\n\n a_mem = cp.cuda.UnownedMemory(\n inputs[0], volume(input_desc[0].dims) * cp.dtype(inp_dtype).itemsize, self\n )\n c_mem = cp.cuda.UnownedMemory(\n outputs[0],\n volume(output_desc[0].dims) * cp.dtype(inp_dtype).itemsize,\n self,\n )\n\n a_ptr = cp.cuda.MemoryPointer(a_mem, 0)\n c_ptr = cp.cuda.MemoryPointer(c_mem, 0)\n\n a_d = cp.ndarray((volume(input_desc[0].dims)), dtype=inp_dtype, memptr=a_ptr)\n c_d = cp.ndarray((volume(output_desc[0].dims)), dtype=inp_dtype, memptr=c_ptr)\n\n a_t = torch.as_tensor(a_d, device=\"cuda\")\n c_t = torch.as_tensor(c_d, device=\"cuda\")\n\n N = len(X_shape)\n all_pads = np.zeros((N * 2,), dtype=np.int32)\n orig_dims = np.array(X_shape, dtype=np.int32)\n out_dims = np.array(X_shape, dtype=np.int32)\n\n for i in range(np.size(pads) // 2):\n out_dims[N - i - 1] += pads[i * 2] + pads[i * 2 + 1]\n all_pads[N * 2 - 2 * i - 2] = pads[i * 2]\n all_pads[N * 2 - 2 * i - 1] = pads[i * 2 + 1]\n\n all_pads = all_pads.tolist()\n orig_dims = orig_dims.tolist()\n out_dims = out_dims.tolist()\n\n blockSize = 256\n numBlocks = (int((np.prod(out_dims) + blockSize - 1) // blockSize),)\n\n circ_pad[numBlocks](\n a_t,\n all_pads[0],\n all_pads[2],\n all_pads[4],\n all_pads[6],\n orig_dims[0],\n orig_dims[1],\n orig_dims[2],\n orig_dims[3],\n c_t,\n out_dims[1],\n out_dims[2],\n out_dims[3],\n int(np.prod(orig_dims)),\n int(np.prod(out_dims)),\n BLOCK_SIZE=256,\n )\n", - "description_1": "Use triton language to implement a circular padding operation. The kernel 'circ_pad' takes 15 parameters: an input tensor 'X', four padding sizes 'all_pads_0', 'all_pads_2', 'all_pads_4', 'all_pads_6', four original dimensions 'orig_dims_0', 'orig_dims_1', 'orig_dims_2', 'orig_dims_3', an output tensor 'Y', three output shape dimensions 'Y_shape_1', 'Y_shape_2', 'Y_shape_3', the length of the input tensor 'X_len', the length of the output tensor 'Y_len', and a constant block size 'BLOCK_SIZE'. The function calculates the padded index for each dimension and performs the load and store operations with masking to avoid out-of-bounds access.", - "description_2": "Use triton language to implement and execute a circular padding operation by defining a kernel with necessary parameters and calling it with input tensors and dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport math\n\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\ndef autotune(\n configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False\n):\n def decorator(fn):\n return Autotuner(\n fn,\n fn.arg_names,\n configs,\n key,\n reset_to_zero,\n prune_configs_by,\n nearest_power_of_two,\n )\n return decorator\n\n@autotune(configs=[\n triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),\n triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),\n ],\n key=['x_size']\n)\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n", - "description_1": "Use triton language to define a kernel function with two parameters: x_ptr (pointer to data) and x_size (size of data). The kernel uses a meta-parameter BLOCK_SIZE to determine block size. The kernel is autotuned with two configurations, each specifying a different BLOCK_SIZE and number of warps. The autotune decorator uses a key 'x_size' to trigger evaluation of configurations when x_size changes.", - "description_2": "Use triton language to create an autotuned kernel with parameters for data pointer and size, using block size as a meta-parameter, and evaluate configurations based on data size changes.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n@triton.jit\ndef matmul_248_kernel(\n a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak, stride_bk, stride_bn,\n stride_cm, stride_cn, stride_scales, stride_zeros,\n BLOCK_SIZE_M: triton.language.constexpr,\n BLOCK_SIZE_N: triton.language.constexpr,\n BLOCK_SIZE_K: triton.language.constexpr,\n GROUP_SIZE_M: triton.language.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n pid = triton.language.program_id(axis=0)\n num_pid_m = triton.language.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = triton.language.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = triton.language.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + triton.language.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + triton.language.arange(0, BLOCK_SIZE_N)\n offs_k = triton.language.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n a_mask = offs_am[:, None] < M\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g_ptrs = g_ptr + offs_k\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = triton.language.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=triton.language.float32)\n\n for k in range(0, num_pid_k):\n g_idx = triton.language.load(g_ptrs)\n scales = triton.language.load(scales_ptrs + g_idx[:, None] * stride_scales)\n zeros = triton.language.load(zeros_ptrs + g_idx[:, None] * stride_zeros)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = triton.language.load(a_ptrs, mask=a_mask, other=0.0)\n b = triton.language.load(b_ptrs)\n\n b = (b >> shifter[:, None]) & maxq\n b = (b - zeros) * scales\n\n accumulator += triton.language.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n triton.language.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n matmul_248_kernel[grid](\n input, qweight, output, scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,\n input.stride(0), input.stride(1), qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)\n )\n return output\n", - "description_1": "Use triton language to implement a matrix multiplication kernel. The kernel takes 21 arguments, including pointers to matrices A, B, and C, scale and zero-point arrays, a group index pointer, dimensions M, N, K, bit-width for quantization, a maximum quantization level, and strides for accessing elements of A, B, and C. The kernel computes the product of a float16 matrix A and an int32 matrix B, adjusting for scale and zero-point, and stores the result in a float16 matrix C.", - "description_2": "Use triton language to perform quantized matrix multiplication. The kernel should handle inputs A, B, scales, zeros, and dimensions, applying shifts and scales to compute the resulting matrix C.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Any, Dict, Tuple\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens,\n stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any]) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused MoE (Mixture of Experts) kernel. The kernel computes the output by multiplying a token matrix (A) by an expert matrix (B) using block-wise matrix multiplication. It takes 24 parameters including pointers to input and output matrices, matrix dimensions, stride variables, and meta-parameters. The kernel ensures compatibility with specific block sizes and can apply a routed weight to the computation.", - "description_2": "Use triton language to invoke a fused MoE kernel. The invocation function requires 12 parameters including input matrices A and B, output matrix C, weights and token IDs, configuration settings, and meta parameters. It sets up a grid for the kernel execution and passes all necessary arguments for performing the Mixture of Experts computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Triton kernel for forward pass without alibi\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n cur_kv_head = cur_head // num_queries_per_kv\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n block_start_loc = BLOCK_M * start_m\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Triton kernel for forward pass with alibi\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n cur_kv_head = cur_head // num_queries_per_kv\n cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n block_start_loc = BLOCK_M * start_m\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +\n cur_head * stride_qh + offs_d[None, :] * stride_qd)\n\n q = tl.load(\n Q + off_q,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = 0\n for start_n in range(0, cur_batch_ctx_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +\n ((start_n + offs_n) // block_size) * stride_b_loc_s,\n mask=(start_n + offs_n) < cur_batch_ctx_len,\n other=0)\n off_k = (bn[None, :] * stride_k_cache_bs +\n cur_kv_head * stride_k_cache_h +\n (offs_d[:, None] // x) * stride_k_cache_d +\n ((start_n + offs_n[None, :]) % block_size) *\n stride_k_cache_bl +\n (offs_d[:, None] % x) * stride_k_cache_x)\n off_v = (\n bn[:, None] * stride_v_cache_bs +\n cur_kv_head * stride_v_cache_h +\n offs_d[None, :] * stride_v_cache_d +\n (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)\n k = tl.load(K_cache + off_k,\n mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,\n float(\"-inf\"))\n qk *= sm_scale\n\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n acc_scale = alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(V_cache + off_v,\n mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n l_i = l_i_new\n m_i = m_i_new\n\n off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +\n offs_d[:, None] * stride_kd)\n off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +\n offs_d[None, :] * stride_vd)\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n block_mask = tl.where(\n block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)\n\n alibi_slope = tl.load(Alibi_slopes + cur_head)\n alibi_start_q = tl.arange(\n 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len\n alibi_start_k = cur_batch_ctx_len\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, allow_tf32=False)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,\n float(\"-inf\"))\n\n alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -\n alibi_start_q[:, None]) * alibi_slope\n alibi = tl.where(\n (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),\n alibi, float(\"-inf\"))\n qk += alibi\n alibi_start_k += BLOCK_N\n\n m_ij = tl.max(qk, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n p = tl.math.exp(qk - m_i_new[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp(m_i - m_i_new)\n l_i_new = alpha * l_i + l_ij\n acc_scale = alpha\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs +\n (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) <\n cur_batch_seq_len - cur_batch_ctx_len,\n other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v, allow_tf32=False)\n l_i = l_i_new\n m_i = m_i_new\n\n acc = acc / l_i[:, None]\n\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +\n cur_head * stride_oh + offs_d[None, :] * stride_od)\n out_ptrs = Out + off_o\n tl.store(out_ptrs,\n acc,\n mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)\n return\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n alibi_slopes=None):\n\n cap = torch.cuda.get_device_capability()\n BLOCK = 128 if cap[0] >= 8 else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n\n num_warps = 8 if Lk <= 64 else 8\n if alibi_slopes is not None:\n _fwd_kernel_alibi[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n alibi_slopes,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4\n ),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3),\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3),\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement efficient context attention kernels. Define `_fwd_kernel` for regular context attention with 45 parameters including query, key, value matrices, cache, and others for dimensions and strides. Define `_fwd_kernel_alibi` similarly, with 46 parameters to include Alibi slopes for biasing. Use `context_attention_fwd` to manage execution with hardware capability checks and kernel calls with 11 parameters plus optional Alibi slopes.", - "description_2": "Use triton language to implement context attention kernels, incorporating regular and Alibi-biasing functions, and manage execution based on GPU capabilities and input dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _flash_decoding_fwd_kernel(\n Q, KCache, VCache, mid_o, mid_o_lse, kv_seq_len, q_len: tl.constexpr, batch_size, sm_scale,\n stride_qt, stride_qh, stride_q_qlen, stride_qd,\n stride_kb, stride_kh, stride_kt, stride_kd,\n stride_vb, stride_vh, stride_vt, stride_vd,\n stride_mid_ot, stride_mid_oh, stride_mid_ob, stride_mid_oqlen, stride_mid_od,\n stride_mid_o_lset, stride_mid_o_lseh, stride_mid_o_lseb,\n KV_GROUPS: tl.constexpr, BLOCK_KV: tl.constexpr, HEAD_DIM: tl.constexpr,\n):\n cur_token_idx = tl.program_id(0)\n cur_head_idx = tl.program_id(1)\n block_start_kv = tl.program_id(2)\n cur_kv_seq_len = tl.load(kv_seq_len + cur_token_idx)\n if block_start_kv * BLOCK_KV >= cur_kv_seq_len:\n return\n\n offsets_dmodel = tl.arange(0, HEAD_DIM)\n offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + offsets_q, shape=(q_len, HEAD_DIM), strides=(stride_q_qlen, stride_qd),\n offsets=(0, 0), block_shape=(q_len, HEAD_DIM), order=(0, 1)\n )\n q = tl.load(Q_block_ptr)\n cur_kv_head_idx = cur_head_idx // KV_GROUPS\n cur_k_offset = cur_token_idx * stride_kb + cur_kv_head_idx * stride_kh + block_start_kv * BLOCK_KV * stride_kt\n cur_v_offset = cur_token_idx * stride_vb + cur_kv_head_idx * stride_vh + block_start_kv * BLOCK_KV * stride_vt\n\n K_block_ptr = tl.make_block_ptr(\n base=KCache + cur_k_offset, shape=(cur_kv_seq_len, HEAD_DIM), strides=(stride_kd, stride_kt),\n offsets=(0, 0), block_shape=(HEAD_DIM, BLOCK_KV), order=(1, 0)\n )\n V_block_ptr = tl.make_block_ptr(\n base=VCache + cur_v_offset, shape=(cur_kv_seq_len, HEAD_DIM), strides=(stride_vt, stride_vd),\n offsets=(0, 0), block_shape=(BLOCK_KV, HEAD_DIM), order=(0, 1)\n )\n block_mask = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_KV) < cur_kv_seq_len\n k_cur_block = tl.load(K_block_ptr)\n v_cur_block = tl.load(V_block_ptr)\n\n acc = tl.zeros([q_len, HEAD_DIM], dtype=tl.float32)\n S_ij = tl.zeros([q_len, BLOCK_KV], dtype=tl.float32)\n S_ij += tl.dot(q, k_cur_block)\n S_ij = tl.where(block_mask[None, :], S_ij, float(\"-inf\"))\n S_ij *= sm_scale\n m = tl.max(S_ij, 1)\n S_ij -= m[:, None]\n p_ij_hat = tl.exp(S_ij)\n l_i = tl.sum(p_ij_hat, 1)\n p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)\n acc += tl.dot(p_ij_hat, v_cur_block)\n acc = acc / l_i[:, None]\n\n cur_offest_mid = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + block_start_kv * stride_mid_ob\n offsets_mid_o = tl.make_block_ptr(\n base=mid_o + cur_offest_mid, shape=(q_len, HEAD_DIM), strides=(stride_mid_oqlen, stride_mid_od),\n offsets=(0, 0), block_shape=(q_len, HEAD_DIM), order=(0, 1)\n )\n tl.store(offsets_mid_o, acc)\n\n offsets_qlen = tl.arange(0, q_len)\n offsets_mid_o_lse = (\n cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb + offsets_qlen\n )\n tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))\n\n@triton.jit\ndef _flash_decoding_fwd_reduce_kernel(\n mid_o, mid_o_lse, O, kv_seq_len, q_len: tl.constexpr, batch_size,\n stride_mid_ot, stride_mid_oh, stride_mid_ob, stride_mid_oqlen, stride_mid_od,\n stride_o_lset, stride_o_lseh, stride_o_lseb, stride_o_lseqlen,\n stride_ot, stride_oh, stride_oqlen,\n BLOCK_KV: tl.constexpr, HEAD_DIM: tl.constexpr,\n):\n cur_token_idx = tl.program_id(0)\n cur_head_idx = tl.program_id(1)\n cur_q_idx = tl.program_id(2)\n cur_kv_seq_len = tl.load(kv_seq_len + cur_token_idx)\n offsets_dmodel = tl.arange(0, HEAD_DIM)\n\n kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV\n m_i = float(\"-inf\")\n l_i = 0.0\n acc = tl.zeros([HEAD_DIM], dtype=tl.float32)\n offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + cur_q_idx * stride_mid_oqlen + offsets_dmodel\n offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh + cur_q_idx * stride_o_lseqlen\n\n for block_i in range(0, kv_split_num, 1):\n mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)\n lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)\n m_ij = tl.maximum(m_i, lse)\n scale = tl.exp(m_i - m_ij)\n acc = acc * scale\n lse -= m_ij\n exp_logic = tl.exp(lse)\n acc += exp_logic * mid_o_block\n l_i = scale * l_i + exp_logic\n m_i = m_ij\n\n acc = acc / l_i\n offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + cur_q_idx * stride_oqlen + offsets_dmodel\n tl.store(O + offsets_O, acc.to(O.type.element_ty))\n return l_i\n\n\ndef flash_decoding_attention(\n q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, kv_seq_len: torch.Tensor,\n block_size: int = 64, max_seq_len_in_batch: int = None, output: torch.Tensor = None,\n mid_output: torch.Tensor = None, mid_output_lse: torch.Tensor = None, sm_scale: int = None,\n kv_group_num: int = 1\n):\n n_tokens, num_heads, q_len, head_dim = q.shape\n q_len = int(q_len)\n bsz = n_tokens\n\n BLOCK_KV = block_size\n sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale\n max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch\n kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV\n\n if mid_output is None:\n mid_output = torch.empty(\n (bsz, num_heads, kv_max_split_num, q_len, head_dim), dtype=torch.float32, device=q.device\n )\n if mid_output_lse is None:\n mid_output_lse = torch.empty((bsz, num_heads, kv_max_split_num, q_len), dtype=torch.float32, device=q.device)\n if output is None:\n output = torch.empty((bsz, num_heads, q_len, head_dim), dtype=q.dtype, device=q.device)\n\n grid = lambda META: (\n triton.next_power_of_2(bsz),\n num_heads,\n triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META[\"BLOCK_KV\"]),\n )\n\n _flash_decoding_fwd_kernel[grid](\n q,\n k_cache,\n v_cache,\n mid_output,\n mid_output_lse,\n kv_seq_len,\n q_len,\n bsz,\n sm_scale,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n mid_output.stride(0),\n mid_output.stride(1),\n mid_output.stride(2),\n mid_output.stride(3),\n mid_output.stride(4),\n mid_output_lse.stride(0),\n mid_output_lse.stride(1),\n mid_output_lse.stride(2),\n KV_GROUPS=kv_group_num,\n BLOCK_KV=block_size,\n HEAD_DIM=head_dim,\n )\n\n grid = (triton.next_power_of_2(bsz), num_heads, q_len)\n _flash_decoding_fwd_reduce_kernel[grid](\n mid_output,\n mid_output_lse,\n output,\n kv_seq_len,\n q_len,\n bsz,\n mid_output.stride(0),\n mid_output.stride(1),\n mid_output.stride(2),\n mid_output.stride(3),\n mid_output.stride(4),\n mid_output_lse.stride(0),\n mid_output_lse.stride(1),\n mid_output_lse.stride(2),\n mid_output_lse.stride(3),\n output.stride(0),\n output.stride(1),\n output.stride(2),\n BLOCK_KV=block_size,\n HEAD_DIM=head_dim,\n )\n\n return output\n", - "description_1": "Use triton language to implement a flash decoding attention mechanism with two kernels: one for forward computation involving queries, key cache, value cache, and storing intermediate results; and another for reducing these intermediate results to final output. The first kernel computes attention scores and accumulates results, while the second reduces these accumulated results. It involves parameters like sequence lengths, batch size, strides, block size, head dimension, scaling factor, and more, to orchestrate memory access patterns and operations within Triton's parallel programming model.", - "description_2": "Use triton language to implement a flash decoding mechanism with two main stages: a forward kernel for computing attention scores and a reduction kernel for final result computation. Employ necessary parameters like strides and block sizes to ensure efficient memory access and parallel computation.", - "difficulty": 4 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out,\n Lse, TMP,\n softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == 'vector':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == 'matrix':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0)\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != 'none':\n if BIAS_TYPE == 'vector':\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == 'matrix':\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((start_n + offs_n)[None, :] < seqlen_k),\n other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o,\n mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'\n ' or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q, k, v, bias, o,\n lse, tmp,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n o.stride(0), o.stride(2), o.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, softmax_scale\n\n", - "description_1": "Use triton language to implement a forward pass of a FlashAttention mechanism with inputs Q, K, V matrices, a bias option, and causal masking. It involves computing a scaled dot-product attention, applying biases (if any), and storing the output and log-sum-exp calculations. The function _flash_attn_forward(q, k, v, bias, causal, softmax_scale) sets up inputs, configurations, and calls the kernel _fwd_kernel for actual computation. The kernel deals with different matrix dimension checks (EVEN_M, EVEN_N, EVEN_HEADDIM) and uses triton's GPU parallel capabilities for efficiency.", - "description_2": "Use triton language to implement a custom forward pass kernel for FlashAttention with Q, K, V matrices, optional bias, and causal flag. Ensure GPU-parallel efficiency and memory safety using Triton's advanced features.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward pass\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n# Triton kernel for backward preprocessing\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n# Triton kernel for backward pass\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(ds.to(q.dtype), q, trans_a=True)\n # # compute dq\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(k.dtype), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n # # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\n# PyTorch function for attention\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk, num_warps=num_warps,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n\n # NOTE: kernel currently buggy for other values of `num_warps`\n num_warps = 8\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,\n num_stages=1,\n )\n return dq.to(q.dtype), dk, dv, None\n\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention mechanism. This includes three kernel functions: `_fwd_kernel` which computes the forward pass of the attention mechanism, `_bwd_preprocess` which prepares data for the backward pass, and `_bwd_kernel` which computes the gradient updates. The `attention` function provides a PyTorch interface for these kernels. `_fwd_kernel` takes 24 parameters including input matrices Q, K, V, and several stride and size parameters. `_bwd_preprocess` and `_bwd_kernel` deal with gradient computations, taking in various intermediate buffers and stride parameters.", - "description_2": "Implement a fused attention mechanism in Triton with three key kernels: a forward pass kernel, a backward preprocessing kernel, and a backward pass kernel, all wrapped in a PyTorch function.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Tanh kernel function\n@triton.jit\ndef tanh(x):\n # Tanh is just a scaled sigmoid\n return 2 * tl.sigmoid(2 * x) - 1\n\n# Cosh kernel function\n@triton.jit\ndef cosh(x):\n exp_x = tl.exp(x)\n return (exp_x + 1.0 / exp_x) * 0.5\n\n# ReLU kernel function\n@triton.jit\ndef relu(x):\n \"\"\"\n ReLU activation function\n \"\"\"\n zero = 0.0\n return tl.where(x >= 0, x, zero.to(x.dtype))\n\n# ReLU gradient kernel function\n@triton.jit\ndef relu_grad(x):\n # Return the upstream gradient directly\n zero = 0.0\n one = 1.0\n return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n# Squared ReLU kernel function\n@triton.jit\ndef squared_relu(x):\n \"\"\"\n Squared ReLU activation\n \"\"\"\n x_ = relu(x)\n return (x_ * x_).to(x.dtype)\n\n# Squared ReLU gradient kernel function\n@triton.jit\ndef squared_relu_grad(x):\n return tl.where(x >= 0, 2.0 * x, 0.0)\n\n# Leaky ReLU kernel function\n@triton.jit\ndef leaky_relu(x):\n \"\"\"\n LeakyReLU activation\n \"\"\"\n scale = 0.01 + 0.0\n scale = scale.to(x.dtype)\n return tl.where(x >= 0, x, scale * x)\n\n# Leaky ReLU gradient kernel function\n@triton.jit\ndef leaky_relu_grad(x):\n min_grad = 0.01\n max_grad = 1\n\n min_grad = min_grad.to(x.dtype)\n max_grad = max_grad.to(x.dtype)\n\n return tl.where(x >= 0, max_grad, min_grad)\n\n# GELU kernel function\n@triton.jit\ndef gelu(x):\n \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n return x * 0.5 * (1.0 + tl.libdevice.erf(x * math.sqrt(1.0 / 2)))\n\n# GELU gradient kernel function\n@triton.jit\ndef gelu_grad(x):\n cdf = 0.5 * (1.0 + tl.libdevice.erf(x * math.sqrt(1.0 / 2)))\n pdf = tl.exp(-0.5 * x * x) * (1.0 / math.sqrt(2 * math.pi))\n return cdf + x * pdf\n\n# Approximate GELU kernel function\n@triton.jit\ndef gelu_approx(x):\n \"\"\"\n GeLU activation - Gaussian error linear unit, with tanh approximation\n \"\"\"\n return 0.5 * x * (1.0 + tanh(math.sqrt(2.0 / math.pi) * x * (1.0 + 0.044715 * x * x)))\n\n# Approximate GELU gradient kernel function\n@triton.jit\ndef gelu_approx_grad(x):\n # Fast implementation of GELU gradient\n tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n return 0.5 * x * (\n (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)\n ) + 0.5 * (1 + tanh_out)\n", - "description_1": "Use triton language to implement various activation functions and their gradients including ReLU, squared ReLU, leaky ReLU, GELU, and approximate GELU, each taking one argument 'x' as input tensor.", - "description_2": "Use triton language to create activation functions (ReLU, GELU) and compute their gradients.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flash_attn.ops.triton.k_activations import gelu, gelu_approx, squared_relu\nfrom flash_attn.ops.triton.k_activations import gelu_grad, gelu_approx_grad, squared_relu_grad\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_fwd(\n C,\n ACT_INPUT,\n A,\n B,\n bias,\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n stride_cm,\n stride_am,\n stride_ak,\n stride_bn,\n stride_bk,\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n A_ROWMAJOR: tl.constexpr,\n B_COLMAJOR: tl.constexpr,\n BIAS: tl.constexpr,\n SAVE_ACT_INPUT: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n if A_ROWMAJOR:\n A = A + (ram[:, None] * stride_am + rk[None, :])\n else:\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n if B_COLMAJOR:\n B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n else:\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n if A_ROWMAJOR:\n A += BLOCK_K\n else:\n A += BLOCK_K * stride_ak\n if B_COLMAJOR:\n B += BLOCK_K\n else:\n B += BLOCK_K * stride_bk\n if BIAS:\n bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n if SAVE_ACT_INPUT:\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n tl.store(act_in_ptrs, acc)\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n elif ACTIVATION == \"gelu_approx\":\n acc = gelu_approx(acc)\n elif ACTIVATION == \"squared_relu\":\n acc = squared_relu(acc)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc)\n\ndef triton_linear_act(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor] = None,\n activation: str = 'id',\n save_act_input: bool = False,\n) -> torch.Tensor:\n assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']\n batch_shape, n = x.shape[:-1], x.shape[-1]\n batch_dim = batch_shape.numel()\n x_reshaped = x.reshape(batch_dim, n)\n if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n x_reshaped = x_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n bias = bias.contiguous() if bias is not None else None\n assert x.dtype == weight.dtype, f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert x.dtype == bias.dtype, f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert x_reshaped.shape[1] == weight.shape[1], f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n assert bias is None or bias.shape[0] == weight.shape[0], \"Incompatible dimensions in between weight and bias\"\n M, K = x_reshaped.shape\n N, K = weight.shape\n output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n act_input = torch.empty_like(output) if save_act_input else None\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n kernel_fwd[grid](\n output,\n act_input,\n x_reshaped,\n weight,\n bias if bias is not None else x,\n M,\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=output.stride(0),\n stride_am=x_reshaped.stride(0),\n stride_ak=x_reshaped.stride(1),\n stride_bk=weight.stride(1),\n stride_bn=weight.stride(0),\n BIAS=bias is not None,\n SAVE_ACT_INPUT=save_act_input,\n ACTIVATION=activation,\n A_ROWMAJOR=x_reshaped.stride(1) == 1,\n B_COLMAJOR=weight.stride(1) == 1,\n GROUP_M=8,\n )\n if not save_act_input:\n return output.reshape(*batch_shape, output.shape[-1])\n else:\n return (output.reshape(*batch_shape, output.shape[-1]),\n act_input.reshape(*batch_shape, act_input.shape[-1]))\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n ],\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_bwd(\n C,\n ACT_INPUT,\n A,\n B,\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n stride_cm,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n if ACTIVATION != 'id':\n act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n act_input = tl.load(act_in_ptrs).to(acc.dtype)\n if ACTIVATION == \"gelu\":\n acc *= gelu_grad(act_input)\n elif ACTIVATION == \"gelu_approx\":\n acc *= gelu_approx_grad(act_input)\n elif ACTIVATION == \"squared_relu\":\n acc *= squared_relu_grad(act_input)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + rm[:, None] * stride_cm + rn[None, :]\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C, acc, mask=mask)\n\ndef triton_dgrad_act(\n grad_output: torch.Tensor,\n weight: torch.Tensor,\n activation: str = 'id',\n act_input: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']\n batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]\n batch_dim = batch_shape.numel()\n grad_output_reshaped = grad_output.reshape(batch_dim, n)\n if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:\n grad_output_reshaped = grad_output_reshaped.contiguous()\n if weight.stride(0) > 1 and weight.stride(1) > 1:\n weight = weight.contiguous()\n assert grad_output.dtype == weight.dtype, f\"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}\"\n assert grad_output_reshaped.shape[1] == weight.shape[0], f\"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}\"\n if activation != 'id':\n assert act_input is not None, f'act_input is required for activation {activation}'\n M, K = grad_output_reshaped.shape\n K, N = weight.shape\n grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n kernel_bwd[grid](\n grad_input,\n act_input,\n grad_output_reshaped,\n weight,\n M,\n N,\n K,\n M // 32,\n N // 32,\n K // 32,\n stride_cm=grad_input.stride(0),\n stride_am=grad_output_reshaped.stride(0),\n stride_ak=grad_output_reshaped.stride(1),\n stride_bk=weight.stride(0),\n stride_bn=weight.stride(1),\n ACTIVATION=activation,\n GROUP_M=8,\n )\n return grad_input.reshape(*batch_shape, grad_input.shape[-1])\n", - "description_1": "Use triton language to implement a forward and backward kernel for matrix multiplication with activation support. The kernels use 29 parameters each to specify input/output matrices, strides, dimensions, and various configuration flags for optimized computation.", - "description_2": "Use triton language to create a wrapper function that calls the optimized forward and backward matrix multiplication kernels, handling tensor reshaping, activation, and optional input saving.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_fwd_kernel(\n x,\n y,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_y = y + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_m = tl.minimum(0., b_x)\n b_z = 1. + tl.exp(-tl.abs(b_x))\n b_y = b_m - tl.log(b_z)\n tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_bwd_kernel(\n x,\n dx,\n dy,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_dx = dx + o_i\n p_dy = dy + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)\n b_dx = b_dy * (1. - tl.sigmoid(b_x))\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n\nclass LogSigmoidFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x):\n T, D = x.numel(), x.shape[-1]\n y = torch.empty_like(x)\n logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)\n ctx.save_for_backward(x,)\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, = ctx.saved_tensors\n T, D = x.numel(), x.shape[-1]\n dx = torch.empty_like(x)\n logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)\n return dx\n\n\nlogsigmoid = LogSigmoidFunction.apply\n", - "description_1": "Use triton language to implement two kernels, logsigmoid_fwd_kernel and logsigmoid_bwd_kernel, for forward and backward pass of logsigmoid activation function. logsigmoid_fwd_kernel takes five parameters: x (input tensor), y (output tensor), T (total number of elements in input), D (dimension), BT (block size) and computes logsigmoid. logsigmoid_bwd_kernel takes six parameters: x (input tensor), dx (gradient tensor to output), dy (gradient from next layer), T (total number of elements in input), D (dimension), BT (block size) to compute gradient wrt input.", - "description_2": "Implement a Triton-based logsigmoid function, providing both forward and backward operations with optimized kernel configurations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_quant_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)\n y = tl.math.round(y * scale)\n y = tl.maximum(tl.minimum(y, 127), -128) / scale\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd_quant(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_quant_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w if HAS_WEIGHT else xhat\n if HAS_BIAS:\n y = y + b\n scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)\n y = tl.math.round(y * scale)\n y = tl.maximum(tl.minimum(y, 127), -128) / scale\n tl.store(Y + cols, y, mask=mask)\n wdy = dy\n if HAS_WEIGHT:\n wdy = dy * w\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n if HAS_WEIGHT:\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) if weight is not None else None\n _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n weight is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype) if weight is not None else None\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to create and execute a fused kernel for layer normalization with quantization. This includes a forward kernel `_layer_norm_fwd_quant_kernel` that applies layer normalization and quantizes the output, and a backward kernel `_layer_norm_bwd_kernel` that computes gradients for the input, weights, and bias. The forward function `_layer_norm_fwd_quant` sets up and launches the forward kernel, while the backward function `_layer_norm_bwd` manages gradient computations and launches the backward kernel.", - "description_2": "Use triton language to implement fused layer norm with quantization and its gradient computation, handling conditions such as having residuals, bias, and different precision settings.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, O, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row,\n N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n O += row * stride_x_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)\n y = y * o * tl.sigmoid(o)\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, o, y, weight, bias, residual, residual_out, mean, rstd,\n x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0, N, eps,\n is_rms_norm, BLOCK_N, residual is not None, residual_out is not None,\n weight is not None, bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a layer normalization forward pass kernel with swish gating. The kernel takes 20 parameters: pointers to input, gate, output, weights, biases, residuals, mean, and rstd, strides for input, output, and residuals, number of columns, epsilon for numerical stability, and several compile-time constants for configuration. The kernel computes the mean and variance of the input, normalizes it, applies weights and biases, and then applies a swish gate using the gate input. The result is stored in the output pointer.", - "description_2": "Use triton language to implement a layer normalization forward pass with swish gating. The function takes 9 parameters: input tensor, gate tensor, weight tensor, bias tensor, epsilon, optional residual tensor, output data type, residual data type, and a flag for RMS normalization. It reshapes inputs, allocates output tensors, and calls the Triton kernel to perform the computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\"],\n)\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0)\n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n y = x * rstd\n tl.store(Y + cols, y, mask=mask)\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\"],\n)\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n DX += row * stride_x_row\n DY += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n var = tl.sum(x * x)\n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)\n dy = tl.where(cols < N, dy, 0.0)\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\ndef _l2_norm_fwd(x, eps=1e-6):\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n M, N = x.shape\n y = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return y.reshape(x_shape_og)\n\ndef _l2_norm_bwd(x, dy, eps=1e-5):\n x_shape_og = x.shape\n x = x.reshape(-1, dy.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n dx = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_bwd_kernel[(M,)](\n x,\n dy,\n dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n", - "description_1": "Use triton language to implement L2 normalization forward and backward pass kernels. The forward kernel computes L2 norm along the last dimension of input tensor X and outputs normalized tensor Y. It requires 6 parameters: input X (pointer), output Y (pointer), stride_x_row (int, stride of X rows), N (int, number of columns in X), eps (float, to avoid division by zero), and BLOCK_N (constant expression, block size). The backward kernel computes gradient DX of input X given DY (gradient of output). It requires similar parameters as the forward kernel, except DX (pointer) replaces output Y.", - "description_2": "Use triton language to create a function for L2 norm computation along the last dimension of a tensor. The forward function normalizes input, and the backward function calculates gradients, requiring input, output, gradients, stride, number of columns, epsilon, and block size as parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n G, # number of groups\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n group = row % G\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + group * stride_x_row + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + group * stride_x_row + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x,\n weight,\n bias,\n eps,\n residual=None,\n out_dtype=None,\n residual_dtype=None,\n is_rms_norm=False,\n num_groups=1\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N, G = *x.shape, num_groups\n if residual is not None:\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (G * N,)\n if bias is not None:\n assert bias.shape == (G * N,)\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n G,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a layer normalization forward pass kernel. The kernel takes 18 parameters: pointers to input, output, weights, biases, residuals, mean, and rstd, strides for input, output, and residuals, number of columns, number of groups, epsilon for numerical stability, and several compile-time constants indicating the presence of residuals, weights, biases, and whether RMS normalization is used. The kernel computes the mean and variance of the input, normalizes it, applies weights and biases, and stores the result.", - "description_2": "Use triton language to implement a function that calls the layer normalization forward pass kernel. The function takes 9 parameters: input tensor, weight tensor, bias tensor, epsilon, optional residual tensor, output data type, residual data type, a boolean for RMS normalization, and the number of groups. It prepares the output and intermediate tensors, sets up the kernel launch configuration, and invokes the kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k, v, z, h, h0, ht,\n s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr, NORMK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n if NORMK:\n p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,))\n else:\n p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_z0).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if NORMK:\n p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[:, None]\n b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype)\n else:\n p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[None, :]\n b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n# Other kernels and function definitions...\n\ndef chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)\n return ov, final_state\n", - "description_1": "Use triton language to implement a forward kernel that processes input tensors `k`, `v`, `z`, `h` with optional state management, employing matrix operations and normalization. Integrate this kernel in a PyTorch autograd function to handle forward operations with optional initial and final states.", - "description_2": "Use triton language to create a kernel for processing tensors with potential state management. Implement the kernel in a PyTorch function for executing forward operations with optional state considerations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr,\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_pre(g, B, H, T, S, BT):\n NT = triton.cdiv(T, BT)\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_gated_abc_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=S, BT=BT\n )\n return g\n", - "description_1": "Use triton language to implement a kernel 'chunk_gated_abc_fwd_kernel_cum' with 5 parameters plus 4 constexprs that computes cumulative sums along a certain axis using masks and stores the result in an output tensor. Another function 'fwd_pre' prepares the data for the kernel execution.", - "description_2": "Use triton language to implement a kernel that computes masked cumulative sums and a function to prepare data for the kernel execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_recurrent_gated_abc_inference_kernel(\n q,\n k,\n v,\n s,\n g,\n o,\n hk,\n hv,\n s_k_h,\n s_v_h,\n s_m_h,\n scale,\n K: tl.constexpr,\n V: tl.constexpr,\n M: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_bh = tl.program_id(0)\n b_s = tl.load(s + i_bh * s_m_h + tl.arange(0, M))\n b_g = tl.load(g + i_bh * s_m_h + tl.arange(0, M)).to(tl.float32)\n b_g = tl.exp(b_g)\n b_ok = tl.zeros([M], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_hk0 = hk + i_bh * K * M + (i_k * BK + tl.arange(0, BK)[None, :]) * M + tl.arange(0, M)[:, None]\n mask = (i_k * BK + tl.arange(0, BK)) < K\n b_hk = tl.load(p_hk0, mask=mask[None, :], other=0).to(tl.float32)\n b_q = tl.load(q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK), mask=mask, other=0).to(tl.float32) * scale\n b_k = tl.load(k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK), mask=mask, other=0).to(tl.float32)\n b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None]\n b_ok += tl.sum(b_hk * b_q[None, :], axis=1)\n\n p_hkt = hk + i_bh * K * M + (i_k * BK + tl.arange(0, BK)[None, :]) * M + tl.arange(0, M)[:, None]\n tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask[None, :])\n\n b_qv = tl.softmax(b_ok)\n for i_v in range(tl.cdiv(V, BV)):\n p_hv0 = hv + i_bh * M * V + tl.arange(0, M)[None, :] * V + (i_v * BV + tl.arange(0, BV)[:, None])\n mask = (i_v * BV + tl.arange(0, BV)) < V\n b_hv = tl.load(p_hv0, mask=mask[:, None], other=0).to(tl.float32)\n b_v = tl.load(v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV), mask=mask, other=0).to(tl.float32)\n b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None]\n b_ov = tl.sum(b_hv * b_qv[None, :], axis=1)\n\n tl.store(o + i_bh * s_v_h + i_v * BV + tl.arange(0, BV), b_ov.to(o.dtype.element_ty), mask=mask)\n\n p_hvt = hv + i_bh * M * V + tl.arange(0, M)[None, :] * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask[:, None])\n\n\n@triton.jit\ndef fused_recurrent_gated_abc_fwd_kernel(\n q,\n k,\n v,\n gk,\n gv,\n o,\n h0,\n ht,\n s_k_h,\n s_v_h,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr,\n USE_GK: tl.constexpr,\n USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n mask_k = (i_k * BK + tl.arange(0, BK)) < K\n mask_v = (i_v * BV + tl.arange(0, BV)) < V\n\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n mask_h = mask_k[None, :] & mask_v[:, None]\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)\n b_h = b_h * tl.exp(b_gk)[None, :]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)\n b_h = b_h * tl.exp(b_gv)[:, None]\n b_h += b_k[None, :] * b_v[:, None]\n b_o = b_h * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)\n\n\n@triton.jit\ndef fused_recurrent_gated_abc_bwd_kernel(\n q,\n k,\n v,\n gk,\n gv,\n do,\n dq,\n dk,\n dv,\n h0,\n s_k_h,\n s_v_h,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr,\n USE_GK: tl.constexpr,\n USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n mask_k = i_k * BK + tl.arange(0, BK) < K\n mask_v = i_v * BV + tl.arange(0, BV) < V\n mask_h = mask_k[:, None] & mask_v[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)\n b_h = b_h * tl.exp(b_gk)[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)\n b_h = b_h * tl.exp(b_gv)[None, :]\n b_h += b_k[:, None] * b_v[None, :]\n b_dq = tl.sum(b_h * b_do[None, :], axis=1) * scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)\n\n p_k += -K if REVERSE else K\n p_v += -V if REVERSE else V\n p_q += -K if REVERSE else K\n p_do += -V if REVERSE else V\n p_dq += -K if REVERSE else K\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for _ in range(T):\n b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)\n b_dh += b_q[:, None] * b_do[None, :]\n b_dk = tl.sum(b_dh * b_v[None, :], axis=1)\n b_dv = tl.sum(b_dh * b_k[:, None], axis=0)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)\n b_dh *= tl.exp(b_gk)[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)\n b_dh *= tl.exp(b_gv)[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)\n\n p_q += K if REVERSE else -K\n p_k += K if REVERSE else -K\n p_v += V if REVERSE else -V\n p_do += V if REVERSE else -V\n p_dk += K if REVERSE else -K\n p_dv += V if REVERSE else -V\n if USE_GK:\n p_gk += K if REVERSE else -K\n if USE_GV:\n p_gv += V if REVERSE else -V\n\n\nclass FusedRecurrentGatedABCFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(\n ctx,\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n g: torch.Tensor,\n scale: Optional[float] = None,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: bool = False,\n reverse: bool = False,\n inference_mode: bool = False\n ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]\n if scale is None:\n scale = K ** -0.5\n\n BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n num_warps = 1\n num_stages = 1\n\n if initial_state is None:\n initial_state = (None, None)\n final_state = (None, None)\n if output_final_state:\n final_state = initial_state if inference_mode else (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))\n\n if inference_mode:\n BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n\n o = torch.empty_like(v)\n grid = (B * H,)\n fused_recurrent_gated_abc_inference_kernel[grid](\n q, k, v, s, g, o, initial_state[0], initial_state[1],\n k.stride(1),\n v.stride(1),\n s.stride(1),\n scale=scale,\n K=K, V=V, M=M, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return o, final_state\n\n ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)\n gk, gv = None, g\n grid = (NM, NK, B * H)\n fused_recurrent_gated_abc_fwd_kernel[grid](\n q, k, s, gk, gv, ok, initial_state[0], final_state[0],\n k.stride(1),\n s.stride(1),\n scale=scale,\n B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,\n USE_INITIAL_STATE=initial_state[0] is not None,\n STORE_FINAL_STATE=final_state[0] is not None,\n USE_GK=False,\n USE_GV=True,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ok = ok.sum(0)\n\n qv = ok.softmax(-1, dtype=torch.float)\n ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)\n gk, gv = g, None\n grid = (NV, NM, B * H)\n fused_recurrent_gated_abc_fwd_kernel[grid](\n qv, s, v, gk, gv, ov, initial_state[1], final_state[1],\n s.stride(1),\n v.stride(1),\n scale=1.,\n B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,\n USE_INITIAL_STATE=initial_state[0] is not None,\n STORE_FINAL_STATE=final_state[0] is not None,\n USE_GK=True,\n USE_GV=False,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ov = ov.sum(0)\n\n ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok)\n ctx.scale = scale\n ctx.reverse = reverse\n return ov.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dht=None):\n q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors\n B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]\n scale = ctx.scale\n\n BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n num_warps = 1\n num_stages = 1\n\n dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)\n dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)\n dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)\n gk, gv = g, None\n grid = (NV, NM, B * H)\n fused_recurrent_gated_abc_bwd_kernel[grid](\n qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1],\n s.stride(1),\n v.stride(1),\n scale=1.,\n B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,\n USE_INITIAL_STATE=initial_state[1] is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dqv = dqv.sum(0)\n dsv = dsv.sum(0)\n dv = dv.sum(0)\n dgk = dqv * qv.float() - dsv * s.float()\n dgk_cumsum = dgk.cumsum(-2)\n dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum\n\n dok = qv * (dqv - (qv * dqv).sum(-1, True))\n dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)\n dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)\n dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)\n gk, gv = None, g\n grid = (NM, NK, B * H)\n fused_recurrent_gated_abc_bwd_kernel[grid](\n q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0],\n q.stride(1),\n s.stride(1),\n scale=scale,\n B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,\n USE_INITIAL_STATE=initial_state[0] is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dsk = dsk.sum(0)\n\n dgv = dok.float() * ok.float() - dsk * s.float()\n dgv_cumsum = dgv.cumsum(-2)\n dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum\n\n ds = dsk.add_(dsv)\n dg = dgk.add_(dgv)\n\n return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None, None\n\n\ndef fused_recurrent_gated_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n g: Optional[torch.Tensor] = None,\n scale: Optional[int] = None,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if g is None:\n z = s.float().logcumsumexp(2)\n g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z\n s = torch.exp(s - z).to(k.dtype)\n if scale is None:\n scale = q.shape[-1] ** -0.5\n inference_mode = q.shape[2] == 1 and not q.requires_grad\n ov, final_state = FusedRecurrentGatedABCFunction.apply(\n q, k, v, s, g, scale, initial_state, output_final_state, False, inference_mode\n )\n return ov, final_state\n", - "description_1": "Use triton language to implement a series of fused recurrent gated attention kernel functions for inference, forward, and backward passes. These kernels operate on input tensors representing queries (q), keys (k), values (v), and several other parameters including scales and strides. The fused operations compute attention-based outputs and manage recurrent state information efficiently, utilizing Triton's parallel capabilities.", - "description_2": "Use triton language to implement optimized fused recurrent gated attention operations leveraging Triton's parallel computation capabilities. Design kernel functions for both forward and backward passes, handling input tensors, scales, and state management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\n@triton.jit\ndef fused_chunk_based_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_h_0o = tl.zeros([BV], dtype=tl.float32)\n b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_0o = 0\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_k_2o = b_k[:, None, :] * b_k[None, :, :]\n b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_z = tl.zeros([BT], dtype=tl.float32)\n\n b_o += b_h_0o\n b_z += k_0o\n b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)\n b_z += tl.sum(b_q * k_1o, axis=1)\n b_q_2o = b_q[:, :, None] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)\n b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5\n b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5\n\n k_1o += tl.sum(b_k, axis=1)[None, :]\n k_2o += tl.sum(b_k_2o, axis=1)[None, :]\n k_0o += BT\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)\n\n b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)\n b_h_0o = b_h_0o + tl.sum(b_v, axis=0)\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_z += BT\n\n\n@triton.jit\ndef fused_chunk_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)\n b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)\n\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n\n b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)\n if i_v == 0:\n b_dq += b_dz[:, None] * k_1o\n b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5\n if i_v == 0:\n b_dq_2o += (b_dz[:, None] * k_2o) * 0.5\n b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])\n b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)\n b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)\n b_dq *= scale\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)\n\n if i_v == 0:\n k_1o += tl.sum(b_k, axis=0)[None, :]\n k_2o += tl.sum(b_k_2o, axis=0)[None, :]\n\n tl.debug_barrier()\n b_h_1o = None\n b_h_2o = None\n\n b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n b_dh_0o = tl.zeros([BV], dtype=tl.float32)\n m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]\n\n dq_1o = tl.zeros([1, BK], dtype=tl.float32)\n dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)\n\n for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i\n\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_dv = tl.zeros([BT, BV], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n b_ds *= (1+b_s)\n\n b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n\n b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)\n b_dv += b_dh_0o\n\n b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)\n\n if i_v == 0:\n b_dk += dq_1o\n\n b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)\n if i_v == 0:\n b_dk_2o += dq_2o\n b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])\n b_k_fp32 = tl.trans(b_k.to(tl.float32))\n b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)\n b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)\n b_dk += tl.trans(b_dk2)\n\n b_dh_0o += tl.sum(b_do, axis=0)\n b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)\n b_q_2o = b_q[None, :, :] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)\n b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5\n\n if i_v == 0:\n dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]\n dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkBasedFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale=1):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=torch.float32)\n z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n )\n o = o.sum(0)\n z = z.sum(0)\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.to(q.dtype), z.to(z.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None\n\n\ntriton_fused_chunk_based = FusedChunkBasedFunction.apply\n\n\ndef fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):\n assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_fused_chunk_based(q, k, v, scale)\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a fused attention mechanism with forward and backward passes. The forward function takes 17 parameters: q (query tensor), k (key tensor), v (value tensor), scale, batch size (B), number of heads (H), sequence length (T), and block sizes (BT, BK, BV, DK, DV). It computes the output (o) and normalizer (z) using Taylor expansions. The backward function takes 21 parameters, including additional parameters for gradients (do, dz, dq, dk, dv) and follows a similar procedure to compute gradients.", - "description_2": "Use triton language to implement a fused attention mechanism. The operation includes computing forward passes with queries, keys, and values, followed by the backward pass to obtain gradients for queries, keys, and values using triton kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_based_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef parallel_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a forward and backward kernel for parallel-based computation in a neural network. The `parallel_based_fwd_kernel` function has 19 parameters, mainly for handling matrix data and computation scale in the forward pass. The `parallel_based_bwd_kernel` function, also with 19 parameters, focuses on the backward pass of the gradient calculation. The functions handle data in batches with specified dimensions for queries, keys, and values, utilizing configurable block sizes and strides. This operation is applied within a custom autograd function `ParallelBasedFunction` in PyTorch.", - "description_2": "Use triton language to implement a parallel-based forward kernel for efficient matrix computations in neural networks, and a backward kernel for gradient calculation with batching and configurable block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_dv_kernel(\n q,\n k,\n do,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n b_A += tl.dot(b_k, b_q, allow_tf32=False)\n\n b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.dot(b_A, b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_dv(q, k, do, BT):\n dv = torch.empty_like(do)\n B, H, T, K, V = *k.shape, do.shape[-1]\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_prepare_dv_kernel[(NT, B*H)](\n q, k, do, dv,\n k.stride(1), k.stride(2), k.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n T, K, V, K**-0.5, BT, BK, BV\n )\n return dv\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_delta_rule_fwd_kernel_h(\n k,\n v,\n d,\n v_new,\n h,\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n final_state, # final state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)\n # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden\n for i_c in range(tl.cdiv(BT, BC)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t),\n (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d),\n (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),\n (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))\n p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),\n (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BK]\n b_d = tl.load(p_d, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)\n # [BK, BV]\n tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))\n b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n b_h += b_h_cumsum\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):\n B, H, T, K, V = *k.shape, u.shape[-1]\n\n BK = triton.next_power_of_2(K)\n assert BK <= 256, \"current kernel does not support head dimension larger than 256.\"\n BV = 16 if BK > 128 else 32\n BV = 64 if BK <= 64 else BV\n BC = 16 if BK > 128 else 32\n BC = 64 if BK <= 64 else BC\n BC = min(BT, BC)\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'\n\n h = k.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n v_new = torch.empty_like(u)\n chunk_delta_rule_fwd_kernel_h[grid](\n k, u, w, v_new, h, initial_state, final_state,\n k.stride(1), k.stride(2), k.stride(3),\n u.stride(1), u.stride(2), u.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n )\n return h, v_new\n", - "description_1": "Use triton language to implement a forward kernel for preparing dv (fwd_prepare_dv_kernel) and a forward kernel for chunk delta rule (chunk_delta_rule_fwd_kernel_h). The fwd_prepare_dv_kernel takes 15 parameters: q, k, do, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, T, K, V, scale, and BT, BK, BV as constexpr. It computes dv using q, k, and do with given strides and dimensions. The chunk_delta_rule_fwd_kernel_h takes 22 parameters: k, v, d, v_new, h, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, and H, T, K, V, BT, BC, BK, BV, NT, USE_INITIAL_STATE, STORE_FINAL_STATE as constexpr. It computes the forward pass for a chunk delta rule using k, v, d, and updates v_new and h, considering initial and final states.", - "description_2": "Use triton language to implement kernels for forward computation of dv and chunk delta rule with specific parameters and configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8)\n ],\n key=[\"BT\", \"BK\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_fwd_kernel(\n q, k, v, v_new, d, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_d = tl.load(p_d, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)\n b_v = b_v - b_v_prime\n tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))\n b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_v_new = tl.advance(p_v_new, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_d = tl.advance(p_d, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_bwd_kernel(\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n b_d = tl.load(p_d, boundary_check=(0, 1))\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n\n tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n if i < (NT - 1):\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.load(p_dv, boundary_check=(0, 1))\n b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)\n p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n ((i+1) * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BT = BT\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, 'NK should be 1'\n o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n v_new = torch.empty_like(v)\n fused_chunk_delta_rule_fwd_kernel[grid](\n q, k, v, v_new, d, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n )\n return o, v_new, CHECK, final_state\n\n\ndef fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_delta_rule_bwd_kernel[grid](\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=CHECK,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dd = dd.sum(0)\n dd[:, :, 0:BT] = 0\n return dq, dk, dv, dd\n", - "description_1": "Use triton language to implement a forward and backward kernel for a fused chunk delta rule operation. The forward kernel takes 24 parameters: q, k, v, v_new, d, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, STORE_FINAL_STATE, CHECK. The backward kernel takes 23 parameters: q, k, v, d, do, dq, dk, dv, dd, initial_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, CHECK.", - "description_2": "Use triton language to create a fused chunk delta rule operation with forward and backward kernels, handling input tensors and various parameters for computation and memory management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef fused_recurrent_fwd_kernel(\n q, # query [B, H, L, K]\n k, # key [B, H, L, V]\n v, # value [B, H, L, V].\n beta, # beta [B, H, L]\n o, # output [B, H, L, V]\n h0,\n ht, # final hidden state [B, H, K, V]\n s_qk_h, # stride size: L * K\n s_vo_h, # stride size: L * V\n scale, # K ** -0.5\n B, # batch size\n H, # n_heads\n T, # seq_len\n K: tl.constexpr, # K\n V: tl.constexpr, # V\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n STORE_FINAL_STATE: tl.constexpr, # whether to store final state\n IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n if IS_HEADWISE_BETA:\n p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n else:\n p_beta = beta + i_bh * T\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _v_minus = tl.sum(h * b_k[None, :], axis=1)\n b_v -= _v_minus\n if IS_HEADWISE_BETA:\n b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)\n else:\n b_beta = tl.load(p_beta).to(tl.float32)\n tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv)\n b_v *= b_beta\n h += b_k[None, :] * b_v[:, None]\n _o = h * b_q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += K\n p_k += K\n p_o += V\n p_v += V\n p_beta += V if IS_HEADWISE_BETA else 1\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_bwd_kernel(\n q, # query [B, H, L, K]\n k, # key [B, H, L, V]\n v, # value [B, H, L, V]\n beta, # beta [B, H, L, (V)]\n\n do, # gradient of output [B, H, L, V]\n dq, # gradient of query [NV, B, H, L, K]\n dk, # gradient of key [NV, B, H, L, K]\n dv, # gradient of value [NK, B, H, L, V]\n dbeta, # gradient of beta [NV, (NK), B, H, L]\n\n h0,\n\n s_qk_h, # stride size: L * K\n\n s_vo_h, # stride size: L * V\n\n NK, # NK block size\n scale, # K ** -0.5\n\n B, # batch_size\n H, # n_heads\n T, # seq_len\n K: tl.constexpr, # K\n V: tl.constexpr, # V\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n if IS_HEADWISE_BETA:\n p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n else:\n p_beta = beta + i_bh * T + T - 1\n\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n if IS_HEADWISE_BETA:\n p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V\n else:\n p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if IS_HEADWISE_BETA:\n b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)\n else:\n b_beta = tl.load(p_beta).to(tl.float32)\n d_h += b_q[:, None] * b_do[None, :]\n d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1)\n d_v = tl.sum(d_h * b_k[:, None], axis=0)\n\n d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v)\n d_v = d_v * b_beta\n\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n if IS_HEADWISE_BETA:\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv)\n else:\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))\n\n d_h -= b_k[:, None] * d_v[None, :]\n\n p_do -= V\n p_q -= K\n p_k -= K\n p_v -= V\n p_dk -= K\n p_dv -= V\n p_dbeta -= V if IS_HEADWISE_BETA else 1\n p_beta -= V if IS_HEADWISE_BETA else 1\n\n tl.debug_barrier()\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n if IS_HEADWISE_BETA:\n p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n else:\n p_beta = beta + i_bh * T\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K\n\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if IS_HEADWISE_BETA:\n b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)\n else:\n b_beta = tl.load(p_beta).to(tl.float32)\n b_v *= b_beta\n\n h += b_k[:, None] * b_v[None, :]\n _d_q = h * b_do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n if i < T - 1:\n d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)\n d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)\n d_k -= tl.sum(d_v[None, :] * h, axis=1)\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n\n p_k += K\n p_do += V\n p_v += V\n p_dk += K\n p_dv += V\n p_dq += K\n p_beta += V if IS_HEADWISE_BETA else 1\n\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False):\n B, H, T, K, V = *q.shape, v.shape[-1]\n\n BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n assert NK == 1, \"NK > 1 is not supported yet\"\n o = q.new_empty(NK, B, H, T, V)\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n\n grid = (NV, NK, B * H)\n fused_recurrent_fwd_kernel[grid](\n q, k, v, beta, o, initial_state, final_state,\n q.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V,\n BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n IS_HEADWISE_BETA=beta.ndim == v.ndim,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, beta, initial_state)\n ctx.scale = scale\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n q, k, v, beta, initial_state = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n scale = ctx.scale\n BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n assert NK == 1, \"NK > 1 is not supported yet\"\n num_stages = 1\n num_warps = 2\n\n beta_vector = beta.ndim == v.ndim\n\n dq = q.new_empty(NV, B, H, T, K)\n dk = q.new_empty(NV, B, H, T, K)\n dv = q.new_empty(NK, B, H, T, V)\n if beta_vector:\n dbeta = q.new_empty(NV, NK, B, H, T, V)\n else:\n dbeta = q.new_empty(NV, B, H, T)\n grid = (NV, NK, B * H)\n\n fused_recurrent_bwd_kernel[grid](\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n q.stride(1),\n v.stride(1),\n NK, scale,\n B=B, H=H, T=T, K=K, V=V,\n BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n IS_HEADWISE_BETA=beta_vector,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None\n\n\ndef fused_recurrent_delta_rule(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n beta: torch.Tensor = None,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n if beta is None:\n beta = torch.ones_like(q[..., 0])\n o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused recurrent forward and backward pass in a sequence processing model. The forward kernel 'fused_recurrent_fwd_kernel' has 17 parameters, including input queries, keys, values, beta, initial state, dimensions, and flags for using/storing initial/final states. The backward kernel 'fused_recurrent_bwd_kernel' has 19 parameters, handling gradients of inputs, output gradients, beta gradients, dimensions, and flags. The class 'FusedRecurrentFunction' contains forward and backward static methods utilizing these kernels, and the function 'fused_recurrent_delta_rule' serves as a callable interface.", - "description_2": "Use triton language to create fused forward and backward kernels for sequence processing. Implement an autograd function in PyTorch to integrate the kernels into a model.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n o,\n o2,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = tl.arange(0, BK) < K\n mask_bv = tl.arange(0, BV) < V\n mask_bk = mask_bk[None, :] & mask_bt[:, None]\n mask_bv = mask_bv[None, :] & mask_bt[:, None]\n # [BT, BK]\n b_k = tl.load(p_k, mask=mask_bk, other=0)\n # [BT,]\n b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)\n # [BT, BV]\n b_v = tl.load(p_v, mask=mask_bv, other=0)\n b_v = (b_v * b_beta[:, None]).to(b_v.dtype)\n # [BT, BK]\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n # [BT, BT]\n b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n b_A = b_A.to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n b_u = tl.dot(b_A, b_v, allow_tf32=False)\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta,\n o, o2, do, do2,\n dk, dv, dbeta,\n NT, K, V, T,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]\n mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]\n b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)\n\n b_beta = b_beta.to(tl.float32)\n A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]\n A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)\n b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)\n b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)\n dA = tl.zeros([BT, BT], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n for i in range(BT-1, -1, -1):\n mask = tl.arange(0, BT) == i\n attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)\n do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)\n dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)\n b_do = b_do - attn[:, None] * do_[None, :]\n b_dv = b_dv - attn[:, None] * dv_[None, :]\n tl.debug_barrier()\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_v = tl.load(p_v, mask=mask_bv)\n b_dk += b_do * b_beta[:, None]\n b_dbeta = tl.sum(b_do * b_k, axis=1)\n b_dbeta += tl.sum(b_dv * b_v, axis=1)\n b_v = None\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_o = tl.load(p_o, mask=mask_bk)\n b_o2 = tl.load(p_o2, mask=mask_bv)\n\n dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)\n dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),\n allow_tf32=False)\n dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)\n b_dv *= b_beta[:, None]\n p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)\n dA = dA * b_beta[:, None]\n b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)\n b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)\n p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)\n\n\ndef fwd_prepare_wy_repr(k, v, beta, chunk_size):\n B, H, T, K, V = *k.shape, v.shape[-1]\n v_new = torch.empty_like(v)\n o_cumdecay = torch.empty_like(k)\n BT = chunk_size\n NT = triton.cdiv(T, BT)\n BK = triton.next_power_of_2(K)\n BV = triton.next_power_of_2(V)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, o_cumdecay, v_new,\n T, K, V, BT, BK, BV\n )\n return o_cumdecay, v_new\n\n\ndef bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):\n b, h, l, d_k = do.shape\n d_v = v.shape[-1]\n BK = triton.next_power_of_2(d_k)\n BV = triton.next_power_of_2(d_v)\n c = chunk_size\n BK = d_k\n NT = triton.cdiv(l, c)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n dbeta = torch.zeros_like(beta)\n bwd_prepare_wy_repr_kernel[(NT, b*h)](\n k, v, beta,\n o_cumdecay, v_new, do, do2,\n dk, dv, dbeta,\n NT, d_k, d_v, l, chunk_size, BK, BV\n )\n return dk, dv, dbeta\n", - "description_1": "Use triton language to implement two kernels: fwd_prepare_wy_repr_kernel and bwd_prepare_wy_repr_kernel. The fwd_prepare_wy_repr_kernel takes 10 parameters: k, v, beta, o, o2, T, K, V, BT, BK, BV. It computes the forward pass of the WY representation preparation. The bwd_prepare_wy_repr_kernel takes 15 parameters: k, v, beta, o, o2, do, do2, dk, dv, dbeta, NT, K, V, T, BT, BK, BV. It computes the backward pass of the WY representation preparation.", - "description_2": "Use triton language to create forward and backward kernels for WY representation preparation, handling input tensors and computing necessary transformations and gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k, v, beta, w, u, A,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n T, K, V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n # Kernel computation logic here\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_recompute_w_u_kernel(\n k, v, beta, w, u, A,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n T, K, V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n # Kernel computation logic here\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta, A, dw, du, dk, dv, dbeta,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n T, K, V, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n # Kernel computation logic here\n\ndef fwd_prepare_wy_repr(k, v, beta, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u, A\n\ndef fwd_recompute_w_u(k, v, beta, A, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_recompute_w_u_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u\n\ndef bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT = triton.cdiv(T, BT)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v).contiguous()\n dbeta = torch.zeros_like(beta)\n\n bwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, A,\n dw, du,\n dk, dv, dbeta,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return dk, dv, dbeta\n", - "description_1": "Use triton language to implement kernels for WY representation preparation, forward and backward propagation in neural networks. These kernels utilize parameters such as k, v, beta, w, u, and matrix A, along with strides and dimensions for T, K, V, and specific block sizes BT, BK, BV. Functions: fwd_prepare_wy_repr_kernel, fwd_recompute_w_u_kernel, bwd_prepare_wy_repr_kernel.", - "description_2": "Use triton language to compute the forward preparation of WY representation and recompute kernels, as well as the backward propagation for deep learning tasks, processing matrices with specific dimensions and strides.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_gla_fwd_kernel_cum(\n s, o, s_s_h, s_s_t, s_s_d,\n T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_h(\n k, v, g, h, h0, ht,\n s_k_h, s_k_t, s_k_d,\n s_v_h, s_v_t, s_v_d,\n s_h_h, s_h_t, s_h_d,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n if i_t < NT - 1:\n b_gn = tl.load(p_gn, boundary_check=(0,))\n else:\n b_gn = tl.min(b_g, axis=1)\n b_h *= tl.exp(b_gn)[:, None]\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_intra(\n q, k, g, A,\n s_k_h, s_k_t, s_k_d,\n scale,\n T: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr,\n BK: tl.constexpr, NC: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC\n n_bh = tl.num_programs(2)\n\n if i_i > i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))\n p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n b_gn = tl.load(p_gn, boundary_check=(0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_gk = tl.load(p_gk, boundary_check=(0, 1))\n b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)\n b_A = tl.dot(b_qg, b_kg, allow_tf32=False)\n tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n elif i_i == i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n\n o_i = tl.arange(0, BC)\n o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n for j in range(0, BC):\n b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)\n b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)\n b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)\n b_A = tl.where(o_i >= j, b_A, 0.)\n tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)\n\n p_k = tl.advance(p_k, (K,))\n p_gk = tl.advance(p_gk, (K,))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_inter(\n q, v, g, h, o, A,\n s_k_h, s_k_t, s_k_d,\n s_v_h, s_v_t, s_v_d,\n s_h_h, s_h_t, s_h_d,\n scale,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)\n b_h = tl.load(p_h, boundary_check=(0, 1))\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\nclass ChunkGLAFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gla_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n chunk_gla_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_gla_fwd_kernel_intra[grid](\n q, k, g, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n grid = (NV, NT, B * H)\n chunk_gla_fwd_kernel_inter[grid](\n q, v, g, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n if checkpoint_level >= 1:\n del g\n g = g_org\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n\n ctx.save_for_backward(q, k, v, g, h, initial_state, A)\n ctx.BT = BT\n ctx.scale = scale\n ctx.checkpoint_level = checkpoint_level\n return o, final_state\n\ndef chunk_gla(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor,\n scale: Optional[int] = None, initial_state: torch.Tensor = None,\n output_final_state: bool = False, checkpoint_level: Optional[int] = 2\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert checkpoint_level in [0, 1, 2]\n if scale is None:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)\n return o, final_state\n", - "description_1": "Use triton language to implement the chunk_gla function with kernels: chunk_gla_fwd_kernel_cum, chunk_gla_fwd_kernel_h, chunk_gla_fwd_kernel_intra, chunk_gla_fwd_kernel_inter. These kernels handle forward pass operations over the input tensors q, k, v, and g with dimensions for different blocks, computing intermediate results and storing them in tensor o. The kernels take into account parameters such as strides, scales, dimensions (T, S, BT, etc.), and configuration settings (e.g., num_warps). These parameters are crucial to execute the correct operations in parallel across the input tensors.", - "description_2": "Use triton language to define kernels that perform block-wise tensor operations for attention mechanisms, optimizing the forward pass by executing cumulative sums, intra, and inter-block computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n # Triton kernel for forward pass\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n # Triton kernel for backward pass\n\nclass FusedChunkGLAFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):\n # Forward function for FusedChunkGLA\n ctx.g_dtype = g.dtype\n g_original = g\n g = torch.empty_like(g, dtype=torch.float32)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n num_stages = 1\n num_warps = 2\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_gla_fwd_kernel[grid](\n q_g, k_g, v, g, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=True,\n num_warps=num_warps,\n num_stages=num_stages\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, g_original, o, initial_state)\n return o.to(v), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, g_origin, o, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n BT = 16\n g = torch.empty_like(g_origin, dtype=torch.float32)\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NV, NK, batch_size * n_heads)\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n fused_chunk_gla_bwd_kernel[grid](\n q_g, k_g, v, g, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=True,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), None, None, None, None\n\ndef fused_chunk_gla(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor,\n scale: int = -1, initial_state: torch.Tensor = None, output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = q.shape[-2]\n q, k, v, g = map(lambda x: pad(x), [q, k, v, g])\n o, final_state = FusedChunkGLAFunction.apply(\n q, k, v, g, scale, initial_state, output_final_state)\n o = o[..., :seq_len, :]\n return o, final_state\n", - "description_1": "Use triton language to define kernels for forward and backward passes of a fused chunk GLA function with specific parameters. Implement the forward and backward methods of a torch.autograd.Function class utilizing these kernels, and create a high-level fused_chunk_gla function to execute the operations.", - "description_2": "Use triton language to define forward and backward kernels for chunk-based attention and implement them in a custom PyTorch autograd function.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n# Forward decay cumulative sum kernel\n@triton.jit\ndef fwd_decay_cumsum(\n g, g_o, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n# Prepare QG KG kernel\n@triton.jit\ndef prepare_qg_kg(\n q, k, g, qg, kg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n# Backward decay global cumulative sum kernel\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner, dq_inter, dk_inner, dk_inter, q, k, g, dg,\n s_qk_h, s_qk_t, s_qk_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n", - "description_1": "Use triton language to define three kernels. The first kernel, 'fwd_decay_cumsum', performs a forward decay cumulative sum on the input tensor 'g' with shape parameters 's_qk_h', 's_qk_t', 's_qk_d', and stores the result in 'g_o'. The second kernel, 'prepare_qg_kg', modifies input tensors 'q' and 'k' based on another tensor 'g' and stores the results in 'qg' and 'kg', respectively. The third kernel, 'bwd_decay_global_cumsum', computes the backward cumulative sum using input gradient tensors and modifies the gradients of input tensors 'q', 'k', and 'g'. These kernels utilize block sizes 'BT', 'BK', and 'DK', with loops iterating over 'BT'.", - "description_2": "Use triton language to create forward, transformation, and backward kernels for processing tensors in a block-wise manner, handling cumulative sums and element-wise operations with triton's API.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, gk, gv, o, h0, ht, s_qk_h, s_vo_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * tl.exp(b_gk[None, :])\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * tl.exp(b_gv[:, None])\n h += b_k[None, :] * b_v[:, None]\n _o = h * b_q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, h0, s_qk_h, s_vo_h, scale,\n B, H, T, K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * tl.exp(b_gk[:, None])\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * tl.exp(b_gv[None, :])\n h += b_k[:, None] * b_v[None, :]\n b_dq = h * b_do[None, :]\n d_q = tl.sum(b_dq, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -K if REVERSE else K\n p_v += -V if REVERSE else V\n p_q += -K if REVERSE else K\n p_do += -V if REVERSE else V\n p_dq += -K if REVERSE else K\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += b_q[:, None] * b_do[None, :]\n d_k = tl.sum(d_h * b_v[None, :], axis=1)\n d_v = tl.sum(d_h * b_k[:, None], axis=0)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n d_h *= tl.exp(b_gk)[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n d_h *= tl.exp(b_gv)[None, :]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_q += K if REVERSE else -K\n p_k += K if REVERSE else -K\n p_v += V if REVERSE else -V\n p_do += V if REVERSE else -V\n p_dk += K if REVERSE else -K\n p_dv += V if REVERSE else -V\n if USE_GK:\n p_gk += K if REVERSE else -K\n if USE_GV:\n p_gv += V if REVERSE else -V\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):\n B, H, T, K, V = *q.shape, v.shape[-1]\n if scale is None:\n scale = K ** -0.5\n\n BK, BV = min(K, 64), min(V, 64)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n\n grid = (NV, NK, B * H)\n fused_recurrent_gla_fwd_kernel[grid](\n q, k, v, gk, gv, o, initial_state, final_state,\n q.stride(1), v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V,\n BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dht=None):\n q, k, v, gk, gv, initial_state, o = ctx.saved_tensors\n batch_size, n_heads, seq_len, K = q.shape\n V = v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(K, 64), min(V, 64)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, V, dtype=torch.float32)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n q.stride(1),\n v.stride(1), scale,\n B=batch_size, H=n_heads, T=seq_len, K=K, V=V, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n if gk is not None:\n _dgk = dq * q.float() - dk * k.float()\n if ctx.reverse:\n dgk = _dgk.cumsum(-2)\n else:\n _dgk_cumsum = _dgk.cumsum(-2)\n dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum\n else:\n dgk = None\n\n if gv is not None:\n _dgv = do.float() * o.float() - dv * v.float()\n if ctx.reverse:\n dgv = _dgv.cumsum(-2)\n else:\n _dgv_cumsum = _dgv.cumsum(-2)\n dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum\n else:\n dgv = None\n\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None\n\ndef fused_recurrent_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n gk: torch.Tensor = None,\n gv: torch.Tensor = None,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n if causal:\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)\n return o, final_state\n else:\n assert initial_state is None\n assert output_final_state is False\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state, False)\n o_reversed, final_state = FusedRecurrentGLAFunction.apply(\n q, k, v, gk, gv, scale, initial_state, output_final_state, True)\n return o, o_reversed\n", - "description_1": "Use triton language to implement a fused recurrent gated linear attention (GLA) forward and backward kernel. The forward kernel takes 20 parameters: q, k, v, gk, gv, o, h0, ht, s_qk_h, s_vo_h, scale, B, H, T, K, V, BK, BV, USE_INITIAL_STATE, STORE_FINAL_STATE, REVERSE, USE_GK, USE_GV. The backward kernel takes 21 parameters: q, k, v, gk, gv, do, dq, dk, dv, h0, s_qk_h, s_vo_h, scale, B, H, T, K, V, BK, BV, USE_INITIAL_STATE, REVERSE, USE_GK, USE_GV. The kernels are used in a custom autograd function to compute the forward and backward passes of the GLA operation.", - "description_2": "Use triton language to create a fused recurrent GLA operation with forward and backward passes, utilizing custom autograd functions in PyTorch.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_fwd_kernel_h(\n x,\n g,\n gc,\n o,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + i_t * BT * D + o_d\n p_g = g + i_bh * T * D + i_t * BT * D + o_d\n p_gc = gc + i_bh * T * D + i_t * BT * D + o_d\n p_o = o + i_bh * T * D + i_t * BT * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n b_gc = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n if i_t == 0:\n b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n for i in range(0, BT):\n mask_t = mask & ((i_t * BT + i) < T)\n b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n b_gc = b_gc + b_g\n tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)\n\n p_x += D\n p_g += D\n p_gc += D\n p_o += D\n\n\n@triton.jit\ndef chunk_hgrn_fwd_kernel_o(\n gc,\n o,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(1, tl.cdiv(T, BT)):\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_o = b_o + tl.exp(b_gc) * b_h0[None, :]\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_bwd_kernel_h(\n g,\n gc,\n dx,\n do,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n BC = min(BT, T - i_t * BT)\n NT = tl.num_programs(1)\n\n p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n\n if i_t == NT - 1:\n b_gc = tl.zeros([BD], dtype=tl.float32)\n else:\n b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for _ in range(BC - 1, -1, -1):\n tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)\n\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n\n b_gc = b_gc + b_g\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_gc -= D\n p_dx -= D\n p_do -= D\n\n\n@triton.jit\ndef chunk_hgrn_bwd_kernel_o(\n g,\n gc,\n o,\n dx,\n dg,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))\n p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n mask_t = mask & ((i_t + 1) * BT < T)\n b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)\n b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)\n b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]\n b_dg = b_o * b_dx * tl.exp(b_g)\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n o = torch.empty_like(x, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_fwd_kernel_h[grid](\n x, g, gc, o, initial_state,\n T, D,\n BT=BT,\n USE_INITIAL_STATE=initial_state is not None\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_fwd_kernel_o[grid](\n gc, o,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n final_state = None\n if output_final_state:\n final_state = o[:, :, -1].clone()\n o = o.to(x.dtype)\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n dx = torch.empty_like(o, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_bwd_kernel_h[grid](\n g, gc, dx, do,\n T, D,\n BT=BT\n )\n\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_bwd_kernel_o[grid](\n g, gc, o, dx, dg,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n if initial_state is not None:\n dg[:, :, 0] = (initial_state * dx[:, :, 0] * g[:, :, 0].float().exp()).to(dg.dtype)\n\n return dx.to(o.dtype), dg, None, None\n\n\ndef chunk_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)\n", - "description_1": "Use triton language to implement a chunkwise HGRN forward and backward pass. The forward kernel 'chunk_hgrn_fwd_kernel_h' takes 9 parameters: x (input tensor), g (gate tensor), gc (intermediate tensor), o (output tensor), h0 (initial state), T (sequence length), D (feature dimension), BT (block size for time), BD (block size for dimension), and USE_INITIAL_STATE (flag for initial state usage). The forward kernel 'chunk_hgrn_fwd_kernel_o' takes 8 parameters: gc, o, s_h, s_t, s_d (strides), T, D, BT, and BD. The backward kernel 'chunk_hgrn_bwd_kernel_h' takes 8 parameters: g, gc, dx, do, T, D, BT, and BD. The backward kernel 'chunk_hgrn_bwd_kernel_o' takes 9 parameters: g, gc, o, dx, dg, s_h, s_t, s_d, T, D, BT, and BD. The function 'chunk_hgrn' wraps these kernels for use in a PyTorch autograd function.", - "description_2": "Use triton language to create a chunkwise HGRN with forward and backward kernels for efficient sequence processing. Implement forward and backward passes with triton.jit kernels, handling input, gate, and output tensors, and supporting optional initial state.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_hgrn_fwd_kernel(\n x,\n g,\n o,\n h0,\n ht,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + o_d\n p_g = g + i_bh * T * D + o_d\n p_o = o + i_bh * T * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * D + o_d\n b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)\n for _ in range(0, T):\n b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)\n\n p_x += D\n p_g += D\n p_o += D\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * D + o_d\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)\n\n@triton.jit\ndef fused_recurrent_hgrn_bwd_kernel(\n g,\n o,\n dx,\n dg,\n do,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_g = g + (i_bh * T + T - 1) * D + o_d\n p_o = o + (i_bh * T + T - 2) * D + o_d\n p_dx = dx + (i_bh * T + T - 1) * D + o_d\n p_dg = dg + (i_bh * T + T - 1) * D + o_d\n p_do = do + (i_bh * T + T - 1) * D + o_d\n\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for i in range(T - 1, -1, -1):\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n if i > 0:\n b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)\n elif USE_INITIAL_STATE:\n b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n else:\n b_o = tl.zeros([BD], dtype=tl.float32)\n\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n b_dg = b_dh * b_o\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_o -= D\n p_dx -= D\n p_dg -= D\n p_do -= D\n\n\nclass FusedRecurrentHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n\n final_state = None\n if output_final_state:\n final_state = x.new_empty(B, H, D)\n\n o = torch.empty_like(x)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_fwd_kernel[grid](\n x, g, o, initial_state, final_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_bwd_kernel[grid](\n g, o, dx, dg, do, initial_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n )\n\n return dx, dg, None, None\n\n\ndef fused_recurrent_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n return FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)\n", - "description_1": "Use triton language to implement forward and backward kernels for a fused recurrent operation. The forward kernel takes 10 arguments: x, g, o, h0, ht (all tensors), T, D, BD, USE_INITIAL_STATE, STORE_FINAL_STATE (all constants) and computes recurrent updates with optional initial and final state handling. The backward kernel also takes 10 arguments: g, o, dx, dg, do, h0 (all tensors), T, D, BD, USE_INITIAL_STATE (all constants) and computes gradients with respect to x and g. A Python class wraps these kernels for PyTorch autograd compatibility.", - "description_2": "Use triton language to create a fused recurrent network with forward and backward operations, utilizing kernel functions for efficient computation and a Python class to integrate with PyTorch's autograd mechanism.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_linear_attn_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_linear_attn_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n\n tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\n\ndef fused_chunk_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunk linear attention mechanism with forward and backward kernels. The forward kernel takes 20 parameters: q, k, v, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, STORE_FINAL_STATE, CHECK. The backward kernel takes 22 parameters: q, k, v, do, dq, dk, dv, initial_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, CHECK. The function fused_chunk_linear_attn wraps these kernels for use in PyTorch, taking 7 parameters: q, k, v, scale, initial_state, output_final_state, normalize.", - "description_2": "Use triton language to create a fused chunk linear attention mechanism with both forward and backward operations, optimized for performance on GPUs. The implementation involves defining two kernels for forward and backward passes, and a PyTorch function to integrate these kernels into a neural network workflow.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef fused_recurrent_linear_attn_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_linear_attn_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n p_k += DK\n p_do += DV\n p_v += DV\n p_dq += DK\n tl.debug_barrier()\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + (T - 1) * DK\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + (T - 1) * DV\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n\nclass FusedRecurrentLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq, dk, dv, None, None\n\ndef fused_recurrent_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentLinearAttentionFunction.apply(\n q, k, v, initial_state, output_final_state)\n if normalize:\n o = normalize_output(q, k, o)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused recurrent linear attention mechanism with forward and backward kernels. The forward kernel takes 22 parameters, including query, key, value tensors, output tensor, initial and final state tensors, stride sizes, batch size, number of heads, sequence length, scaling factor, block sizes, dimensions, and constants to determine state usage and storage. It computes a linear attention result for each time step, optionally using an initial state and storing the final state. The backward kernel takes 21 parameters, including query, key, value tensors, gradient of output, and gradients of query, key, value, initial state, stride sizes, batch size, number of heads, sequence length, scaling factor, block sizes, dimensions, and a constant to determine initial state usage. It computes gradients for the input tensors using the stored intermediate states.", - "description_2": "Use triton language to create forward and backward kernels for fused recurrent linear attention, where the forward pass calculates the attention output and the backward pass computes the gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, \n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t),\n (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (0, i_v * BV), (BTS, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t),\n (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len, d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len, device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel rebased forward and backward kernel for a transformer-like operation. The forward kernel computes attention scores and outputs by reading blocks of queries, keys, and values, then accumulating results in shared memory. The backward kernel calculates gradients for query, key, and value tensors. Both kernels handle stride and data indexing for block-wise processing in large tensors, ensuring efficient memory usage.", - "description_2": "Use triton language to implement transformer-like operations with parallelized forward and backward kernels for efficient memory management and processing.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n if i == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)\n d_b = tl.math.exp2(BT * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n b_dq = tl.dot(b_ds, b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n d_s = tl.trans(d_s)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\n\ndef fused_chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunk retention forward and backward kernel for a transformer model. The forward kernel takes 20 parameters: q, k, v, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, STORE_FINAL_STATE, CHECK. The backward kernel takes 21 parameters: q, k, v, do, dq, dk, dv, initial_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BT, BK, BV, DK, DV, USE_INITIAL_STATE, CHECK. The kernels perform operations on blocks of data to compute attention scores and gradients efficiently.", - "description_2": "Use triton language to create a fused chunk retention function for a transformer model, which includes both forward and backward passes. The function should handle input tensors q, k, v, and optionally initial_state, and return the output tensor and final_state. The function should be optimized for performance using Triton's grid and block mechanisms.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef parallel_retention_fwd_kernel(\n q, k, v, o, s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n o_k = tl.arange(0, BTS)\n d_h = tl.math.exp2((BTS - o_k) * b_b)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :]\n b_o = b_o * tl.math.exp2(b_b * BTS)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n d_q = tl.math.exp2(tl.arange(0, BTL) * b_b)\n b_o *= d_q[:, None]\n\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (o_q[:, None] - o_k[None, :]) * b_b), 0)\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef parallel_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelRetentionFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_retention_fwd_kernel[grid](\n q, k, v, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n return o.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)\n\n\nparallel_retention = ParallelRetentionFunction.apply\n", - "description_1": "Use triton language to implement parallel retention mechanism with four kernels: forward and backward kernels that handle the parallelism in sequence and head dimensions. The kernels perform operations on query, key, value, and output tensors with multiple constant parameters for block size and dimensionality. Each kernel requires careful memory access through block pointers, cumulative decay application, and dot products for matrix operations. The backward kernel also calculates gradients with respect to input tensors.", - "description_2": "Use triton language to create a parallel retention mechanism for handling forward and backward operations in sequence and head dimensions with careful memory and computational management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef fused_recurrent_rwkv4_forward_kernel(\n # W\n w_ptr,\n w_s_c,\n # U\n u_ptr,\n u_s_c,\n # K\n k_ptr,\n k_s_b,\n k_s_t,\n k_s_c,\n # V\n v_ptr,\n v_s_b,\n v_s_t,\n v_s_c,\n # State\n state_ptr,\n state_s_b,\n state_s_abe,\n state_s_c,\n # WKV\n wkv_ptr,\n wkv_s_b,\n wkv_s_t,\n wkv_s_c,\n # Output state\n state_out_ptr,\n state_out_s_b,\n state_out_s_abe,\n state_out_s_t,\n state_out_s_c,\n # Params\n chans,\n tsz,\n BLOCK_SIZE_C: tl.constexpr,\n):\n # Parallelize over the batch dimension.\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n\n # Pointers to the batch (and possibly channel) for the input tensors.\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n\n # Pointers to the batch (and possibly channel) for the output tensors.\n wkv_ptr = wkv_ptr + b_idx * wkv_s_b\n alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b\n beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe\n eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe\n\n # Loads parameters.\n alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)\n\n ukt = u + kt\n tau = tl.maximum(ukt, eps)\n e1a = tl.exp(eps - tau)\n e2a = tl.exp(ukt - tau)\n wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)\n tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)\n\n w_eps = w + eps\n eps = tl.maximum(w_eps, kt)\n e1b = tl.exp(w_eps - eps)\n e2b = tl.exp(kt - eps)\n alpha = e1b * alpha + e2b * vt\n beta = e1b * beta + e2b\n tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)\n tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)\n tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)\n\n\ndef fused_recurrent_rwkv4_forward(\n w: Tensor,\n u: Tensor,\n k: Tensor,\n v: Tensor,\n state: Tensor,\n) -> tuple[Tensor, Tensor]:\n (bsz, tsz, chans) = k.shape\n\n # New tensors to output.\n wkvs = k.new_empty(bsz, tsz, chans)\n state_out = k.new_empty(bsz, 3, tsz, chans)\n\n # Constants.\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_forward_kernel[grid](\n # W\n w,\n w.stride(0),\n # U\n u,\n u.stride(0),\n # K\n k,\n k.stride(0),\n k.stride(1),\n k.stride(2),\n # V\n v,\n v.stride(0),\n v.stride(1),\n v.stride(2),\n # State\n state,\n state.stride(0),\n state.stride(1),\n state.stride(3),\n # WKV\n wkvs,\n wkvs.stride(0),\n wkvs.stride(1),\n wkvs.stride(2),\n # Output state\n state_out,\n state_out.stride(0),\n state_out.stride(1),\n state_out.stride(2),\n state_out.stride(3),\n # Params\n chans,\n tsz,\n BLOCK_SIZE_C=block_size_c,\n )\n\n state_out = torch.cat((state, state_out), dim=2)\n\n return wkvs, state_out\n\n\n@triton.jit\ndef fused_recurrent_rwkv4_backward_kernel(\n # W\n w_ptr,\n w_s_c,\n # U\n u_ptr,\n u_s_c,\n # K\n k_ptr,\n k_s_b,\n k_s_t,\n k_s_c,\n # V\n v_ptr,\n v_s_b,\n v_s_t,\n v_s_c,\n # State\n state_ptr,\n state_s_b,\n state_s_abe,\n state_s_t,\n state_s_c,\n # WKV grad\n gwkv_ptr,\n gwkv_s_b,\n gwkv_s_t,\n gwkv_s_c,\n # Output state grad\n gstate_out_ptr,\n gstate_out_s_b,\n gstate_out_s_abe,\n gstate_out_s_c,\n # W grad\n gw_ptr,\n gw_s_c,\n # U grad\n gu_ptr,\n gu_s_c,\n # K grad\n gk_ptr,\n gk_s_b,\n gk_s_t,\n gk_s_c,\n # V grad\n gv_ptr,\n gv_s_b,\n gv_s_t,\n gv_s_c,\n # State grad\n gstate_ptr,\n gstate_s_b,\n gstate_s_abe,\n gstate_s_c,\n # Params\n tsz,\n chans,\n BLOCK_SIZE_C: tl.constexpr,\n):\n # Parallelize over the batch dimension.\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n\n # Pointers to the batch (and possibly channel) for the input tensors.\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n\n # Pointers to the batch (and possibly channel) for the output tensors.\n gk_ptr = gk_ptr + b_idx * gk_s_b\n gv_ptr = gv_ptr + b_idx * gv_s_b\n\n # Pointers to gradients which were recieved by the function.\n gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b\n galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b\n gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe\n geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe\n\n # Loads parameters.\n galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)\n\n # Gradient accumulators.\n gw = tl.zeros_like(w)\n gu = tl.zeros_like(u)\n\n alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n tc = tsz - t - 1\n\n kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)\n\n alpha_curr = alpha_prev\n beta_curr = beta_prev\n eps_curr = eps_prev\n\n alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n ukt = u + kt\n tau = tl.maximum(ukt, eps_prev)\n e1 = tl.exp(eps_prev - tau)\n e2 = tl.exp(ukt - tau)\n\n euke = tl.exp(ukt + eps_prev - 2 * tau)\n\n denom = e1 * beta_prev + e2\n denom_sq = denom * denom\n\n gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)\n\n # Backpropagates wkv gradients.\n guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq\n gu += guk\n gk = guk\n gv = gwkvt * e2 / denom\n\n galpha_wkv = gwkvt * e1 / denom\n gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq\n geps_wkv_denom = e1 * beta_prev + e2\n geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)\n\n e1 = tl.exp(w + eps_prev - eps_curr)\n e2 = tl.exp(kt - eps_curr)\n\n # Backpropagates alpha gradients.\n galpha_we = galpha * e1 * alpha_prev\n gw += galpha_we\n gk += galpha * e2 * vt\n gv += galpha * e2\n geps += galpha * -alpha_curr\n\n # Backpropagates beta gradients.\n gbeta_we = gbeta * e1 * beta_prev\n gw += gbeta_we\n gk += gbeta * e2\n geps += gbeta * -beta_curr\n\n # Backpropagates epsilon gradients.\n geps_mask = w + eps_prev > kt\n geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))\n gw += geps_we\n gk += tl.where(geps_mask, tl.zeros_like(geps), geps)\n\n # Stores the gradients for k and v.\n tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)\n tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)\n\n # Computes new gradients for alpha and beta.\n galpha = galpha * e1 + galpha_wkv\n gbeta = gbeta * e1 + gbeta_wkv\n geps = galpha_we + gbeta_we + geps_we + geps_wkv\n\n # Stores final gradients for alpha and beta.\n galpha_ptr = gstate_ptr + b_idx * gstate_s_b\n gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe\n geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe\n tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)\n tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)\n tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)\n\n # Stores final gradients for w and u.\n gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)\n gw_temp += gw\n tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)\n gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)\n gu_temp += gu\n tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)\n\n\ndef fused_recurrent_rwkv4_backward(\n w: Tensor,\n u: Tensor,\n k: Tensor,\n v: Tensor,\n state: Tensor,\n grad_wkv: Tensor,\n grad_state: Tensor,\n) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n bsz, tsz, chans = k.shape\n\n gw = torch.zeros_like(w) # New tensors to output.\n gu = torch.zeros_like(u)\n gk = torch.empty_like(k)\n gv = torch.empty_like(v)\n gstate = k.new_empty(bsz, 3, 1, chans)\n\n block_size_c = get_block_size_c(chans) # Constants.\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_backward_kernel[grid](\n # W\n w,\n w.stride(0),\n # U\n u,\n u.stride(0),\n # K\n k,\n k.stride(0),\n k.stride(1),\n k.stride(2),\n # V\n v,\n v.stride(0),\n v.stride(1),\n v.stride(2),\n # State\n state,\n state.stride(0),\n state.stride(1),\n state.stride(2),\n state.stride(3),\n # WKV grad\n grad_wkv,\n grad_wkv.stride(0),\n grad_wkv.stride(1),\n grad_wkv.stride(2),\n # Output state grad\n grad_state,\n grad_state.stride(0),\n grad_state.stride(1),\n grad_state.stride(3),\n # W grad\n gw,\n gw.stride(0),\n # U grad\n gu,\n gu.stride(0),\n # K grad\n gk,\n gk.stride(0),\n gk.stride(1),\n gk.stride(2),\n # V grad\n gv,\n gv.stride(0),\n gv.stride(1),\n gv.stride(2),\n # State grad\n gstate,\n gstate.stride(0),\n gstate.stride(1),\n gstate.stride(3),\n # Params\n tsz,\n chans,\n BLOCK_SIZE_C=block_size_c,\n )\n\n return gw, gu, gk, gv, gstate\n", - "description_1": "Use triton language to implement a fused recurrent RWKV forward and backward kernel. The forward kernel takes 25 parameters: pointers to tensors w, u, k, v, state, wkv, state_out, and their respective strides, along with the number of channels, time size, and block size. It computes the RWKV forward pass by iterating over the time dimension and updating the state and wkv tensors. The backward kernel takes 34 parameters: pointers to tensors w, u, k, v, state, gwkv, gstate_out, gw, gu, gk, gv, gstate, and their respective strides, along with the number of channels, time size, and block size. It computes the gradients for the RWKV backward pass by iterating over the time dimension in reverse and updating the gradient tensors.", - "description_2": "Use triton language to create a fused recurrent RWKV forward kernel with 25 parameters for computing the forward pass, and a backward kernel with 34 parameters for computing the backward pass, both iterating over the time dimension.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_cum(\n s,\n o,\n o_minus_s,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef post_process_grad(\n q,\n k,\n v,\n u,\n do,\n dk,\n dq,\n du,\n scale,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n H,\n T: tl.constexpr,\n BT: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n i_h = i_bh % H\n\n # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_u = tl.load(p_u, boundary_check=(0,))\n\n b_vdo = tl.sum(b_v * b_do, axis=1)\n b_du = b_vdo[:, None] * b_k * b_q * scale\n b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale\n b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale\n\n b_dq += tl.load(p_dq, boundary_check=(0, 1))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_dk += tl.load(p_dk, boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))\n\nclass ChunkRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):\n q = r # alias\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_rwkv6_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n chunk_rwkv6_fwd_kernel_cum[grid](\n g_org, g, gs,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_rwkv6_fwd_kernel_intra[grid](\n q, k, g, gs, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n\n grid = (NV, NT, B * H)\n chunk_rwkv6_fwd_kernel_inter[grid](\n q, v, gs, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n del g, gs\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n q, k, v, g, u, h, initial_state, A = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = ctx.BT, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_rwkv6_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n dh = q.new_empty(B, H, NT * K, V)\n dh0 = torch.empty_like(h0) if h0 is not None else None\n grid = (NK, NV, B * H)\n chunk_rwkv6_bwd_kernel_dh[grid](\n q, g, gs, do, dh, dh0,\n q.stride(1), q.stride(2), q.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n dh.stride(1), dh.stride(2), dh.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return dh, dh0\n\n # recompute cumulative log decays.\n g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n chunk_rwkv6_fwd_kernel_cum[grid](\n g_org, g, gs,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n\n # rerun the forward pass to get h if checkpoint_level >= 1\n if ctx.checkpoint_level == 1:\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=None\n )\n\n scale = ctx.scale\n # g, gs: torch.float32\n dh, dh0 = bwd_inner(\n q.to(torch.float), g, gs, initial_state, do.to(torch.float),\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n scale=scale\n )\n dh, dh0 = dh.to(q), dh0.to(q)\n dq = torch.empty_like(q, dtype=torch.float)\n dk = torch.empty_like(k, dtype=torch.float)\n dv = v.new_empty(NK, *v.shape)\n dA = q.new_zeros(B, H, T, BT)\n grid = (NK, NT, B * H)\n chunk_rwkv6_bwd_kernel_inter[grid](\n k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0, dtype=dv.dtype)\n grid = (NK, NT * NC, B * H)\n chunk_rwkv6_bwd_kernel_intra[grid](\n q, k, g, gs, dA, dq, dk,\n k.stride(1), k.stride(2), k.stride(3),\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n # TODO: fuse?\n dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]\n dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)\n dg = chunk_reversed_cumsum_fwd(dg).to(g)\n # equivalent to the following pytorch code.\n # du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)\n # dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])\n # dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])\n BT = 64\n grid = (triton.cdiv(T, BT), B * H)\n du = torch.empty_like(g, dtype=torch.float)\n post_process_grad[grid](\n q, k, v, u, do, dk, dq, du, scale,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), H=H,\n T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),\n num_warps=4\n )\n du = du.sum([0, 2])\n return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None\n\ndef chunk_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n u: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 0\n) -> Tuple[torch.Tensor, torch.Tensor]:\n r\"\"\"\n Args:\n r (torch.Tensor):\n reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.\n k (torch.Tensor):\n keys of shape `(B, H, T, K)`\n v (torch.Tensor):\n values of shape `(B, H, T, V)`\n w (torch.Tensor):\n data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.\n u (torch.Tensor):\n bonus of shape `(H, K)`\n scale (Optional[int]):\n Scale factor for the RWKV6 attention scores.\n If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n initial_state (Optional[torch.Tensor]):\n Initial state of shape `(B, H, K, V)`. Default: `None`.\n output_final_state (Optional[bool]):\n Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.\n checkpoint_level (Optional[int]):\n Checkpointing level; higher values will save more memories and do more recomputations during backward.\n Default: `0`:\n - Level `0`: store forward hidden states for backprop.\n - Level `1`: recompute the forward hidden states during backward.\n \"\"\"\n assert checkpoint_level in [0, 1]\n if scale is None:\n scale = r.shape[-1] ** -0.5\n o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)\n return o, final_state\n", - "description_1": "Use triton language to implement several kernels for computing a forward and backward pass for the ChunkRWKV6 function in RWKV-based attention. The kernels handle cumulative operations, forward computations, backward gradient computations, and post-processing gradients.", - "description_2": "Use triton language to implement a kernel for computing cumulative operations and a kernel for post-processing gradients in a machine learning attention mechanism.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_fwd\n\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel32(\n q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr,\n K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)\n b_w = tl.exp(b_w)\n b_kv = b_k[None, :] * b_v[:, None]\n b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n b_h = b_h * b_w[None, :]\n b_h += b_kv\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel(\n q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr,\n K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n TargetDType = tl.bfloat16\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n b_h = tl.zeros([BV, BK], dtype=TargetDType)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(TargetDType)\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(TargetDType)\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(TargetDType)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(TargetDType)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(TargetDType) * scale\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(TargetDType)\n b_w = tl.exp(b_w.to(tl.float32)).to(TargetDType)\n b_kv = b_k[None, :] * b_v[:, None]\n b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n b_h = b_h * b_w[None, :]\n b_h += b_kv\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel16(\n q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr,\n K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n TargetDType = tl.bfloat16\n TargetDType2 = tl.bfloat16\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n b_h = tl.zeros([BV, BK], dtype=TargetDType)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(TargetDType)\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(TargetDType)\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(TargetDType2)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(TargetDType2)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(TargetDType2) * scale\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(TargetDType2)\n b_w = tl.exp(b_w.to(tl.float32)).to(TargetDType)\n b_kv = b_k[None, :] * b_v[:, None]\n b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n b_h = b_h * b_w[None, :]\n b_h += b_kv.to(TargetDType)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n\nclass FusedRecurrentRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):\n q = r\n B, H, T, K, V = *q.shape, v.shape[-1]\n BK, BV = min(triton.next_power_of_2(K), 128), min(triton.next_power_of_2(V), 128)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n if r.dtype == torch.float16 and 0:\n o = q.new_empty(NK, B, H, T, V, dtype=torch.float16)\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_fwd_kernel16[grid](\n q, k, v, w, u, o, initial_state, final_state,\n k.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n else:\n o = q.new_empty(NK, B, H, T, V, dtype=torch.bfloat16)\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_fwd_kernel[grid](\n q, k, v, w, u, o, initial_state, final_state,\n k.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n o = o.sum(0)\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n\ndef fused_recurrent_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n w: torch.Tensor,\n u: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = r.shape[-1] ** -0.5\n o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to create a fused recurrent kernel that processes a query, key, value, log gate, and bonus inputs for RWKV6 forward operations. It handles optional initial and final states, stride sizes, scaling, and allows autoregressive modeling in the reverse direction.", - "description_2": "Use triton language to perform fused recurrent computations for RWKV6, including query, key, value transformations and state management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_h(\n k,\n v,\n h,\n g,\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n final_state, # final state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,\n (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BK, BV]\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n b_h *= tl.math.exp2(b_g_last)\n b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))\n b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(\n final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q,\n k,\n v,\n h,\n g,\n o,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT]\n\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_o = b_o * tl.math.exp2(b_g)[:, None]\n b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dh(\n q,\n g,\n do,\n dh,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +\n i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)\n # [BT, V]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BK, BV]\n b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkv(\n q,\n k,\n v,\n h,\n g,\n do,\n dh,\n dq,\n dk,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n mask = tl.math.exp2(b_g[None, :] - b_g[:, None])\n mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)\n b_s = b_s * mask\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),\n (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),\n (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BV, BK]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # [BK, BV]\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n # [BT, BT]\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n # [BT, BK]\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n # [BT, BV]\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \\\n tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_dq = b_dq * tl.math.exp2(b_g)[:, None]\n b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]\n b_ds = b_ds * tl.trans(mask)\n b_ds = b_ds.to(b_k.dtype)\n # [BT, BK]\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass SimpleGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, g, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(\n 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n BT = 64\n assert T % BT == 0, 'sequence length must be divisible by BT'\n g = g.reshape(B, H, -1, BT)\n g = g.cumsum(-1) * 1.44269504\n g = g.reshape(B, H, -1)\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_fwd_kernel_h[grid](\n k, v, h, g, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h, g)\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_ht=None):\n q, k, v, h, g = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(\n 32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_bwd_kernel_dh[grid](\n q, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_simple_gla_bwd_kernel_dqkv[grid](\n q, k, v, h, g, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n dg = (dq * q - dk * k).sum(-1)\n\n def rev_cumsum(x):\n cumsum_x = x.cumsum(-1)\n rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x\n return rev_cumsum_x + x\n dg = rev_cumsum(dg)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None\n\n\ndef chunk_simple_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor, # log decay\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n g = g.float()\n o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a series of kernels that perform generalized linear attention with optional initial and final state saving. The forward kernel (chunk_simple_gla_fwd_kernel_h) handles the initial forward pass, optionally using an initial state tensor, and computes intermediate results in a chunk-wise fashion. The backward kernels (chunk_simple_gla_bwd_kernel_dh and chunk_simple_gla_bwd_kernel_dqkv) compute gradients of hidden states and input tensors using a similar chunk-based approach.", - "description_2": "Use triton language to implement generalized linear attention with chunk-based forward and backward passes and optional state handling.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef logcumsumexp_fwd_kernel(\n s, z, s_s_h, s_s_t, s_s_d, T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr\n):\n i_bh = tl.program_id(0)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)\n b_zp = tl.zeros([S,], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))\n\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_mc = tl.max(b_s, 0)\n if i_t > 0:\n b_mc = tl.maximum(b_mp, b_mc)\n b_zp = b_zp * tl.exp(b_mp - b_mc)\n b_s = tl.exp(b_s - b_mc)\n b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp\n b_zc = tl.max(b_z, 0)\n b_mp = b_mc\n b_zp = b_zc\n b_z = tl.log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef softmax_fwd_kernel(\n s, p, s_s_h, s_s_t, s_s_d, T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))\n p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))\n\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_m = tl.max(b_s, 1)\n b_s = tl.exp(b_s - b_m[:, None])\n b_z = tl.sum(b_s, 1)\n b_p = tl.where(b_s != 0, b_s / b_z[:, None], 0.)\n tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_fwd_kernel(\n s, z, s_s_h, s_s_t, s_s_d, T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\ndef chunk_cumsum_fwd(s: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_cumsum_fwd_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n", - "description_1": "Use triton language to implement multiple kernels and wrapper functions: `logcumsumexp_fwd_kernel` computes the forward pass of log cumulative sum of exponentials, `softmax_fwd_kernel` performs softmax computation, `chunk_cumsum_fwd_kernel` computes the chunk-wise cumulative sum. Each function operates on input tensors, utilizing triton's block mapping and parallel processing capabilities. The functions use grid indexing to handle different chunks of data in parallel.", - "description_2": "Use triton language to implement kernels for log cumulative sum of exponentials, softmax, and chunk-wise cumulative sum, each processing data blocks in parallel using grid indexing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_fwd_kernel(\n x,\n y,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_y = y + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_m = tl.minimum(0., b_x)\n b_z = 1. + tl.exp(-tl.abs(b_x))\n b_y = b_m - tl.log(b_z)\n tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_bwd_kernel(\n x,\n dx,\n dy,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_dx = dx + o_i\n p_dy = dy + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)\n b_dx = b_dy * (1. - tl.sigmoid(b_x))\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n\nclass LogSigmoidFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x):\n T, D = x.numel(), x.shape[-1]\n y = torch.empty_like(x)\n logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)\n ctx.save_for_backward(x,)\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, = ctx.saved_tensors\n T, D = x.numel(), x.shape[-1]\n dx = torch.empty_like(x)\n logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)\n return dx\n\n\nlogsigmoid = LogSigmoidFunction.apply\n", - "description_1": "Use triton language to implement a forward and backward kernel for the logsigmoid function. The forward kernel, logsigmoid_fwd_kernel, takes 5 parameters: x (input tensor), y (output tensor), T (total number of elements), D (dimension size), and BT (block size). It computes the logsigmoid of the input tensor and stores the result in the output tensor. The backward kernel, logsigmoid_bwd_kernel, also takes 5 parameters: x (input tensor), dx (gradient of input), dy (gradient of output), T (total number of elements), and D (dimension size). It computes the gradient of the logsigmoid function with respect to the input tensor.", - "description_2": "Use triton language to create a logsigmoid function with forward and backward passes, utilizing triton.jit for kernel compilation and triton.autotune for performance optimization.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_quant_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n RESIDUAL, # pointer to the residual\n RESIDUAL_OUT, # pointer to the residual\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_res_row,\n stride_res_out_row,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n # Normalize and apply linear transformation\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n\n # Aply quantization to the output\n scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)\n # Quantize and then de-quantize the tensor\n y = tl.math.round(y * scale)\n y = tl.maximum(tl.minimum(y, 127), -128) / scale\n\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd_quant(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n # allocate output\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_quant_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n # residual_out is None if residual is None and residual_dtype == input_dtype\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n B, # pointer to the biases\n Y, # pointer to the output to be recomputed\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n DRESIDUAL,\n DRESIDUAL_IN,\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride_x_row, # how much to increase the pointer when moving by 1 row\n stride_y_row,\n stride_dy_row,\n stride_dx_row,\n stride_dres_row,\n stride_dres_in_row,\n M, # number of rows in X\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n rows_per_program,\n IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr,\n HAS_WEIGHT: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n # Load data to SRAM\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # Compute dx\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w if HAS_WEIGHT else xhat\n if HAS_BIAS:\n y = y + b\n\n # Aply quantization to the output\n scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)\n # Quantize and then de-quantize the tensor\n y = tl.math.round(y * scale)\n y = tl.maximum(tl.minimum(y, 127), -128) / scale\n\n tl.store(Y + cols, y, mask=mask)\n wdy = dy\n if HAS_WEIGHT:\n wdy = dy * w\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n # Write dx\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n if HAS_WEIGHT:\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy,\n x,\n weight,\n bias,\n eps,\n mean,\n rstd,\n dresidual=None,\n has_residual=False,\n is_rms_norm=False,\n x_dtype=None,\n recompute_output=False,\n):\n M, N = x.shape\n # allocate output\n dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) if weight is not None else None\n _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x,\n weight,\n bias,\n y,\n dy,\n dx,\n _dw,\n _db,\n dresidual,\n dresidual_in,\n mean,\n rstd,\n x.stride(0),\n 0 if not recompute_output else y.stride(0),\n dy.stride(0),\n dx.stride(0),\n dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M,\n N,\n eps,\n rows_per_program,\n is_rms_norm,\n BLOCK_N,\n dresidual is not None,\n dresidual_in is not None,\n weight is not None,\n bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype) if weight is not None else None\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n # Don't need to compute dresidual_in separately in this case\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement a fused layer normalization and quantization kernel. The forward kernel (_layer_norm_fwd_quant_kernel) takes 18 parameters: pointers to input, output, weights, biases, residuals, mean, and rstd, strides for input, output, and residuals, number of columns, epsilon for numerical stability, and several compile-time constants. It computes the mean and variance, normalizes the input, applies weights and biases, and quantizes the output. The backward kernel (_layer_norm_bwd_kernel) takes 27 parameters: pointers to input, weights, biases, output, gradients, mean, rstd, strides, number of rows and columns, epsilon, rows per program, and several compile-time constants. It computes gradients for input, weights, biases, and residuals, and optionally recomputes the output.", - "description_2": "Use triton language to create a fused layer normalization and quantization operator with forward and backward passes. The forward pass normalizes input, applies linear transformation, and quantizes the result. The backward pass computes gradients for input, weights, biases, and optionally recomputes the output.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n z_loss_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n n_rows,\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n SPLIT: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n max_logits = tl.max(logits, 0)\n if HAS_SMOOTHING:\n sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)\n lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits\n tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)\n if label_idx == ignored_index:\n loss = 0.0\n z_loss = 0.0\n else:\n label_idx -= class_start_idx\n if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(\n n_cols, (col_block_idx + 1) * BLOCK_SIZE\n ):\n logits_label = tl.load(logits_ptr + label_idx) * logit_scale\n if HAS_SMOOTHING:\n loss = (\n (lse if not SPLIT else 0.0)\n - smoothing * sum_logits / total_classes\n - (1 - smoothing) * logits_label\n )\n else:\n loss = (lse if not SPLIT else 0.0) - logits_label\n else:\n if HAS_SMOOTHING:\n loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n else:\n loss = 0.0\n if not SPLIT:\n z_loss = lse_square_scale * lse * lse\n loss += z_loss\n else:\n z_loss = 0.0\n tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)\n if not SPLIT:\n tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx != ignored_index:\n dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n else:\n dloss = 0.0\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n lse = tl.load(lse_ptr + row_idx)\n probs = tl.exp(logits - lse)\n probs += 2.0 * lse_square_scale * lse * probs\n label_idx -= class_start_idx\n if HAS_SMOOTHING:\n smooth_negative = smoothing / total_classes\n probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative\n else:\n probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)\n\ndef cross_entropy_loss(\n logits: torch.Tensor,\n labels: torch.Tensor,\n label_smoothing: float = 0.0,\n logit_scale: float = 1.0,\n lse_square_scale: float = 0.0,\n ignored_index=-100,\n inplace_backward: bool = False,\n process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n return CrossEntropyLossFunction.apply(\n logits,\n labels,\n label_smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n inplace_backward,\n process_group,\n )\n", - "description_1": "Use triton language to implement forward and backward kernels for a cross-entropy loss function with optional label smoothing, scaling, and z-loss. The forward kernel computes the loss for each input batch by considering different splits and smoothing techniques, while the backward kernel calculates gradients based on precomputed losses. Both kernels are customized by constants like BLOCK_SIZE and whether smoothing is applied.", - "description_2": "Use triton language to implement cross-entropy loss with label smoothing and scaling.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, O, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row,\n N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n O += row * stride_x_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)\n y = y * o * tl.sigmoid(o)\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, o, y, weight, bias, residual, residual_out, mean, rstd,\n x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N, eps, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None,\n weight is not None, bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, O, W, B, Y, DY, DX, DO, DW, DB, DRESIDUAL, DRESIDUAL_IN, Mean, Rstd,\n stride_x_row, stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row,\n stride_dres_in_row, M, N, eps, rows_per_program, IS_RMS_NORM: tl.constexpr,\n BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr,\n HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n O += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n DO += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n y = xhat * w if HAS_WEIGHT else xhat\n if HAS_BIAS:\n y = y + b\n if RECOMPUTE_OUTPUT:\n tl.store(Y + cols, y, mask=mask)\n sigmoid_o = tl.sigmoid(o)\n do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))\n dy = dy * o * sigmoid_o\n wdy = dy\n if HAS_WEIGHT:\n wdy = dy * w\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n tl.store(DO + cols, do, mask=mask)\n X += stride_x_row\n O += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n DO += stride_dx_row\n if HAS_WEIGHT:\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\ndef _layer_norm_bwd(\n dy, x, o, weight, bias, eps, mean, rstd, dresidual=None,\n has_residual=False, is_rms_norm=False, x_dtype=None, recompute_output=False,\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n do = (\n torch.empty_like(o)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = (\n torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n if weight is not None\n else None\n )\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x, o, weight, bias, y, dy, dx, do, _dw, _db, dresidual, dresidual_in,\n mean, rstd, x.stride(0), 0 if not recompute_output else y.stride(0),\n dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0,\n dresidual_in.stride(0) if dresidual_in is not None else 0,\n M, N, eps, rows_per_program, is_rms_norm, BLOCK_N,\n dresidual is not None, dresidual_in is not None, weight is not None, bias is not None,\n )\n dw = _dw.sum(0).to(weight.dtype) if weight is not None else None\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement a fused layer normalization with Swish gate, supporting both forward and backward passes. The forward kernel (_layer_norm_fwd_1pass_kernel) takes 19 parameters: pointers to input, gate, output, weights, biases, residuals, mean, and rstd, along with strides, number of columns, epsilon, and several constexpr flags. The backward kernel (_layer_norm_bwd_kernel) takes 30 parameters: pointers to input, gate, weights, biases, output, gradients, mean, rstd, and several strides, along with dimensions, epsilon, rows per program, and several constexpr flags. The forward function (_layer_norm_fwd) and backward function (_layer_norm_bwd) handle the setup and invocation of these kernels.", - "description_2": "Use triton language to create a fused layer normalization with Swish gate, including both forward and backward operations, optimized for performance with configurable parameters and support for residual connections.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\"],\n)\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0)\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n mask = cols < N\n y = x * rstd\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\"],\n)\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the elements of X, DX, and DY it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n DX += row * stride_x_row\n DY += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n var = tl.sum(x * x)\n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)\n dy = tl.where(cols < N, dy, 0.0)\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\n\ndef _l2_norm_fwd(\n x, eps=1e-6\n):\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n M, N = x.shape\n assert x.stride(-1) == 1\n # allocate output\n y = torch.empty_like(x)\n assert y.stride(-1) == 1\n N = x.shape[-1]\n M = x.shape[0]\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return y.reshape(x_shape_og)\n\n\ndef _l2_norm_bwd(\n x, dy, eps=1e-5,\n):\n x_shape_og = x.shape\n x = x.reshape(-1, dy.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n assert dy.shape == x.shape\n # allocate output\n dx = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _l2_norm_bwd_kernel[(M,)](\n x,\n dy,\n dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n", - "description_1": "Use triton language to implement forward and backward kernels for L2 normalization. The forward kernel (_l2_norm_fwd_1pass_kernel) takes 6 arguments: X (input pointer), Y (output pointer), stride_x_row (stride for rows in X), N (number of columns in X), eps (epsilon for numerical stability), and BLOCK_N (block size for computation). The backward kernel (_l2_norm_bwd_kernel) takes 7 arguments: X (input pointer), DY (output gradient pointer), DX (input gradient pointer), stride_x_row (stride for rows in X), N (number of columns in X), eps (epsilon for numerical stability), and BLOCK_N (block size for computation). Both kernels are decorated with @triton.jit for compilation and optimization with multiple configurations.", - "description_2": "Use triton language to define L2 normalization forward and backward kernels with configurable execution parameters for optimized computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, stride_x_row, stride_y_row,\n stride_res_row, stride_res_out_row, N, G, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_WEIGHT: tl.constexpr, \n HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n group = row % G\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + group * stride_x_row + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + group * stride_x_row + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False, \n num_groups=1\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N, G = *x.shape, num_groups\n if residual is not None:\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (G * N,)\n if bias is not None:\n assert bias.shape == (G * N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, residual_out, mean, rstd, x.stride(0), y.stride(0), \n residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, \n N, G, eps, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, \n weight is not None, bias is not None\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, W, B, Y, DY, DX, DW, DB, DRESIDUAL, DRESIDUAL_IN, Mean, Rstd, stride_x_row, stride_y_row,\n stride_dy_row, stride_dx_row, stride_dres_row, stride_dres_in_row, M, N, G, rows_per_program,\n programs_per_group, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr,\n STORE_DRESIDUAL: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr, \n RECOMPUTE_OUTPUT: tl.constexpr\n):\n row_block_id = tl.program_id(0)\n group_id, program_id_in_group = row_block_id // programs_per_group, row_block_id % programs_per_group\n row_start = group_id + program_id_in_group * G * rows_per_program\n row_end = min(row_start + G * rows_per_program, M)\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + group_id * stride_x_row + cols, mask=mask).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + group_id * stride_x_row + cols, mask=mask, other=0.0).to(tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n\n for row in range(row_start, row_end, G):\n x = tl.load(X + row * stride_x_row + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + row * stride_dy_row + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w if HAS_WEIGHT else xhat\n if HAS_BIAS:\n y = y + b\n tl.store(Y + row * stride_y_row + cols, y, mask=mask)\n wdy = dy\n if HAS_WEIGHT:\n wdy = dy * w\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + row * stride_dres_row + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + row * stride_dres_in_row + cols, dx, mask=mask)\n tl.store(DX + row * stride_dx_row + cols, dx, mask=mask)\n\n if HAS_WEIGHT:\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\ndef _layer_norm_bwd(\n dy, x, weight, bias, eps, mean, rstd, dresidual=None, has_residual=False, is_rms_norm=False, \n x_dtype=None, recompute_output=False, num_groups=1\n):\n M, N, G = *x.shape, num_groups\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (G * N,)\n if bias is not None:\n assert bias.shape == (G * N,)\n dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n S = triton.cdiv(torch.cuda.get_device_properties(x.device).multi_processor_count, G) * G\n dw = torch.empty((S, N), dtype=torch.float32, device=weight.device) if weight is not None else None\n db = torch.empty((S, N), dtype=torch.float32, device=bias.device) if bias is not None else None\n rows_per_program = triton.cdiv(M, S)\n programs_per_group = S // G\n grid = (S,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x, weight, bias, y, dy, dx, dw, db, dresidual, dresidual_in, mean, rstd, \n x.stride(0), 0 if not recompute_output else y.stride(0), dy.stride(0), dx.stride(0), \n dresidual.stride(0) if dresidual is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, \n M, N, G, rows_per_program, programs_per_group, is_rms_norm, BLOCK_N, \n dresidual is not None, dresidual_in is not None, weight is not None, bias is not None\n )\n dw = dw.view(G, -1, N).sum(1).to(weight).view_as(weight) if weight is not None else None\n db = db.view(G, -1, N).sum(1).to(bias).view_as(bias) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n", - "description_1": "Use triton language to implement forward and backward pass kernels for layer normalization, supporting optional residual connections and RMS normalization. The forward kernel (_layer_norm_fwd_1pass_kernel) requires 18 parameters: pointers to input, output, weights, biases, residual, residual output, mean, rstd, strides, number of columns, groups, epsilon, and constexpr flags for RMS, block size, and presence of residual, weight, and bias. The backward kernel (_layer_norm_bwd_kernel) requires 29 parameters: pointers to input, weights, biases, output, gradients, partial sums, strides, matrix dimensions, group size, rows per program, programs per group, and constexpr flags for RMS, block size, presence of derivatives, recomputation, weight, and bias.", - "description_2": "Use triton language to create layer normalization kernels with forward and backward passes, allowing optional residuals and RMS normalization. Implement forward (_layer_norm_fwd_1pass_kernel) with 18 parameters including pointers and flags, and backward (_layer_norm_bwd_kernel) with 29 parameters for input, gradients, dimensions, and additional control flags.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k, v, z, h, h0, ht, s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, NT: tl.constexpr, NORMK: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n if NORMK:\n p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,))\n else:\n p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_z0).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if NORMK:\n p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[:, None]\n b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype)\n else:\n p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n b_zc = tl.load(p_zc, boundary_check=(0,))\n b_r, b_zp = tl.exp(b_zp - b_zc), b_zc\n b_h = b_h * b_r[None, :]\n b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n# Function chunk_abc_fwd_kernel_h parameters and meaning:\n# k, v, z, h, h0, ht: input/output tensors\n# s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d: strides for different tensors\n# T, K, V, BT, BK, BV, NT: constexprs representing various dimensions and tile sizes\n# NORMK, USE_INITIAL_STATE, STORE_FINAL_STATE: flags for conditional execution\n\n@triton.jit\ndef chunk_abc_fwd_kernel_intra_K(\n v, z, o, A, s_v_h, s_v_t, s_v_d, T: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BC: tl.constexpr, BV: tl.constexpr, NC: tl.constexpr\n):\n i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i = i_c // NC, i_c % NC\n p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))\n p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))\n b_zn = tl.load(p_zn, boundary_check=(0,))\n b_o = tl.zeros([BC, BV], dtype=tl.float32)\n for i_j in range(0, i_i):\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, tl.exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)\n b_z = tl.load(p_z, boundary_check=(0, 1))\n b_o *= tl.exp(b_zn[None, :] - b_z)\n o_i = tl.arange(0, BC)\n o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n for j in range(0, BC):\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))\n b_A = tl.load(A + o_A + j, mask=m_A, other=0)\n b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)\n m_i = o_i[:, None] >= j\n b_o += tl.where(m_i, b_A[:, None] * tl.exp(b_v[None, :] - b_z), 0)\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n# Function chunk_abc_fwd_kernel_intra_K parameters and meaning:\n# v, z, o, A: input/output tensors\n# s_v_h, s_v_t, s_v_d: strides for different tensors\n# T, V, BT, BC, BV, NC: constexprs representing various dimensions and tile sizes\n\n@triton.jit\ndef chunk_abc_fwd_kernel_K(\n q, k, z, h, o, A, s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d, scale,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_p = tl.maximum(i_t * BT - 1, 0)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_A += tl.dot(b_q, b_k, allow_tf32=False)\n p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_z = tl.load(p_z, boundary_check=(0, 1))\n p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,))\n b_zp = tl.load(p_zp, boundary_check=(0,))\n b_o = b_o * tl.exp(b_zp[None, :] - b_z)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.where(m_s, b_A, 0)\n if i_v == 0:\n tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))\n\n# Function chunk_abc_fwd_kernel_K parameters and meaning:\n# q, k, z, h, o, A: input/output tensors\n# s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d: strides for different tensors\n# scale: scaling factor for query tensor\n# T, K, V, BT, BK, BV: constexprs representing various dimensions and tile sizes\n\ndef chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)\n return ov, final_state\n\n# Function chunk_abc parameters and meaning:\n# q, k, v, s: input tensors for query, key, value, and some state\n# initial_state: optional initial state for calculations\n# output_final_state: flag to determine if the final state should be output\n", - "description_1": "Use triton language to implement forward kernels for the ABC algorithm handling specific matrix operations on inputs such as query, key, value, and state tensors, optionally managing initial and final states with triton's parallel execution.", - "description_2": "Use triton language to execute matrix operations with potential initial state input and produce a final state output for tensor computations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr,\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_pre(g, B, H, T, S, BT):\n NT = triton.cdiv(T, BT)\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_gated_abc_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=S, BT=BT\n )\n return g\n", - "description_1": "Use triton language to define and use a kernel 'chunk_gated_abc_fwd_kernel_cum'. This kernel computes cumulative sums with specific constraints and writes them to an output tensor. It is decorated with @triton.autotune and @triton.jit, accepting 8 parameters: three tensors and five integers. The function 'fwd_pre' calls this kernel with specific grid settings, organizing the input tensor 'g' into cumulative format across dimensions.", - "description_2": "Use triton language to implement a cumulative sum kernel with input constraints and invoke it through a helper function that organizes input data into a specific cumulative format.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_gated_abc_inference_kernel(\n q, k, v, s, g, o, hk, hv, s_k_h, s_v_h, s_m_h, scale,\n K: tl.constexpr, V: tl.constexpr, M: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_bh = tl.program_id(0)\n b_s = tl.load(s + i_bh * s_m_h + tl.arange(0, M))\n b_g = tl.load(g + i_bh * s_m_h + tl.arange(0, M)).to(tl.float32)\n b_g = tl.exp(b_g)\n\n b_ok = tl.zeros([M], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_hk0 = hk + i_bh * K * M + (i_k * BK + tl.arange(0, BK)[None, :]) * M + tl.arange(0, M)[:, None]\n mask = (i_k * BK + tl.arange(0, BK)) < K\n b_hk = tl.load(p_hk0, mask=mask[None, :], other=0).to(tl.float32)\n b_q = tl.load(q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK), mask=mask, other=0).to(tl.float32) * scale\n b_k = tl.load(k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK), mask=mask, other=0).to(tl.float32)\n b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None]\n b_ok += tl.sum(b_hk * b_q[None, :], axis=1)\n\n p_hkt = hk + i_bh * K * M + (i_k * BK + tl.arange(0, BK)[None, :]) * M + tl.arange(0, M)[:, None]\n tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask[None, :])\n\n b_qv = tl.softmax(b_ok)\n for i_v in range(tl.cdiv(V, BV)):\n p_hv0 = hv + i_bh * M * V + tl.arange(0, M)[None, :] * V + (i_v * BV + tl.arange(0, BV)[:, None])\n mask = (i_v * BV + tl.arange(0, BV)) < V\n b_hv = tl.load(p_hv0, mask=mask[:, None], other=0).to(tl.float32)\n b_v = tl.load(v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV), mask=mask, other=0).to(tl.float32)\n b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None]\n b_ov = tl.sum(b_hv * b_qv[None, :], axis=1)\n\n tl.store(o + i_bh * s_v_h + i_v * BV + tl.arange(0, BV), b_ov.to(o.dtype.element_ty), mask=mask)\n\n p_hvt = hv + i_bh * M * V + tl.arange(0, M)[None, :] * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask[:, None])\n\n\nclass FusedRecurrentGatedABCFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: torch.Tensor, scale: Optional[float] = None, initial_state: Optional[Tuple[torch.Tensor]] = None, output_final_state: bool = False, reverse: bool = False, inference_mode: bool = False) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]\n if scale is None:\n scale = K ** -0.5\n\n BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n num_warps = 1\n num_stages = 1\n\n if initial_state is None:\n initial_state = (None, None)\n final_state = (None, None)\n if output_final_state:\n final_state = initial_state if inference_mode else (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))\n\n if inference_mode:\n BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)\n NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)\n\n o = torch.empty_like(v)\n grid = (B * H,)\n fused_recurrent_gated_abc_inference_kernel[grid](\n q, k, v, s, g, o, initial_state[0], initial_state[1],\n k.stride(1),\n v.stride(1),\n s.stride(1),\n scale=scale,\n K=K, V=V, M=M, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return o, final_state\n\ndef fused_recurrent_gated_abc(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: Optional[torch.Tensor] = None, scale: Optional[int] = None, initial_state: Optional[Tuple[torch.Tensor]] = None, output_final_state: Optional[bool] = False) -> Tuple[torch.Tensor, torch.Tensor]:\n if g is None:\n z = s.float().logcumsumexp(2)\n g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z\n s = torch.exp(s - z).to(k.dtype)\n if scale is None:\n scale = q.shape[-1] ** -0.5\n inference_mode = q.shape[2] == 1 and not q.requires_grad\n ov, final_state = FusedRecurrentGatedABCFunction.apply(\n q, k, v, s, g, scale, initial_state, output_final_state, False, inference_mode\n )\n return ov, final_state\n", - "description_1": "Use triton language to define three kernels fused_recurrent_gated_abc_inference_kernel, fused_recurrent_gated_abc_fwd_kernel, and fused_recurrent_gated_abc_bwd_kernel, which perform the computations required for a gated recurrent unit (GRU)-like model. The kernels interact with tensors representing queries, keys, values, and several other intermediate matrices. The function fused_recurrent_gated_abc manages the input/output tensors and invokes the appropriate kernel based on mode (forward/inference).", - "description_2": "Use triton language to create kernels for a GRU-like model using matrices for queries, keys, and values. Manage tensor inputs and outputs.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_based_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_h_0o = tl.zeros([BV], dtype=tl.float32)\n b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_0o = 0\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_k_2o = b_k[:, None, :] * b_k[None, :, :]\n b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_z = tl.zeros([BT], dtype=tl.float32)\n\n b_o += b_h_0o\n b_z += k_0o\n b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)\n b_z += tl.sum(b_q * k_1o, axis=1)\n b_q_2o = b_q[:, :, None] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)\n b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5\n b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5\n\n k_1o += tl.sum(b_k, axis=1)[None, :]\n k_2o += tl.sum(b_k_2o, axis=1)[None, :]\n k_0o += BT\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=(i * BT + tl.arange(0, BT)) < T)\n\n b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)\n b_h_0o = b_h_0o + tl.sum(b_v, axis=0)\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_z += BT\n\n\n@triton.jit\ndef fused_chunk_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)\n b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)\n\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n\n b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)\n if i_v == 0:\n b_dq += b_dz[:, None] * k_1o\n b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5\n if i_v == 0:\n b_dq_2o += (b_dz[:, None] * k_2o) * 0.5\n b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])\n b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)\n b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)\n b_dq *= scale\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)\n\n if i_v == 0:\n k_1o += tl.sum(b_k, axis=0)[None, :]\n k_2o += tl.sum(b_k_2o, axis=0)[None, :]\n\n tl.debug_barrier()\n b_h_1o = None\n b_h_2o = None\n\n b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n b_dh_0o = tl.zeros([BV], dtype=tl.float32)\n m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]\n\n dq_1o = tl.zeros([1, BK], dtype=tl.float32)\n dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)\n\n for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i\n\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_dv = tl.zeros([BT, BV], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n b_ds *= (1+b_s)\n\n b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n\n b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)\n b_dv += b_dh_0o\n\n b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)\n\n if i_v == 0:\n b_dk += dq_1o\n\n b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),\n tl.trans(b_v), allow_tf32=False)\n if i_v == 0:\n b_dk_2o += dq_2o\n b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])\n b_k_fp32 = tl.trans(b_k.to(tl.float32))\n b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)\n b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)\n b_dk += tl.trans(b_dk2)\n\n b_dh_0o += tl.sum(b_do, axis=0)\n b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)\n b_q_2o = b_q[None, :, :] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)\n b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5\n\n if i_v == 0:\n dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]\n dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkBasedFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale=1):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n )\n o = o.sum(0)\n z = z.sum(0)\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.to(q.dtype), z.to(z.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None\n\n\ntriton_fused_chunk_based = FusedChunkBasedFunction.apply\n\n\ndef fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):\n assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_fused_chunk_based(q, k, v, scale)\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement forward and backward kernels for a chunk-based fusion of query, key, and value matrices in a sequence. The forward kernel calculates attention scores using zero, first, and second-order Taylor expansions and normalizes the output. The backward kernel calculates gradients for query, key, and value by propagating errors back through the computed attention. Both kernels iterate over sequence chunks and perform blockwise computations for efficiency. Parameters for these kernels include batch size, number of heads, sequence length, and scaling factors.", - "description_2": "Use triton language to perform chunk-based computation for the fusion of QKV matrices in a sequence and compute gradients during the backward pass.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_based_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef parallel_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n i_h = i_bh % H\n\n _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel attention mechanism. The forward kernel ('parallel_based_fwd_kernel') computes the attention output and normalization factors from query, key, and value tensors with given stride sizes, batch size, number of heads, sequence length, and scaling factor. The backward kernel ('parallel_based_bwd_kernel') computes gradients for these tensors using additional inputs such as output gradients and normalization gradients. Both kernels require block sizes and constant dimensions as additional parameters.", - "description_2": "Use triton language to implement a parallel attention mechanism with both forward and backward passes, computing attention outputs and gradients with specified parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_dv_kernel(\n q,\n k,\n do,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n b_A += tl.dot(b_k, b_q, allow_tf32=False)\n\n b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.dot(b_A, b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_dv(q, k, do, BT):\n dv = torch.empty_like(do)\n B, H, T, K, V = *k.shape, do.shape[-1]\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_prepare_dv_kernel[(NT, B*H)](\n q, k, do, dv,\n k.stride(1), k.stride(2), k.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n T, K, V, K**-0.5, BT, BK, BV\n )\n return dv\n", - "description_1": "Use triton language to implement a kernel function 'fwd_prepare_dv_kernel' that computes the forward pass for a delta rule operation. The kernel takes 15 parameters: q, k, do, dv (all tensors), s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d (all strides), T, K, V (dimensions), scale (a scaling factor), and BT, BK, BV (block sizes). It computes a matrix multiplication and stores the result in dv. The function 'fwd_prepare_dv' is a wrapper that prepares the input and calls the kernel.", - "description_2": "Use triton language to implement a kernel function 'fwd_prepare_dv_kernel' for computing matrix multiplication in a delta rule operation, and a wrapper function 'fwd_prepare_dv' to set up and invoke the kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8)\n ],\n key=[\"BT\", \"BK\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_fwd_kernel(\n q, k, v, v_new, d, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_d = tl.load(p_d, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)\n b_v = b_v - b_v_prime\n tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))\n b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_v_new = tl.advance(p_v_new, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_d = tl.advance(p_d, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_bwd_kernel(\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n b_d = tl.load(p_d, boundary_check=(0, 1))\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)\n\n tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n m_s = o_i[:, None] >= o_i[None, :]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n if i < (NT - 1):\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.load(p_dv, boundary_check=(0, 1))\n b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)\n p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n ((i+1) * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BT = BT\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, 'NK should be 1'\n o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n v_new = torch.empty_like(v)\n fused_chunk_delta_rule_fwd_kernel[grid](\n q, k, v, v_new, d, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n )\n return o, v_new, CHECK, final_state\n\n\ndef fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_delta_rule_bwd_kernel[grid](\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=CHECK,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dd = dd.sum(0)\n dd[:, :, 0:BT] = 0\n return dq, dk, dv, dd\n", - "description_1": "Use triton language to implement forward and backward kernels for the Fused Chunk Delta Rule. The forward kernel computes new values (v_new) and outputs (o) based on input tensors q, k, v, and d, with optional initial and final states. It includes multiple configurations for num_warps and uses grid configurations based on input dimensions. The backward kernel computes gradients for inputs q, k, v, and d based on the gradient of the output (do). Both kernels utilize block pointers and allow boundary checks.", - "description_2": "Use triton language to implement kernels for the Fused Chunk Delta Rule, providing both forward computations (output and new values) and backward gradients for the specified inputs.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n o,\n o2,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = tl.arange(0, BK) < K\n mask_bv = tl.arange(0, BV) < V\n mask_bk = mask_bk[None, :] & mask_bt[:, None]\n mask_bv = mask_bv[None, :] & mask_bt[:, None]\n # [BT, BK]\n b_k = tl.load(p_k, mask=mask_bk, other=0)\n # [BT,]\n b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)\n # [BT, BV]\n b_v = tl.load(p_v, mask=mask_bv, other=0)\n b_v = (b_v * b_beta[:, None]).to(b_v.dtype)\n # [BT, BK]\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n # [BT, BT]\n b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n b_A = b_A.to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n b_u = tl.dot(b_A, b_v, allow_tf32=False)\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta,\n o, o2, do, do2,\n dk, dv, dbeta,\n NT, K, V, T,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]\n mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]\n b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)\n\n b_beta = b_beta.to(tl.float32)\n A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]\n A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)\n b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)\n b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)\n dA = tl.zeros([BT, BT], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n for i in range(BT-1, -1, -1):\n mask = tl.arange(0, BT) == i\n attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)\n do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)\n dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)\n b_do = b_do - attn[:, None] * do_[None, :]\n b_dv = b_dv - attn[:, None] * dv_[None, :]\n tl.debug_barrier()\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_v = tl.load(p_v, mask=mask_bv)\n b_dk += b_do * b_beta[:, None]\n b_dbeta = tl.sum(b_do * b_k, axis=1)\n b_dbeta += tl.sum(b_dv * b_v, axis=1)\n b_v = None\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_o = tl.load(p_o, mask=mask_bk)\n b_o2 = tl.load(p_o2, mask=mask_bv)\n\n dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)\n dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),\n allow_tf32=False)\n dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)\n b_dv *= b_beta[:, None]\n p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)\n dA = dA * b_beta[:, None]\n b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)\n b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)\n p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)\n\ndef fwd_prepare_wy_repr(k, v, beta, chunk_size):\n B, H, T, K, V = *k.shape, v.shape[-1]\n v_new = torch.empty_like(v)\n o_cumdecay = torch.empty_like(k)\n BT = chunk_size\n NT = triton.cdiv(T, BT)\n BK = triton.next_power_of_2(K)\n BV = triton.next_power_of_2(V)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, o_cumdecay, v_new,\n T, K, V, BT, BK, BV\n )\n return o_cumdecay, v_new\n\ndef bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):\n b, h, l, d_k = do.shape\n d_v = v.shape[-1]\n BK = triton.next_power_of_2(d_k)\n BV = triton.next_power_of_2(d_v)\n c = chunk_size\n BK = d_k\n NT = triton.cdiv(l, c)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n dbeta = torch.zeros_like(beta)\n bwd_prepare_wy_repr_kernel[(NT, b*h)](\n k, v, beta,\n o_cumdecay, v_new, do, do2,\n dk, dv, dbeta,\n NT, d_k, d_v, l, chunk_size, BK, BV\n )\n return dk, dv, dbeta\n", - "description_1": "Use triton language to implement forward and backward WY representation preparation kernels. The forward kernel (fwd_prepare_wy_repr_kernel) takes 10 arguments: 'k' (input key tensor), 'v' (input value tensor), 'beta' (decay factor tensor), 'o' (output tensor for cumulative decay), 'o2' (output tensor for new values), 'T' (length of the sequence), 'K' (dimension of keys), 'V' (dimension of values), 'BT', 'BK', 'BV' (block sizes for triton kernels). It computes the WY representation using provided inputs and stores the results in 'o' and 'o2'. The backward kernel (bwd_prepare_wy_repr_kernel) takes 18 arguments: 'k', 'v', 'beta', 'o', 'o2', 'do' (gradient of 'o'), 'do2' (gradient of 'o2'), 'dk' (gradient to be computed for 'k'), 'dv' (gradient to be computed for 'v'), 'dbeta' (gradient to be computed for 'beta'), 'NT', 'K', 'V', 'T', 'BT', 'BK', 'BV'. It computes gradients for the inputs based on the forward pass and stores them in 'dk', 'dv', and 'dbeta'.", - "description_2": "Use triton language to create forward and backward kernels for WY representation preparation, enabling computation of WY matrices and their gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n w,\n u,\n A,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n for i in range(1, BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))\n b_A = b_A.to(k.dtype.element_ty)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_recompute_w_u_kernel(\n k,\n v,\n beta,\n w,\n u,\n A,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta, A,\n dw, du,\n dk, dv, dbeta,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n b_dbeta = tl.zeros([BT], dtype=tl.float32)\n b_dA = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_du = tl.load(p_du, boundary_check=(0, 1))\n b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)\n b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)\n b_dv = b_dv_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dv_beta * b_v, 1)\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n tl.debug_barrier()\n b_A2 = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_dw = tl.load(p_dw, boundary_check=(0, 1))\n b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)\n b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)\n b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)\n b_dk = b_dk_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])\n b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)\n tl.debug_barrier()\n for i in range(BT-1, 0, -1):\n mask = tl.arange(0, BT) == i\n b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0)\n b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)\n b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)\n b_dA = tl.where(mask[:, None], b_da2, b_dA)\n b_dA += b_da[None, :] * b_a[:, None]\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)\n tl.debug_barrier()\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_dk = tl.load(p_dk, boundary_check=(0, 1))\n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)\n b_dk += b_dk_beta * b_beta[:, None]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_wy_repr(k, v, beta, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u, A\n\n\ndef fwd_recompute_w_u(k, v, beta, A, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_recompute_w_u_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u\n\n\ndef bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT = triton.cdiv(T, BT)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v).contiguous()\n dbeta = torch.zeros_like(beta)\n bwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, A,\n dw, du,\n dk, dv, dbeta,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return dk, dv, dbeta\n\n\nclass WYRepresentationPrepration(torch.autograd.Function):\n @staticmethod\n def forward(ctx, k, v, beta, chunk_size):\n ctx.BT = chunk_size\n w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)\n ctx.save_for_backward(k, v, beta, A)\n return w, u\n\n @staticmethod\n def backward(ctx, dw, du):\n k, v, beta, A = ctx.saved_tensors\n BT = ctx.BT\n dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)\n return dk, dv, dbeta, None\n\n\nprepare_wy_repr = WYRepresentationPrepration.apply\n", - "description_1": "Use triton language to implement three kernels for WY representation preparation. The 'fwd_prepare_wy_repr_kernel' has 16 parameters, computing matrix operations and storing results. The 'fwd_recompute_w_u_kernel' with 16 parameters, recomputes similar operations as the first kernel. The 'bwd_prepare_wy_repr_kernel' takes 19 parameters, handling backward pass computations for WY preparation. Functions 'fwd_prepare_wy_repr', 'fwd_recompute_w_u', and 'bwd_prepare_wy_repr' serve as wrappers, preparing inputs for these kernels and executing them. 'WYRepresentationPrepration' is a custom autograd function for this computation.", - "description_2": "Use triton language to create forward and backward kernels for WY representation, which are wrapped in functions that prepare and execute these kernels, and an autograd function for tensor operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_gla_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BK, BT]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n if i_t < NT - 1:\n # [BK,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n else:\n b_gn = tl.min(b_g, axis=1)\n b_h *= tl.exp(b_gn)[:, None]\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_intra(\n q,\n k,\n g,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n NC: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC\n n_bh = tl.num_programs(2)\n\n if i_i > i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))\n p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n # [BK,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BC, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)\n # [BK, BC]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_gk = tl.load(p_gk, boundary_check=(0, 1))\n b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)\n # [BC, BC]\n b_A = tl.dot(b_qg, b_kg, allow_tf32=False)\n tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n elif i_i == i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n # [BC, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n\n o_i = tl.arange(0, BC)\n o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n for j in range(0, BC):\n # [BK,]\n b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)\n b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)\n # [BC,]\n b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)\n b_A = tl.where(o_i >= j, b_A, 0.)\n tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)\n\n p_k = tl.advance(p_k, (K,))\n p_gk = tl.advance(p_gk, (K,))\n\n@triton.jit\ndef chunk_gla_fwd_kernel_inter(\n q,\n v,\n g,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n # [BT, BK]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n # [BT, BK]\n b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # works but dkw, owing to divine benevolence\n # [BT, BV]\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BT]\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 2\n) -> Tuple[torch.Tensor, torch.Tensor]:\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gla_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_gla_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_gla_fwd_kernel_intra[grid](\n q, k, g, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n grid = (NV, NT, B * H)\n chunk_gla_fwd_kernel_inter[grid](\n q, v, g, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n if checkpoint_level >= 1:\n del g\n g = g_org\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n\n return o, final_state\n", - "description_1": "Use triton language to implement a series of forward kernels: `chunk_gla_fwd_kernel_cum`, `chunk_gla_fwd_kernel_h`, `chunk_gla_fwd_kernel_intra`, and `chunk_gla_fwd_kernel_inter`. These functions perform tensor operations for generalized linear attention (GLA) computation. The kernels manage tasks such as cumulative summation, block operations, intra-chunk, and inter-chunk calculations. These functions require various parameters including tensor shapes, strides, block sizes, and scales to process input tensors like queries, keys, values, and forget gates.", - "description_2": "Use triton language to implement Triton kernels for GLA attention. Develop forward kernels to handle cumulative operations, intra-chunk, and inter-chunk processing for tensors, incorporating input parameter management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom packaging import version\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n if CHECK and i == 0:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n else:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_db += BT * DK\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n if CHECK and i == 1:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n else:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkGLAFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):\n ctx.g_dtype = g.dtype\n g_original = g\n g = torch.empty_like(g, dtype=torch.float32)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n\n BT = 16\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n num_stages = 1\n num_warps = 2\n\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n prepare_qg_kg[grid](\n q, k, g, q_g, k_g,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, DK=d_head_qk, num_warps=1\n )\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)\n else:\n final_state = None\n\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_gla_fwd_kernel[grid](\n q_g, k_g, v, g, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n\n chunk_size = 16\n num_chunk = seq_len // chunk_size\n v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)\n BK = min(d_head_qk, 64)\n NK = triton.cdiv(d_head_qk, BK)\n A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n fwd_inner_chunk[grid](\n q, k, g, A,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3,\n num_warps=4\n )\n A = A.sum(0)\n o2 = A @ v2\n o2 = rearrange(o2, 'b h n c d -> b h (n c) d')\n o.add_(o2)\n ctx.save_for_backward(q, k, v, g_original, A, initial_state)\n ctx.CHECK = CHECK\n return o.to(v), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, g_origin, A, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n g = torch.empty_like(g_origin, dtype=torch.float32)\n BK, BV = min(d_head_qk, 64), min(d_head_v, 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n q_g = torch.empty_like(q)\n k_g = torch.empty_like(k)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n prepare_qg_kg[grid](\n q, k, g, q_g, k_g,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, BK=BK, DK=d_head_qk, num_warps=1\n )\n\n BT = 16\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 2\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_gla_bwd_kernel[grid](\n q_g, k_g, v, g, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n\n num_chunk = seq_len // BT\n v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)\n do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)\n dA2 = (do2 @ v2.transpose(-2, -1)) * scale\n dv2 = A.transpose(-1, -2) @ do2\n dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)\n\n BK = min(triton.next_power_of_2(d_head_qk), 16)\n NK = triton.cdiv(d_head_qk, BK)\n dk2 = torch.empty_like(k)\n dq2 = torch.empty_like(q)\n\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n bwd_inner_chunk[grid](\n q, k, g, dA2, dq2, dk2,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, BK=BK,\n num_warps=1,\n num_stages=3\n )\n\n BK = min(triton.next_power_of_2(d_head_qk), 32)\n NK = triton.cdiv(d_head_qk, BK)\n dg = torch.empty_like(g, dtype=torch.float32)\n grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)\n bwd_decay_global_cumsum[grid](\n dq2, dq, dk2, dk, q, k, g, dg,\n q.stride(1), q.stride(2), q.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, BK=BK,\n num_warps=1,\n num_stages=1\n )\n dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)\n\n def rev_cumsum_exclusive(x):\n cumsum_x = x.cumsum(-2)\n rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x\n return rev_cumsum_x\n\n rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])\n dg.add_(rev_cumsum_dg.unsqueeze(-2))\n dv.add_(dv2)\n dg = rearrange(dg, 'b h n c d -> b h (n c) d')\n\n return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None\n\ndef fused_chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n seq_len = q.shape[-2]\n q, k, v, g = map(lambda x: pad(x), [q, k, v, g])\n o, final_state = FusedChunkGLAFunction.apply(\n q, k, v, g, scale, initial_state, output_final_state)\n o = o[..., :seq_len, :]\n return o, final_state\n", - "description_1": "Use triton language to implement fused_chunk_gla_fwd_kernel and fused_chunk_gla_bwd_kernel. The forward kernel takes 24 parameters, including input tensors q, k, v, and g, output tensor o, and various constants and strides. The backward kernel also takes 24 parameters, including input tensors q, k, v, g, and do, and output tensors dq, dk, and dv. The kernels perform computations for a fused chunked gated linear attention mechanism, leveraging triton's GPU acceleration.", - "description_2": "Use triton language to create fused forward and backward kernels for a chunked gated linear attention mechanism, operating on inputs q, k, v, and g, with specific attention to memory layout and block sizes for GPU optimization.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n# Kernel for forward decay cumulative sum\n@triton.jit\ndef fwd_decay_cumsum(\n g, g_o, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n# Kernel for preparing qg and kg\n@triton.jit\ndef prepare_qg_kg(\n q, k, g, qg, kg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n# Kernel for backward decay global cumulative sum\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner, dq_inter, dk_inner, dk_inter, q, k, g, dg,\n s_qk_h, s_qk_t, s_qk_d, B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n", - "description_1": "Use triton language to implement three kernels: fwd_decay_cumsum, prepare_qg_kg, and bwd_decay_global_cumsum. Each kernel processes data in parallel using triton's program_id to handle different dimensions. The fwd_decay_cumsum kernel computes a cumulative sum with decay, prepare_qg_kg prepares qg and kg tensors by applying transformations based on input tensors q, k, and g, and bwd_decay_global_cumsum computes gradients for decay using backward pass logic.", - "description_2": "Use triton language to create kernels for forward and backward cumulative sum operations with decay, and to prepare transformed tensors for further computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, gk, gv, o, h0, ht, s_qk_h, s_vo_h, scale,\n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr, USE_GK: tl.constexpr,\n USE_GV: tl.constexpr\n):\n # Kernel code\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * tl.exp(b_gk[None, :])\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * tl.exp(b_gv[:, None])\n h += b_k[None, :] * b_v[:, None]\n _o = h * b_q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, h0, s_qk_h, s_vo_h, scale, B, H, T,\n K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr\n):\n # Kernel code\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * tl.exp(b_gk[:, None])\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * tl.exp(b_gv[None, :])\n h += b_k[:, None] * b_v[None, :]\n b_dq = h * b_do[None, :]\n d_q = tl.sum(b_dq, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -K if REVERSE else K\n p_v += -V if REVERSE else V\n p_q += -K if REVERSE else K\n p_do += -V if REVERSE else V\n p_dq += -K if REVERSE else K\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + ((T - 1) * V if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += b_q[:, None] * b_do[None, :]\n d_k = tl.sum(d_h * b_v[None, :], axis=1)\n d_v = tl.sum(d_h * b_k[:, None], axis=0)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n d_h *= tl.exp(b_gk)[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n d_h *= tl.exp(b_gv)[None, :]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_q += K if REVERSE else -K\n p_k += K if REVERSE else -K\n p_v += V if REVERSE else -V\n p_do += V if REVERSE else -V\n p_dk += K if REVERSE else -K\n p_dv += V if REVERSE else -V\n if USE_GK:\n p_gk += K if REVERSE else -K\n if USE_GV:\n p_gv += V if REVERSE else -V\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):\n B, H, T, K, V = *q.shape, v.shape[-1]\n if scale is None:\n scale = K ** -0.5\n BK, BV = min(K, 64), min(V, 64)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n\n grid = (NV, NK, B * H)\n fused_recurrent_gla_fwd_kernel[grid](\n q, k, v, gk, gv, o, initial_state, final_state,\n q.stride(1), v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V,\n BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n q, k, v, gk, gv, initial_state, o = ctx.saved_tensors\n batch_size, n_heads, seq_len, K = q.shape\n V = v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(K, 64), min(V, 64)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, V, dtype=torch.float32)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n q.stride(1),\n v.stride(1), scale,\n B=batch_size, H=n_heads, T=seq_len, K=K, V=V, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n if gk is not None:\n _dgk = dq * q.float() - dk * k.float()\n if ctx.reverse:\n dgk = _dgk.cumsum(-2)\n else:\n _dgk_cumsum = _dgk.cumsum(-2)\n dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum\n else:\n dgk = None\n\n if gv is not None:\n _dgv = do.float() * o.float() - dv * v.float()\n if ctx.reverse:\n dgv = _dgv.cumsum(-2)\n else:\n _dgv_cumsum = _dgv.cumsum(-2)\n dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum\n else:\n dgv = None\n\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None\n\ndef fused_recurrent_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n gk: torch.Tensor = None,\n gv: torch.Tensor = None,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n if causal:\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)\n return o, final_state\n else:\n assert initial_state is None\n assert output_final_state is False\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state, False)\n o_reversed, final_state = FusedRecurrentGLAFunction.apply(\n q, k, v, gk, gv, scale, initial_state, output_final_state, True)\n return o, o_reversed\n", - "description_1": "Use triton language to implement a forward kernel (fused_recurrent_gla_fwd_kernel) with 21 parameters, executing operations on tensors with optional initial states, masks, and various operations like exponentiation and summation. Another kernel (fused_recurrent_gla_bwd_kernel) implements a backward operation with 21 parameters, computing gradients with respect to inputs using similar operations. These kernels are invoked by a PyTorch custom autograd function (FusedRecurrentGLAFunction) with methods 'forward' and 'backward', managing the computation flow and tensor operations for a recurrent attention mechanism. Finally, the function 'fused_recurrent_gla' serves as an interface for users, handling parameter checks, optional states, and invoking the autograd function appropriately.", - "description_2": "Use triton language to implement forward and backward kernels for a fused recurrent attention mechanism, utilizing tensor operations and optional states. Interface these kernels with a PyTorch custom autograd function to manage the execution flow.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_fwd_kernel_h(\n x,\n g,\n gc,\n o,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + i_t * BT * D + o_d\n p_g = g + i_bh * T * D + i_t * BT * D + o_d\n p_gc = gc + i_bh * T * D + i_t * BT * D + o_d\n p_o = o + i_bh * T * D + i_t * BT * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n b_gc = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n if i_t == 0:\n b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n for i in range(0, BT):\n mask_t = mask & ((i_t * BT + i) < T)\n b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n b_gc = b_gc + b_g\n tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)\n\n p_x += D\n p_g += D\n p_gc += D\n p_o += D\n\n@triton.jit\ndef chunk_hgrn_fwd_kernel_o(\n gc,\n o,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(1, tl.cdiv(T, BT)):\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_o = b_o + tl.exp(b_gc) * b_h0[None, :]\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_bwd_kernel_h(\n g,\n gc,\n dx,\n do,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n BC = min(BT, T - i_t * BT)\n NT = tl.num_programs(1)\n\n p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n\n if i_t == NT - 1:\n b_gc = tl.zeros([BD], dtype=tl.float32)\n else:\n b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for _ in range(BC - 1, -1, -1):\n tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)\n\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n\n b_gc = b_gc + b_g\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_gc -= D\n p_dx -= D\n p_do -= D\n\n@triton.jit\ndef chunk_hgrn_bwd_kernel_o(\n g,\n gc,\n o,\n dx,\n dg,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))\n p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n mask_t = mask & ((i_t + 1) * BT < T)\n b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)\n b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)\n b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]\n b_dg = b_o * b_dx * tl.exp(b_g)\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))\n\nclass ChunkHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n o = torch.empty_like(x, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_fwd_kernel_h[grid](\n x, g, gc, o, initial_state,\n T, D,\n BT=BT,\n USE_INITIAL_STATE=initial_state is not None\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_fwd_kernel_o[grid](\n gc, o,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n final_state = None\n if output_final_state:\n final_state = o[:, :, -1].clone()\n o = o.to(x.dtype)\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n dx = torch.empty_like(o, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_bwd_kernel_h[grid](\n g, gc, dx, do,\n T, D,\n BT=BT\n )\n\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_bwd_kernel_o[grid](\n g, gc, o, dx, dg,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n if initial_state is not None:\n dg[:, :, 0] = (initial_state * dx[:, :, 0] * g[:, :, 0].float().exp()).to(dg.dtype)\n\n return dx.to(o.dtype), dg, None, None\n\ndef chunk_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)\n", - "description_1": "Use triton language to implement HGRN forward and backward pass with parameters: (1) x, input tensor of shape (B, H, T, D); (2) g, gating tensor of the same shape as x; (3) initial_state, optional tensor of shape (B, H, D) for initial state; (4) output_final_state, boolean flag to return final state.", - "description_2": "Use triton language to perform HGRN's forward and backward passes using x, g, optional initial_state, and output_final_state flag.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_linear_attn_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n\n # [BT, BT]\n m_s = o_i[:, None] >= o_i[None, :]\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n # make block pointers\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n # [BT, BT]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n # [BT, BV]\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_linear_attn_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n\n m_s = o_i[:, None] >= o_i[None, :]\n # [BV, BK]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [DV, BT]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, DV]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n # [BT, BT]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n # [BT, DK]\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n # [DV, DK]\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n # sync threads\n b_h = None\n tl.debug_barrier()\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n # [DK, BT]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, DV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n # [BT, BT]\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n # [BT, BT]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n # [BT, DK]\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n # [BT, DV]\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n\n tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\n\ndef fused_chunk_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n if normalize:\n o = normalize_output(q * scale, k, o)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunk linear attention mechanism, with both forward and backward kernels. The forward function takes in 12 arguments (queries, keys, values, output, initial state, final state, strides, batch size, heads, sequence length, scale, and various constant expressions). The backward function takes in 15 arguments (queries, keys, values, output gradient, gradients for queries, keys, values, initial state, strides, batch size, heads, sequence length, scale, and various constant expressions).", - "description_2": "Use triton language to implement a fused forward and backward chunk linear attention mechanism with scale and state management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.utils import contiguous\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, \n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n DK: tl.constexpr, DV: tl.constexpr\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n \n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h, q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_q.dtype)\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)\n b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n\n b_dq *= scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_d, s_qk_t), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n \n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h, q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32)\n\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale\n if i_v == 0:\n b_ds += b_dz[None, :] * scale\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)\n\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n m_s = o_k[:, None] <= o_q[None, :]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)\n o_q += BTS\n\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len, d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len, device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z, q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps, num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n @contiguous\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv, q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps, num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel forward and backward pass for a rebased attention mechanism. The forward kernel ('parallel_rebased_fwd_kernel') computes the attention output and a normalizer for a batch of query, key, and value tensors ('q', 'k', 'v'), using block sizes defined by 'BTL', 'BTS', 'BK', and 'BV'. The backward kernel ('parallel_rebased_bwd_kernel') computes gradients for the input tensors ('q', 'k', 'v') given the gradients of the output ('do') and normalizer ('dz'). The parallel-based function class integrates these kernels into a PyTorch autograd.Function with methods for forward and backward passes.", - "description_2": "Use triton language to optimize the computation of a rebased attention mechanism by executing forward and backward passes in parallel. Leverage block-wise computations and constraints like block sizes and scaling factors to ensure efficient tensor operations and gradient computations in the context of deep learning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_retention_fwd_kernel_h(\n k, v, h, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n # Kernel operations\n\n@triton.jit\ndef chunk_retention_fwd_kernel_o(\n q, k, v, h, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n # Kernel operations\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dh(\n q, do, dh, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n # Kernel operations\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dqkv(\n q, k, v, h, do, dh, dq, dk, dv,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n # Kernel operations\n\nclass ChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n @triton.jit\n def forward(ctx, q, k, v, initial_state, output_final_state):\n # Forward pass function implementation\n pass\n\n @staticmethod\n @triton.jit\n def backward(ctx, do, d_ht=None):\n # Backward pass function implementation\n pass\n\ndef chunk_retention(q, k, v, initial_state=None, output_final_state=False):\n o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a chunk retention mechanism that involves forward and backward passes through multiple Triton kernels. The forward kernels perform computations using queries (q), keys (k), values (v), and states, while the backward kernels handle gradient computations. The number of dimensions (H, T, K, V) and chunk sizes (BT, BK, BV) are expressed as constexpr constants.", - "description_2": "Use triton language to create forward and backward kernels for a chunk retention function, employing multiple constexpr constants for dimensional and chunk size parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n if i == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)\n d_b = tl.math.exp2(BT * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n b_dq = tl.dot(b_ds, b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n d_s = tl.trans(d_s)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\n\ndef fused_chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused chunk retention forward and backward kernel that computes attention outputs and gradients efficiently. The forward kernel takes 24 parameters including query, key, value, and other configurations, while the backward kernel takes 25 parameters including gradients and initial states.", - "description_2": "Use triton language to create a custom autograd function in PyTorch for efficient chunk-based attention mechanisms. This includes defining forward and backward operations using triton kernels.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.utils import contiguous\n\n@triton.jit\ndef parallel_retention_fwd_kernel(\n q, k, v, o,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr,\n BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n o_k = tl.arange(0, BTS)\n d_h = tl.math.exp2((BTS - o_k) * b_b)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :]\n b_o = b_o * tl.math.exp2(b_b * BTS)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n d_q = tl.math.exp2(tl.arange(0, BTL) * b_b)\n b_o *= d_q[:, None]\n\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (o_q[:, None] - o_k[None, :]) * b_b), 0)\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef parallel_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\nclass ParallelRetentionFunction(torch.autograd.Function):\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, q, k, v):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_retention_fwd_kernel[grid](\n q, k, v, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n return o.sum(0).to(q.dtype)\n\n @staticmethod\n @contiguous\n @custom_bwd\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)\n\nparallel_retention = ParallelRetentionFunction.apply\n", - "description_1": "Use triton language to implement a parallel retention forward and backward kernel for attention-like mechanisms. The kernels take tensors `q`, `k`, `v` for the forward pass, and additionally `do`, `dq`, `dk`, `dv` for the backward pass, each with respective strides and dimensions `B`, `H`, `T`, `DK`, `DV`. The forward kernel computes an attention output tensor `o` using a scaling factor, while the backward kernel computes gradients for `q`, `k`, `v` using the outputs and inputs.", - "description_2": "Use triton language to perform forward and backward computations for a parallel retention mechanism, involving operations on input tensors to compute outputs and gradients with attention-like behavior.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef fused_recurrent_rwkv4_forward_kernel(\n # W\n w_ptr,\n w_s_c,\n # U\n u_ptr,\n u_s_c,\n # K\n k_ptr,\n k_s_b,\n k_s_t,\n k_s_c,\n # V\n v_ptr,\n v_s_b,\n v_s_t,\n v_s_c,\n # State\n state_ptr,\n state_s_b,\n state_s_abe,\n state_s_c,\n # WKV\n wkv_ptr,\n wkv_s_b,\n wkv_s_t,\n wkv_s_c,\n # Output state\n state_out_ptr,\n state_out_s_b,\n state_out_s_abe,\n state_out_s_t,\n state_out_s_c,\n # Params\n chans,\n tsz,\n BLOCK_SIZE_C: tl.constexpr,\n):\n # Parallelize over the batch dimension.\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n\n # Pointers to the batch (and possibly channel) for the input tensors.\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n\n # Pointers to the batch (and possibly channel) for the output tensors.\n wkv_ptr = wkv_ptr + b_idx * wkv_s_b\n alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b\n beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe\n eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe\n\n # Loads parameters.\n alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)\n\n ukt = u + kt\n tau = tl.maximum(ukt, eps)\n e1a = tl.exp(eps - tau)\n e2a = tl.exp(ukt - tau)\n wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)\n tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)\n\n w_eps = w + eps\n eps = tl.maximum(w_eps, kt)\n e1b = tl.exp(w_eps - eps)\n e2b = tl.exp(kt - eps)\n alpha = e1b * alpha + e2b * vt\n beta = e1b * beta + e2b\n tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)\n tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)\n tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)\n\n\ndef fused_recurrent_rwkv4_forward(\n w: Tensor,\n u: Tensor,\n k: Tensor,\n v: Tensor,\n state: Tensor,\n) -> tuple[Tensor, Tensor]:\n (bsz, tsz, chans) = k.shape\n\n # New tensors to output.\n wkvs = k.new_empty(bsz, tsz, chans)\n state_out = k.new_empty(bsz, 3, tsz, chans)\n\n # Constants.\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_forward_kernel[grid](\n # W\n w,\n w.stride(0),\n # U\n u,\n u.stride(0),\n # K\n k,\n k.stride(0),\n k.stride(1),\n k.stride(2),\n # V\n v,\n v.stride(0),\n v.stride(1),\n v.stride(2),\n # State\n state,\n state.stride(0),\n state.stride(1),\n state.stride(3),\n # WKV\n wkvs,\n wkvs.stride(0),\n wkvs.stride(1),\n wkvs.stride(2),\n # Output state\n state_out,\n state_out.stride(0),\n state_out.stride(1),\n state_out.stride(2),\n state_out.stride(3),\n # Params\n chans,\n tsz,\n BLOCK_SIZE_C=block_size_c,\n )\n\n state_out = torch.cat((state, state_out), dim=2)\n\n return wkvs, state_out\n\n\n@triton.jit\ndef fused_recurrent_rwkv4_backward_kernel(\n # W\n w_ptr,\n w_s_c,\n # U\n u_ptr,\n u_s_c,\n # K\n k_ptr,\n k_s_b,\n k_s_t,\n k_s_c,\n # V\n v_ptr,\n v_s_b,\n v_s_t,\n v_s_c,\n # State\n state_ptr,\n state_s_b,\n state_s_abe,\n state_s_t,\n state_s_c,\n # WKV grad\n gwkv_ptr,\n gwkv_s_b,\n gwkv_s_t,\n gwkv_s_c,\n # Output state grad\n gstate_out_ptr,\n gstate_out_s_b,\n gstate_out_s_abe,\n gstate_out_s_c,\n # W grad\n gw_ptr,\n gw_s_c,\n # U grad\n gu_ptr,\n gu_s_c,\n # K grad\n gk_ptr,\n gk_s_b,\n gk_s_t,\n gk_s_c,\n # V grad\n gv_ptr,\n gv_s_b,\n gv_s_t,\n gv_s_c,\n # State grad\n gstate_ptr,\n gstate_s_b,\n gstate_s_abe,\n gstate_s_c,\n # Params\n tsz,\n chans,\n BLOCK_SIZE_C: tl.constexpr,\n):\n # Parallelize over the batch dimension.\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n\n # Pointers to the batch (and possibly channel) for the input tensors.\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n\n # Pointers to the batch (and possibly channel) for the output tensors.\n gk_ptr = gk_ptr + b_idx * gk_s_b\n gv_ptr = gv_ptr + b_idx * gv_s_b\n\n # Pointers to gradients which were recieved by the function.\n gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b\n galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b\n gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe\n geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe\n\n # Loads parameters.\n galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)\n\n # Gradient accumulators.\n gw = tl.zeros_like(w)\n gu = tl.zeros_like(u)\n\n alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n for t in range(tsz):\n tc = tsz - t - 1\n\n kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)\n\n alpha_curr = alpha_prev\n beta_curr = beta_prev\n eps_curr = eps_prev\n\n alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n\n ukt = u + kt\n tau = tl.maximum(ukt, eps_prev)\n e1 = tl.exp(eps_prev - tau)\n e2 = tl.exp(ukt - tau)\n\n euke = tl.exp(ukt + eps_prev - 2 * tau)\n\n denom = e1 * beta_prev + e2\n denom_sq = denom * denom\n\n gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)\n\n # Backpropagates wkv gradients.\n guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq\n gu += guk\n gk = guk\n gv = gwkvt * e2 / denom\n\n galpha_wkv = gwkvt * e1 / denom\n gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq\n geps_wkv_denom = e1 * beta_prev + e2\n geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)\n\n e1 = tl.exp(w + eps_prev - eps_curr)\n e2 = tl.exp(kt - eps_curr)\n\n # Backpropagates alpha gradients.\n galpha_we = galpha * e1 * alpha_prev\n gw += galpha_we\n gk += galpha * e2 * vt\n gv += galpha * e2\n geps += galpha * -alpha_curr\n\n # Backpropagates beta gradients.\n gbeta_we = gbeta * e1 * beta_prev\n gw += gbeta_we\n gk += gbeta * e2\n geps += gbeta * -beta_curr\n\n # Backpropagates epsilon gradients.\n geps_mask = w + eps_prev > kt\n geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))\n gw += geps_we\n gk += tl.where(geps_mask, tl.zeros_like(geps), geps)\n\n # Stores the gradients for k and v.\n tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)\n tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)\n\n # Computes new gradients for alpha and beta.\n galpha = galpha * e1 + galpha_wkv\n gbeta = gbeta * e1 + gbeta_wkv\n geps = galpha_we + gbeta_we + geps_we + geps_wkv\n\n # Stores final gradients for alpha and beta.\n galpha_ptr = gstate_ptr + b_idx * gstate_s_b\n gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe\n geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe\n tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)\n tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)\n tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)\n\n # Stores final gradients for w and u.\n gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)\n gw_temp += gw\n tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)\n gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)\n gu_temp += gu\n tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)\n\n\ndef fused_recurrent_rwkv4_backward(\n w: Tensor,\n u: Tensor,\n k: Tensor,\n v: Tensor,\n state: Tensor,\n grad_wkv: Tensor,\n grad_state: Tensor,\n) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n bsz, tsz, chans = k.shape\n\n gw = torch.zeros_like(w) # New tensors to output.\n gu = torch.zeros_like(u)\n gk = torch.empty_like(k)\n gv = torch.empty_like(v)\n gstate = k.new_empty(bsz, 3, 1, chans)\n\n block_size_c = get_block_size_c(chans) # Constants.\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_backward_kernel[grid](\n # W\n w,\n w.stride(0),\n # U\n u,\n u.stride(0),\n # K\n k,\n k.stride(0),\n k.stride(1),\n k.stride(2),\n # V\n v,\n v.stride(0),\n v.stride(1),\n v.stride(2),\n # State\n state,\n state.stride(0),\n state.stride(1),\n state.stride(2),\n state.stride(3),\n # WKV grad\n grad_wkv,\n grad_wkv.stride(0),\n grad_wkv.stride(1),\n grad_wkv.stride(2),\n # Output state grad\n grad_state,\n grad_state.stride(0),\n grad_state.stride(1),\n grad_state.stride(3),\n # W grad\n gw,\n gw.stride(0),\n # U grad\n gu,\n gu.stride(0),\n # K grad\n gk,\n gk.stride(0),\n gk.stride(1),\n gk.stride(2),\n # V grad\n gv,\n gv.stride(0),\n gv.stride(1),\n gv.stride(2),\n # State grad\n gstate,\n gstate.stride(0),\n gstate.stride(1),\n gstate.stride(3),\n # Params\n tsz,\n chans,\n BLOCK_SIZE_C=block_size_c,\n )\n\n return gw, gu, gk, gv, gstate\n", - "description_1": "Use triton language to implement a fused recurrent RWKV forward and backward kernel. The forward kernel takes 25 parameters: pointers to tensors w, u, k, v, state, wkv, and state_out, along with their strides, the number of channels, the time size, and a block size constant. It computes the RWKV forward pass by iterating over the time dimension and updating the state and wkv tensors. The backward kernel takes 35 parameters: pointers to tensors w, u, k, v, state, gwkv, gstate_out, gw, gu, gk, gv, gstate, along with their strides, the number of channels, the time size, and a block size constant. It computes the gradients for the RWKV backward pass by iterating over the time dimension in reverse and updating the gradient tensors.", - "description_2": "Use triton language to create a fused recurrent RWKV kernel for forward and backward passes, handling tensor operations and gradient computations efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Kernel function for cumulative RWKV forward pass\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_cum(\n s, o, o_minus_s, s_s_h, s_s_t, s_s_d, \n T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))\n\n# Kernel function for post-processing the gradient\n@triton.jit\ndef post_process_grad(\n q, k, v, u, do, dk, dq, du, scale, s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, \n H, T: tl.constexpr, BT: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n i_h = i_bh % H\n\n # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_u = tl.load(p_u, boundary_check=(0,))\n\n b_vdo = tl.sum(b_v * b_do, axis=1)\n b_du = b_vdo[:, None] * b_k * b_q * scale\n b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale\n b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale\n\n b_dq += tl.load(p_dq, boundary_check=(0, 1))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_dk += tl.load(p_dk, boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))\n\n# Forward pass kernel function for chunked RWKV\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_h(\n k, v, g, h, h0, ht, \n s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d, \n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n o_t = min(i_t * BT + BT, T)\n\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BK, BT]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n if i_t < NT - 1:\n # [BK,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n else:\n b_gn = tl.min(b_g, axis=1)\n b_h *= tl.exp(b_gn)[:, None]\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n# Inter-chunk kernel for RWKV forward pass\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_inter(\n q, v, gs, h, o, A, s_k_h, s_k_t, s_k_d, s_v_h, s_v_t, s_v_d, s_h_h, s_h_t, s_h_d, \n scale, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n # [BT, BK]\n b_gs = tl.load(p_gs, boundary_check=(0, 1))\n # [BT, BK]\n b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype)\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # works but dkw, owing to divine benevolence\n # [BT, BV]\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BT]\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n", - "description_1": "Use triton language to implement a series of kernel functions for the RWKV model's forward pass, post-processing of gradients, and inter-chunk operations. Functions include: 1. chunk_rwkv6_fwd_kernel_cum: Computes cumulative sums for forward pass. 2. post_process_grad: Processes gradients for the RWKV model. 3. chunk_rwkv6_fwd_kernel_h: Handles intra-chunk operations for forward pass. 4. chunk_rwkv6_fwd_kernel_inter: Manages inter-chunk operations for forward pass.", - "description_2": "Use triton language to create kernels to compute cumulative sums and process gradients in the RWKV model, focusing on intra- and inter-chunk operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_fwd\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel(\n q, # query [B, H, T, K]\n k, # key [B, H, T, K]\n v, # value [B, H, T, V]\n w, # log gate [B, H, T, K]\n u, # bonus [B, H, K]\n o, # output [B, H, T, V]\n h0, # initial hidden state initialization [B, H, K, V]\n ht, # final hidden state [B, H, K, V]\n s_k_h, # stride size: T * K\n s_v_h, # stride size: T * V\n scale, # K ** -0.5\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n STORE_FINAL_STATE: tl.constexpr, # whether to store final state\n REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction\n):\n TargetDType = tl.bfloat16\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n\n b_h = tl.zeros([BV, BK], dtype=TargetDType)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(TargetDType)\n\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(TargetDType)\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(TargetDType)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(TargetDType)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(TargetDType) * scale\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(TargetDType)\n b_w = tl.exp(b_w.to(tl.float32)).to(TargetDType)\n b_kv = b_k[None, :] * b_v[:, None]\n b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n b_h = b_h * b_w[None, :]\n b_h += b_kv\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\nclass FusedRecurrentRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):\n q = r\n B, H, T, K, V = *q.shape, v.shape[-1]\n\n BK, BV = min(triton.next_power_of_2(K), 128), min(triton.next_power_of_2(V), 128)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n\n o = q.new_empty(NK, B, H, T, V, dtype=torch.bfloat16)\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_fwd_kernel[grid](\n q, k, v, w, u, o, initial_state, final_state,\n k.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\ndef fused_recurrent_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n w: torch.Tensor,\n u: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = r.shape[-1] ** -0.5\n o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused recurrent operation with RWKV pattern for a set of inputs and parameters, allowing for hidden state manipulation and directional control.", - "description_2": "Use triton language to define and execute a fused recurrent kernel with custom data types and block sizes.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_h(\n k,\n v,\n h,\n g,\n initial_state,\n final_state,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,\n (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n b_h *= tl.math.exp2(b_g_last)\n b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))\n b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(\n final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q,\n k,\n v,\n h,\n g,\n o,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_o = b_o * tl.math.exp2(b_g)[:, None]\n b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dh(\n q,\n g,\n do,\n dh,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +\n i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkv(\n q,\n k,\n v,\n h,\n g,\n do,\n dh,\n dq,\n dk,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n mask = tl.math.exp2(b_g[None, :] - b_g[:, None])\n mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)\n b_s = b_s * mask\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),\n (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),\n (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \\\n tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_dq = b_dq * tl.math.exp2(b_g)[:, None]\n b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]\n b_ds = b_ds * tl.trans(mask)\n b_ds = b_ds.to(b_k.dtype)\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass SimpleGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, g, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(\n 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n BT = 64\n assert T % BT == 0, 'sequence length must be divisible by BT'\n g = g.reshape(B, H, -1, BT)\n g = g.cumsum(-1) * 1.44269504\n g = g.reshape(B, H, -1)\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_fwd_kernel_h[grid](\n k, v, h, g, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h, g)\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_ht=None):\n q, k, v, h, g = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(\n 32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_bwd_kernel_dh[grid](\n q, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_simple_gla_bwd_kernel_dqkv[grid](\n q, k, v, h, g, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n dg = (dq * q - dk * k).sum(-1)\n\n def rev_cumsum(x):\n cumsum_x = x.cumsum(-1)\n rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x\n return rev_cumsum_x + x\n dg = rev_cumsum(dg)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None\n\n\ndef chunk_simple_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n g = g.float()\n o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a series of kernels for a generalized linear attention mechanism. The kernels include forward and backward passes for handling input tensors q, k, v, and g, with optional initial and final states. The forward kernels compute intermediate states and outputs, while the backward kernels compute gradients for q, k, v, and g. The kernels are optimized for specific block sizes and use triton's block pointer and program id features.", - "description_2": "Use triton language to create kernels for forward and backward passes of a linear attention mechanism, handling tensors q, k, v, and g, with optional state management.", - "difficulty": 5 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_fwd_kernel(\n s,\n z,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_bwd_kernel(\n ds,\n dz,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)\n\n b_ds = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)\n tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_ds += tl.sum(b_dz, 0)\n\n\ndef chunk_cumsum_fwd(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_cumsum_fwd_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n\n\ndef chunk_cumsum_bwd(\n dz: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = dz.shape\n BS = 32\n\n dtype = dtype or dz.dtype\n grid = (triton.cdiv(S, BS), B * H)\n ds = torch.empty_like(dz, dtype=dtype)\n chunk_cumsum_bwd_kernel[grid](\n ds, dz,\n ds.stride(1), ds.stride(2), ds.stride(3),\n T=T, S=S, BS=BS\n )\n return ds\n\n\nclass CumsumFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, s, dtype):\n z = chunk_cumsum_fwd(s, dtype)\n ctx.dtype = dtype\n return z\n\n @staticmethod\n def backward(ctx, dz):\n ds = chunk_cumsum_bwd(dz, ctx.dtype)\n return ds, None\n\n\ndef cumsum(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n return CumsumFunction.apply(s, dtype)\n", - "description_1": "Use triton language to implement forward and backward kernels for chunk-based cumulative sum operations. The forward kernel 'chunk_cumsum_fwd_kernel' has parameters: s (input tensor), z (output tensor), s_s_h, s_s_t, s_s_d (stride values), T, S, BT, and BS (block sizes). The backward kernel 'chunk_cumsum_bwd_kernel' has parameters: ds (input gradient tensor), dz (output gradient tensor), s_s_h, s_s_t, s_s_d (stride values), T, S, BT, and BS (block sizes). The 'chunk_cumsum_fwd' function prepares the grid and launches the forward kernel, while 'chunk_cumsum_bwd' does similarly for the backward kernel. Both operate over 4D tensors of dimensions (B, H, T, S).", - "description_2": "Use triton language to perform chunk-based cumulative sum forward and backward operations on 4D tensors with specific stride and block size configurations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom fla.utils import contiguous\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_fwd_kernel(\n x,\n y,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_y = y + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_m = tl.minimum(0., b_x)\n b_z = 1. + tl.exp(-tl.abs(b_x))\n b_y = b_m - tl.log(b_z)\n tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 128}, num_warps=2),\n triton.Config({'BT': 128}, num_warps=4),\n triton.Config({'BT': 128}, num_warps=8),\n triton.Config({'BT': 256}, num_warps=2),\n triton.Config({'BT': 256}, num_warps=4),\n triton.Config({'BT': 256}, num_warps=8)\n ],\n key=['D']\n)\n@triton.jit\ndef logsigmoid_bwd_kernel(\n x,\n dx,\n dy,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr\n):\n i = tl.program_id(0)\n o_i = i * BT + tl.arange(0, BT)\n\n p_x = x + o_i\n p_dx = dx + o_i\n p_dy = dy + o_i\n mask = o_i < T\n\n # [D,]\n b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)\n b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)\n b_dx = b_dy * (1. - tl.sigmoid(b_x))\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n\nclass LogSigmoidFunction(torch.autograd.Function):\n\n @staticmethod\n @contiguous\n def forward(ctx, x):\n T, D = x.numel(), x.shape[-1]\n y = torch.empty_like(x)\n logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)\n ctx.save_for_backward(x,)\n return y\n\n @staticmethod\n @contiguous\n def backward(ctx, dy):\n x, = ctx.saved_tensors\n T, D = x.numel(), x.shape[-1]\n dx = torch.empty_like(x)\n logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)\n return dx\n\n\nlogsigmoid = LogSigmoidFunction.apply\n", - "description_1": "Use triton language to create two kernels, logsigmoid_fwd_kernel and logsigmoid_bwd_kernel, for forward and backward log-sigmoid operations. The forward kernel computes the log-sigmoid of input tensor x and stores the result in tensor y. The backward kernel computes the gradient with respect to x and stores it in tensor dx using the input gradient dy. These kernels are configured with various block sizes using triton.autotune. A PyTorch autograd Function LogSigmoidFunction is implemented to use these kernels in forward and backward passes.", - "description_2": "Use triton language to create logsigmoid forward and backward kernels with autotuning, and implement a PyTorch Function to utilize these kernels.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n z_loss_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n n_rows,\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n SPLIT: tl.constexpr,\n):\n # Kernel implementation here\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n # Kernel implementation here\n\nclass CrossEntropyLossFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n logits,\n labels,\n smoothing=0.0,\n logit_scale=1.0,\n lse_square_scale=0.0,\n ignored_index=-100,\n inplace_backward=False,\n process_group=None,\n ):\n # Forward implementation here\n with torch.cuda.device(logits.device.index):\n cross_entropy_fwd_kernel[(n_rows, n_splits)](\n # Kernel call with all parameters\n )\n\n @staticmethod\n def backward(ctx, grad_losses, grad_z_losses):\n del grad_z_losses\n\n logits, lse, labels = ctx.saved_tensors\n dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)\n num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)\n \n def grid(META): \n return (n_rows, triton.cdiv(n_cols, META[\"BLOCK_SIZE\"])) # noqa\n \n with torch.cuda.device(logits.device.index):\n cross_entropy_bwd_kernel[grid](\n dlogits, # Kernel call with all parameters\n )\n return dlogits, None, None, None, None, None, None, None, None\n\n\ndef cross_entropy_loss(\n logits: torch.Tensor,\n labels: torch.Tensor,\n label_smoothing: float = 0.0,\n logit_scale: float = 1.0,\n lse_square_scale: float = 0.0,\n ignored_index=-100,\n inplace_backward: bool = False,\n process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n return CrossEntropyLossFunction.apply(\n logits,\n labels,\n label_smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n inplace_backward,\n process_group,\n )\n", - "description_1": "Use triton language to define two kernels for cross-entropy loss computation. The first kernel, `cross_entropy_fwd_kernel`, computes the forward pass of cross-entropy loss with optional label smoothing and tensor parallel capabilities. It takes parameters such as pointers to data, smoothing factor, logit scaling, ignored index, and constants for configuration. The second kernel, `cross_entropy_bwd_kernel`, computes the backward pass for the gradients with respect to the logits. It handles label smoothing and various dimensions. The calling function `cross_entropy_loss` uses a custom autograd function to apply these kernels on given input tensors, managing both forward and backward operations.", - "description_2": "Use triton language to create forward and backward kernels for calculating cross-entropy loss with support for label smoothing and parallel processing, then integrate them into a PyTorch autograd function for automatic differentiation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, O, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row,\n N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr,\n HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n O += row * stride_x_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)\n y = y * o * tl.sigmoid(o)\n tl.store(Y + cols, y, mask=mask)\n\ndef _layer_norm_fwd(\n x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, o, y, weight, bias, residual, residual_out, mean, rstd,\n x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N, eps, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None,\n weight is not None, bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a forward pass kernel for layer normalization with optional residuals, weights, and biases. The kernel computes the mean and variance for normalization, applies a linear transformation, and includes a Swish activation function. The function _layer_norm_fwd is used to set up and call this kernel with appropriate parameters.", - "description_2": "Use triton language to create a layer normalization kernel with Swish activation, supporting optional residuals, weights, and biases.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n if HAS_WEIGHT:\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w if HAS_WEIGHT else x_hat\n if HAS_BIAS:\n y = y + b\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n if weight is not None:\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n weight,\n bias,\n residual,\n residual_out,\n mean,\n rstd,\n x.stride(0),\n y.stride(0),\n residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0,\n N,\n eps,\n is_rms_norm,\n BLOCK_N,\n residual is not None,\n residual_out is not None,\n weight is not None,\n bias is not None,\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n", - "description_1": "Use triton language to implement a kernel function for forward pass of layer normalization, which handles mean and variance computation, residual addition, and normalization of input data, supporting both standard and RMS layer normalization with optional weight and bias. The function is parameterized to handle varying input sizes and configurations through constants and conditions.", - "description_2": "Use triton language to create a fused kernel for layer normalization forward pass, efficiently computing normalized outputs with optional weight and bias, supporting configurations for RMS norm and handling varying input dimensions and residuals.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef chunk_abc_fwd_kernel_h(\n k,\n v,\n z,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n NORMK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef chunk_abc_fwd_kernel_intra_K(\n v,\n z,\n o,\n A,\n s_v_h,\n s_v_t,\n s_v_d,\n T: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BV: tl.constexpr,\n NC: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef chunk_abc_fwd_kernel_K(\n q,\n k,\n z,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.jit\ndef chunk_abc_fwd_kernel_V(\n q,\n v,\n z,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n # Kernel implementation...\n\n# Function to launch Triton kernels\ndef chunk_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)\n return ov, final_state\n", - "description_1": "Use triton language to implement forward kernels for chunked attention computation, with multiple kernels handling different parts of the computation (e.g., intra-chunk, inter-chunk). Each kernel is parameterized with tensor inputs and strides, tensor constants for dimensions, and configuration flags for initial and final state handling. The overall operation includes computation for query-key-value interaction and state updates.", - "description_2": "Use triton language to implement backward kernels for chunked attention computation, focusing on handling gradients for intra- and inter-chunk interactions. These kernels backpropagate through the attention operation considering normalization and state transitions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# This kernel is used for cumulative sum operation along the input tensor.\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr,\n):\n # The kernel function performs a cumulative sum operation with gating.\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n# This function wraps the Triton kernel, preparing and launching it.\ndef fwd_pre(g, B, H, T, S, BT):\n NT = triton.cdiv(T, BT)\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)\n # Perform a cumulative sum along a specific dimension of the input tensor.\n chunk_gated_abc_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=S, BT=BT\n )\n return g\n", - "description_1": "Use triton language to implement a cumulative sum with gating on a tensor, involving kernel configurations and loading/storing data using block pointers.", - "description_2": "Implement a cumulative sum operation on a tensor using triton language, optimized with autotuned configurations for block sizes and kernel parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_gated_abc_fwd_kernel(\n q,\n k,\n v,\n gk,\n gv,\n o,\n h0,\n ht,\n s_k_h,\n s_v_h,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr,\n USE_GK: tl.constexpr,\n USE_GV: tl.constexpr,\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * b_gk[None, :]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * b_gv[:, None]\n h += b_k[None, :] * b_v[:, None]\n b_o = h * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_gated_abc_bwd_kernel(\n q,\n k,\n v,\n gk,\n gv,\n do,\n dq,\n dk,\n dv,\n h0,\n s_k_h,\n s_v_h,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr,\n USE_GK: tl.constexpr,\n USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * b_gk[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * b_gv[None, :]\n h += b_k[:, None] * b_v[None, :]\n b_dq = tl.sum(h * b_do[None, :], axis=1) * scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -K if REVERSE else K\n p_v += -V if REVERSE else V\n p_q += -K if REVERSE else K\n p_do += -V if REVERSE else V\n p_dq += -K if REVERSE else K\n if USE_GK:\n p_gk += -K if REVERSE else K\n if USE_GV:\n p_gv += -V if REVERSE else V\n\n # sync threads\n tl.debug_barrier()\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for _ in range(T):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n b_dh += b_q[:, None] * b_do[None, :]\n b_dk = tl.sum(b_dh * b_v[None, :], axis=1)\n b_dv = tl.sum(b_dh * b_k[:, None], axis=0)\n if USE_GK:\n b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n b_dh *= b_gk[:, None]\n if USE_GV:\n b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n b_dh *= b_gv[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_q += K if REVERSE else -K\n p_k += K if REVERSE else -K\n p_v += V if REVERSE else -V\n p_do += V if REVERSE else -V\n p_dk += K if REVERSE else -K\n p_dv += V if REVERSE else -V\n if USE_GK:\n p_gk += K if REVERSE else -K\n if USE_GV:\n p_gv += V if REVERSE else -V\n\n\ndef fused_recurrent_gated_abc(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n s: torch.Tensor,\n g: Optional[torch.Tensor] = None,\n scale: Optional[int] = None,\n initial_state: Optional[Tuple[torch.Tensor]] = None,\n output_final_state: Optional[bool] = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n r\"\"\"\n Args:\n q (torch.Tensor):\n queries of shape `(B, H, T, K)`\n k (torch.Tensor):\n keys of shape `(B, H, T, K)`\n v (torch.Tensor):\n values of shape `(B, H, T, V)`\n g (torch.Tensor):\n Forget gates of shape `(B, H, T, M)` applied to keys.\n If not provided, this function is equivalent to vanilla ABC.\n scale (Optional[int]):\n Scale factor for attention scores.\n If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n initial_state (Optional[Tuple[torch.Tensor]]):\n Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.\n output_final_state (Optional[bool]):\n Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.\n \"\"\"\n if initial_state is not None:\n initial_state = tuple(i.detach() for i in initial_state)\n if g is None:\n # TODO: this 3 steps took huge amount of time, ought to be optimized\n z = s.float().logcumsumexp(2)\n g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z\n s = torch.exp(s - z).to(k.dtype)\n if scale is None:\n scale = q.shape[-1] ** -0.5\n ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state)\n return ov, final_state\n", - "description_1": "Use triton language to implement two kernels: 'fused_recurrent_gated_abc_fwd_kernel' and 'fused_recurrent_gated_abc_bwd_kernel'. The forward kernel computes recurrent gated operations on inputs q, k, v with optional gating factors gk and gv, utilizing block sizes BK and BV, respecting various configurations such as REVERSE and USE_INITIAL_STATE. The backward kernel computes the gradient of inputs based on the output gradient do and involves similar configurations.", - "description_2": "Use triton language to create kernels for forward and backward passes of a recurrent gated operation on tensors, handling additional gating factors and initial states, with attention to reversing sequences and storing final states.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_based_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h_0o = tl.zeros([BV], dtype=tl.float32)\n b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_0o = 0\n\n for i in range(0, tl.cdiv(T, BT)):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_k_2o = b_k[:, None, :] * b_k[None, :, :]\n b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_z = tl.zeros([BT], dtype=tl.float32)\n\n b_o += b_h_0o\n b_z += k_0o\n b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)\n b_z += tl.sum(b_q * k_1o, axis=1)\n b_q_2o = b_q[:, :, None] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)\n b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5\n b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5\n\n k_1o += tl.sum(b_k, axis=1)[None, :]\n k_2o += tl.sum(b_k_2o, axis=1)[None, :]\n k_0o += BT\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=(i * BT + tl.arange(0, BT)) < T)\n\n b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)\n b_h_0o = b_h_0o + tl.sum(b_v, axis=0)\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_z += BT\n\n@triton.jit\ndef fused_chunk_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)\n b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)\n k_1o = tl.zeros([1, BK], dtype=tl.float32)\n k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n\n b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)\n if i_v == 0:\n b_dq += b_dz[:, None] * k_1o\n b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5\n if i_v == 0:\n b_dq_2o += (b_dz[:, None] * k_2o) * 0.5\n b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])\n b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)\n b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)\n b_dq *= scale\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)\n b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)\n\n if i_v == 0:\n k_1o += tl.sum(b_k, axis=0)[None, :]\n k_2o += tl.sum(b_k_2o, axis=0)[None, :]\n\n tl.debug_barrier()\n b_h_1o = None\n b_h_2o = None\n b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)\n b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)\n b_dh_0o = tl.zeros([BV], dtype=tl.float32)\n m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]\n dq_1o = tl.zeros([1, BK], dtype=tl.float32)\n dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)\n\n for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))\n p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i\n\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_dv = tl.zeros([BT, BV], dtype=tl.float32)\n\n b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n b_ds = tl.where(m_s, b_ds, 0)\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n b_s2 = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n b_ds *= (1+b_s)\n\n b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)\n\n b_k_2o = b_k[:, :, None] * b_k[:, None, :]\n b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)\n\n b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)\n b_dv += b_dh_0o\n\n b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)\n\n if i_v == 0:\n b_dk += dq_1o\n\n b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),\n tl.trans(b_v), allow_tf32=False)\n if i_v == 0:\n b_dk_2o += dq_2o\n b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])\n b_k_fp32 = tl.trans(b_k.to(tl.float32))\n b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)\n b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)\n b_dk += tl.trans(b_dk2)\n\n b_dh_0o += tl.sum(b_do, axis=0)\n b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)\n b_q_2o = b_q[None, :, :] * b_q[:, None, :]\n b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)\n b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5\n\n if i_v == 0:\n dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]\n dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkBasedFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale=1):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = scale\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=torch.float32)\n z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n )\n o = o.sum(0)\n z = z.sum(0)\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.to(q.dtype), z.to(z.dtype)\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 16\n BK, BV = min(d_head_qk, 16), min(d_head_v, 32)\n BK, BV = max(BK, 16), max(BV, 16)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None\n\ntriton_fused_chunk_based = FusedChunkBasedFunction.apply\n\ndef fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):\n assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_fused_chunk_based(q, k, v, scale)\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a fused chunk-based forward and backward kernel for a transformer-like operation. The forward kernel takes query, key, and value tensors along with stride sizes and other parameters to compute an output tensor and a normalizer tensor using Taylor expansion for fast matrix multiplication. The backward kernel calculates the gradients of query, key, and value tensors using the outputs from the forward pass. The main function, fused_chunk_based, interfaces these kernels with PyTorch's autograd functionality.", - "description_2": "Use triton language to create kernels for efficient matrix multiplication in a transformer model, handling both forward and backward passes with gradient computation, and integrate with PyTorch's autograd system.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef parallel_based_fwd_kernel(\n q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = 1 + b_s + 0.5 * b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef parallel_based_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_based_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_based_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n @contiguous\n @custom_fwd\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_based_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n @custom_bwd\n @contiguous\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_based_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + 1e-6)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement a parallel-based forward and backward kernel for a sequence mixer. The forward kernel takes 18 parameters: q, k, v, o, z, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV. The backward kernel takes 20 parameters: q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV.", - "description_2": "Use triton language to create a custom autograd function in PyTorch for a parallel-based sequence mixer with forward and backward passes, utilizing triton kernels for efficient computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_prepare_dv_kernel(\n q,\n k,\n do,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1)) \n b_q = (b_q * scale).to(b_k.dtype)\n b_A += tl.dot(b_k, b_q, allow_tf32=False)\n\n b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_dv = tl.dot(b_A, b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_prepare_dv(q, k, do, BT):\n dv = torch.empty_like(do)\n B, H, T, K, V = *k.shape, do.shape[-1]\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_prepare_dv_kernel[(NT, B*H)](\n q, k, do, dv,\n k.stride(1), k.stride(2), k.stride(3), \n do.stride(1), do.stride(2), do.stride(3),\n T, K, V, K**-0.5, BT, BK, BV\n )\n return dv\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef chunk_delta_rule_fwd_kernel_h(\n k,\n v,\n d, \n v_new,\n h,\n initial_state,\n final_state,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)\n for i_c in range(tl.cdiv(BT, BC)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))\n p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))\n p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) \n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_d = tl.load(p_d, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)\n tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))\n b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)\n b_h += b_h_cumsum \n \n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\ndef chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):\n B, H, T, K, V = *k.shape, u.shape[-1]\n\n BK = triton.next_power_of_2(K)\n assert BK <= 256, \"current kernel does not support head dimension larger than 256.\"\n BV = 16 if BK > 128 else 32 \n BV = 64 if BK <= 64 else BV\n BC = 16 if BK > 128 else 32 \n BC = 64 if BK <= 64 else BC\n BC = min(BT, BC)\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'\n\n h = k.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n v_new = torch.empty_like(u)\n chunk_delta_rule_fwd_kernel_h[grid](\n k, u, w, v_new, h, initial_state, final_state,\n k.stride(1), k.stride(2), k.stride(3),\n u.stride(1), u.stride(2), u.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n )\n return h, v_new\n", - "description_1": "Use triton language to implement a kernel (fwd_prepare_dv_kernel) with parameters q, k, do, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, T, K, V, scale, BT, BK, BV. The kernel computes dv as the result of matrix multiplications and element-wise operations involving q, k, and do. Another kernel (chunk_delta_rule_fwd_kernel_h) takes k, v, d, v_new, h, initial_state, final_state, and several strides and constants as parameters. It computes a transformation of h and v_new based on k, v, and d, supporting state tracking via initial_state and final_state with specific memory layouts.", - "description_2": "Use triton language to define and call kernels for computing tensor transformations with state tracking and accumulation of results.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8)\n ],\n key=[\"BT\", \"BK\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_fwd_kernel(\n q, k, v, v_new, d, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n # Kernel implementation...\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fused_chunk_delta_rule_bwd_kernel(\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n # Kernel implementation...\n\ndef fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BT = BT\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, 'NK should be 1'\n o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n grid = (NV, NK, batch_size * n_heads)\n v_new = torch.empty_like(v)\n fused_chunk_delta_rule_fwd_kernel[grid](\n q, k, v, v_new, d, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n )\n return o, v_new, CHECK, final_state\n\ndef fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_delta_rule_bwd_kernel[grid](\n q, k, v, d, do, dq, dk, dv, dd, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=CHECK,\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dd = dd.sum(0)\n dd[:, :, 0:BT] = 0\n return dq, dk, dv, dd\n", - "description_1": "Use triton language to implement fused_chunk_delta_rule_fwd_kernel with 25 parameters for forward pass calculation of the fused chunk delta rule, considering various input tensors, stride sizes, constants and configuration for block sizes, enabling gradient accumulation. Use fused_chunk_delta_rule_bwd_kernel with 28 parameters for backward pass calculation with similar inputs and additional tensors for gradients, again accounting for stride sizes, constants, and configuration for block sizes, enabling state handling.", - "description_2": "Use triton language to create forward and backward kernels for fused chunk delta rule, managing queries, keys, values, decays, and gradients, with block size configuration and optional state management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef fused_recurrent_fwd_kernel(\n q, k, v, beta, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n):\n # Kernel function for forward pass of recurrent network\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_beta = beta + i_bh * T\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _v_minus = tl.sum(h * _k[None, :], axis=1)\n _v -= _v_minus\n _beta = tl.load(p_beta).to(tl.float32)\n tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)\n _v *= _beta\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n p_beta += 1\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_bwd_kernel(\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n):\n # Kernel function for backward pass of recurrent network\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_beta = beta + i_bh * T + T - 1\n p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1\n\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + (T - 1) * DK\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + (T - 1) * DV\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _beta = tl.load(p_beta).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n\n d_beta = tl.sum(d_v * _v)\n d_v = d_v * _beta\n\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))\n\n d_h -= _k[:, None] * d_v[None, :]\n\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n p_dbeta -= 1\n p_beta -= 1\n\n tl.debug_barrier()\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_beta = beta + i_bh * T\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK\n\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _beta = tl.load(p_beta).to(tl.float32)\n _v *= _beta\n\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n if i < T - 1:\n d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)\n d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)\n d_k -= tl.sum(d_v[None, :] * h, axis=1)\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n\n p_k += DK\n p_do += DV\n p_v += DV\n p_dk += DK\n p_dv += DV\n p_dq += DK\n p_beta += 1\n\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n assert NK == 1, \"NK > 1 is not supported yet\"\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_fwd_kernel[grid](\n q, k, v, beta, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, beta, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, beta, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n assert NK == 1, \"NK > 1 is not supported yet\"\n num_stages = 1\n num_warps = 2\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)\n\n fused_recurrent_bwd_kernel[grid](\n q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dbeta = dbeta.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None\n\n\ndef fused_recurrent_linear_attn_delta_rule(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n beta: torch.Tensor = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if beta is None:\n beta = torch.ones_like(q[..., 0])\n o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement fused recurrent forward and backward kernels for a recurrent network with parameters such as query, key, value tensors, beta tensor for scaling, output and hidden state initialization, along with strides and dimensions for batch size, number of heads, sequence length, scaling factor, block sizes, dimensions of head, and constants for initial and final state usage. The forward kernel computes the weighted sum of queries and keys and modifies the value tensor in place, while the backward kernel computes gradients for the query, key, value, and beta tensors. The function 'FusedRecurrentFunction' calls these kernels with appropriate settings and is used in the 'fused_recurrent_linear_attn_delta_rule' to return the output tensor and optionally the final state tensor.", - "description_2": "Use triton language to create forward and backward kernels for a fused recurrent network, handling query, key, value tensors with scaling and state management, for efficient computation and gradient calculation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n o,\n o2,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = tl.arange(0, BK) < K\n mask_bv = tl.arange(0, BV) < V\n mask_bk = mask_bk[None, :] & mask_bt[:, None]\n mask_bv = mask_bv[None, :] & mask_bt[:, None]\n # [BT, BK]\n b_k = tl.load(p_k, mask=mask_bk, other=0)\n # [BT,]\n b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)\n # [BT, BV]\n b_v = tl.load(p_v, mask=mask_bv, other=0)\n b_v = (b_v * b_beta[:, None]).to(b_v.dtype)\n # [BT, BK]\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n # [BT, BT]\n b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n b_A = b_A.to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n b_u = tl.dot(b_A, b_v, allow_tf32=False)\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta,\n o, o2, do, do2,\n dk, dv, dbeta,\n NT, K, V, T,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n\n p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)\n mask_bt = (tl.arange(0, BT) + i_t * BT) < T\n mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]\n mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]\n b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)\n\n b_beta = b_beta.to(tl.float32)\n A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]\n A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)\n b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)\n b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)\n dA = tl.zeros([BT, BT], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n for i in range(BT-1, -1, -1):\n mask = tl.arange(0, BT) == i\n attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)\n do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)\n dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)\n b_do = b_do - attn[:, None] * do_[None, :]\n b_dv = b_dv - attn[:, None] * dv_[None, :]\n tl.debug_barrier()\n p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_v = tl.load(p_v, mask=mask_bv)\n b_dk += b_do * b_beta[:, None]\n b_dbeta = tl.sum(b_do * b_k, axis=1)\n b_dbeta += tl.sum(b_dv * b_v, axis=1)\n b_v = None\n\n p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n b_o = tl.load(p_o, mask=mask_bk)\n b_o2 = tl.load(p_o2, mask=mask_bv)\n\n dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)\n dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),\n allow_tf32=False)\n dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)\n b_dv *= b_beta[:, None]\n p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)\n dA = dA * b_beta[:, None]\n b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)\n b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)\n p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)\n\n\ndef fwd_prepare_wy_repr(k, v, beta, chunk_size):\n B, H, T, K, V = *k.shape, v.shape[-1]\n v_new = torch.empty_like(v)\n o_cumdecay = torch.empty_like(k)\n BT = chunk_size\n NT = triton.cdiv(T, BT)\n BK = triton.next_power_of_2(K)\n BV = triton.next_power_of_2(V)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, o_cumdecay, v_new,\n T, K, V, BT, BK, BV\n )\n return o_cumdecay, v_new\n\n\ndef bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):\n b, h, l, d_k = do.shape\n d_v = v.shape[-1]\n BK = triton.next_power_of_2(d_k)\n BV = triton.next_power_of_2(d_v)\n c = chunk_size\n BK = d_k\n NT = triton.cdiv(l, c)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n dbeta = torch.zeros_like(beta)\n bwd_prepare_wy_repr_kernel[(NT, b*h)](\n k, v, beta,\n o_cumdecay, v_new, do, do2,\n dk, dv, dbeta,\n NT, d_k, d_v, l, chunk_size, BK, BV\n )\n return dk, dv, dbeta\n\nclass WYRepresentationPrepration(torch.autograd.Function):\n @staticmethod\n def forward(ctx, k, v, beta, chunk_size):\n o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)\n ctx.chunk_size = chunk_size\n ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)\n return o_cumdecay, v_new\n\n @staticmethod\n def backward(ctx, do, do2):\n k, v, beta, o_cumdecay, v_new = ctx.saved_tensors\n dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)\n return dk, dv, dbeta, None\n\nprepare_wy_repr = WYRepresentationPrepration.apply\n", - "description_1": "Use triton language to implement two kernels: fwd_prepare_wy_repr_kernel and bwd_prepare_wy_repr_kernel. The fwd_prepare_wy_repr_kernel takes 10 parameters: k, v, beta, o, o2, T, K, V, BT, BK, BV. It computes the forward pass of the WY representation preparation. The bwd_prepare_wy_repr_kernel takes 15 parameters: k, v, beta, o, o2, do, do2, dk, dv, dbeta, NT, K, V, T, BT, BK, BV. It computes the backward pass of the WY representation preparation. Both kernels are used in the functions fwd_prepare_wy_repr and bwd_prepare_wy_repr, which are called in the WYRepresentationPrepration class.", - "description_2": "Use triton language to create forward and backward kernels for WY representation preparation, and integrate them into a PyTorch autograd function.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_prepare_wy_repr_kernel(\n k,\n v,\n beta,\n w, \n u,\n A, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n b_A = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)\n\n b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)\n\n for i in range(1, BT):\n mask = tl.arange(0, BT) == i\n b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)\n b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)\n b_A = tl.where(mask[:, None], b_a, b_A)\n\n b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))\n b_A = b_A.to(k.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"], \n)\n@triton.jit\ndef fwd_recompute_w_u_kernel(\n k,\n v,\n beta,\n w, \n u,\n A, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n \n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_u = tl.dot(b_A, b_vb, allow_tf32=False)\n p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_w = tl.dot(b_A, b_kb, allow_tf32=False)\n p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef bwd_prepare_wy_repr_kernel(\n k, v, beta, A, \n dw, du,\n dk, dv, dbeta,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n T,\n K,\n V,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)\n\n b_dbeta = tl.zeros([BT], dtype=tl.float32)\n b_dA = tl.zeros([BT, BT], dtype=tl.float32)\n p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_beta = tl.load(p_beta, boundary_check=(0,))\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)\n b_du = tl.load(p_du, boundary_check=(0, 1))\n b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)\n b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)\n b_dv = b_dv_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dv_beta * b_v, 1)\n # store\n p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n tl.debug_barrier() \n b_A2 = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1)) \n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n b_dw = tl.load(p_dw, boundary_check=(0, 1))\n b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) \n b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)\n b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)\n b_dk = b_dk_beta * b_beta[:, None]\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n # store \n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])\n b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)\n tl.debug_barrier()\n\n for i in range(BT-1, 0, -1):\n mask = tl.arange(0, BT) == i\n b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0) \n b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) \n b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1) \n b_dA = tl.where(mask[:, None], b_da2, b_dA)\n b_dA += b_da[None, :] * b_a[:, None]\n\n b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)\n tl.debug_barrier()\n\n for i_k in range(tl.cdiv(K, BK)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1)) \n b_dk = tl.load(p_dk, boundary_check=(0, 1))\n b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)\n\n b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)\n b_dbeta += tl.sum(b_dk_beta * b_k, 1)\n b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) \n b_dk += b_dk_beta * b_beta[:, None] \n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n \n p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,))\n\n\ndef fwd_prepare_wy_repr(k, v, beta, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)\n fwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u, A\n\n\ndef fwd_recompute_w_u(k, v, beta, A, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n u = torch.empty_like(v)\n w = torch.empty_like(k)\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n fwd_recompute_w_u_kernel[(NT, B*H)](\n k, v, beta, w, u, A,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return w, u\n\n\ndef bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):\n B, H, T, K, V = *k.shape, v.shape[-1]\n\n NT = triton.cdiv(T, BT)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT = triton.cdiv(T, BT)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v).contiguous()\n dbeta = torch.zeros_like(beta)\n\n bwd_prepare_wy_repr_kernel[(NT, B*H)](\n k, v, beta, A,\n dw, du, \n dk, dv, dbeta,\n k.stride(1), k.stride(2), k.stride(3), \n v.stride(1), v.stride(2), v.stride(3),\n T, K, V, BT, BK, BV\n )\n return dk, dv, dbeta\n\n\nclass WYRepresentationPrepration(torch.autograd.Function):\n @staticmethod\n def forward(ctx, k, v, beta, chunk_size):\n ctx.BT = chunk_size\n w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)\n ctx.save_for_backward(k, v, beta, A)\n return w, u\n\n @staticmethod\n def backward(ctx, dw, du):\n k, v, beta, A = ctx.saved_tensors\n BT = ctx.BT\n dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)\n return dk, dv, dbeta, None\n\n\nprepare_wy_repr = WYRepresentationPrepration.apply\n", - "description_1": "Use triton language to implement three kernels: fwd_prepare_wy_repr_kernel, fwd_recompute_w_u_kernel, and bwd_prepare_wy_repr_kernel. Each kernel is decorated with @triton.jit and performs matrix operations on input tensors k, v, beta, and others. The kernels are used in functions fwd_prepare_wy_repr, fwd_recompute_w_u, and bwd_prepare_wy_repr, which prepare and recompute matrices for WY representation and its backward pass. The kernels handle block-wise operations and use triton's block pointers and dot products for efficient computation.", - "description_2": "Use triton language to create kernels for forward and backward WY representation preparation, utilizing block-wise matrix operations and triton's efficient computation features.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef chunk_gla_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_gla_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n if i_t < NT - 1:\n b_gn = tl.load(p_gn, boundary_check=(0,))\n else:\n b_gn = tl.min(b_g, axis=1)\n b_h *= tl.exp(b_gn)[:, None]\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_gla_fwd_kernel_intra(\n q,\n k,\n g,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n NC: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC\n n_bh = tl.num_programs(2)\n\n if i_i > i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))\n p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n b_gn = tl.load(p_gn, boundary_check=(0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_gk = tl.load(p_gk, boundary_check=(0, 1))\n b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)\n b_A = tl.dot(b_qg, b_kg, allow_tf32=False)\n tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n elif i_i == i_j:\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n\n o_i = tl.arange(0, BC)\n o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n for j in range(0, BC):\n b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)\n b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)\n b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)\n b_A = tl.where(o_i >= j, b_A, 0.)\n tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)\n\n p_k = tl.advance(p_k, (K,))\n p_gk = tl.advance(p_gk, (K,))\n\n\n@triton.jit\ndef chunk_gla_fwd_kernel_inter(\n q,\n v,\n g,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_qg, b_h, allow_tf32=False)\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkGLAFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gla_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n chunk_gla_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_gla_fwd_kernel_intra[grid](\n q, k, g, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n grid = (NV, NT, B * H)\n chunk_gla_fwd_kernel_inter[grid](\n q, v, g, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n if checkpoint_level >= 1:\n del g\n g = g_org\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n\n ctx.save_for_backward(q, k, v, g, h, initial_state, A)\n ctx.BT = BT\n ctx.scale = scale\n ctx.checkpoint_level = checkpoint_level\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n q, k, v, g, h, initial_state, A = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = ctx.BT, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gla_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_gla_bwd_kernel_dh[grid](\n q, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n dh.stride(1), dh.stride(2), dh.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return dh\n\n if ctx.checkpoint_level >= 1:\n g_org, g = g, torch.zeros_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n chunk_gla_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n\n if ctx.checkpoint_level > 1:\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=None\n )\n\n scale = ctx.scale\n dh = bwd_inner(\n q, g, do,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n scale=scale\n )\n dq = torch.empty_like(q, dtype=torch.float)\n dk = torch.empty_like(k, dtype=torch.float)\n dg = torch.empty_like(k, dtype=torch.float)\n dv = v.new_empty(NK, *v.shape)\n dA = q.new_zeros(B, H, T, BT)\n grid = (NK, NT, B * H)\n chunk_gla_bwd_kernel_inter[grid](\n k, v, h, g, A, do, dh, dq, dk, dv, dA,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0, dtype=dv.dtype)\n grid = (NK, NT * NC, B * H)\n chunk_gla_bwd_kernel_intra[grid](\n q, k, g, dA, dq, dk, dg,\n k.stride(1), k.stride(2), k.stride(3),\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n dq = dq.to(q.dtype)\n dk = dk.to(q.dtype)\n dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype)\n return dq, dk, dv, dg, None, None, None, None\n\n\ndef chunk_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 2\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert checkpoint_level in [0, 1, 2]\n if scale is None:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)\n return o, final_state\n", - "description_1": "Use triton language to implement various kernels for a custom attention mechanism. The kernels include a cumulative sum operation, forward and backward propagation operations. The forward function accepts inputs such as queries, keys, values, forget gates, scale factor, initial and final states, and computes an output and optionally the final hidden states. Backward function computes gradients with respect to the inputs. These functions are integrated into a PyTorch autograd-compatible function for efficient computation on GPU.", - "description_2": "Use triton language to create forward and backward kernels for an attention mechanism, implementing cumulative sums, data transformations, and matrix multiplications using PyTorch autograd.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom einops import rearrange\nfrom packaging import version\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ninv_ln2 = 1.44269504\n\n@triton.jit\ndef fused_chunk_gla_fwd_kernel(\n q, k, v, g, o, \n initial_state, final_state, \n s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n B, H, T, scale, \n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n DK: tl.constexpr, DV: tl.constexpr, \n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, \n CHECK: tl.constexpr\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n # make block pointers\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(0, tl.cdiv(T, BT)):\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n if CHECK and i == 0:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n else:\n b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n p_db += BT * DK\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_gla_bwd_kernel(\n q, k, v, g, do, dq, dk, dv, \n initial_state, s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n B, H, T, scale, \n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, \n DK: tl.constexpr, DV: tl.constexpr, \n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n # [BV, BK]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK \n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n # [DV, BT]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, DV]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [DV, DK]\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n # sync threads\n b_h = None\n tl.debug_barrier()\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n\n # cum = tl.zeros([BK], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n # [DK, BT]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, DV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)\n\n # inter-chunk\n # [DK, DV]\n if CHECK and i == 1:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n else:\n b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))\n b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)\n b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fwd_inner_chunk(\n q, k, g, A,\n s_qk_h, s_qk_t, s_qk_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr,\n):\n\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n\n p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n o_i = tl.arange(0, BT)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)\n p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)\n p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0) * scale\n gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)\n s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g)\n score = tl.sum(s, axis=1)\n score = tl.where(o_i <= i, score, 0)\n tl.store(p_A, score.to(p_A.dtype.element_ty))\n p_q += DK\n p_gq += DK\n p_A += BT\n\n\n@triton.jit\ndef bwd_inner_chunk(\n q, k, g, dA, dq, dk,\n s_qk_h, s_qk_t, s_qk_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr,\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n o_i = tl.arange(0, BT)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)\n p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)\n p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)\n p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)\n\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)\n score = tl.math.exp2(gq[None, :] - b_g)\n score = tl.where(o_i[:, None] <= i, score, 0)\n _dA = tl.load(p_dA)\n _dA = tl.where(o_i <= i, _dA, 0)\n b_dk += (_dA[:, None] * score * _q[None, :])\n b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)\n tl.store(p_dq, b_dq, mask=mask)\n p_q += DK\n p_dq += DK\n p_gq += DK\n p_dA += BT\n\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))\n", - "description_1": "Use triton language to define multiple kernels for processing tensors, where the kernels perform forward and backward computations for a transformer-like architecture with Gated Linear Attention (GLA). Each kernel has multiple parameters including tensors like query, key, value, gradients, etc., and various strides and block sizes used for efficient computation.", - "description_2": "Use triton language to create kernels for GLA in transformers, focusing on handling queries, keys, values, and their gradients efficiently with block pointers and custom parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\ninv_ln2 = 1.44269504\n\n# Kernel to compute the forward decay cumulative sum\n@triton.jit\ndef fwd_decay_cumsum(\n g,\n g_o, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n# Kernel to prepare qg and kg\n@triton.jit\ndef prepare_qg_kg(\n q,\n k,\n g,\n qg,\n kg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n# Kernel for backward decay global cumulative sum\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner,\n dq_inter,\n dk_inner,\n dk_inter,\n q, k, g, dg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n", - "description_1": "Use triton language to create three kernels: 1) fwd_decay_cumsum: Computes the cumulative sum of decays for given inputs. Parameters include input pointers and dimensions. 2) prepare_qg_kg: Prepares qg and kg tensors based on input q, k, g tensors and other parameters for scaling and transformation. 3) bwd_decay_global_cumsum: Calculates the backward cumulative sum of global decay using inner and inter-component derivatives for q and k.", - "description_2": "Use triton language to define kernels for forward and backward decay operations on tensors with specific scaling, masking, and cumulative operations.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_gla_fwd_kernel(\n q, k, v, gk, gv, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr,\n REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[None, :]) * DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[None, :]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[:, None]\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -DK if REVERSE else DK\n p_k += -DK if REVERSE else DK\n p_o += -DV if REVERSE else DV\n p_v += -DV if REVERSE else DV\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[None, :]) * DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_gla_bwd_kernel(\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + (i_k * BK + tl.arange(0, BK)[:, None]) * DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n h = h * _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n h = h * _gv[None, :]\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += -DK if REVERSE else DK\n p_v += -DV if REVERSE else DV\n p_q += -DK if REVERSE else DK\n p_do += -DV if REVERSE else DV\n p_dq += -DK if REVERSE else DK\n if USE_GK:\n p_gk += -DK if REVERSE else DK\n if USE_GV:\n p_gv += -DV if REVERSE else DV\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n if USE_GK:\n p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)\n if USE_GV:\n p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)\n\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n if USE_GK:\n _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)\n d_h *= _gk[:, None]\n if USE_GV:\n _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)\n d_h *= _gv[None, :]\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do += DV if REVERSE else -DV\n p_q += DK if REVERSE else -DK\n p_k += DK if REVERSE else -DK\n p_v += DV if REVERSE else -DV\n p_dk += DK if REVERSE else -DK\n p_dv += DV if REVERSE else -DV\n if USE_GK:\n p_gk += DK if REVERSE else -DK\n if USE_GV:\n p_gv += DV if REVERSE else -DV\n\nclass FusedRecurrentGLAFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n if scale is None:\n scale = d_head_qk ** -0.5\n if gk is not None:\n gk = gk.float().exp()\n if gv is not None:\n gv = gv.float().exp()\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=torch.float32)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_gla_fwd_kernel[grid](\n q, k, v, gk, gv, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n USE_GK=gk is not None,\n USE_GV=gv is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, gk, gv, initial_state, o = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=torch.float32)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk, dtype=torch.float32)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v, dtype=torch.float32)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_gla_bwd_kernel[grid](\n q, k, v, gk, gv, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n USE_GK=gk is not None,\n USE_GV=gv is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n if gk is not None:\n _dgk = dq * q.float() - dk * k.float()\n if ctx.reverse:\n dgk = _dgk.cumsum(-2)\n else:\n _dgk_cumsum = _dgk.cumsum(-2)\n dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum\n else:\n dgk = None\n\n if gv is not None:\n _dgv = do.float() * o.float() - dv * v.float()\n if ctx.reverse:\n dgv = _dgv.cumsum(-2)\n else:\n _dgv_cumsum = _dgv.cumsum(-2)\n dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum\n else:\n dgv = None\n\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None\n\ndef fused_recurrent_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n gk: torch.Tensor = None,\n gv: torch.Tensor = None,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n if initial_state is not None:\n initial_state = initial_state.detach()\n if causal:\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)\n return o, final_state\n else:\n assert initial_state is None\n assert output_final_state is False\n o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state, False)\n o_reversed, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state, True)\n return [o, o_reversed]\n", - "description_1": "Use triton language to define two kernels, `fused_recurrent_gla_fwd_kernel` and `fused_recurrent_gla_bwd_kernel`, for forward and backward passes of a recurrent neural network layer with gating mechanisms. The forward kernel computes a weighted sum of key and value vectors, optionally modulated by gate values and an initial hidden state. It stores the output and optionally the final state. The backward kernel computes gradients of queries, keys, and values based on the derivatives of the output and adjusts for gating. Both kernels involve block operations with respect to query, key, and value dimensions using triton's parallel programming capabilities. Define a `FusedRecurrentGLAFunction` which implements the forward and backward passes using these kernels and supports autograd. The `fused_recurrent_gla` function serves as a wrapper for ease of use, exposing the necessary parameters for operation.", - "description_2": "Use triton language to implement the forward and backward kernels of a gated recurrent network layer, supporting block-wise parallel computation across queries, keys, and values.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_fwd_kernel_h(\n x,\n g,\n gc,\n o,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + i_t * BT * D + o_d\n p_g = g + i_bh * T * D + i_t * BT * D + o_d\n p_gc = gc + i_bh * T * D + i_t * BT * D + o_d\n p_o = o + i_bh * T * D + i_t * BT * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n b_gc = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n if i_t == 0:\n b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n for i in range(0, BT):\n mask_t = mask & ((i_t * BT + i) < T)\n b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n b_gc = b_gc + b_g\n tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)\n\n p_x += D\n p_g += D\n p_gc += D\n p_o += D\n\n\n@triton.jit\ndef chunk_hgrn_fwd_kernel_o(\n gc,\n o,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(1, tl.cdiv(T, BT)):\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n # [BD,]\n b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)\n # [BT, BD]\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_o = b_o + tl.exp(b_gc) * b_h0[None, :]\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef chunk_hgrn_bwd_kernel_h(\n g,\n gc,\n dx,\n do,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n BC = min(BT, T - i_t * BT)\n NT = tl.num_programs(1)\n\n p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d\n\n if i_t == NT - 1:\n b_gc = tl.zeros([BD], dtype=tl.float32)\n else:\n b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for _ in range(BC - 1, -1, -1):\n tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)\n\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n\n b_gc = b_gc + b_g\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_gc -= D\n p_dx -= D\n p_do -= D\n\n\n@triton.jit\ndef chunk_hgrn_bwd_kernel_o(\n g,\n gc,\n o,\n dx,\n dg,\n s_h,\n s_t,\n s_d,\n T: tl.constexpr,\n D: tl.constexpr,\n BT: tl.constexpr,\n BD: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))\n p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))\n\n # [BD,]\n mask_t = mask & ((i_t + 1) * BT < T)\n b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)\n # [BT, BD]\n b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)\n b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)\n b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)\n b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)\n b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]\n b_dg = b_o * b_dx * tl.exp(b_g)\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n o = torch.empty_like(x, dtype=torch.float)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_fwd_kernel_h[grid](\n x, g, gc, o, initial_state,\n T, D,\n BT=BT,\n USE_INITIAL_STATE=initial_state is not None\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_fwd_kernel_o[grid](\n gc, o,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n final_state = None\n if output_final_state:\n final_state = o[:, :, -1].clone()\n o = o.to(x.dtype)\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n BT, BD = 128, min(64, triton.next_power_of_2(D))\n num_warps = 8 if BD == 64 else 4\n\n gc = torch.empty_like(g, dtype=torch.float)\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)\n chunk_hgrn_bwd_kernel_h[grid](\n g, gc, dx, do,\n T, D,\n BT=BT\n )\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n chunk_hgrn_bwd_kernel_o[grid](\n g, gc, o, dx, dg,\n o.stride(1), o.stride(2), o.stride(3),\n T, D,\n BT=BT, BD=BD,\n num_warps=num_warps\n )\n if initial_state is not None:\n dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()\n\n return dx, dg, None, None\n\n\ndef chunk_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement forward and backward kernels for the chunk-wise HGRN. The forward kernel `chunk_hgrn_fwd_kernel_h` takes inputs x (input tensor), g (gating tensor), gc (cumulative gating tensor), o (output tensor), h0 (initial hidden state), and several compile-time constants (T, D, BT, BD, USE_INITIAL_STATE) to perform computation over each chunk. The kernel computes updates in hidden states and output for each time step using exponential smoothing with gating values. The output is stored in o. The `chunk_hgrn_fwd_kernel_o` further processes the cumulative gating for subsequent chunks. The backward kernels `chunk_hgrn_bwd_kernel_h` and `chunk_hgrn_bwd_kernel_o` compute gradients with respect to the input and gate tensors in a similar manner by reversing the forward computations. The kernels are autotuned for different configurations to optimize performance.", - "description_2": "Use triton language to develop efficient kernels for chunk-wise HGRN forward and backward passes, leveraging Triton’s autotuning for performance optimization.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_hgrn_fwd_kernel(\n x,\n g,\n o,\n h0,\n ht,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + o_d\n p_g = g + i_bh * T * D + o_d\n p_o = o + i_bh * T * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * D + o_d\n b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)\n for _ in range(0, T):\n b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_h = tl.exp(b_g) * b_h + b_x\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)\n\n p_x += D\n p_g += D\n p_o += D\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * D + o_d\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)\n\n@triton.jit\ndef fused_recurrent_hgrn_bwd_kernel(\n g,\n o,\n dx,\n dg,\n do,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_g = g + (i_bh * T + T - 1) * D + o_d\n p_o = o + (i_bh * T + T - 2) * D + o_d\n p_dx = dx + (i_bh * T + T - 1) * D + o_d\n p_dg = dg + (i_bh * T + T - 1) * D + o_d\n p_do = do + (i_bh * T + T - 1) * D + o_d\n\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for i in range(T - 1, -1, -1):\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n if i > 0:\n b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)\n elif USE_INITIAL_STATE:\n b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n else:\n b_o = tl.zeros([BD], dtype=tl.float32)\n\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dh = b_dh * tl.exp(b_g)\n b_dg = b_dh * b_o\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_o -= D\n p_dx -= D\n p_dg -= D\n p_do -= D\n\n\nclass FusedRecurrentHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n\n final_state = None\n if output_final_state:\n final_state = x.new_empty(B, H, D)\n\n o = torch.empty_like(x)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_fwd_kernel[grid](\n x, g, o, initial_state, final_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_bwd_kernel[grid](\n g, o, dx, dg, do, initial_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n )\n\n return dx, dg, None, None\n\n\ndef fused_recurrent_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement two kernels for fused recurrent computation. The forward kernel takes 9 parameters: x, g, o, h0, ht, and 4 constant expressions: T, D, BD, USE_INITIAL_STATE, and STORE_FINAL_STATE, which control the dimensions and the use of initial and final states. The backward kernel uses 8 parameters: g, o, dx, dg, do, h0, and 3 constant expressions: T, D, BD, USE_INITIAL_STATE, for calculating gradients. The operation computes recurrent updates in both forward and backward passes, handling initial and final states if specified.", - "description_2": "Use triton language to develop a fused recurrent computation kernel with forward and backward passes, handling initial and final states, using constant expressions for dimensions.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_linear_attn_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n o, # output [B, H, L, D_head_V]\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n final_state, # final state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n\n # [BT, BT]\n m_s = o_i[:, None] >= o_i[None, :]\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n # make block pointers\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n # [BT, BT]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n # [BT, BV]\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_linear_attn_bwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n do, # gradient of output [B, H, L, D_head_V]\n dq, # gradient of query [NV, B, H, L, D_head_K]\n dk, # gradient of key [NV, B, H, L, D_head_K]\n dv, # gradient of value [NK, B, H, L, D_head_V]\n\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch_size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr,\n CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n\n m_s = o_i[:, None] >= o_i[None, :]\n # [BV, BK]\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [DV, BT]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BT, DV]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n # [BT, BT]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0)\n # [BT, DK]\n b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n # [DV, DK]\n if CHECK and i == 0:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)\n b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)\n b_dq *= scale\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n # sync threads\n b_h = None\n tl.debug_barrier()\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n m_s = o_i[:, None] <= o_i[None, :]\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n # [DK, BT]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BT, DK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, DV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n\n # b_dd = (b_do]).to(b_do.dtype)\n\n # [BT, BT]\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)\n # [BT, BT]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)\n # [BT, DK]\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n # [BT, DV]\n b_dv = tl.dot(b_s, b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)\n b_dh += tl.dot(b_q, b_do, allow_tf32=False)\n\n tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass FusedChunkLinearAttentionFunction(torch.autograd.Function):\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n ctx.scale = scale\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = ctx.scale\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\n\ndef fused_chunk_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n if scale == -1:\n scale = q.shape[-1] ** -0.5\n o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement forward and backward kernels for a fused chunked linear attention mechanism. The forward kernel processes query, key, and value tensors and outputs a result tensor and optionally a final state tensor. The backward kernel computes the gradients for query, key, and value based on the gradient of the output and optionally an initial state. This is implemented using triton for high-performance computation on GPUs, leveraging block pointers and efficient memory operations.", - "description_2": "Use triton language to create a high-performance fused chunked linear attention operator, including both forward and backward passes, for GPU acceleration.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_recurrent_linear_attn_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n o, # output [B, H, L, D_head_V]\n initial_state,\n final_state, # final hidden state [B, H, D_head_K, D_head_V]\n\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n STORE_FINAL_STATE: tl.constexpr, # whether to store final state\n):\n # indices\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n\n h += _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n\n@triton.jit\ndef fused_recurrent_linear_attn_bwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n\n do, # gradient of output [B, H, L, D_head_V]\n dq, # gradient of query [NV, B, H, L, D_head_K]\n dk, # gradient of key [NV, B, H, L, D_head_K]\n dv, # gradient of value [NK, B, H, L, D_head_V]\n\n # initial hidden state initialization [B, H, D_head_K, D_head_V]\n initial_state,\n\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n\n B, # batch_size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n USE_INITIAL_STATE: tl.constexpr, # whether to use initial state\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n\n h += _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += DK\n p_do += DV\n p_v += DV\n p_dq += DK\n\n # sync threads\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + (T - 1) * DK\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + (T - 1) * DV\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n\n\nclass FusedRecurrentLinearAttentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_linear_attn_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_linear_attn_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq, dk, dv, None, None\n\n\ndef fused_recurrent_linear_attn(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n normalize: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentLinearAttentionFunction.apply(\n q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a fused recurrent linear attention forward and backward kernel. The forward kernel takes 20 parameters: q, k, v, o, initial_state, final_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BK, BV, DK, DV, USE_INITIAL_STATE, STORE_FINAL_STATE. The backward kernel takes 21 parameters: q, k, v, do, dq, dk, dv, initial_state, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, B, H, T, scale, BK, BV, DK, DV, USE_INITIAL_STATE. The forward and backward functions work together to compute the attention mechanism while optionally using and storing states.", - "description_2": "Use triton language to create forward and backward kernels for fused recurrent linear attention, handling inputs q, k, v with optional state usage. The forward kernel computes attention outputs and optionally final states, while the backward kernel computes gradients for q, k, v based on the gradients of the output.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, k, v, o, z,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n for _ in range(0, i_c * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // NV\n i_v = i_kv % NV\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelBasedFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, device=q.device)\n z = torch.empty(NK, batch_size, n_heads, seq_len,\n device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n else:\n o = o\n return o.to(q.dtype)\n", - "description_1": "Use triton language to implement forward and backward kernels for parallel rebased operations. The forward kernel 'parallel_rebased_fwd_kernel' has parameters: q (query tensor), k (key tensor), v (value tensor), o (output tensor), z (normalizer), and various strides and block sizes for tensor dimensions (B, H, T, etc.). The backward kernel 'parallel_rebased_bwd_kernel' utilizes saved tensors from forward pass and computes gradients for q, k, and v. The class 'ParallelBasedFunction' applies these kernels in the forward and backward methods, allowing use in autograd. These functions require specific memory layouts and compute scales.", - "description_2": "Use triton language to implement parallel forward and backward kernels for efficient tensor operations in autograd. These kernels operate on query, key, and value tensors with specific stride and block size configurations, computing outputs and gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef chunk_retention_fwd_kernel_h(\n k, v, h, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if i_t == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_retention_fwd_kernel_o(\n q, k, v, h, o,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_i = tl.math.exp2((o_i + 1) * b_b)\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n b_s *= d_s\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dh(\n q, do, dh,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False)\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dqkv(\n q, k, v, h, do, dh, dq, dk, dv,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t, scale,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n n_bh = tl.num_programs(2)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n d_q = (d_q * scale).to(d_q.dtype)\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s)\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n \n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False)\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_ds = (b_ds * d_s).to(b_q.dtype)\n b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\nclass ChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_fwd_kernel_h[grid](\n k, v, h, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_retention_fwd_kernel_o[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h)\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_ht=None):\n q, k, v, h = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_bwd_kernel_dh[grid](\n q, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_retention_bwd_kernel_dqkv[grid](\n q, k, v, h, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\ndef chunk_retention(q, k, v, initial_state=None, output_final_state=False):\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to define forward and backward kernels for chunk retention operation. The forward kernels compute intermediate states and the output by processing input tensors q, k, v, and optional initial state. The backward kernels compute gradients for q, k, v using computed output gradient do. The function chunk_retention serves as the interface by handling inputs, invoking kernels with proper grid sizes, and returning the result.", - "description_2": "Use triton language to perform parallel computation on tensors with dimensions specified. Implement forward and backward operations for neural network layers, efficiently computing necessary state and gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef fused_chunk_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n\n d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))\n\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n \n NT = tl.cdiv(T, BT)\n for i in range(0, NT):\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_k.dtype)\n\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if CHECK and i == 0:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]\n if i == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n p_q = tl.advance(p_q, (BT, 0))\n p_k = tl.advance(p_k, (0, BT))\n p_v = tl.advance(p_v, (BT, 0))\n p_o = tl.advance(p_o, (BT, 0))\n\n if STORE_FINAL_STATE:\n p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef fused_chunk_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n B, H, T, scale,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, CHECK: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n o_i = tl.arange(0, BT)\n b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))\n d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)\n d_b = tl.math.exp2(BT * b_b)\n\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n\n for i in range(0, tl.cdiv(T, BT)):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))\n\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n b_dq = tl.dot(b_ds, b_k, allow_tf32=False)\n if CHECK and i == 0:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n else:\n b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)\n b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)\n\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_h = None\n tl.debug_barrier()\n d_s = tl.trans(d_s)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i in range(1, tl.cdiv(T, BT) + 1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))\n p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))\n \n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dd = (b_do * d_q[:, None]).to(b_do.dtype)\n\n b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)\n b_ds = (b_ds * d_s).to(b_k.dtype)\n\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)\n b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n if CHECK and i == 1:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n else:\n b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]\n b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]\n b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\nclass FusedChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, initial_state, output_final_state):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)\n else:\n final_state = None\n\n CHECK = True\n if version.parse(triton.__version__) < version.parse('2.2.0'):\n import warnings\n warnings.warn(\n \"Triton<2.2.0 detected for running this kernel, \"\n \"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) \"\n \"that lead to significant precision loss. \"\n \"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. \"\n \"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible).\"\n )\n CHECK = True\n\n grid = (NV, NK, batch_size * n_heads)\n fused_chunk_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n CHECK=CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n ctx.CHECK = CHECK\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BT = 64\n BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 4\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_chunk_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n CHECK=ctx.CHECK,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\n\ndef fused_chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement fused forward and backward kernels for a chunk retention mechanism in a transformer model. The forward kernel computes the attention output and optionally updates the state for each chunk of the input sequence. It requires parameters such as query, key, value tensors, initial and final states, strides, batch size, number of heads, sequence length, scaling factor, and block sizes. The backward kernel computes gradients for the query, key, and value tensors using similar parameters as the forward kernel.", - "description_2": "Use triton language to implement fused kernels for chunk retention in transformers, calculating both forward attention outputs and backward gradients with given tensor parameters and block sizes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n@triton.jit\ndef fused_recurrent_rwkv4_forward_kernel(\n w_ptr, w_s_c, u_ptr, u_s_c, k_ptr, k_s_b, k_s_t, k_s_c,\n v_ptr, v_s_b, v_s_t, v_s_c, state_ptr, state_s_b, state_s_abe,\n state_s_c, wkv_ptr, wkv_s_b, wkv_s_t, wkv_s_c, state_out_ptr,\n state_out_s_b, state_out_s_abe, state_out_s_t, state_out_s_c,\n chans, tsz, BLOCK_SIZE_C: tl.constexpr,\n):\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n wkv_ptr = wkv_ptr + b_idx * wkv_s_b\n alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b\n beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe\n eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe\n alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)\n \n for t in range(tsz):\n kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)\n ukt = u + kt\n tau = tl.maximum(ukt, eps)\n e1a = tl.exp(eps - tau)\n e2a = tl.exp(ukt - tau)\n wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)\n tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)\n w_eps = w + eps\n eps = tl.maximum(w_eps, kt)\n e1b = tl.exp(w_eps - eps)\n e2b = tl.exp(kt - eps)\n alpha = e1b * alpha + e2b * vt\n beta = e1b * beta + e2b\n tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)\n tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)\n tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)\n\ndef fused_recurrent_rwkv4_forward(\n w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor,\n) -> tuple[Tensor, Tensor]:\n (bsz, tsz, chans) = k.shape\n wkvs = k.new_empty(bsz, tsz, chans)\n state_out = k.new_empty(bsz, 3, tsz, chans)\n block_size_c = get_block_size_c(chans)\n \n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n \n fused_recurrent_rwkv4_forward_kernel[grid](\n w, w.stride(0), u, u.stride(0), k, k.stride(0), k.stride(1), k.stride(2),\n v, v.stride(0), v.stride(1), v.stride(2), state, state.stride(0), state.stride(1),\n state.stride(3), wkvs, wkvs.stride(0), wkvs.stride(1), wkvs.stride(2),\n state_out, state_out.stride(0), state_out.stride(1), state_out.stride(2),\n state_out.stride(3), chans, tsz, BLOCK_SIZE_C=block_size_c,\n )\n \n state_out = torch.cat((state, state_out), dim=2)\n return wkvs, state_out\n\n@triton.jit\ndef fused_recurrent_rwkv4_backward_kernel(\n w_ptr, w_s_c, u_ptr, u_s_c, k_ptr, k_s_b, k_s_t, k_s_c,\n v_ptr, v_s_b, v_s_t, v_s_c, state_ptr, state_s_b, state_s_abe, state_s_t,\n state_s_c, gwkv_ptr, gwkv_s_b, gwkv_s_t, gwkv_s_c, gstate_out_ptr,\n gstate_out_s_b, gstate_out_s_abe, gstate_out_s_c, gw_ptr, gw_s_c, gu_ptr,\n gu_s_c, gk_ptr, gk_s_b, gk_s_t, gk_s_c, gv_ptr, gv_s_b, gv_s_t, gv_s_c,\n gstate_ptr, gstate_s_b, gstate_s_abe, gstate_s_c, tsz, chans,\n BLOCK_SIZE_C: tl.constexpr,\n):\n b_idx = tl.program_id(0)\n c_idx = tl.program_id(1)\n cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)\n cmask = cs < chans\n k_ptr = k_ptr + b_idx * k_s_b\n v_ptr = v_ptr + b_idx * v_s_b\n alpha_ptr = state_ptr + b_idx * state_s_b\n beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe\n eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe\n gk_ptr = gk_ptr + b_idx * gk_s_b\n gv_ptr = gv_ptr + b_idx * gv_s_b\n gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b\n galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b\n gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe\n geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe\n galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)\n w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)\n u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)\n gw = tl.zeros_like(w)\n gu = tl.zeros_like(u)\n alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n \n for t in range(tsz):\n tc = tsz - t - 1\n kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)\n vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)\n alpha_curr = alpha_prev\n beta_curr = beta_prev\n eps_curr = eps_prev\n alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)\n ukt = u + kt\n tau = tl.maximum(ukt, eps_prev)\n e1 = tl.exp(eps_prev - tau)\n e2 = tl.exp(ukt - tau)\n euke = tl.exp(ukt + eps_prev - 2 * tau)\n denom = e1 * beta_prev + e2\n denom_sq = denom * denom\n gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)\n guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq\n gu += guk\n gk = guk\n gv = gwkvt * e2 / denom\n galpha_wkv = gwkvt * e1 / denom\n gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq\n geps_wkv_denom = e1 * beta_prev + e2\n geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)\n e1 = tl.exp(w + eps_prev - eps_curr)\n e2 = tl.exp(kt - eps_curr)\n galpha_we = galpha * e1 * alpha_prev\n gw += galpha_we\n gk += galpha * e2 * vt\n gv += galpha * e2\n geps += galpha * -alpha_curr\n gbeta_we = gbeta * e1 * beta_prev\n gw += gbeta_we\n gk += gbeta * e2\n geps += gbeta * -beta_curr\n geps_mask = w + eps_prev > kt\n geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))\n gw += geps_we\n gk += tl.where(geps_mask, tl.zeros_like(geps), geps)\n tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)\n tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)\n galpha = galpha * e1 + galpha_wkv\n gbeta = gbeta * e1 + gbeta_wkv\n geps = galpha_we + gbeta_we + geps_we + geps_wkv\n\n galpha_ptr = gstate_ptr + b_idx * gstate_s_b\n gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe\n geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe\n tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)\n tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)\n tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)\n gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)\n gw_temp += gw\n tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)\n gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)\n gu_temp += gu\n tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)\n\ndef fused_recurrent_rwkv4_backward(\n w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor,\n grad_wkv: Tensor, grad_state: Tensor,\n) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n bsz, tsz, chans = k.shape\n gw = torch.zeros_like(w)\n gu = torch.zeros_like(u)\n gk = torch.empty_like(k)\n gv = torch.empty_like(v)\n gstate = k.new_empty(bsz, 3, 1, chans)\n block_size_c = get_block_size_c(chans)\n\n def grid(meta: dict[str, Any]) -> tuple[int, ...]:\n return (bsz, triton.cdiv(chans, meta[\"BLOCK_SIZE_C\"]))\n\n fused_recurrent_rwkv4_backward_kernel[grid](\n w, w.stride(0), u, u.stride(0), k, k.stride(0), k.stride(1), k.stride(2),\n v, v.stride(0), v.stride(1), v.stride(2), state, state.stride(0),\n state.stride(1), state.stride(2), state.stride(3), grad_wkv, grad_wkv.stride(0),\n grad_wkv.stride(1), grad_wkv.stride(2), grad_state, grad_state.stride(0),\n grad_state.stride(1), grad_state.stride(3), gw, gw.stride(0), gu, gu.stride(0),\n gk, gk.stride(0), gk.stride(1), gk.stride(2), gv, gv.stride(0), gv.stride(1),\n gv.stride(2), gstate, gstate.stride(0), gstate.stride(1), gstate.stride(3),\n tsz, chans, BLOCK_SIZE_C=block_size_c,\n )\n\n return gw, gu, gk, gv, gstate\n", - "description_1": "Use triton language to define kernels for the RWKV model. The forward kernel takes 25 arguments: 9 tensor pointers (w_ptr, u_ptr, k_ptr, v_ptr, state_ptr, wkv_ptr, state_out_ptr), their respective strides (w_s_c, u_s_c, k_s_b, k_s_t, k_s_c, v_s_b, v_s_t, v_s_c, state_s_b, state_s_abe, state_s_c, wkv_s_b, wkv_s_t, wkv_s_c, state_out_s_b, state_out_s_abe, state_out_s_t, state_out_s_c), and constant parameters (chans, tsz, BLOCK_SIZE_C). It performs element-wise operations across batches and channels to compute WKV and update state tensors. The backward kernel takes 39 arguments, similar to the forward kernel, plus additional pointers and strides for gradients. It computes gradients for inputs and updates state gradient tensors using pre-computed WKV gradients.", - "description_2": "Use triton language to implement RWKV model kernels for forward and backward passes, handling batch and channel dimensions with element-wise operations. Ensure gradient computation for model parameters and state updates.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_rwkv6_fwd_kernel_cum(\n s,\n o,\n o_minus_s,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef post_process_grad(\n q,\n k,\n v,\n u,\n do,\n dk,\n dq,\n du,\n scale,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n H,\n T: tl.constexpr,\n BT: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n i_t, i_bh = tl.program_id(0), tl.program_id(1)\n i_h = i_bh % H\n\n # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))\n p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_u = tl.load(p_u, boundary_check=(0,))\n\n b_vdo = tl.sum(b_v * b_do, axis=1)\n b_du = b_vdo[:, None] * b_k * b_q * scale\n b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale\n b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale\n\n b_dq += tl.load(p_dq, boundary_check=(0, 1))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n\n b_dk += tl.load(p_dk, boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):\n q = r # alias\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = 64, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_rwkv6_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float)\n\n g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g_org = g_org.view(B, H, NT, BT, -1)\n # g = g_org.cumsum(-2).view(B, H, T, -1)\n # gs = g - g_org\n chunk_rwkv6_fwd_kernel_cum[grid](\n g_org, g, gs,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=final_state if final_state is not None else None\n )\n A = q.new_zeros(NK, B, H, T, BT)\n grid = (NK, NT * NC * NC, B * H)\n chunk_rwkv6_fwd_kernel_intra[grid](\n q, k, g, gs, u, A,\n k.stride(1), k.stride(2), k.stride(3),\n scale,\n H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,\n num_warps=num_warps,\n num_stages=num_stages\n )\n A = A.sum(0, dtype=A.dtype)\n o = torch.empty_like(v)\n\n grid = (NV, NT, B * H)\n chunk_rwkv6_fwd_kernel_inter[grid](\n q, v, gs, h, o, A,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n if checkpoint_level > 1:\n del h\n h, initial_state = None, None\n del g, gs\n ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A)\n ctx.BT = BT\n ctx.scale = scale\n ctx.checkpoint_level = checkpoint_level\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n q, k, v, g, u, h, initial_state, A = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT, BC = ctx.BT, 16\n BK = min(64, triton.next_power_of_2(K))\n BV = min(64, triton.next_power_of_2(V))\n NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)\n NK = triton.cdiv(K, BK)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_rwkv6_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n dh = q.new_empty(B, H, NT * K, V)\n dh0 = torch.empty_like(h0) if h0 is not None else None\n grid = (NK, NV, B * H)\n chunk_rwkv6_bwd_kernel_dh[grid](\n q, g, gs, do, dh, dh0,\n q.stride(1), q.stride(2), q.stride(3),\n do.stride(1), do.stride(2), do.stride(3),\n dh.stride(1), dh.stride(2), dh.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=h0 is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return dh, dh0\n\n # recompute cumulative log decays.\n g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)\n def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_rwkv6_fwd_kernel_cum[grid](\n g_org, g, gs,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=K, BT=BT\n )\n\n # rerun the forward pass to get h if checkpoint_level >= 1\n if ctx.checkpoint_level == 1:\n h = fwd_inner(\n q=q, k=k, v=v, g=g,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n h0=initial_state if initial_state is not None else None,\n ht=None\n )\n\n scale = ctx.scale\n dh, dh0 = bwd_inner(\n q, g, gs, initial_state, do,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n scale=scale\n )\n dq = torch.empty_like(q, dtype=torch.float)\n dk = torch.empty_like(k, dtype=torch.float)\n dv = v.new_empty(NK, *v.shape)\n dA = q.new_zeros(B, H, T, BT)\n grid = (NK, NT, B * H)\n chunk_rwkv6_bwd_kernel_inter[grid](\n k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0, dtype=dv.dtype)\n grid = (NK, NT * NC, B * H)\n chunk_rwkv6_bwd_kernel_intra[grid](\n q, k, g, gs, dA, dq, dk,\n k.stride(1), k.stride(2), k.stride(3),\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n # TODO: fuse?\n dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]\n dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)\n dg = chunk_reversed_cumsum_fwd(dg).to(g)\n # equivalent to the following pytorch code.\n # du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)\n # dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])\n # dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])\n BT = 64\n grid = (triton.cdiv(T, BT), B * H)\n du = torch.empty_like(g, dtype=torch.float)\n post_process_grad[grid](\n q, k, v, u, do, dk, dq, du, scale,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3), H=H,\n T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),\n num_warps=4\n )\n du = du.sum([0, 2])\n return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None\n\n\ndef chunk_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor,\n u: torch.Tensor,\n scale: Optional[int] = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n checkpoint_level: Optional[int] = 0\n) -> Tuple[torch.Tensor, torch.Tensor]:\n r\"\"\"\n Args:\n r (torch.Tensor):\n reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.\n k (torch.Tensor):\n keys of shape `(B, H, T, K)`\n v (torch.Tensor):\n values of shape `(B, H, T, V)`\n w (torch.Tensor):\n data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.\n u (torch.Tensor):\n bonus of shape `(H, K)`\n scale (Optional[int]):\n Scale factor for the RWKV6 attention scores.\n If not provided, it will default to `1 / sqrt(K)`. Default: `None`.\n initial_state (Optional[torch.Tensor]):\n Initial state of shape `(B, H, K, V)`. Default: `None`.\n output_final_state (Optional[bool]):\n Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.\n checkpoint_level (Optional[int]):\n Checkpointing level; higher values will save more memories and do more recomputations during backward.\n Default: `0`:\n - Level `0`: store forward hidden states for backprop.\n - Level `1`: recompute the forward hidden states during backward.\n \"\"\"\n assert checkpoint_level in [0, 1]\n if scale is None:\n scale = r.shape[-1] ** -0.5\n o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)\n return o, final_state\n", - "description_1": "Use triton language to implement a forward and backward pass of a RWKV (Receptance Weight Key Value) neural network function. The function takes tensors r, k, v, g, u, along with optional parameters scale, initial_state, output_final_state, and checkpoint_level to perform matrix operations and store necessary gradients for backpropagation. It uses triton kernels to execute operations efficiently on GPUs.", - "description_2": "Use triton language to develop optimized forward and backward passes of a specific neural network layer, leveraging custom kernels for GPU efficiency.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom fla.ops.utils import chunk_reversed_cumsum_fwd\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel(\n q, # query [B, H, T, K]\n k, # key [B, H, T, K]\n v, # value [B, H, T, V]\n w, # log gate [B, H, T, K]\n u, # bonus [B, H, K]\n o, # output [B, H, T, V]\n h0, # initial hidden state initialization [B, H, K, V]\n ht, # final hidden state [B, H, K, V]\n s_k_h, s_v_h, scale, \n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BK: tl.constexpr, BV: tl.constexpr, \n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)\n b_w = tl.exp(b_w)\n b_kv = b_k[None, :] * b_v[:, None]\n b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n b_h = b_h * b_w[None, :]\n b_h += b_kv\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_rwkv6_bwd_kernel_dq(\n k, v, w, u, do, dq, dq_aux, h0,\n s_k_h, s_v_h, scale, \n B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, BK: tl.constexpr, \n BV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, \n USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_kv = b_k[None, :] * b_v[:, None]\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)\n b_w = tl.exp(b_w)\n h_q = b_h * b_do[:, None]\n b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0)\n b_dq *= scale\n b_dq_aux = tl.sum(h_q, axis=0)\n b_h = b_h * b_w[None, :]\n b_h += b_kv\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)\n tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk)\n p_k += -K if REVERSE else K\n p_do += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n p_dq += -K if REVERSE else K\n p_dq_aux += -K if REVERSE else K\n\n@triton.jit\ndef fused_recurrent_rwkv6_bwd_kernel_dkv(\n q, k, v, w, u, do, dk, dk_aux, dv, dh0,\n s_k_h, s_v_h, scale, B, H, T, BK: tl.constexpr, BV: tl.constexpr, \n K: tl.constexpr, V: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)\n\n for _ in range(T-1, -1, -1):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n b_dkv = b_q[:, None] * b_do[None, :]\n b_dk = tl.sum(b_dh * b_v[None, :], axis=1)\n tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk)\n b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1)\n b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0)\n\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)\n b_dh *= tl.exp(b_w)[:, None]\n b_dh += b_dkv\n\n p_q += K if REVERSE else -K\n p_k += K if REVERSE else -K\n p_v += V if REVERSE else -V\n p_w += K if REVERSE else -K\n p_do += V if REVERSE else -V\n p_dk += K if REVERSE else -K\n p_dk_aux += K if REVERSE else -K\n p_dv += V if REVERSE else -V\n\n if USE_INITIAL_STATE:\n p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)\n\nclass FusedRecurrentRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):\n q = r\n B, H, T, K, V = *q.shape, v.shape[-1]\n\n BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V)\n else:\n final_state = None\n\n o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_fwd_kernel[grid](\n q, k, v, w, u, o, initial_state, final_state,\n k.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, w, u, initial_state, o)\n ctx.scale = scale\n ctx.reverse = reverse\n if final_state is not None:\n final_state = final_state.detach()\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_final_state=None):\n q, k, v, w, u, initial_state, o = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n scale = ctx.scale\n\n BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32)\n dq_aux = torch.empty_like(dq)\n grid = (NV, NK, B * H)\n\n fused_recurrent_rwkv6_bwd_kernel_dq[grid](\n k, v, w, u, do, dq, dq_aux, initial_state,\n q.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n )\n dq = dq.sum(0).to(q)\n dq_aux = dq_aux.sum(0)\n\n BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n\n dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32)\n dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32)\n dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_bwd_kernel_dkv[grid](\n q, k, v, w, u, do, dk, dk_aux, dv, dh0,\n q.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n REVERSE=ctx.reverse,\n )\n dk = dk.sum(0).to(k)\n dv = dv.sum(0).to(v)\n dk_aux = dk_aux.sum(0)\n\n dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1]\n dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)\n dw = chunk_reversed_cumsum_fwd(dw).to(w)\n\n du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u)\n return dq, dk, dv, dw, du, None, dh0, None, None\n\n\ndef fused_recurrent_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n w: torch.Tensor,\n u: torch.Tensor,\n scale: int = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n causal: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = r.shape[-1] ** -0.5\n o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement a series of kernels for a fused recurrent RWKV6 forward and backward pass. The forward kernel accepts 19 parameters: q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale, B, H, T, K, V, BK, BV, USE_INITIAL_STATE, STORE_FINAL_STATE, REVERSE. It computes an attention-like operation with state management. The backward kernel for dq accepts 21 parameters: k, v, w, u, do, dq, dq_aux, h0, s_k_h, s_v_h, scale, B, H, T, BK, BV, K, V, USE_INITIAL_STATE, REVERSE. The backward kernel for dkv accepts 23 parameters: q, k, v, w, u, do, dk, dk_aux, dv, dh0, s_k_h, s_v_h, scale, B, H, T, BK, BV, K, V, USE_INITIAL_STATE, REVERSE. It computes gradients with respect to the input tensors.", - "description_2": "Use triton language to implement a forward and backward kernel for fused recurrent RWKV6 that performs tensor operations for an attention mechanism with optional initial and final state handling.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_h(\n k,\n v,\n h,\n g,\n initial_state, \n final_state, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,\n (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n b_h *= tl.math.exp2(b_g_last)\n b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))\n b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(\n final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q,\n k,\n v,\n h,\n g,\n o,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(\n k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_o = b_o * tl.math.exp2(b_g)[:, None]\n b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dh(\n q,\n g,\n do,\n dh,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,\n (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +\n i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkv(\n q,\n k,\n v,\n h,\n g,\n do,\n dh,\n dq,\n dk,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False)\n p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)\n b_g = tl.load(p_g)\n b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)\n mask = tl.math.exp2(b_g[None, :] - b_g[:, None])\n mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)\n b_s = b_s * mask\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(\n v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),\n (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),\n (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),\n (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \\\n tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n\n b_dq = b_dq * tl.math.exp2(b_g)[:, None]\n b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]\n b_ds = b_ds * tl.trans(mask)\n b_ds = b_ds.to(b_k.dtype)\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),\n (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass SimpleGLAFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd\n def forward(ctx, q, k, v, g, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(\n 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n assert T % BT == 0, 'sequence length must be divisible by BT'\n g = g.reshape(B, H, -1, BT)\n g = g.cumsum(-1) * 1.44269504\n g = g.reshape(B, H, -1)\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_fwd_kernel_h[grid](\n k, v, h, g, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h, g)\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd\n def backward(ctx, do, d_ht=None):\n q, k, v, h, g = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(\n 32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_simple_gla_bwd_kernel_dh[grid](\n q, g, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n chunk_simple_gla_bwd_kernel_dqkv[grid](\n q, k, v, h, g, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n dg = (dq * q - dk * k).sum(-1)\n\n def rev_cumsum(x):\n cumsum_x = x.cumsum(-1)\n rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x\n return rev_cumsum_x + x\n dg = rev_cumsum(dg)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None\n\n\ndef chunk_simple_gla(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n g: torch.Tensor, \n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n g = g.float()\n o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)\n return o, final_state\n", - "description_1": "Use triton language to implement forward and backward kernels for a custom attention mechanism with chunking. The forward kernels take in tensors q, k, v, g, and optionally initial and final states, along with various strides and dimensions for tensor manipulation. The backward kernels compute gradients for q, k, v, g based on the provided forward output and a gradient tensor. They all involve block-level operations and tensor contractions with parameters H, T, K, V, BT, BK, BV, and NT.", - "description_2": "Use triton language to implement a series of custom kernels for forward and backward passes of an attention-like mechanism, utilizing tensor chunking and block-level operations, designed to efficiently compute outputs and gradients with given tensor dimensions and strides.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_fwd_kernel(\n s,\n z,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_cumsum_bwd_kernel(\n ds,\n dz,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)\n\n b_ds = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)\n tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_ds += tl.sum(b_dz, 0)\n\ndef chunk_cumsum_fwd(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_cumsum_fwd_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n\ndef chunk_cumsum_bwd(\n dz: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = dz.shape\n BS = 32\n\n dtype = dtype or dz.dtype\n grid = (triton.cdiv(S, BS), B * H)\n ds = torch.empty_like(dz, dtype=dtype)\n chunk_cumsum_bwd_kernel[grid](\n ds, dz,\n ds.stride(1), ds.stride(2), ds.stride(3),\n T=T, S=S, BS=BS\n )\n return ds\n", - "description_1": "Use triton language to implement a forward and backward cumulative sum operation on a 4D tensor. The forward kernel 'chunk_cumsum_fwd_kernel' takes 8 parameters: input tensor 's', output tensor 'z', strides 's_s_h', 's_s_t', 's_s_d', and constants 'T', 'S', 'BT', 'BS'. It computes the cumulative sum along the last dimension in chunks. The backward kernel 'chunk_cumsum_bwd_kernel' takes the same parameters but computes the gradient of the cumulative sum. The functions 'chunk_cumsum_fwd' and 'chunk_cumsum_bwd' are Python wrappers that prepare the grid and launch the kernels.", - "description_2": "Use triton language to create a forward and backward cumulative sum operation for 4D tensors, with kernels handling chunked operations and Python functions managing kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Out,\n S,\n KV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_h = off_bh % h\n off_e = tl.program_id(1)\n # get the (b, h) location\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n kv_offset = off_bh * d * e\n\n e_offset = off_e * BLOCK_MODEL\n\n Q_block_ptr = (\n Q + qk_offset + tl.arange(0, BLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n K_trans_block_ptr = (\n K + qk_offset + tl.arange(0, BLOCK)[None, :] * d + tl.arange(0, d)[:, None]\n )\n V_block_ptr = (\n V\n + v_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n O_block_ptr = (\n Out\n + o_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n KV_block_ptr = (\n KV\n + kv_offset\n + e_offset\n + tl.arange(0, d)[:, None] * e\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n\n array = tl.arange(0, BLOCK).to(tl.float32)\n q_decay = tl.exp(-s.to(tl.float32) * array[:, None])\n k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - array[None, :]))\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n # diag\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n\n kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n q = tl.load(Q_block_ptr).to(tl.float32)\n k_trans = tl.load(K_trans_block_ptr).to(tl.float32)\n v = tl.load(V_block_ptr).to(tl.float32)\n\n qkv_none_diag = tl.dot(q, kv) * q_decay\n qk = tl.dot(q, k_trans) * diag_decay\n qkv_diag = tl.dot(qk, v)\n\n qkv = qkv_none_diag + qkv_diag\n\n tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty))\n kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)\n\n Q_block_ptr += BLOCK * d\n K_trans_block_ptr += BLOCK * d\n V_block_ptr += BLOCK * e\n O_block_ptr += BLOCK * e\n\n KV = tl.load(KV_block_ptr).to(tl.float32)\n KV = tl.exp(-s.to(tl.float32) * n) * KV + kv\n tl.store(KV_block_ptr, KV.to(KV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_diag_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n\n block_offset = off_block * BLOCK\n qk_block_offset = block_offset * d\n v_block_offset = block_offset * e\n o_block_offset = block_offset * e\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + tl.arange(0, BLOCK)[None, :] * e\n + tl.arange(0, e)[:, None]\n )\n\n DQ_block_ptr = (\n DQ\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + v_block_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + o_block_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n array = tl.arange(0, BLOCK).to(tl.float32)\n\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n diag_decay_trans = tl.trans(diag_decay)\n\n k = tl.load(K_block_ptr).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n do = tl.load(DO_block_ptr).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr).to(tl.float32)\n\n dqk = tl.dot(do, v_trans) * diag_decay\n dq_diag = tl.dot(dqk, k)\n\n dq = dq_diag\n\n dk_diag_trans = tl.dot(q_trans, dqk)\n\n qk_trans = tl.dot(k, q_trans) * diag_decay_trans\n dv_diag = tl.dot(qk_trans, do)\n\n dk_trans = dk_diag_trans\n dv = dv_diag\n\n tl.store(DQ_block_ptr, dq.to(DQ_block_ptr.dtype.element_ty))\n tl.store(DK_trans_block_ptr, dk_trans.to(DK_trans_block_ptr.dtype.element_ty))\n tl.store(DV_block_ptr, dv.to(DV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_none_diag_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n DKV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n kv_offset = off_bh * d * e\n\n block_offset = off_block * BLOCK\n qk_block_offset = block_offset * d\n v_block_offset = block_offset * e\n o_block_offset = block_offset * e\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n\n block_offset = off_block * BLOCK\n qk_block_offset = block_offset * d\n v_block_offset = block_offset * e\n o_block_offset = block_offset * e\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n\n DQ_block_ptr = (\n DQ\n + qk_offset\n + qk_block_offset\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n K_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + tl.arange(0, CBLOCK)[None, :] * e\n + tl.arange(0, e)[:, None]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + o_block_offset\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n DKV_block_ptr = (\n DKV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, e)[None, :]\n )\n\n # compute block array\n c_array = tl.arange(0, CBLOCK)\n\n kv_trans = tl.zeros([e, d], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n for j in range(NUM_CBLOCK):\n q_decay = tl.exp(-s.to(tl.float32) * (j * CBLOCK + c_array[:, None]))\n do = tl.load(DO_block_ptr).to(tl.float32)\n dq_none_diag = tl.dot(do, kv_trans) * q_decay\n dq = dq_none_diag + tl.load(DQ_block_ptr)\n tl.store(DQ_block_ptr, dq.to(DQ_block_ptr.dtype.element_ty))\n\n DQ_block_ptr += CBLOCK * d\n DO_block_ptr += CBLOCK * e\n\n kv_trans_current = tl.zeros([e, d], dtype=tl.float32)\n for j in range(NUM_CBLOCK):\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n k = tl.load(K_block_ptr).to(tl.float32)\n k_decay = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[:, None]))\n )\n kv_trans_current += tl.dot(v_trans, k * k_decay)\n\n K_block_ptr += CBLOCK * d\n V_trans_block_ptr += CBLOCK * e\n\n kv_trans = block_decay * kv_trans + kv_trans_current\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + qk_block_offset\n + n * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + n * d\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + n * e\n + tl.arange(0, CBLOCK)[None, :] * e\n + tl.arange(0, e)[None, :]\n )\n\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + qk_block_offset\n + n * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + v_block_offset\n + n * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + o_block_offset\n + n * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n dkv = tl.zeros([d, e], dtype=tl.float32)\n for i in range(NUM_BLOCK - 1, -1, -1):\n for j in range(NUM_CBLOCK - 1, -1, -1):\n K_block_ptr -= CBLOCK * d\n V_trans_block_ptr -= CBLOCK * e\n DK_trans_block_ptr -= CBLOCK * d\n DV_block_ptr -= CBLOCK * e\n\n k = tl.load(K_block_ptr).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n\n k_decay_trans = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[None, :]))\n )\n k_decay = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[:, None]))\n )\n dk_none_diag_trans = tl.dot(dkv, v_trans) * k_decay_trans\n dv_none_diag = tl.dot(k, dkv) * k_decay\n\n dk_trans = dk_none_diag_trans + tl.load(DK_trans_block_ptr)\n dv = dv_none_diag + tl.load(DV_block_ptr)\n\n tl.store(\n DK_trans_block_ptr, dk_trans.to(DK_trans_block_ptr.dtype.element_ty)\n )\n tl.store(DV_block_ptr, dv.to(DV_block_ptr.dtype.element_ty))\n\n dkv_current = tl.zeros([d, e], dtype=tl.float32)\n for j in range(NUM_CBLOCK - 1, -1, -1):\n DO_block_ptr -= CBLOCK * e\n Q_trans_block_ptr -= CBLOCK * d\n do = tl.load(DO_block_ptr).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr).to(tl.float32)\n q_decay_trans = tl.exp(-s.to(tl.float32) * (j * CBLOCK + c_array[None, :]))\n dkv_current += tl.dot(q_trans * q_decay_trans, do)\n\n dkv = block_decay * dkv + dkv_current\n tl.store(DKV_block_ptr, dkv.to(DKV_block_ptr.dtype.element_ty))\n\n\ndef lasp_forward(q, k, v, s):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n\n # shape constraints\n b, h, n, d = q.shape\n e = v.shape[-1]\n # right\n o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)\n kv = torch.empty((b, h, d, e), dtype=q.dtype, device=q.device)\n\n BLOCK = 64\n NUM_BLOCK = q.shape[2] // BLOCK\n\n BLOCK_MODEL = 32\n\n grid = (b * h, e // BLOCK_MODEL)\n\n with torch.cuda.device(q.device.index):\n _fwd_kernel[grid](\n q,\n k,\n v,\n o,\n s,\n kv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n return o, kv\n\n\ndef lasp_backward(q, k, v, s, do):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n\n do = do.contiguous()\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n BLOCK = 32\n NUM_BLOCK = triton.cdiv(n, BLOCK)\n\n CBLOCK = 16\n\n assert BLOCK % CBLOCK == 0\n NUM_CBLOCK = BLOCK // CBLOCK\n\n dkv = torch.empty((b, h, d, e), dtype=q.dtype, device=q.device)\n\n with torch.cuda.device(q.device.index):\n grid = (b * h, NUM_BLOCK)\n _bwd_diag_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n grid = (b * h,)\n\n _bwd_none_diag_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n dkv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n return dq, dk, dv, None, dkv\n", - "description_1": "Use triton language to implement a sequence of kernels for the forward and backward pass of a custom attention mechanism. This involves three kernels: _fwd_kernel, _bwd_diag_kernel, and _bwd_none_diag_kernel. Each kernel takes varying numbers of parameters to perform matrix operations and decay calculations for neural network layers. The forward kernel (_fwd_kernel) takes 14 parameters, including Q, K, V matrices, and outputs to calculate the attention matrix. The backward kernel for diagonal elements (_bwd_diag_kernel) takes 15 parameters, performing backward calculations on gradients. The non-diagonal backward kernel (_bwd_none_diag_kernel) also takes 16 parameters, further handling gradient calculations with additional decay and accumulation logic.", - "description_2": "Use triton language to create efficient kernels that handle both forward and backward operations in a custom attention mechanism, focusing on memory and computation optimization through the use of block decay and transposed matrix operations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Out,\n S,\n KV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n DBLOCK: tl.constexpr,\n NUM_DBLOCK: tl.constexpr,\n EBLOCK: tl.constexpr,\n NUM_EBLOCK: tl.constexpr,\n):\n off_d = tl.program_id(0)\n off_e = tl.program_id(1)\n off_bh = tl.program_id(2)\n off_h = off_bh % h\n # get the (b, h) location\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_d * b * h * n * e + off_bh * n * e\n kv_offset = off_bh * d * e\n\n d_offset = off_d * DBLOCK\n e_offset = off_e * EBLOCK\n\n kv_d_offset = d_offset * e\n\n Q_block_ptr = (\n Q\n + qk_offset\n + d_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, DBLOCK)[None, :]\n )\n K_trans_block_ptr = (\n K\n + qk_offset\n + d_offset\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, DBLOCK)[:, None]\n )\n V_block_ptr = (\n V\n + v_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n O_block_ptr = (\n Out\n + o_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n KV_block_ptr = (\n KV\n + kv_offset\n + kv_d_offset\n + e_offset\n + tl.arange(0, DBLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n\n array = tl.arange(0, BLOCK).to(tl.float32)\n q_decay = tl.exp(-s.to(tl.float32) * array[:, None])\n k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - array[None, :]))\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n # diag\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n\n # load global KV\n KV = tl.load(KV_block_ptr).to(tl.float32)\n\n kv = tl.zeros([DBLOCK, EBLOCK], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n q = tl.load(Q_block_ptr).to(tl.float32)\n k_trans = tl.load(K_trans_block_ptr).to(tl.float32)\n v = tl.load(V_block_ptr).to(tl.float32)\n\n qkv_none_diag = tl.dot(q, kv) * q_decay + tl.dot(q, KV) * tl.exp(\n -s.to(tl.float32) * (array[:, None] + i * BLOCK)\n )\n # diag\n qk = tl.dot(q, k_trans) * diag_decay\n qkv_diag = tl.dot(qk, v)\n\n qkv = qkv_none_diag + qkv_diag\n\n tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty))\n kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)\n\n Q_block_ptr += BLOCK * d\n K_trans_block_ptr += BLOCK * d\n V_block_ptr += BLOCK * e\n O_block_ptr += BLOCK * e\n\n KV = tl.exp(-s.to(tl.float32) * n) * KV + kv\n tl.store(KV_block_ptr, KV.to(KV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n KV,\n DKV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n DBLOCK: tl.constexpr,\n NUM_DBLOCK: tl.constexpr,\n EBLOCK: tl.constexpr,\n NUM_EBLOCK: tl.constexpr,\n):\n off_d = tl.program_id(0)\n off_e = tl.program_id(1)\n off_bh = tl.program_id(2)\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n kv_offset = off_bh * d * e\n\n d_offset = off_d * DBLOCK\n e_offset = off_e * EBLOCK\n\n dqk_offset = off_e * b * h * n * d\n dv_offset = off_d * b * h * n * e\n\n d_offset = off_d * DBLOCK\n e_offset = off_e * EBLOCK\n kv_d_offset = d_offset * e\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n\n DQ_block_ptr = (\n DQ\n + qk_offset\n + dqk_offset\n + d_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, DBLOCK)[None, :]\n )\n K_block_ptr = (\n K\n + qk_offset\n + d_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, DBLOCK)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + e_offset\n + tl.arange(0, BLOCK)[None, :] * e\n + tl.arange(0, EBLOCK)[:, None]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n\n KV_trans_block_ptr = (\n KV\n + kv_offset\n + kv_d_offset\n + e_offset\n + tl.arange(0, DBLOCK)[None, :] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n DKV_block_ptr = (\n DKV\n + kv_offset\n + kv_d_offset\n + e_offset\n + tl.arange(0, DBLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n\n # compute block array\n array = tl.arange(0, BLOCK)\n\n # diag\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n diag_decay_trans = tl.trans(diag_decay)\n\n KV_trans = tl.load(KV_trans_block_ptr).to(tl.float32)\n kv_trans = tl.zeros([EBLOCK, DBLOCK], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n q_decay = tl.exp(-s.to(tl.float32) * array[:, None])\n k_decay = tl.exp(-s.to(tl.float32) * (BLOCK - array[:, None]))\n do = tl.load(DO_block_ptr).to(tl.float32)\n k = tl.load(K_block_ptr).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n\n dq_none_diag = tl.dot(do, kv_trans) * q_decay + tl.dot(do, KV_trans) * tl.exp(\n -s.to(tl.float32) * (i * BLOCK + array[:, None])\n )\n\n dqk = tl.dot(do, v_trans) * diag_decay\n dq_diag = tl.dot(dqk, k)\n\n dq = dq_none_diag + dq_diag\n\n tl.store(DQ_block_ptr, dq.to(DQ_block_ptr.dtype.element_ty))\n\n DQ_block_ptr += BLOCK * d\n DO_block_ptr += BLOCK * e\n K_block_ptr += BLOCK * d\n V_trans_block_ptr += BLOCK * e\n\n kv_trans = block_decay * kv_trans + tl.dot(v_trans, k * k_decay)\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + d_offset\n + n * d\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, DBLOCK)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + d_offset\n + n * d\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, DBLOCK)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + e_offset\n + n * e\n + tl.arange(0, BLOCK)[None, :] * e\n + tl.arange(0, EBLOCK)[:, None]\n )\n\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + dqk_offset\n + d_offset\n + n * d\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, DBLOCK)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + dv_offset\n + e_offset\n + n * e\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + e_offset\n + n * e\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, EBLOCK)[None, :]\n )\n\n DKV = tl.load(DKV_block_ptr)\n dkv = tl.zeros([DBLOCK, EBLOCK], dtype=tl.float32)\n for i in range(NUM_BLOCK - 1, -1, -1):\n K_block_ptr -= BLOCK * d\n V_trans_block_ptr -= BLOCK * e\n DK_trans_block_ptr -= BLOCK * d\n DV_block_ptr -= BLOCK * e\n DO_block_ptr -= BLOCK * e\n Q_trans_block_ptr -= BLOCK * d\n\n k = tl.load(K_block_ptr).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n do = tl.load(DO_block_ptr).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr).to(tl.float32)\n\n k_decay_trans = tl.exp(-s.to(tl.float32) * (BLOCK - array[None, :]))\n k_decay = tl.exp(-s.to(tl.float32) * (BLOCK - array[:, None]))\n q_decay_trans = tl.exp(-s.to(tl.float32) * array[None, :])\n\n dqk = tl.dot(do, v_trans) * diag_decay\n dk_diag_trans = tl.dot(q_trans, dqk)\n dk_none_diag_trans = tl.dot(dkv, v_trans) * k_decay_trans + tl.dot(\n DKV, v_trans.to(DKV.dtype)\n ) * tl.exp(-s.to(tl.float32) * (n - i * BLOCK - array[None, :]))\n dk_trans = dk_none_diag_trans + dk_diag_trans\n\n qk_trans = tl.dot(k, q_trans) * diag_decay_trans\n dv_diag = tl.dot(qk_trans, do)\n dv_none_diag = tl.dot(k, dkv) * k_decay + tl.dot(k.to(DKV.dtype), DKV) * tl.exp(\n -s.to(tl.float32) * (n - i * BLOCK - array[:, None])\n )\n dv = dv_none_diag + dv_diag\n\n tl.store(DK_trans_block_ptr, dk_trans.to(DK_trans_block_ptr.dtype.element_ty))\n tl.store(DV_block_ptr, dv.to(DV_block_ptr.dtype.element_ty))\n\n dkv = block_decay * dkv + tl.dot(q_trans * q_decay_trans, do)\n\n DKV = tl.exp(-s.to(tl.float32) * n) * DKV + dkv\n tl.store(DKV_block_ptr, DKV.to(DKV_block_ptr.dtype.element_ty))\n\n\ndef lasp_forward(q, k, v, s, KV):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n KV = KV.contiguous()\n\n # shape constraints\n b, h, n, d = q.shape\n e = v.shape[-1]\n # split over head\n cd = 64\n ce = 64\n d_, e_ = min(triton.next_power_of_2(d), cd), min(triton.next_power_of_2(e), ce)\n nd, ne = d // d_, e // e_\n # right\n o = torch.empty((nd, b, h, n, e), dtype=q.dtype, device=q.device)\n\n BLOCK = 64\n\n NUM_BLOCK = q.shape[2] // BLOCK\n\n grid = (nd, ne, b * h)\n\n with torch.cuda.device(q.device.index):\n _fwd_kernel[grid](\n q,\n k,\n v,\n o,\n s,\n KV,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n DBLOCK=d_,\n NUM_DBLOCK=nd,\n EBLOCK=e_,\n NUM_EBLOCK=ne,\n )\n\n if nd > 1:\n o = o.sum(0)\n else:\n o.squeeze_()\n\n return o\n\n\ndef lasp_backward(q, k, v, s, do, KV, DKV):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n do = do.contiguous()\n KV = KV.contiguous()\n DKV = DKV.contiguous()\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n BLOCK = 32\n NUM_BLOCK = triton.cdiv(n, BLOCK)\n\n cd = 64\n ce = 64\n d_, e_ = min(triton.next_power_of_2(d), cd), min(triton.next_power_of_2(e), ce)\n nd, ne = d // d_, e // e_\n\n dq = torch.empty((ne, b, h, n, d), dtype=q.dtype, device=q.device)\n dk = torch.empty((ne, b, h, n, d), dtype=q.dtype, device=q.device)\n dv = torch.empty((nd, b, h, n, e), dtype=q.dtype, device=q.device)\n\n grid = (\n nd,\n ne,\n b * h,\n )\n\n with torch.cuda.device(q.device.index):\n _bwd_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n KV,\n DKV,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n DBLOCK=d_,\n NUM_DBLOCK=nd,\n EBLOCK=e_,\n NUM_EBLOCK=ne,\n )\n\n if ne > 1:\n dq = dq.sum(0)\n dk = dk.sum(0)\n else:\n dq.squeeze_(0)\n dk.squeeze_(0)\n\n if nd > 1:\n dv = dv.sum(0)\n else:\n dv.squeeze_(0)\n\n return dq, dk, dv\n", - "description_1": "Use triton language to implement a forward and backward kernel for a custom layer with attention-like operations. The forward kernel (_fwd_kernel) takes in Q, K, V, Out, S, and KV tensors, along with several block and dimension size constants. It computes outputs by performing operations like element-wise exponential decay, diagonal decay, and matrix multiplication. The backward kernel (_bwd_kernel) computes gradients of Q, K, and V given DO, utilizing similar exponential decay and matrix operations. The lasp_forward function manages input data shape and calls _fwd_kernel, while lasp_backward manages gradient computation and calls _bwd_kernel.", - "description_2": "Use triton language to create a custom forward and backward kernel for a layer performing decay and matrix operations, and utilize these in lasp_forward and lasp_backward functions to process input tensors and compute gradients.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_diag_kernel(\n Q,\n K,\n V,\n Out,\n S,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off = tl.program_id(0)\n off_bh = off // NUM_BLOCK\n off_block = off % NUM_BLOCK\n off_cblock = tl.program_id(1)\n\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n\n block_offset = off_block * BLOCK\n qk_block_offset = block_offset * d\n v_block_offset = block_offset * e\n o_block_offset = block_offset * e\n\n cblock_offset = off_cblock * CBLOCK\n q_cblock_offset = cblock_offset * d\n o_cblock_offset = cblock_offset * e\n\n Q_block_ptr = (\n Q\n + qk_offset\n + qk_block_offset\n + q_cblock_offset\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n K_trans_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n V_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n O_block_ptr = (\n Out\n + o_offset\n + o_block_offset\n + o_cblock_offset\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n\n i = off_cblock\n q_index = tl.arange(0, CBLOCK) + i * CBLOCK\n\n q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)\n\n qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)\n\n for j in range(i + 1):\n kv_index = tl.arange(0, CBLOCK) + j * CBLOCK\n diff = q_index[:, None] - kv_index[None, :]\n s_index = s * diff\n s_index = tl.where(diff >= 0, -s_index, float(\"-inf\"))\n decay = tl.exp(s_index)\n\n k_trans = tl.load(K_trans_block_ptr, mask=kv_index[None, :] < n, other=0.0).to(\n tl.float32\n )\n v = tl.load(V_block_ptr, mask=kv_index[:, None] < n, other=0.0).to(tl.float32)\n\n qk = tl.dot(q, k_trans) * decay\n\n qkv += tl.dot(qk, v)\n\n K_trans_block_ptr += CBLOCK * d\n V_block_ptr += CBLOCK * e\n\n tl.store(\n O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n\n )\n\n\n@triton.jit\ndef _fwd_kv_parallel(\n K,\n V,\n S,\n KV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n D_FBLOCK: tl.constexpr,\n E_FBLOCK: tl.constexpr,\n NUM_FBLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_de = tl.program_id(2)\n\n off_h = off_bh % h\n off_d = off_de // NUM_FBLOCK\n off_e = off_de % NUM_FBLOCK\n\n block_offset = off_block * BLOCK\n\n k_block_offset = block_offset * d\n v_block_offset = block_offset * e\n kv_block_offset = off_block * d * e\n\n k_offset = off_bh * n * d\n v_offset = off_bh * n * e\n kv_offset = off_bh * (NUM_BLOCK + 1) * d * e\n d_offset = off_d * D_FBLOCK\n e_offset = off_e * E_FBLOCK\n\n # (CBLOCK, FBLOCK)\n K_trans_block_ptr = (\n K\n + k_offset\n + k_block_offset\n + d_offset\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, D_FBLOCK)[:, None]\n )\n V_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + e_offset\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, E_FBLOCK)[None, :]\n )\n KV_block_ptr = (\n KV\n + kv_offset\n + kv_block_offset\n + d_offset * e\n + e_offset\n + tl.arange(0, D_FBLOCK)[:, None] * e\n + tl.arange(0, E_FBLOCK)[None, :]\n )\n\n s_ptrs = S + off_h\n s = tl.load(s_ptrs)\n\n # compute block array\n c_array = tl.arange(0, CBLOCK)\n\n kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)\n for j in range(NUM_CBLOCK):\n k_trans = tl.load(K_trans_block_ptr).to(tl.float32)\n v = tl.load(V_block_ptr).to(tl.float32)\n k_decay = tl.exp(-s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[None, :])))\n\n kv += tl.dot(k_trans * k_decay, v)\n\n K_trans_block_ptr += CBLOCK * d\n V_block_ptr += CBLOCK * e\n\n tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _fwd_kv_reduce(\n K,\n V,\n S,\n KV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n D_FBLOCK: tl.constexpr,\n E_FBLOCK: tl.constexpr,\n NUM_FBLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_h = off_bh % h\n off_d = tl.program_id(1)\n off_e = tl.program_id(2)\n\n kv_offset = off_bh * (NUM_BLOCK + 1) * d * e\n d_offset = off_d * D_FBLOCK\n e_offset = off_e * E_FBLOCK\n\n # (CBLOCK, FBLOCK)\n KV_block_ptr = (\n KV\n + kv_offset\n + d_offset * e\n + e_offset\n + tl.arange(0, D_FBLOCK)[:, None] * e\n + tl.arange(0, E_FBLOCK)[None, :]\n )\n\n s_ptrs = S + off_h\n s = tl.load(s_ptrs)\n\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n\n # compute block array\n\n kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n kv_current = tl.load(KV_block_ptr).to(tl.float32)\n tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))\n\n kv = block_decay * kv + kv_current\n KV_block_ptr += d * e\n\n # for GKV\n tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _fwd_none_diag_kernel(\n Q,\n K,\n V,\n Out,\n S,\n KV,\n GKV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n D_FBLOCK: tl.constexpr,\n E_FBLOCK: tl.constexpr,\n NUM_FBLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_h = off_bh % h\n\n off_nc = tl.program_id(1)\n off_n = off_nc // NUM_CBLOCK\n off_c = off_nc % NUM_CBLOCK\n off_e = tl.program_id(2)\n\n n_offset = off_n * BLOCK\n c_offset = off_c * CBLOCK\n e_offset = off_e * E_FBLOCK\n\n q_offset = off_bh * n * d + (n_offset + c_offset) * d\n o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset\n\n kv_offset = off_bh * (NUM_BLOCK + 1) * d * e + off_n * d * e + e_offset\n gkv_offset = off_bh * d * e + e_offset\n\n Q_block_ptr = (\n Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n O_block_ptr = (\n Out\n + o_offset\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, E_FBLOCK)[None, :]\n )\n KV_block_ptr = (\n KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]\n )\n GKV_block_ptr = (\n GKV\n + gkv_offset\n + tl.arange(0, d)[:, None] * e\n + tl.arange(0, E_FBLOCK)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n\n c_array = tl.arange(0, CBLOCK)\n\n GKV = tl.load(GKV_block_ptr).to(tl.float32)\n kv = tl.load(KV_block_ptr).to(tl.float32)\n q = tl.load(Q_block_ptr).to(tl.float32)\n q_decay = tl.exp(-s.to(tl.float32) * (c_offset + c_array[:, None]))\n qkv_none_diag = tl.dot(q, kv) * q_decay + tl.dot(q, GKV) * tl.exp(\n -s.to(tl.float32) * (c_offset + c_array[:, None] + n_offset)\n )\n qkv_diag = tl.load(O_block_ptr).to(tl.float32)\n\n qkv = qkv_diag + qkv_none_diag\n\n tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty))\n\n\ndef lasp_forward(q, k, v, s, KV, BLOCK=128, CBLOCK=64):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n\n # shape constraints\n b, h, n, d = q.shape\n e = v.shape[-1]\n # right\n o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)\n\n NUM_BLOCK = q.shape[2] // BLOCK\n\n NUM_CBLOCK = BLOCK // CBLOCK\n\n grid = (b * h * NUM_BLOCK, NUM_CBLOCK)\n\n with torch.cuda.device(q.device.index):\n _fwd_diag_kernel[grid](\n q,\n k,\n v,\n o,\n s,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n NUM_FBLOCK = 1\n D_FBLOCK = d // NUM_FBLOCK\n E_FBLOCK = e // NUM_FBLOCK\n assert d % NUM_FBLOCK == 0\n assert e % NUM_FBLOCK == 0\n grid = (b * h, NUM_FBLOCK, NUM_FBLOCK)\n\n kv = torch.empty((b, h, NUM_BLOCK + 1, d, e), dtype=torch.float32, device=q.device)\n\n with torch.cuda.device(q.device.index):\n grid = (b * h, NUM_BLOCK, NUM_FBLOCK * NUM_FBLOCK)\n _fwd_kv_parallel[grid](\n k,\n v,\n s,\n kv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n D_FBLOCK=D_FBLOCK,\n E_FBLOCK=E_FBLOCK,\n NUM_FBLOCK=NUM_FBLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n grid = (b * h, NUM_FBLOCK, NUM_FBLOCK)\n _fwd_kv_reduce[grid](\n k,\n v,\n s,\n kv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n D_FBLOCK=D_FBLOCK,\n E_FBLOCK=E_FBLOCK,\n NUM_FBLOCK=NUM_FBLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n grid = (b * h, NUM_BLOCK * NUM_CBLOCK, NUM_FBLOCK)\n _fwd_none_diag_kernel[grid](\n q,\n k,\n v,\n o,\n s,\n kv,\n KV,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n D_FBLOCK=D_FBLOCK,\n E_FBLOCK=E_FBLOCK,\n NUM_FBLOCK=NUM_FBLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n block_decay = torch.exp(-s.to(torch.float32) * n)\n KV = block_decay * KV + kv[:, :, -1]\n\n return o, kv, KV\n", - "description_1": "Use triton language to implement forward pass kernels for a custom attention mechanism. The kernels include _fwd_diag_kernel with 14 arguments for processing diagonal blocks, _fwd_kv_parallel with 15 arguments for parallel KV computation, and _fwd_kv_reduce with 14 arguments for reducing KV results. These kernels are called in lasp_forward which has 6 arguments and manages data grids for computation.", - "description_2": "Use triton language to implement custom attention mechanism kernels, including diagonal block processing and KV parallel computation, integrated in a forward function with grid management.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Out,\n S,\n KV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_h = off_bh % h\n off_e = tl.program_id(1)\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n kv_offset = off_bh * d * e\n\n e_offset = off_e * BLOCK_MODEL\n\n Q_block_ptr = (\n Q + qk_offset + tl.arange(0, BLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n K_trans_block_ptr = (\n K + qk_offset + tl.arange(0, BLOCK)[None, :] * d + tl.arange(0, d)[:, None]\n )\n V_block_ptr = (\n V\n + v_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n O_block_ptr = (\n Out\n + o_offset\n + e_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n KV_block_ptr = (\n KV\n + kv_offset\n + e_offset\n + tl.arange(0, d)[:, None] * e\n + tl.arange(0, BLOCK_MODEL)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n\n array = tl.arange(0, BLOCK).to(tl.float32)\n q_decay = tl.exp(-s.to(tl.float32) * array[:, None])\n k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - array[None, :]))\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n\n kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n q = tl.load(Q_block_ptr).to(tl.float32)\n k_trans = tl.load(K_trans_block_ptr).to(tl.float32)\n v = tl.load(V_block_ptr).to(tl.float32)\n\n qkv_none_diag = tl.dot(q, kv) * q_decay\n qk = tl.dot(q, k_trans) * diag_decay\n qkv_diag = tl.dot(qk, v)\n\n qkv = qkv_none_diag + qkv_diag\n\n tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty))\n kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)\n\n Q_block_ptr += BLOCK * d\n K_trans_block_ptr += BLOCK * d\n V_block_ptr += BLOCK * e\n O_block_ptr += BLOCK * e\n\n tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_diag_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n\n block_offset = off_block * BLOCK\n qk_block_offset = block_offset * d\n v_block_offset = block_offset * e\n o_block_offset = block_offset * e\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + tl.arange(0, BLOCK)[None, :] * e\n + tl.arange(0, e)[:, None]\n )\n\n DQ_block_ptr = (\n DQ\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + qk_block_offset\n + tl.arange(0, BLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + v_block_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + o_block_offset\n + tl.arange(0, BLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n array = tl.arange(0, BLOCK).to(tl.float32)\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n diag_decay_trans = tl.trans(diag_decay)\n\n k = tl.load(K_block_ptr).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n do = tl.load(DO_block_ptr).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr).to(tl.float32)\n\n dqk = tl.dot(do, v_trans) * diag_decay\n dq_diag = tl.dot(dqk, k)\n\n dq = dq_diag\n\n dk_diag_trans = tl.dot(q_trans, dqk)\n\n qk_trans = tl.dot(k, q_trans) * diag_decay_trans\n dv_diag = tl.dot(qk_trans, do)\n\n dk_trans = dk_diag_trans\n dv = dv_diag\n\n tl.store(DQ_block_ptr, dq.to(DQ_block_ptr.dtype.element_ty))\n tl.store(DK_trans_block_ptr, dk_trans.to(DK_trans_block_ptr.dtype.element_ty))\n tl.store(DV_block_ptr, dv.to(DV_block_ptr.dtype.element_ty))\n\n\n@triton.jit\ndef _bwd_none_diag_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n DKV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n kv_offset = off_bh * d * e\n\n block_offset = off_block * BLOCK\n qk_block_offset = block_offset * d\n v_block_offset = block_offset * e\n o_block_offset = block_offset * e\n\n S_block_ptr = S + off_h\n s = tl.load(S_block_ptr)\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n\n DQ_block_ptr = (\n DQ\n + qk_offset\n + qk_block_offset\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n K_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + tl.arange(0, CBLOCK)[None, :] * e\n + tl.arange(0, e)[:, None]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + o_block_offset\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n DKV_block_ptr = (\n DKV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, e)[None, :]\n )\n\n c_array = tl.arange(0, CBLOCK)\n\n kv_trans = tl.zeros([e, d], dtype=tl.float32)\n for i in range(NUM_BLOCK):\n for j in range(NUM_CBLOCK):\n q_decay = tl.exp(-s.to(tl.float32) * (j * CBLOCK + c_array[:, None]))\n do = tl.load(DO_block_ptr).to(tl.float32)\n dq_none_diag = tl.dot(do, kv_trans) * q_decay\n dq = dq_none_diag + tl.load(DQ_block_ptr)\n tl.store(DQ_block_ptr, dq.to(DQ_block_ptr.dtype.element_ty))\n\n DQ_block_ptr += CBLOCK * d\n DO_block_ptr += CBLOCK * e\n\n kv_trans_current = tl.zeros([e, d], dtype=tl.float32)\n for j in range(NUM_CBLOCK):\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n k = tl.load(K_block_ptr).to(tl.float32)\n k_decay = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[:, None]))\n )\n kv_trans_current += tl.dot(v_trans, k * k_decay)\n\n K_block_ptr += CBLOCK * d\n V_trans_block_ptr += CBLOCK * e\n\n kv_trans = block_decay * kv_trans + kv_trans_current\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + qk_block_offset\n + n * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + qk_block_offset\n + n * d\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + v_block_offset\n + n * e\n + tl.arange(0, CBLOCK)[None, :] * e\n + tl.arange(0, e)[None, :]\n )\n\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + qk_block_offset\n + n * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + v_block_offset\n + n * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + o_block_offset\n + n * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n dkv = tl.zeros([d, e], dtype=tl.float32)\n for i in range(NUM_BLOCK - 1, -1, -1):\n for j in range(NUM_CBLOCK - 1, -1, -1):\n K_block_ptr -= CBLOCK * d\n V_trans_block_ptr -= CBLOCK * e\n DK_trans_block_ptr -= CBLOCK * d\n DV_block_ptr -= CBLOCK * e\n\n k = tl.load(K_block_ptr).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr).to(tl.float32)\n\n k_decay_trans = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[None, :]))\n )\n k_decay = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[:, None]))\n )\n dk_none_diag_trans = tl.dot(dkv, v_trans) * k_decay_trans\n dv_none_diag = tl.dot(k, dkv) * k_decay\n\n dk_trans = dk_none_diag_trans + tl.load(DK_trans_block_ptr)\n dv = dv_none_diag + tl.load(DV_block_ptr)\n\n tl.store(\n DK_trans_block_ptr, dk_trans.to(DK_trans_block_ptr.dtype.element_ty)\n )\n tl.store(DV_block_ptr, dv.to(DV_block_ptr.dtype.element_ty))\n\n dkv_current = tl.zeros([d, e], dtype=tl.float32)\n for j in range(NUM_CBLOCK - 1, -1, -1):\n DO_block_ptr -= CBLOCK * e\n Q_trans_block_ptr -= CBLOCK * d\n do = tl.load(DO_block_ptr).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr).to(tl.float32)\n q_decay_trans = tl.exp(-s.to(tl.float32) * (j * CBLOCK + c_array[None, :]))\n dkv_current += tl.dot(q_trans * q_decay_trans, do)\n\n dkv = block_decay * dkv + dkv_current\n tl.store(DKV_block_ptr, dkv.to(DKV_block_ptr.dtype.element_ty))\n\n\ndef lasp_forward(q, k, v, s, kv):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n kv = kv.contiguous()\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)\n\n BLOCK = 64\n NUM_BLOCK = q.shape[2] // BLOCK\n\n BLOCK_MODEL = 32\n\n grid = (b * h, e // BLOCK_MODEL)\n\n with torch.cuda.device(q.device.index):\n _fwd_kernel[grid](\n q,\n k,\n v,\n o,\n s,\n kv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n return o\n\n\ndef lasp_backward(q, k, v, s, do):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n\n do = do.contiguous()\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n BLOCK = 32\n NUM_BLOCK = triton.cdiv(n, BLOCK)\n\n CBLOCK = 16\n\n assert BLOCK % CBLOCK == 0\n NUM_CBLOCK = BLOCK // CBLOCK\n\n dkv = torch.empty((b, h, d, e), dtype=q.dtype, device=q.device)\n\n with torch.cuda.device(q.device.index):\n grid = (b * h, NUM_BLOCK)\n _bwd_diag_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n grid = (b * h,)\n\n _bwd_none_diag_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n dkv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n return dq, dk, dv, None, dkv\n", - "description_1": "Use triton language to implement forward and backward kernel functions for a custom neural network operation. The forward function '_fwd_kernel' takes 8 tensor inputs (Q, K, V, Out, S, KV) and 7 configuration constants (b, h, n, d, e, BLOCK, NUM_BLOCK, BLOCK_MODEL) to compute an output tensor. The backward functions '_bwd_diag_kernel' and '_bwd_none_diag_kernel' also involve 10 tensor inputs (Q, K, V, S, DO, DQ, DK, DV, DKV) and 7 configuration constants (b, h, n, d, e, BLOCK, NUM_BLOCK, CBLOCK, NUM_CBLOCK) to calculate gradients for these tensors.", - "description_2": "Use triton language to create kernels for a forward and backward pass in a neural network, handling tensors and configuration constants to compute outputs and gradients.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Out,\n S, # log lambda\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n ##### get offset\n off_bh = tl.program_id(0)\n off_h = off_bh % h\n off_e = tl.program_id(1)\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n # channel offset\n e_offset = off_e * BLOCK_MODEL\n\n ##### get block ptr\n Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]\n K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]\n V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]\n O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]\n S_block_ptr = S + off_h\n\n ##### init diag decay(Lambda); q, k decay; kv\n s = tl.load(S_block_ptr)\n # q, k decay\n off_block = tl.arange(\n 0, BLOCK\n ) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent\n q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])\n k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n # diag decay\n index = off_block[:, None] - off_block[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)\n\n ##### compute\n for i in range(NUM_BLOCK):\n # load\n q = tl.load(\n Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0\n ).to(tl.float32)\n k_trans = tl.load(\n K_trans_block_ptr + off_block[None, :] * d,\n mask=off_block[None, :] < n,\n other=0.0,\n ).to(tl.float32)\n v = tl.load(\n V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0\n ).to(tl.float32)\n\n # compute\n qk = tl.dot(q, k_trans) * diag_decay\n o_intra = tl.dot(qk, v)\n o_inter = tl.dot(q, kv) * q_decay\n o = o_intra + o_inter\n\n # save and update\n tl.store(\n O_block_ptr + off_block[:, None] * e,\n o.to(O_block_ptr.dtype.element_ty),\n mask=off_block[:, None] < n,\n )\n kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)\n off_block += BLOCK\n\n@triton.jit\ndef _bwd_intra_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n ##### get offset\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_h = off_bh % h\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n block_offset = off_block * BLOCK + tl.arange(0, BLOCK)\n\n ##### get block ptr\n Q_trans_block_ptr = (\n Q + qk_offset + block_offset[None, :] * d + tl.arange(0, d)[:, None]\n )\n K_block_ptr = K + qk_offset + block_offset[:, None] * d + tl.arange(0, d)[None, :]\n V_trans_block_ptr = (\n V + v_offset + block_offset[None, :] * e + tl.arange(0, e)[:, None]\n )\n\n DQ_block_ptr = DQ + qk_offset + block_offset[:, None] * d + tl.arange(0, d)[None, :]\n DK_trans_block_ptr = (\n DK + qk_offset + block_offset[None, :] * d + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = DV + v_offset + block_offset[:, None] * e + tl.arange(0, e)[None, :]\n DO_block_ptr = DO + o_offset + block_offset[:, None] * e + tl.arange(0, e)[None, :]\n\n S_block_ptr = S + off_h\n\n ##### init diag decay(Lambda)\n s = tl.load(S_block_ptr)\n array = tl.arange(0, BLOCK).to(tl.float32)\n # diag\n index = array[:, None] - array[None, :]\n s_index = s * index\n s_index = tl.where(index >= 0, -s_index, float(\"-inf\"))\n diag_decay = tl.exp(s_index)\n diag_decay_trans = tl.trans(diag_decay)\n\n ##### load block\n k = tl.load(K_block_ptr, mask=block_offset[:, None] < n, other=0.0).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr, mask=block_offset[None, :] < n, other=0.0).to(\n tl.float32\n )\n do = tl.load(DO_block_ptr, mask=block_offset[:, None] < n, other=0.0).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr, mask=block_offset[None, :] < n, other=0.0).to(\n tl.float32\n )\n\n ##### compute\n dqk = tl.dot(do, v_trans) * diag_decay\n dq_intra = tl.dot(dqk, k)\n\n dk_intra_trans = tl.dot(q_trans, dqk)\n\n qk_trans = tl.dot(k, q_trans) * diag_decay_trans\n dv_intra = tl.dot(qk_trans, do)\n\n dq = dq_intra\n dk_trans = dk_intra_trans\n dv = dv_intra\n\n # save\n tl.store(\n DQ_block_ptr,\n dq.to(DQ_block_ptr.dtype.element_ty),\n mask=block_offset[:, None] < n,\n )\n tl.store(\n DK_trans_block_ptr,\n dk_trans.to(DK_trans_block_ptr.dtype.element_ty),\n mask=block_offset[None, :] < n,\n )\n tl.store(\n DV_block_ptr,\n dv.to(DV_block_ptr.dtype.element_ty),\n mask=block_offset[:, None] < n,\n )\n\n@triton.jit\ndef _bwd_inter_kernel(\n Q,\n K,\n V,\n S,\n DO,\n DQ,\n DK,\n DV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n ##### get offset\n off_bh = tl.program_id(0)\n off_h = off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n S_block_ptr = S + off_h\n\n ##### get block ptr\n DQ_block_ptr = (\n DQ + qk_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n K_block_ptr = (\n K + qk_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V + v_offset + tl.arange(0, CBLOCK)[None, :] * e + tl.arange(0, e)[:, None]\n )\n DO_block_ptr = (\n DO + o_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, e)[None, :]\n )\n # mask\n off_block1 = tl.arange(0, CBLOCK)\n off_block2 = tl.arange(0, CBLOCK)\n # compute block array\n c_array = tl.arange(0, CBLOCK)\n\n ##### init lambda; kv\n s = tl.load(S_block_ptr)\n block_decay = tl.exp(-s.to(tl.float32) * BLOCK)\n kv_trans = tl.zeros([e, d], dtype=tl.float32)\n\n ##### compute dq inter\n for i in range(NUM_BLOCK):\n # compute in subblock\n for j in range(NUM_CBLOCK):\n if i > 0: # if not add this, may have bug\n q_decay = tl.exp(-s.to(tl.float32) * (j * CBLOCK + c_array[:, None]))\n do = tl.load(DO_block_ptr, mask=off_block1[:, None] < n, other=0.0).to(\n tl.float32\n )\n dq_inter = tl.dot(do, kv_trans) * q_decay\n dq = dq_inter + tl.load(\n DQ_block_ptr, mask=off_block1[:, None] < n, other=0.0\n )\n tl.store(\n DQ_block_ptr,\n dq.to(DQ_block_ptr.dtype.element_ty),\n mask=off_block1[:, None] < n,\n )\n\n DQ_block_ptr += CBLOCK * d\n DO_block_ptr += CBLOCK * e\n off_block1 += CBLOCK\n\n # update kv in subblock\n kv_trans_current = tl.zeros([e, d], dtype=tl.float32)\n for j in range(NUM_CBLOCK):\n v_trans = tl.load(\n V_trans_block_ptr, mask=off_block2[None, :] < n, other=0.0\n ).to(tl.float32)\n k = tl.load(K_block_ptr, mask=off_block2[:, None] < n, other=0.0).to(\n tl.float32\n )\n k_decay = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[:, None]))\n )\n kv_trans_current += tl.dot(v_trans, k * k_decay)\n\n K_block_ptr += CBLOCK * d\n V_trans_block_ptr += CBLOCK * e\n off_block2 += CBLOCK\n\n kv_trans = block_decay * kv_trans + kv_trans_current\n\n ##### get block ptr\n m = NUM_BLOCK * BLOCK\n off_block1 = m + tl.arange(0, CBLOCK)\n off_block2 = m + tl.arange(0, CBLOCK)\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + m * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + m * d\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + m * e\n + tl.arange(0, CBLOCK)[None, :] * e\n + tl.arange(0, e)[:, None]\n )\n\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + m * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + m * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + m * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n ##### init dkv\n dkv = tl.zeros([d, e], dtype=tl.float32)\n\n ##### compute dk, dv inter\n for i in range(NUM_BLOCK - 1, -1, -1):\n # compute in subblock\n for j in range(NUM_CBLOCK - 1, -1, -1):\n K_block_ptr -= CBLOCK * d\n V_trans_block_ptr -= CBLOCK * e\n DK_trans_block_ptr -= CBLOCK * d\n DV_block_ptr -= CBLOCK * e\n off_block1 -= CBLOCK\n\n if i < NUM_BLOCK - 1: # if not add this, may have bug\n k = tl.load(K_block_ptr, mask=off_block1[:, None] < n, other=0.0).to(\n tl.float32\n )\n v_trans = tl.load(\n V_trans_block_ptr, mask=off_block1[None, :] < n, other=0.0\n ).to(tl.float32)\n\n k_decay_trans = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[None, :]))\n )\n k_decay = tl.exp(\n -s.to(tl.float32) * (BLOCK - (j * CBLOCK + c_array[:, None]))\n )\n dk_inter_trans = tl.dot(dkv, v_trans) * k_decay_trans\n dv_inter = tl.dot(k, dkv) * k_decay\n\n dk_trans = dk_inter_trans + tl.load(\n DK_trans_block_ptr, mask=off_block1[None, :] < n, other=0.0\n )\n dv = dv_inter + tl.load(\n DV_block_ptr, mask=off_block1[:, None] < n, other=0.0\n )\n\n tl.store(\n DK_trans_block_ptr,\n dk_trans.to(DK_trans_block_ptr.dtype.element_ty),\n mask=off_block1[None, :] < n,\n )\n tl.store(\n DV_block_ptr,\n dv.to(DV_block_ptr.dtype.element_ty),\n mask=off_block1[:, None] < n,\n )\n\n # update dkv in subblock\n dkv_current = tl.zeros([d, e], dtype=tl.float32)\n for j in range(NUM_CBLOCK - 1, -1, -1):\n DO_block_ptr -= CBLOCK * e\n Q_trans_block_ptr -= CBLOCK * d\n off_block2 -= CBLOCK\n\n do = tl.load(DO_block_ptr, mask=off_block2[:, None] < n, other=0.0).to(\n tl.float32\n )\n q_trans = tl.load(\n Q_trans_block_ptr, mask=off_block2[None, :] < n, other=0.0\n ).to(tl.float32)\n q_decay_trans = tl.exp(-s.to(tl.float32) * (j * CBLOCK + c_array[None, :]))\n dkv_current += tl.dot(q_trans * q_decay_trans, do)\n\n dkv = block_decay * dkv + dkv_current\n\nclass LightningAttention(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, s):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)\n\n BLOCK = 64\n NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)\n # parallel over channel\n BLOCK_MODEL = min(triton.next_power_of_2(e), 32)\n grid = (b * h, triton.cdiv(e, BLOCK_MODEL))\n\n with torch.cuda.device(q.device.index):\n _fwd_kernel[grid](\n q,\n k,\n v,\n o,\n s,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n ctx.save_for_backward(q, k, v, s)\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, s = ctx.saved_tensors\n\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n s = s.contiguous()\n do = do.contiguous()\n\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n # block size\n BLOCK = 64\n NUM_BLOCK = triton.cdiv(n, BLOCK)\n # compute block size\n CBLOCK = 32\n NUM_CBLOCK = BLOCK // CBLOCK\n\n with torch.cuda.device(q.device.index):\n # for intra part, compute in parallel\n grid = (b * h, NUM_BLOCK)\n _bwd_intra_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n # for inter part, compute in sequencial\n grid = (b * h,)\n _bwd_inter_kernel[grid](\n q,\n k,\n v,\n s,\n do,\n dq,\n dk,\n dv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n return dq, dk, dv, None, None\n\nlightning_attn_ = LightningAttention.apply\n\ndef lightning_attn(q, k, v, ed):\n d = q.shape[-1]\n e = v.shape[-1]\n if d >= 128:\n m = 128\n else:\n m = 64\n arr = [m * i for i in range(d // m + 1)]\n if arr[-1] != d:\n arr.append(d)\n n = len(arr)\n output = 0\n for i in range(n - 1):\n s = arr[i]\n e = arr[i + 1]\n q1 = q[..., s:e]\n k1 = k[..., s:e]\n\n o = lightning_attn_(q1, k1, v, ed)\n output = output + o\n\n return output\n", - "description_1": "Use triton language to implement a forward and backward pass for a custom attention mechanism. The forward kernel (_fwd_kernel) takes 13 parameters: Q, K, V, Out, S, and 8 constexpr parameters (b, h, n, d, e, BLOCK, NUM_BLOCK, BLOCK_MODEL). It computes the attention output using block-wise operations. The backward kernels (_bwd_intra_kernel and _bwd_inter_kernel) take 15 parameters: Q, K, V, S, DO, DQ, DK, DV, and 7 constexpr parameters (b, h, n, d, e, BLOCK, NUM_BLOCK, CBLOCK, NUM_CBLOCK). They compute the gradients for Q, K, and V using intra-block and inter-block operations. The LightningAttention class wraps these kernels for use in PyTorch's autograd system, with forward and backward methods handling the data preparation and kernel invocation.", - "description_2": "Use triton language to create a custom attention mechanism with forward and backward passes. The forward pass computes attention using block-wise operations, while the backward pass calculates gradients for input tensors. Implement this using triton kernels and integrate with PyTorch autograd.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward pass of Simple RMS Norm\n@triton.jit\ndef srms_norm_fw(X, Y, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n x_ptrs = X + row * stride + cols\n x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)\n x_zm = tl.where(mask, x, 0.0)\n x_var = tl.sum(x_zm * x_zm, axis=0) / N\n rstd = 1.0 / tl.sqrt(x_var + eps)\n y = x_zm * rstd\n tl.store(V + row, rstd)\n y_ptrs = Y + row * stride + cols\n tl.store(y_ptrs, y, mask=mask)\n\n# Triton kernel for backward pass (DX) of Simple RMS Norm\n@triton.jit\ndef srms_norm_bwd_dx_fused(\n DX, DY, X, V, stride, N, BLOCK_SIZE_N: tl.constexpr,\n):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n x_ptrs = X + row * stride + cols\n dy_ptrs = DY + row * stride + cols\n x = tl.load(x_ptrs, mask=mask, other=0)\n dy = tl.load(dy_ptrs, mask=mask, other=0)\n rstd = tl.load(V + row)\n xhat = x * rstd\n wdy = dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n mean1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - (xhat * mean1)) * rstd\n mask = cols < N\n dx_ptrs = DX + row * stride + cols\n tl.store(dx_ptrs, dx, mask=mask)\n\nclass _SrmsNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, eps):\n if x.dtype == torch.float16:\n eps = max(eps, 1.6e-5)\n y = torch.empty_like(x)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n if not x_arg.is_contiguous() or not y.is_contiguous():\n x_arg = x_arg.contiguous()\n y = y.contiguous()\n num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)\n srms_norm_fw[(M,)](\n x_arg, y, rstd, x_arg.stride(0), N, eps, num_warps=num_warps, BLOCK_SIZE_N=BLOCK_SIZE_N,\n )\n ctx.save_for_backward(x, rstd)\n ctx.BLOCK_SIZE_N = BLOCK_SIZE_N\n ctx.num_warps = num_warps\n return y.reshape_as(x)\n\n @staticmethod\n def backward(ctx, dy):\n x, rstd = ctx.saved_tensors\n x = x.reshape(-1, x.size(-1))\n M, N = x.size()\n GROUP_SIZE_M = 32\n if N <= 8192:\n GROUP_SIZE_M = 64\n if N <= 4096:\n GROUP_SIZE_M = 96\n if N <= 2048:\n GROUP_SIZE_M = 128\n if N <= 1024:\n GROUP_SIZE_M = 256\n if dy.dtype == torch.float32:\n GROUP_SIZE_M = GROUP_SIZE_M // 2\n dy = dy.contiguous()\n dx = torch.empty_like(dy)\n assert (\n dy.numel() == x.numel()\n ), \"Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm\"\n num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)\n srms_norm_bwd_dx_fused[(M,)](\n dx, dy, x, rstd, x.stride(0), N, BLOCK_SIZE_N=ctx.BLOCK_SIZE_N, num_warps=num_warps\n )\n dx = dx.reshape_as(dy)\n return dx, None, None\n\nclass SimpleRMSNorm(torch.nn.Module):\n def __init__(self, dim: int, eps: float = 1e-6):\n super().__init__()\n self.eps = eps\n self.dim = dim\n\n def forward(self, x):\n return _SrmsNorm.apply(x, self.eps)\n", - "description_1": "Use triton language to implement a simple RMS normalization operation. This includes two kernels: one for the forward pass (srms_norm_fw) which normalizes input tensors based on their variance and a given epsilon, and another for the backward pass (srms_norm_bwd_dx_fused) which computes the gradients of the input tensors based on the gradients of the output tensors. The kernels require handling of tensor shapes, masking to handle non-square tensors, and efficient memory access patterns. The operation also involves using a custom PyTorch autograd function for integration with PyTorch, encapsulating the forward and backward passes with proper context management. Each kernel needs specific tuning parameters like BLOCK_SIZE_N to determine the number of elements processed in parallel and num_warps for controlling parallel execution.", - "description_2": "Use triton language to implement a forward and backward RMS normalization operation in PyTorch using custom autograd functions, optimizing for memory access and parallel execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef triton_cross_scan(\n x, # (B, C, H, W)\n y, # (B, 4, C, H, W)\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n _tmp0 = i_c * BC * DH * DW\n _tmp1 = DC * DH * DW\n _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]\n p_x = x + i_b * _tmp1 + _tmp2\n p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same\n p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(\n 0, BH)[:, None] # trans\n p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + (\n BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (\n BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip\n p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + (\n BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (\n BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip\n\n for idxc in range(_for_C):\n _idx = idxc * DH * DW\n _x = tl.load(p_x + _idx, mask=_mask_hw)\n tl.store(p_y1 + _idx, _x, mask=_mask_hw)\n tl.store(p_y2 + _idx, _x, mask=_mask_hw)\n tl.store(p_y3 + _idx, _x, mask=_mask_hw)\n tl.store(p_y4 + _idx, _x, mask=_mask_hw)\n\n\n@triton.jit\ndef triton_cross_merge(\n x, # (B, C, H, W)\n y, # (B, 4, C, H, W)\n BC: tl.constexpr,\n BH: tl.constexpr,\n BW: tl.constexpr,\n DC: tl.constexpr,\n DH: tl.constexpr,\n DW: tl.constexpr,\n NH: tl.constexpr,\n NW: tl.constexpr,\n):\n i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h, i_w = (i_hw // NW), (i_hw % NW)\n _mask_h = (i_h * BH + tl.arange(0, BH)) < DH\n _mask_w = (i_w * BW + tl.arange(0, BW)) < DW\n _mask_hw = _mask_h[:, None] & _mask_w[None, :]\n _for_C = min(DC - i_c * BC, BC)\n\n _tmp0 = i_c * BC * DH * DW\n _tmp1 = DC * DH * DW\n _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]\n p_x = x + i_b * _tmp1 + _tmp2\n p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same\n p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(\n 0, BH)[:, None] # trans\n p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + (\n BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (\n BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip\n p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + (\n BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (\n BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip\n\n for idxc in range(_for_C):\n _idx = idxc * DH * DW\n _y1 = tl.load(p_y1 + _idx, mask=_mask_hw)\n _y2 = tl.load(p_y2 + _idx, mask=_mask_hw)\n _y3 = tl.load(p_y3 + _idx, mask=_mask_hw)\n _y4 = tl.load(p_y4 + _idx, mask=_mask_hw)\n tl.store(p_x + _idx, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)\n\nclass CrossScanTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x: torch.Tensor):\n B, C, H, W = x.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = min(triton.next_power_of_2(C), 2), min(triton.next_power_of_2(H), 32), min(\n triton.next_power_of_2(W), 32)\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n x = x.contiguous()\n y = x.new_empty((B, 4, C, H, W))\n triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return y.view(B, 4, C, -1)\n\n @staticmethod\n def backward(ctx, y: torch.Tensor):\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n y = y.contiguous().view(B, 4, C, H, W)\n x = y.new_empty((B, C, H, W))\n triton_cross_merge[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return x\n\nclass CrossMergeTriton(torch.autograd.Function):\n @staticmethod\n def forward(ctx, y: torch.Tensor):\n B, K, C, H, W = y.shape\n B, C, H, W = int(B), int(C), int(H), int(W)\n BC, BH, BW = min(triton.next_power_of_2(C), 2), min(triton.next_power_of_2(H), 32), min(\n triton.next_power_of_2(W), 32)\n NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)\n ctx.shape = (B, C, H, W)\n ctx.triton_shape = (BC, BH, BW, NC, NH, NW)\n y = y.contiguous().view(B, 4, C, H, W)\n x = y.new_empty((B, C, H, W))\n triton_cross_merge[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return x.view(B, C, -1)\n\n @staticmethod\n def backward(ctx, x: torch.Tensor):\n B, C, H, W = ctx.shape\n BC, BH, BW, NC, NH, NW = ctx.triton_shape\n x = x.contiguous()\n y = x.new_empty((B, 4, C, H, W))\n triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW)\n return y\n", - "description_1": "Use triton language to implement two kernels: triton_cross_scan and triton_cross_merge. The triton_cross_scan kernel takes 10 parameters: x (input tensor of shape (B, C, H, W)), y (output tensor of shape (B, 4, C, H, W)), and 8 constexpr parameters (BC, BH, BW, DC, DH, DW, NH, NW) which define block sizes and dimensions. It performs a cross scan operation on the input tensor x and stores the result in y. The triton_cross_merge kernel also takes 10 parameters with the same meanings and performs a merge operation on the input tensor y, storing the result in x. Both kernels use triton's parallel programming model to efficiently handle large tensor operations.", - "description_2": "Use triton language to create two kernels for cross scan and merge operations on tensors. The kernels should handle input and output tensors with specific shapes and use constexpr parameters to define block sizes and dimensions for efficient parallel computation.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel function to perform element-wise addition of two vectors\n@triton.jit\ndef vector_add_kernel(X, Y, Z, N, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < N\n x = tl.load(X + offsets, mask=mask)\n y = tl.load(Y + offsets, mask=mask)\n z = x + y\n tl.store(Z + offsets, z, mask=mask)\n\n# Function to launch the kernel\ndef vector_add(X, Y, Z, N, BLOCK_SIZE=1024):\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']),)\n vector_add_kernel[grid](X, Y, Z, N, BLOCK_SIZE)\n", - "description_1": "Use triton language to implement a kernel function 'vector_add_kernel' that performs element-wise addition of two vectors X and Y, storing the result in vector Z. The kernel is launched using the 'vector_add' function, which calculates the grid size based on the input size N and a specified BLOCK_SIZE.", - "description_2": "Use triton language to create a vector addition kernel and a corresponding launch function to perform element-wise addition of two input vectors.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom exposer.ops.triton.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _fc1_fwd_kernel(A, B, C, M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS: tl.constexpr,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n fp8_fast_accum: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr\n ):\n # matrix multiplication\n pid_z = tl.program_id(1)\n pid = tl.program_id(0)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = NUM_NZ_BLOCKS\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n block_offset = tl.load(NZ_BLOCK_INDICES + pid_n)\n rn = block_offset * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n if AB_DTYPE is not None:\n a = a.to(AB_DTYPE)\n b = b.to(AB_DTYPE)\n if fp8_fast_accum:\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n else:\n acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = block_offset * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _fc1_bwd_kernel(A, B, C, M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS: tl.constexpr,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n fp8_fast_accum: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr\n ):\n pid_z = tl.program_id(1)\n pid = tl.program_id(0)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n STRIDE_A = BLOCK_K * SPLIT_K * stride_ak\n STRIDE_B = BLOCK_K * SPLIT_K * stride_bk\n for k in range(0, NUM_NZ_BLOCKS):\n block_offset = tl.load(NZ_BLOCK_INDICES + k)\n _A = A + STRIDE_A * block_offset\n _B = B + STRIDE_B * block_offset\n if EVEN_K:\n a = tl.load(_A)\n b = tl.load(_B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(_A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(_B, mask=rk[:, None] < k_remaining, other=_0)\n if AB_DTYPE is not None:\n a = a.to(AB_DTYPE)\n b = b.to(AB_DTYPE)\n if fp8_fast_accum:\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n else:\n acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n acc = acc.to(C.dtype.element_ty)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\nclass _fc1_matmul(torch.autograd.Function):\n @staticmethod\n def _fwd(a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype):\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n ab_dtype = get_higher_dtype(a.dtype, b.dtype)\n if (output_dtype is None):\n output_dtype = ab_dtype\n c = torch.empty((M, N), device=device, dtype=output_dtype)\n supported_acc_dtypes = {\n torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),\n torch.float32: (torch.float32, ), torch.int8: (torch.int32, )\n }\n if acc_dtype is None:\n acc_dtype = supported_acc_dtypes[ab_dtype][0]\n else:\n assert isinstance(acc_dtype, torch.dtype), \"acc_dtype must be a torch.dtype\"\n assert acc_dtype in supported_acc_dtypes[a.dtype], \"acc_dtype not compatible with the type of a\"\n assert acc_dtype in supported_acc_dtypes[b.dtype], \"acc_dtype not compatible with the type of b\"\n\n def to_tl_type(ty):\n return getattr(tl, str(ty).split(\".\")[-1])\n\n acc_dtype = to_tl_type(acc_dtype)\n ab_dtype = to_tl_type(ab_dtype)\n output_dtype = to_tl_type(output_dtype)\n if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:\n ab_dtype = None\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * NUM_NZ_BLOCKS, META['SPLIT_K'])\n _fc1_fwd_kernel[grid](\n a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS=NUM_NZ_BLOCKS,\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, AB_DTYPE=ab_dtype\n )\n return c\n\n @staticmethod\n def _bwd(a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype):\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n ab_dtype = get_higher_dtype(a.dtype, b.dtype)\n if (output_dtype is None):\n output_dtype = ab_dtype\n c = torch.empty((M, N), device=device, dtype=output_dtype)\n supported_acc_dtypes = {\n torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),\n torch.float32: (torch.float32, ), torch.int8: (torch.int32, )\n }\n if acc_dtype is None:\n acc_dtype = supported_acc_dtypes[ab_dtype][0]\n else:\n assert isinstance(acc_dtype, torch.dtype), \"acc_dtype must be a torch.dtype\"\n assert acc_dtype in supported_acc_dtypes[a.dtype], \"acc_dtype not compatible with the type of a\"\n assert acc_dtype in supported_acc_dtypes[b.dtype], \"acc_dtype not compatible with the type of b\"\n\n def to_tl_type(ty):\n return getattr(tl, str(ty).split(\".\")[-1])\n\n acc_dtype = to_tl_type(acc_dtype)\n ab_dtype = to_tl_type(ab_dtype)\n output_dtype = to_tl_type(output_dtype)\n if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:\n ab_dtype = None\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _fc1_bwd_kernel[grid](\n a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS=NUM_NZ_BLOCKS,\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, AB_DTYPE=ab_dtype\n )\n return c\n\n @staticmethod\n def forward(ctx, a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None):\n ctx.save_for_backward(a, b)\n ctx.NZ_BLOCK_INDICES = NZ_BLOCK_INDICES\n ctx.NUM_NZ_BLOCKS = NUM_NZ_BLOCKS\n return _fc1_matmul._fwd(a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, \n acc_dtype=acc_dtype, allow_tf32=allow_tf32, \n fp8_fast_accum=fp8_fast_accum, output_dtype=output_dtype)\n\n @staticmethod\n def backward(ctx, dc):\n a, b = ctx.saved_tensors\n NZ_BLOCK_INDICES = ctx.NZ_BLOCK_INDICES\n NUM_NZ_BLOCKS = ctx.NUM_NZ_BLOCKS\n grad_a = None\n if ctx.needs_input_grad[0]:\n grad_a = _fc1_matmul._bwd(dc, b.t(), NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype=dc.dtype, allow_tf32=True, fp8_fast_accum=True, output_dtype=a.dtype)\n if ctx.needs_input_grad[1]:\n NotImplementedError(\"backward for b is not implemented, as it is not needed in PEFT\")\n return grad_a, None, None, None\n\nfc1_matmul = _fc1_matmul.apply\n", - "description_1": "Use triton language to implement a matrix multiplication kernel '_fc1_fwd_kernel' with parameters: 23 parameters including input matrices A, B, output matrix C, dimensions M, N, K, and several stride and block-related settings. The kernel includes options for data type accumulation, tensor core usage, and other optimizations. Implement another kernel '_fc1_bwd_kernel' with 23 similar parameters for backward computation in matrix multiplication. Create a PyTorch autograd Function class '_fc1_matmul' to wrap these kernels, including a forward method for computation and a backward method to calculate gradients.", - "description_2": "Use triton language to implement forward and backward matrix multiplication kernels with configurable parameters and PyTorch integration using custom autograd functions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nfrom exposer.ops.triton.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _fc2_fwd_kernel(A, B, C, M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS: tl.constexpr,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n fp8_fast_accum: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr\n ):\n pid_z = tl.program_id(1)\n pid = tl.program_id(0)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n STRIDE_A = BLOCK_K * SPLIT_K * stride_ak\n STRIDE_B = BLOCK_K * SPLIT_K * stride_bk\n for k in range(0, NUM_NZ_BLOCKS):\n block_offset = tl.load(NZ_BLOCK_INDICES + k)\n _A = A + STRIDE_A * block_offset\n _B = B + STRIDE_B * block_offset\n if EVEN_K:\n a = tl.load(_A)\n b = tl.load(_B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(_A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(_B, mask=rk[:, None] < k_remaining, other=_0)\n if AB_DTYPE is not None:\n a = a.to(AB_DTYPE)\n b = b.to(AB_DTYPE)\n if fp8_fast_accum:\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n else:\n acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n acc = acc.to(C.dtype.element_ty)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _fc2_bwd_kernel(A, B, C, M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS: tl.constexpr,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n fp8_fast_accum: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr\n ):\n pid_z = tl.program_id(1)\n pid = tl.program_id(0)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = NUM_NZ_BLOCKS\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n block_offset = tl.load(NZ_BLOCK_INDICES + pid_n)\n rn = block_offset * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n if AB_DTYPE is not None:\n a = a.to(AB_DTYPE)\n b = b.to(AB_DTYPE)\n if fp8_fast_accum:\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n else:\n acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = block_offset * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\nclass _fc2_matmul(torch.autograd.Function):\n @staticmethod\n def _fwd(a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype):\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n\n ab_dtype = get_higher_dtype(a.dtype, b.dtype)\n\n if (output_dtype is None):\n output_dtype = ab_dtype\n\n c = torch.empty((M, N), device=device, dtype=output_dtype)\n\n supported_acc_dtypes = {\n torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),\n torch.float32: (torch.float32, ), torch.int8: (torch.int32, )\n }\n\n if acc_dtype is None:\n acc_dtype = supported_acc_dtypes[ab_dtype][0]\n else:\n assert isinstance(acc_dtype, torch.dtype), \"acc_dtype must be a torch.dtype\"\n assert acc_dtype in supported_acc_dtypes[a.dtype], \"acc_dtype not compatible with the type of a\"\n assert acc_dtype in supported_acc_dtypes[b.dtype], \"acc_dtype not compatible with the type of b\"\n\n def to_tl_type(ty):\n return getattr(tl, str(ty).split(\".\")[-1])\n\n acc_dtype = to_tl_type(acc_dtype)\n ab_dtype = to_tl_type(ab_dtype)\n output_dtype = to_tl_type(output_dtype)\n\n if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:\n ab_dtype = None\n\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _fc2_fwd_kernel[grid](\n a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS=NUM_NZ_BLOCKS,\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, AB_DTYPE=ab_dtype\n )\n return c\n\n @staticmethod\n def _bwd(a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype):\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n\n ab_dtype = get_higher_dtype(a.dtype, b.dtype)\n\n if (output_dtype is None):\n output_dtype = ab_dtype\n\n c = torch.empty((M, N), device=device, dtype=output_dtype)\n\n supported_acc_dtypes = {\n torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),\n torch.float32: (torch.float32, ), torch.int8: (torch.int32, )\n }\n\n if acc_dtype is None:\n acc_dtype = supported_acc_dtypes[ab_dtype][0]\n else:\n assert isinstance(acc_dtype, torch.dtype), \"acc_dtype must be a torch.dtype\"\n assert acc_dtype in supported_acc_dtypes[a.dtype], \"acc_dtype not compatible with the type of a\"\n assert acc_dtype in supported_acc_dtypes[b.dtype], \"acc_dtype not compatible with the type of b\"\n\n def to_tl_type(ty):\n return getattr(tl, str(ty).split(\".\")[-1])\n\n acc_dtype = to_tl_type(acc_dtype)\n ab_dtype = to_tl_type(ab_dtype)\n output_dtype = to_tl_type(output_dtype)\n\n if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:\n ab_dtype = None\n\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * NUM_NZ_BLOCKS, META['SPLIT_K'])\n _fc2_bwd_kernel[grid](\n a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n NZ_BLOCK_INDICES,\n NUM_NZ_BLOCKS=NUM_NZ_BLOCKS,\n acc_dtype=acc_dtype,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, AB_DTYPE=ab_dtype\n )\n return c\n\n @staticmethod\n def forward(ctx, a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None):\n ctx.save_for_backward(a, b)\n ctx.NZ_BLOCK_INDICES = NZ_BLOCK_INDICES\n ctx.NUM_NZ_BLOCKS = NUM_NZ_BLOCKS\n return _fc2_matmul._fwd(a, b, NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, \n acc_dtype=acc_dtype, allow_tf32=allow_tf32, \n fp8_fast_accum=fp8_fast_accum, output_dtype=output_dtype)\n\n @staticmethod\n def backward(ctx, dc):\n a, b = ctx.saved_tensors\n NZ_BLOCK_INDICES = ctx.NZ_BLOCK_INDICES\n NUM_NZ_BLOCKS = ctx.NUM_NZ_BLOCKS\n grad_a = None\n if ctx.needs_input_grad[0]:\n grad_a = _fc2_matmul._bwd(dc, b.t(), NZ_BLOCK_INDICES, NUM_NZ_BLOCKS, acc_dtype=dc.dtype, allow_tf32=True, fp8_fast_accum=True, output_dtype=a.dtype)\n if ctx.needs_input_grad[1]:\n NotImplementedError(\"backward for b is not implemented, as it is not needed in PEFT\")\n return grad_a, None, None, None\n\n\nfc2_matmul = _fc2_matmul.apply\n", - "description_1": "Use triton language to implement matrix multiplication kernels (_fc2_fwd_kernel and _fc2_bwd_kernel) that handle forward and backward passes, respectively. These kernels take as input matrices A and B, produce output matrix C, and incorporate block sizes, strides, and non-zero block indices. The forward kernel supports mixed-precision operations with optional accumulation datatype, while the backward kernel computes gradients only for matrix A.", - "description_2": "Use triton language to create optimized matrix multiplication kernels for forward and backward passes, supporting mixed-precision and block-wise computation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics({\n 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,\n})\n@triton.jit\ndef _sdd_kernel(A, B, C, #\n stride_za, stride_ha, stride_ma, stride_ak, #\n stride_zb, stride_hb, stride_bk, stride_nb, #\n stride_zc, stride_hc, stride_mc, stride_nc, #\n K, grid_offset, lut, #\n TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #\n BLOCK: tl.constexpr, EVEN_K: tl.constexpr #\n ):\n # Triton SDD Kernel: performs sparse-dense-dense matrix multiplication\n block_id = tl.program_id(0) + grid_offset\n lut += block_id * 3\n off_z = tl.program_id(2) # batch\n off_h = tl.load(lut + 0) # head\n start_am = tl.load(lut + 1)\n offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)\n offs_ak = tl.arange(0, TILE_K)\n a_ptrs = A \\\n + off_z * stride_za \\\n + off_h * stride_ha \\\n + offs_am[:, None] * stride_ma \\\n + offs_ak[None, :] * stride_ak\n start_bn = tl.load(lut + 2)\n offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)\n offs_bk = tl.arange(0, TILE_K)\n b_ptrs = B \\\n + off_z * stride_zb \\\n + off_h * stride_hb \\\n + offs_bn[None, :] * stride_nb \\\n + offs_bk[:, None] * stride_bk\n acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n for k in range(K, 0, -TILE_K):\n if EVEN_K:\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n else:\n a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)\n b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)\n acc += tl.dot(a, b, out_dtype=tl.float32)\n a_ptrs += TILE_K * stride_ak\n b_ptrs += TILE_K * stride_bk\n c = acc.to(C.dtype.element_ty)\n offs_cm = tl.arange(0, TILE_M) % BLOCK\n offs_cn = tl.arange(0, TILE_N) % BLOCK\n pc = C \\\n + off_z * stride_zc \\\n + block_id * stride_hc \\\n + offs_cm[:, None] * stride_mc \\\n + offs_cn[None, :] * stride_nc\n tl.store(pc, c, mask=True)\n\ndef sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):\n # SDD Matmul: orchestrates the execution of the SDD kernel\n if a.stride(2) != 1 and a.stride(3) != 1:\n a = a.contiguous()\n if b.stride(2) != 1 and b.stride(3) != 1:\n b = b.contiguous()\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n Ka, Kb = a.shape[a_dim], b.shape[b_dim]\n if Ka != Kb:\n raise ValueError(f\"Inner dimension mismatch (A: {Ka} vs B: {Kb})\")\n if out is None:\n c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)\n else:\n assert out.shape == (a.shape[0], lut.shape[0], block, block)\n c = out\n grid = [c.shape[1], 1, c.shape[0]]\n _sdd_kernel[grid](\n a, b, c, #\n a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #\n b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #\n c.stride(0), c.stride(1), c.stride(2), c.stride(3), #\n Ka, 0, lut, #\n TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #\n num_warps=4 #\n )\n return c\n\n@triton.jit\ndef _dsd_kernel(A, B, C, #\n stride_az, stride_ha, stride_am, stride_ak, #\n stride_zb, stride_hb, stride_bk, stride_bn, #\n stride_zc, stride_hc, stride_cm, stride_cn, #\n DS0, DS1, lut, #\n TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #\n GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #\n ):\n # Triton DSD Kernel: performs dense-sparse-dense matrix multiplication\n pid_m = tl.program_id(0)\n pid_n = tl.program_id(1)\n num_pid_m = tl.num_programs(0)\n num_pid_n = tl.num_programs(1)\n pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)\n pidz = tl.program_id(2)\n header = lut + pid_n * 4\n offset = tl.load(header + 0)\n K = tl.load(header + 1)\n column = tl.load(header + 2)\n off_h = tl.load(header + 3)\n pinc = lut + offset\n block_id = tl.load(pinc + 1)\n block_id = tl.multiple_of(block_id, 8) # compiler hint\n offs_am = tl.arange(0, TILE_M)\n offs_ak = tl.arange(0, TILE_K)\n pa = A + pidz * stride_az \\\n + block_id * stride_ha \\\n + offs_am[:, None] * stride_am \\\n + offs_ak[None, :] * stride_ak\n offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)\n start_bk = tl.load(pinc)\n start_bk = tl.multiple_of(start_bk, 8) # compiler hint\n offs_bk = start_bk + tl.arange(0, TILE_K)\n pb = B + pidz * stride_zb \\\n + off_h * stride_hb \\\n + offs_bn[None, :] * stride_bn \\\n + offs_bk[:, None] * stride_bk\n acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n pinc += 2\n inc_a = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.load(pinc)\n inc_b = tl.multiple_of(inc_b, 8)\n for k in range(K, 0, -TILE_K):\n a = tl.load(pa)\n b = tl.load(pb)\n acc += tl.dot(a, b, out_dtype=tl.float32)\n pa += inc_a\n pb += inc_b * stride_bk\n pinc += 2\n inc_a = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.load(pinc)\n inc_b = tl.multiple_of(inc_b, 8)\n c = acc.to(C.dtype.element_ty)\n offs_cm = column * TILE_M + tl.arange(0, TILE_M)\n offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)\n pc = C \\\n + off_h * stride_hc \\\n + pidz * stride_zc \\\n + offs_cm[:, None] * stride_cm \\\n + offs_cn[None, :] * stride_cn\n tl.store(pc, c, mask=offs_cn[None, :] < DS0)\n\ndef dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):\n # DSD Matmul: orchestrates the execution of the DSD kernel\n if a.stride(2) != 1 and a.stride(3) != 1:\n a = a.contiguous()\n if b.stride(2) != 1 and b.stride(3) != 1:\n b = b.contiguous()\n AS1 = block * spdims[2 if trans_a else 1]\n BS0 = b.size(0)\n BS1 = b.size(1)\n BS3 = b.size(2 if trans_b else 3)\n dtype = a.dtype\n CS0 = BS0\n CS1 = BS1\n CS2 = BS3 if trans_c else AS1\n CS3 = AS1 if trans_c else BS3\n if out is None:\n c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)\n else:\n assert out.shape == (CS0, CS1, CS2, CS3)\n c = out\n TILE_N = 128\n grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]\n _dsd_kernel[grid](\n a, b, c, #\n a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #\n b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #\n c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #\n BS3, AS1, lut, #\n TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #\n num_warps=4, GROUP_SIZE_M=4 #\n )\n return c\n", - "description_1": "Use triton language to implement SDD and DSD matrix multiplication. The SDD kernel takes 16 arguments: A, B, C, 4 strides for A, 4 strides for B, 4 strides for C, integer K, grid_offset, LUT, and 4 constexpr TILE_M, TILE_N, TILE_K, BLOCK, EVEN_K, to perform sparse-dense-dense matmul. The DSD kernel takes 17 arguments: A, B, C, 4 strides for A, 4 strides for B, 4 strides for C, integers DS0, DS1, LUT, and 5 constexpr TILE_M, TILE_N, TILE_K, GROUP_SIZE_M, BLOCK, to perform dense-sparse-dense matmul. SDD and DSD matmul functions call their respective kernels, configuring grid and parameters.", - "description_2": "Use triton language to create kernels for SDD and DSD matrix multiplication and provide wrapper functions to execute these kernels with the appropriate grid and parameters.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #\n R, extent, stride_zr, stride_hr, # relative attention\n scale, is_causal, #\n ROW_SIZE: tl.constexpr, #\n BLOCK_SIZE: tl.constexpr, #\n IS_DENSE: tl.constexpr #\n ):\n h = tl.program_id(0)\n m = tl.program_id(1)\n z = tl.program_id(2)\n # create index ranges\n hm = h * tl.num_programs(1) + m\n lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE\n block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE\n # extract information from LUT\n header = LUT + (hm // BLOCK_SIZE) * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # pointer offset\n off_a = z * stride_xz\n off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx\n off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx\n # do not need to read column indices in the dense case\n if IS_DENSE:\n ns = tl.arange(0, ROW_SIZE)\n else:\n off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE\n start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)\n ns = start_n * BLOCK_SIZE + lane_n\n # load X\n mask = block_n < size\n a = tl.load(A + off_a + lane_n, mask=mask, other=-float(\"inf\"))\n a = a.to(tl.float32)\n # compute\n out = a\n out *= scale\n # apply relative attention\n if R is not None:\n R += z * stride_zr\n R += h * stride_hr\n off_lo = (extent - m - 1) + ns\n mask_lo = (off_lo >= 0) & (off_lo < extent)\n rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)\n out += rel_logits\n out = out.to(tl.float32)\n # apply causal mask\n out = tl.where((ns > m) & is_causal, -float(\"inf\"), out)\n # computation\n out = tl.softmax(out)\n # write-back\n tl.store(Out + off_a + lane_n, out, mask=mask)\n\n@triton.jit\ndef _blocksparse_softmax_bwd(DA, stride_zdx, #\n DOut, stride_zdout, #\n Out, stride_zout, #\n scale, #\n LUT, #\n DR, extent, stride_zr, stride_hr, stride_er, #\n is_causal, #\n ROW_SIZE: tl.constexpr, #\n BLOCK_SIZE: tl.constexpr, #\n IS_DENSE: tl.constexpr):\n h = tl.program_id(0)\n m = tl.program_id(1)\n z = tl.program_id(2)\n # create index ranges\n hm = h * tl.num_programs(1) + m\n lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE\n block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE\n # extract information from LUT\n header = LUT + (hm // BLOCK_SIZE) * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # row-col offset\n off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE\n off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE\n mask = block_n < size\n # pointers\n As = Out + z * stride_zout + off_mn\n DOuts = DOut + z * stride_zdout + off_mn\n # do not need to read column indices in the dense case\n if IS_DENSE:\n ns = tl.arange(0, ROW_SIZE)\n else:\n off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE\n start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)\n ns = start_n * BLOCK_SIZE + lane_n\n # load data\n a = tl.load(As + lane_n, mask=mask, other=0.0)\n a = a.to(tl.float32)\n dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)\n dout = dout.to(tl.float32)\n # compute\n a = tl.where((ns > m) & is_causal & (a == a), 0., a)\n da = a * (dout - tl.sum(a * dout, 0))\n # apply relative attention\n if DR is not None:\n DR += z * stride_zr\n DR += h * stride_hr\n off_lo = (extent - m - 1) + ns\n mask_lo = (off_lo >= 0) & (off_lo < extent) & mask\n tl.store(DR + m * extent + off_lo, da, mask=mask_lo)\n da = da * scale\n # convert da\n # write-back\n DAs = DA + z * stride_zdx + off_mn\n tl.store(DAs + lane_n, da, mask=mask)\n\nclass _softmax(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):\n if scale is not None and isinstance(scale, torch.Tensor):\n assert scale.device.type == \"cpu\"\n scale = scale.item()\n M = a.shape[0]\n grid = [spdims[0], spdims[1] * block, M]\n rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape\n rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()\n # enqueue kernel\n out = torch.empty_like(a)\n _blocksparse_softmax_fwd[grid](\n out, a, a.stride(0), lut, #\n rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#\n scale, #\n is_causal, #\n BLOCK_SIZE=block, #\n ROW_SIZE=triton.next_power_of_2(maxlut), #\n IS_DENSE=is_dense, #\n num_warps=num_warps(maxlut) #\n )\n # save to context\n # ctx.mark_dirty(x)\n ctx.save_for_backward(out, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.rel_shape = rel_shape\n ctx.rel_strides = rel_strides\n ctx.rel_dtype = a.dtype\n ctx.is_dense = is_dense\n ctx.is_causal = is_causal\n return out\n\n @staticmethod\n def backward(ctx, dout):\n # retrieve from context\n out, lut = ctx.saved_tensors\n # relative logits gradients\n dr = None\n if ctx.needs_input_grad[3]:\n dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)\n # run kernel\n M = out.shape[0]\n grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)\n da = torch.empty_like(dout)\n _blocksparse_softmax_bwd[grid](\n da, da.stride(0), #\n dout, dout.stride(0), #\n out, out.stride(0), #\n ctx.scale, #\n lut, #\n dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #\n ctx.is_causal, #\n BLOCK_SIZE=ctx.block, #\n ROW_SIZE=triton.next_power_of_2(ctx.maxlut), #\n IS_DENSE=ctx.is_dense, #\n num_warps=num_warps(ctx.maxlut) #\n )\n return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)\n", - "description_1": "Use triton language to implement a block-sparse softmax forward and backward kernel. The forward kernel (_blocksparse_softmax_fwd) takes 12 parameters: Out (output tensor), A (input tensor), stride_xz (stride for input tensor), LUT (lookup table), R (relative attention tensor), extent (extent of relative attention), stride_zr (stride for relative attention), stride_hr (stride for relative attention), scale (scaling factor), is_causal (causal flag), ROW_SIZE (row size as constexpr), BLOCK_SIZE (block size as constexpr), and IS_DENSE (dense flag as constexpr). The backward kernel (_blocksparse_softmax_bwd) takes 15 parameters: DA (gradient of input tensor), stride_zdx (stride for DA), DOut (gradient of output tensor), stride_zdout (stride for DOut), Out (output tensor), stride_zout (stride for Out), scale (scaling factor), LUT (lookup table), DR (gradient of relative attention), extent (extent of relative attention), stride_zr (stride for relative attention), stride_hr (stride for relative attention), stride_er (stride for relative attention), is_causal (causal flag), ROW_SIZE (row size as constexpr), BLOCK_SIZE (block size as constexpr), and IS_DENSE (dense flag as constexpr).", - "description_2": "Use triton language to create a block-sparse softmax operation with forward and backward passes, handling relative attention and causal masking.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(Q, K, V, sm_scale, \n L, \n Out, \n stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vn, stride_vk, \n stride_oz, stride_oh, stride_om, stride_on, \n Z, H, N_CTX, \n Z_H_N_CTX, \n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, \n BLOCK_N: tl.constexpr, \n IS_CAUSAL: tl.constexpr \n ):\n # Kernel logic\n\n@triton.jit\ndef _bwd_preprocess(\n Out,\n DO,\n Delta,\n BLOCK_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n # Kernel logic\n\n@triton.jit\ndef _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, \n Out, DO, \n DQ, DK, DV, \n L, \n D, \n Q_block_ptr, K_block_ptr, V_block_ptr, \n DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, \n stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vn, stride_vk, \n Z, H, N_CTX, \n off_h, off_z, off_hz, start_n, num_block, \n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, \n BLOCK_N: tl.constexpr, \n SEQUENCE_PARALLEL: tl.constexpr, \n CAUSAL: tl.constexpr, \n MMA_V3: tl.constexpr \n ):\n # Kernel logic\n\n@triton.jit\ndef _bwd_kernel(Q, K, V, sm_scale, \n Out, DO, \n DQ, DK, DV, \n L, \n D, \n stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vn, stride_vk, \n Z, H, N_CTX, \n Z_H_N_CTX, \n SQ_Z_H_N_CTX, \n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, \n BLOCK_N: tl.constexpr, \n SEQUENCE_PARALLEL: tl.constexpr, \n CAUSAL: tl.constexpr, \n MMA_V3: tl.constexpr \n ):\n # Kernel logic\n", - "description_1": "Use triton language to implement forward and backward kernels for a sequence-to-sequence model, with optimizations for parallelism, causal masking, and various block configurations.", - "description_2": "Use triton language to implement forward and backward kernels for sequence modeling tasks, with optimizations for GPU parallelism and memory efficiency.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef get_higher_dtype(a, b):\n _ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]\n def upcast_if_fp8(a):\n if \"fp8\" in str(a):\n return torch.float16\n return a\n\n a = upcast_if_fp8(a)\n b = upcast_if_fp8(b)\n if a is b:\n return a\n\n assert a in _ordered_datatypes\n assert b in _ordered_datatypes\n\n for d in _ordered_datatypes:\n if a is d:\n return b\n if b is d:\n return a\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n prune_configs_by={\n 'early_config_prune': lambda *args: None,\n 'perf_model': lambda *args: None,\n 'top_k': 10,\n },\n)\n@triton.heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef _kernel(A, B, C, M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n acc_dtype: tl.constexpr, #\n allow_tf32: tl.constexpr, #\n fp8_fast_accum: tl.constexpr, #\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #\n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #\n ):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n if AB_DTYPE is not None:\n a = a.to(AB_DTYPE)\n b = b.to(AB_DTYPE)\n if fp8_fast_accum:\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n else:\n acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = acc.to(C.dtype.element_ty)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\nclass _matmul(torch.autograd.Function):\n kernel = _kernel\n\n @staticmethod\n def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype):\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n ab_dtype = get_higher_dtype(a.dtype, b.dtype)\n if (output_dtype is None):\n output_dtype = ab_dtype\n c = torch.empty((M, N), device=device, dtype=output_dtype)\n supported_acc_dtypes = {\n torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),\n torch.float32: (torch.float32, ), torch.int8: (torch.int32, )\n }\n if acc_dtype is None:\n acc_dtype = supported_acc_dtypes[ab_dtype][0]\n else:\n assert isinstance(acc_dtype, torch.dtype), \"acc_dtype must be a torch.dtype\"\n assert acc_dtype in supported_acc_dtypes[a.dtype], \"acc_dtype not compatible with the type of a\"\n assert acc_dtype in supported_acc_dtypes[b.dtype], \"acc_dtype not compatible with the type of b\"\n\n def to_tl_type(ty):\n return getattr(tl, str(ty).split(\".\")[-1])\n\n acc_dtype = to_tl_type(acc_dtype)\n ab_dtype = to_tl_type(ab_dtype)\n output_dtype = to_tl_type(output_dtype)\n if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:\n ab_dtype = None\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n _kernel[grid](\n a, b, c, M, N, K, #\n a.stride(0), a.stride(1), #\n b.stride(0), b.stride(1), #\n c.stride(0), c.stride(1), #\n acc_dtype=acc_dtype, #\n allow_tf32=allow_tf32, #\n fp8_fast_accum=fp8_fast_accum, #\n GROUP_M=8, AB_DTYPE=ab_dtype)\n return c\n\n @staticmethod\n def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None):\n return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum,\n output_dtype=output_dtype)\n\nmatmul = _matmul.apply\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with parameters for input matrices A and B, output matrix C, dimensions M, N, K, strides for A, B, C, and various compile-time constants for optimization. The kernel supports different data types and accumulation strategies, including support for TensorFloat-32 and fast accumulation for float8 types.", - "description_2": "Use triton language to create a matrix multiplication function that handles different data types and optimizes performance using compile-time constants and heuristics.", - "difficulty": 4 - }, - { - "code": "import math\nimport triton\nimport triton.testing\n\n\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n\n\n@triton.autotune(configs=[\n triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),\n triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),\n ],\n key=['x_size']\n)\ndef call_kernel(x_ptr, x_size):\n kernel[(1,)](x_ptr, x_size)\n", - "description_1": "Use triton language to define a kernel that processes data based on block sizes, with two configurations (BLOCK_SIZE = 128 with 4 warps and BLOCK_SIZE = 1024 with 8 warps). An autotuner decorates the kernel, using x_size as a key for tuning configurations.", - "description_2": "Use triton language to define a kernel with autotuning based on x_size to choose between two block size configurations.", - "difficulty": 1 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(\n scales_ptrs + g_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs + g_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty(\n (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n input.shape[1],\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (matmul_248_kernel) that computes C = A x B, where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The kernel also uses scales and zeros for quantization, which are float16 matrices of shape (G, N). The function matmul248 is a wrapper that prepares the input and output tensors and launches the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a matrix multiplication kernel with quantization support, and a wrapper function to execute it.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Kernel function with @triton.jit decorator\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\n# Function to call the kernel\ndef call_kernel(x_ptr, x_size):\n # Example of calling the kernel with specific configurations\n kernel[(1,)](x_ptr, x_size, BLOCK_SIZE=128)\n", - "description_1": "Use triton language to define a kernel function 'kernel' with 2 parameters: x_ptr (pointer to data) and x_size (size of data). The kernel uses a meta-parameter BLOCK_SIZE. A separate function 'call_kernel' is used to invoke this kernel with specific configurations.", - "description_2": "Use triton language to create a kernel with parameters for data pointer and size, and a meta-parameter for block size. Implement a function to call this kernel with specific configurations.", - "difficulty": 2 - }, - { - "code": "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.heuristics(\n {\n \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n }\n)\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, Bias, Out,\n Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n softmax_scale,\n stride_qb, stride_qh, stride_qm,\n stride_kb, stride_kh, stride_kn,\n stride_vb, stride_vh, stride_vn,\n stride_bb, stride_bh, stride_bm,\n stride_ob, stride_oh, stride_om,\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,\n CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,\n BIAS_TYPE: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_HEADDIM: tl.constexpr,\n EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hb = tl.program_id(1)\n off_b = off_hb // nheads\n off_h = off_hb % nheads\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n if BIAS_TYPE == 'vector':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n elif BIAS_TYPE == 'matrix':\n b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])\n t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n if EVEN_M & EVEN_N:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs)\n else:\n q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n else:\n q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n other=0.0)\n end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n for start_n in range(0, end_n, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n if not EVEN_N:\n qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n if IS_CAUSAL:\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n if BIAS_TYPE != 'none':\n if BIAS_TYPE == 'vector':\n if EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)\n bias = bias[None, :]\n elif BIAS_TYPE == 'matrix':\n if EVEN_M & EVEN_N:\n bias = tl.load(b_ptrs + start_n).to(tl.float32)\n else:\n bias = tl.load(b_ptrs + start_n,\n mask=(offs_m[:, None] < seqlen_q)\n & ((start_n + offs_n)[None, :] < seqlen_k),\n other=0.0).to(tl.float32)\n qk = qk * softmax_scale + bias\n m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n p = tl.exp(qk - m_ij[:, None])\n else:\n m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n p = tl.exp(qk * softmax_scale - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n acc_o_scale = tl.exp(m_i - m_ij)\n tl.store(t_ptrs, acc_o_scale)\n acc_o_scale = tl.load(t_ptrs)\n acc_o = acc_o * acc_o_scale[:, None]\n if EVEN_N & EVEN_M:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n else:\n if EVEN_HEADDIM:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,\n other=0.0)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn,\n mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n other=0.0)\n p = p.to(v.dtype)\n acc_o += tl.dot(p, v)\n m_i = m_ij\n l_i_new = tl.exp(lse_i - m_ij) + l_ij\n lse_i = m_ij + tl.log(l_i_new)\n o_scale = tl.exp(m_i - lse_i)\n tl.store(t_ptrs, o_scale)\n o_scale = tl.load(t_ptrs)\n acc_o = acc_o * o_scale[:, None]\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n tl.store(lse_ptrs, lse_i)\n offs_d = tl.arange(0, BLOCK_HEADDIM)\n out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])\n if EVEN_M:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o)\n else:\n tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n else:\n if EVEN_HEADDIM:\n tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n else:\n tl.store(out_ptrs, acc_o,\n mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n batch, seqlen_q, nheads, d = q.shape\n _, seqlen_k, _, _ = k.shape\n assert k.shape == (batch, seqlen_k, nheads, d)\n assert v.shape == (batch, seqlen_k, nheads, d)\n assert d <= 128, 'FlashAttention only support head dimensions up to 128'\n assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'\n assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'\n assert q.is_cuda and k.is_cuda and v.is_cuda\n softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n\n has_bias = bias is not None\n bias_type = 'none'\n if has_bias:\n assert bias.dtype in [q.dtype, torch.float]\n assert bias.is_cuda\n assert bias.dim() == 4\n if bias.stride(-1) != 1:\n bias = bias.contiguous()\n if bias.shape[2:] == (1, seqlen_k):\n bias_type = 'vector'\n elif bias.shape[2:] == (seqlen_q, seqlen_k):\n bias_type = 'matrix'\n else:\n raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'\n ' or (seqlen_q, seqlen_k)')\n bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n\n seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.empty_like(q)\n\n BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n BLOCK = 128\n num_warps = 4 if d <= 64 else 8\n grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n _fwd_kernel[grid](\n q, k, v, bias, o,\n lse, tmp,\n softmax_scale,\n q.stride(0), q.stride(2), q.stride(1),\n k.stride(0), k.stride(2), k.stride(1),\n v.stride(0), v.stride(2), v.stride(1),\n *bias_strides,\n o.stride(0), o.stride(2), o.stride(1),\n nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,\n seqlen_q // 32, seqlen_k // 32,\n bias_type, causal, BLOCK_HEADDIM,\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return o, lse, softmax_scale\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):\n q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]\n o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)\n ctx.save_for_backward(q, k, v, o, lse, bias)\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, lse, bias = ctx.saved_tensors\n assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'\n with torch.inference_mode():\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,\n bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)\n return dq, dk, dv, None, None, None\n\n\nflash_attn_func = FlashAttnFunc.apply\n", - "description_1": "Use triton language to implement a FlashAttention forward kernel and its corresponding backward function. The forward kernel (_fwd_kernel) takes 30 parameters: Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, and several constexpr parameters. It computes the attention output using the provided Q, K, V matrices and optional bias, with support for causal masking and different head dimensions. The backward function (FlashAttnFunc) computes gradients for Q, K, and V given the gradient of the output.", - "description_2": "Use triton language to create a FlashAttention operator with a forward kernel that computes attention outputs from Q, K, V matrices, and a backward function to compute gradients. The operator supports optional bias and causal masking.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n # ! Convert to fp16\n b = b.to(tl.float16)\n a = a.to(tl.float16)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c = accumulator.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\n@triton.jit\ndef trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K, bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n # ! Convert to fp16\n b = b.to(tl.float16)\n a = a.to(tl.float16)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c = accumulator.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):\n assert input.shape[-1] == qweight.shape[0] * 32 // bits\n outshape = input.shape[:-1] + (qweight.shape[1],)\n input = input.reshape(-1, input.shape[-1])\n output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)\n matmul_248_kernel[grid](input, qweight, output,\n scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0))\n output = output.reshape(outshape)\n return output\n\n\ndef triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):\n assert input.shape[-1] == qweight.shape[1]\n out_dim = qweight.shape[0] * 32 // bits\n outshape = input.shape[:-1] + (out_dim,)\n input = input.reshape(-1, input.shape[-1])\n output_shape_mid = (input.shape[0], out_dim)\n output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),)\n trans_matmul_248_kernel[grid](input, qweight, output,\n scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0))\n output = output.reshape(outshape)\n return output\n", - "description_1": "Use triton language to implement two kernels and their respective calling functions for matrix multiplication. The first kernel, matmul_248_kernel, computes matrix C as a product of A and B with scaling and zero offset adjustments, where A is float16, B is int32, and scales and zeros are float16. The kernel uses parameters to manage strides, block sizes, and groups. The second kernel, trans_matmul_248_kernel, computes C for transposed A. Both kernels involve bit manipulation of B for matrix computation. The triton_matmul and triton_matmul_transpose functions call these kernels, passing appropriate tensor arguments.", - "description_2": "Use triton language to implement kernels for matrix multiplication, including bit manipulation and scaling, with functions to handle matrix input and output.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@torch.compile(fullgraph=True)\ndef add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)\n return output\n\nx = torch.randn(4, device=\"cuda\")\ny = torch.randn(4, device=\"cuda\")\nout = add_fn(x, y)\nprint(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@torch.compile(fullgraph=True)\ndef add_fn_autotuned(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel_autotuned[grid](x, y, output, n_elements)\n return output\n\nx = torch.randn(4, device=\"cuda\")\ny = torch.randn(4, device=\"cuda\")\nout = add_fn_autotuned(x, y)\nprint(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")\n", - "description_1": "Use triton language to define a vector addition kernel with a block size of 4, performing element-wise addition of two input vectors and storing the result in an output vector. The kernel considers the number of elements and uses a program ID to calculate block start and offsets for loading and storing data. The kernel is called using a torch-compiled function 'add_fn'. Similarly, define an autotuned version of the vector addition kernel with multiple configuration options, leveraging Triton's autotune feature to optimize kernel execution. This autotuned kernel is called using a torch-compiled function 'add_fn_autotuned'.", - "description_2": "Use triton language to create a vector addition kernel with specified block sizes, performing element-wise addition of two vectors. Employ triton.autotune to explore different configurations for optimal performance, integrating with torch.compile for seamless execution in PyTorch.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@torch.compile(fullgraph=True)\ndef add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)\n return output\n\nx = torch.randn(4, device=\"cuda\")\ny = torch.randn(4, device=\"cuda\")\nout = add_fn(x, y)\nprint(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@torch.compile(fullgraph=True)\ndef add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel_autotuned[grid](x, y, output, n_elements)\n return output\n\nx = torch.randn(4, device=\"cuda\")\ny = torch.randn(4, device=\"cuda\")\nout = add_fn(x, y)\nprint(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")\n", - "description_1": "Use triton language to implement a vector addition kernel with two versions: a basic version and an autotuned version. The basic version, 'add_kernel', takes five parameters: two input pointers, an output pointer, the number of elements, and a block size. It performs element-wise addition of two vectors. The autotuned version, 'add_kernel_autotuned', is similar but uses triton's autotuning feature to optimize performance. The 'add_fn' function wraps these kernels for use with torch.compile, taking two input tensors and returning their element-wise sum.", - "description_2": "Use triton language to create a vector addition kernel with basic and autotuned versions, wrapped for torch.compile.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n LAST_K_BLOCK: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n BLOCK_N: tl.constexpr,\n D_HEAD: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +\n k_block_col_idx * layout_col_stride_m).to(tl.int32)\n start_n = k_block_id * BLOCK_N\n if LAST_K_BLOCK:\n if EVEN_D:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=offs_n[None, :] + start_n < k_seqlen,\n )\n else:\n k = tl.load(\n k_ptrs + start_n * stride_kt,\n mask=(offs_n[None, :] + start_n < k_seqlen) &\n (offs_d[:, None] < D_HEAD),\n )\n else:\n if EVEN_D:\n k = tl.load(k_ptrs + start_n * stride_kt)\n else:\n k = tl.load(k_ptrs + start_n * stride_kt,\n mask=offs_d[:, None] < D_HEAD)\n\n qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n\n if LAST_K_BLOCK | M_LT_N:\n qk += tl.where(\n offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),\n 0,\n float(\"-inf\"),\n )\n\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n p = tl.math.exp2(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n m_i = m_ij\n l_i = l_i * alpha + l_ij\n\n p = p.to(Q.dtype.element_ty)\n if LAST_K_BLOCK:\n if EVEN_D:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=offs_n[:, None] + start_n < k_seqlen,\n )\n else:\n v = tl.load(\n v_ptrs + start_n * stride_vt,\n mask=(offs_n[:, None] + start_n < k_seqlen) &\n (offs_d[None, :] < D_HEAD),\n )\n else:\n if EVEN_D:\n v = tl.load(v_ptrs + start_n * stride_vt)\n else:\n v = tl.load(v_ptrs + start_n * stride_vt,\n mask=offs_d[None, :] < D_HEAD)\n\n acc += tl.dot(p, v)\n\n return acc, l_i, m_i\n\n\n@triton.heuristics({\n \"M_LT_N\":\n lambda kwargs: kwargs[\"BLOCK_M\"] < kwargs[\"BLOCK_N\"],\n})\n@triton.jit\ndef _fwd_kernel_batch_inference(\n Q,\n K,\n V,\n Out,\n sm_scale,\n q_batch_starts,\n q_batch_ends,\n k_batch_starts,\n k_batch_ends,\n q_batch_ids,\n q_start_sids,\n stride_qb,\n stride_qt,\n stride_qh,\n stride_qd,\n stride_kb,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_vb,\n stride_vt,\n stride_vh,\n stride_vd,\n stride_ob,\n stride_ot,\n stride_oh,\n stride_od,\n layout_crow_ptr,\n layout_col_ptr,\n layout_crow_stride_h,\n layout_crow_stride_m,\n layout_col_stride_h,\n layout_col_stride_m,\n q_k_ratio,\n HAS_BATCH_DIM: tl.constexpr,\n D_HEAD: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_M_LOADING: tl.constexpr,\n EVEN_D: tl.constexpr,\n M_LT_N: tl.constexpr,\n):\n off_zm = tl.program_id(0)\n off_h = tl.program_id(1)\n\n off_h_for_kv = off_h // q_k_ratio\n\n if HAS_BATCH_DIM:\n off_z = tl.program_id(2)\n Q += off_z * stride_qb\n K += off_z * stride_kb\n V += off_z * stride_vb\n Out += off_z * stride_ob\n start_m = off_zm\n q_start_sid = start_m * BLOCK_M\n else:\n off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)\n q_start_sid = tl.load(q_start_sids + off_zm)\n start_m = q_start_sid // BLOCK_M\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_D)\n\n q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)\n q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start\n k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)\n k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start\n past_len = k_seqlen - q_seqlen\n\n Q += q_cu_start * stride_qt + off_h * stride_qh\n K += k_cu_start * stride_kt + off_h_for_kv * stride_kh\n V += k_cu_start * stride_vt + off_h_for_kv * stride_vh\n Out += q_cu_start * stride_ot + off_h * stride_oh\n\n q_pbid = (past_len + q_start_sid) // BLOCK_M\n\n if EVEN_D:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n q = tl.load(\n Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n other=0,\n )\n\n sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +\n q_pbid * layout_crow_stride_m)\n\n k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)\n k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)\n\n m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)\n\n k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd\n v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd\n\n sm_scale *= (\n 1.44269504\n )\n\n for k_block_col_idx in range(k_block_start, k_block_end - 1):\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_col_idx,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n False,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n acc, l_i, m_i = _fwd_kernel_inner(\n acc,\n l_i,\n m_i,\n q,\n Q,\n k_block_end - 1,\n layout_col_ptr,\n layout_col_stride_h,\n layout_col_stride_m,\n k_ptrs,\n v_ptrs,\n off_h,\n offs_m,\n offs_n,\n offs_d,\n stride_kt,\n stride_vt,\n sm_scale,\n k_seqlen,\n past_len,\n True,\n BLOCK_M_LOADING,\n BLOCK_N,\n D_HEAD,\n EVEN_D,\n M_LT_N,\n )\n\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n\n if EVEN_D:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=offs_m[:, None] < q_seqlen,\n )\n else:\n tl.store(\n Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,\n acc,\n mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),\n )\n\ndef blocksparse_flash_attn_varlen_fwd(\n q,\n k,\n v,\n cu_seqlens_k,\n cu_seqlens_q,\n sm_scale,\n sparse_layout,\n *,\n block_size=64,\n q_block_size=None,\n max_seqlen=None):\n assert isinstance(sparse_layout, (list, tuple))\n\n _, n_heads, head_size = q.shape\n batch_size = cu_seqlens_k.size(0) - 1\n q_block_size = q_block_size or block_size\n\n assert q.dim() == k.dim() == v.dim() == 3\n assert q.size(1) % k.size(1) == 0\n assert q.size(2) == k.size(2)\n assert k.shape == v.shape\n assert cu_seqlens_k.dim() == 1\n\n q_k_ratio = q.size(1) // k.size(1)\n\n if cu_seqlens_q is None:\n if q.size(0) == batch_size:\n cu_seqlens_q = torch.arange(\n 0,\n batch_size + 1,\n dtype=cu_seqlens_k.dtype,\n device=cu_seqlens_k.device,\n )\n elif q.size(0) == k.size(0):\n cu_seqlens_q = cu_seqlens_k\n else:\n raise ValueError(\"cu_seqlens_q must be specified\\\n if it mix of prefilling and decoding.\")\n else:\n assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)\n\n q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()\n k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()\n\n assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (\n \"length of q should either be 1 (decoding) or same as k (prefilling).\")\n\n if max_seqlen:\n assert k_lens.max() <= max_seqlen\n\n n_blocks = (q_lens + q_block_size - 1) // q_block_size\n\n q_batch_ids = torch.tensor(\n [i for i, n in enumerate(n_blocks) for _ in range(n)],\n dtype=cu_seqlens_q.dtype,\n device=cu_seqlens_q.device,\n )\n q_start_sids = torch.tensor(\n [i * q_block_size for n in n_blocks for i in range(n)],\n dtype=cu_seqlens_q.dtype,\n device=cu_seqlens_q.device,\n )\n\n out = q.new_empty(q.shape)\n cu_seqlens_q = cu_seqlens_q.contiguous()\n cu_seqlens_k = cu_seqlens_k.contiguous()\n\n layout_crow_indices, layout_col_indices = sparse_layout\n block_d = triton.next_power_of_2(head_size)\n\n decoding_only = (q_lens == 1).all().item()\n grid = (len(q_start_sids), n_heads, 1)\n\n _fwd_kernel_batch_inference[grid](\n q,\n k,\n v,\n out,\n sm_scale,\n cu_seqlens_q[:-1],\n cu_seqlens_q[1:],\n cu_seqlens_k[:-1],\n cu_seqlens_k[1:],\n q_batch_ids,\n q_start_sids,\n 0,\n *q.stride(),\n 0,\n *k.stride(),\n 0,\n *v.stride(),\n 0,\n *out.stride(),\n layout_crow_indices,\n layout_col_indices,\n *layout_crow_indices.stride(),\n *layout_col_indices.stride(),\n q_k_ratio,\n HAS_BATCH_DIM=False,\n D_HEAD=head_size,\n BLOCK_M=q_block_size,\n BLOCK_N=block_size,\n BLOCK_D=block_d,\n BLOCK_M_LOADING=(16 if decoding_only else\n q_block_size),\n EVEN_D=block_d == head_size,\n num_warps=1 if decoding_only else 4,\n num_stages=3)\n\n return out\n", - "description_1": "Use triton language to implement a blocksparse flash attention mechanism with variable length sequences. The kernel '_fwd_kernel_inner' takes 22 parameters including tensors for accumulation, scaling, and layout indices, and constants for block sizes and dimensions. The kernel '_fwd_kernel_batch_inference' takes 38 parameters including input and output tensors, scaling factors, sequence lengths, and layout indices, along with constants for block sizes and dimensions. The function 'blocksparse_flash_attn_varlen_fwd' orchestrates the process by preparing input data and launching the Triton kernel with appropriate grid dimensions.", - "description_2": "Use triton language to implement a blocksparse flash attention mechanism with variable length sequences, utilizing two kernels for computation and a function to manage data preparation and kernel execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom aphrodite.platforms import current_platform\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n k_scale,\n v_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n SLIDING_WINDOW: tl.constexpr,\n):\n # Kernel implementation\n ...\n\n@triton.jit\ndef _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n k_scale,\n v_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_DMODEL_PADDED: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Kernel implementation\n ...\n\n@torch.inference_mode()\ndef context_attention_fwd(q,\n k,\n v,\n o,\n kv_cache_dtype: str,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n k_scale: float = 1.0,\n v_scale: float = 1.0,\n alibi_slopes=None,\n sliding_window=None):\n # Function implementation\n ...\n", - "description_1": "Use triton language to implement forward kernels for context attention with optional alibi bias. The kernels handle query, key, and value matrices, caching, and output computation. The main function, context_attention_fwd, sets up the grid and calls the appropriate kernel based on the presence of alibi slopes.", - "description_2": "Use triton language to implement forward kernels for context attention with optional alibi bias, handling query, key, and value matrices, caching, and output computation.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef cdiv_fn(x, y):\n return (x + y - 1) // y\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):\n ms = tl.arange(0, m)\n ns = tl.arange(0, n)\n return philox_offset + ms[:, None] * stride + ns[None, :]\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)\n return tl.rand(philox_seed, rng_offsets)\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)\n rng_keep = rng_output > dropout_p\n return rng_keep\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n if first and second:\n tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)\n elif first:\n tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)\n elif second:\n tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)\n else:\n tensor = tl.load(block_ptr)\n return tensor\n\n@triton.jit\ndef _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n actual_seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n OFFS_M: tl.constexpr,\n OFFS_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n MASK_STEPS: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n):\n for start_n in range(block_min, block_max, BLOCK_N):\n k = load_fn(\n K_block_ptr,\n PADDED_HEAD,\n MASK_STEPS and (n_extra_tokens != 0),\n \"zero\",\n )\n if PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MASK_STEPS:\n if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)\n size_n = start_n + OFFS_N[None, :]\n mask = size_n < boundary_m[:, None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if IS_CAUSAL:\n causal_boundary = start_n + offs_n_causal\n causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n qk = tl.where(causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n if bias_ptr is not None:\n bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), \"zero\")\n qk += bias * 1.44269504089\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = (batch_philox_offset +\n start_m * BLOCK_M * actual_seqlen_k + start_n -\n BLOCK_N)\n keep = dropout_mask(\n philox_seed,\n philox_offset,\n dropout_p,\n BLOCK_M,\n BLOCK_N,\n actual_seqlen_k,\n )\n if RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n tl.where(keep, p,\n -p).to(encoded_softmax_block_ptr.type.element_ty),\n )\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(\n encoded_softmax_block_ptr,\n p.to(encoded_softmax_block_ptr.type.element_ty),\n )\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not PRE_LOAD_V:\n v = load_fn(\n V_block_ptr,\n MASK_STEPS and (n_extra_tokens != 0),\n PADDED_HEAD,\n \"zero\",\n )\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,\n (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 256,\n \"BLOCK_N\": 128,\n \"waves_per_eu\": 2,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": True,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 128,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 3,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 64,\n \"BLOCK_N\": 64,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 32,\n \"BLOCK_N\": 32,\n \"waves_per_eu\": 4,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_M\": 16,\n \"BLOCK_N\": 16,\n \"waves_per_eu\": 1,\n \"PRE_LOAD_V\": False,\n },\n num_stages=1,\n num_warps=4,\n ),\n ],\n key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],\n)\n@triton.jit\ndef attn_fwd(\n Q,\n K,\n V,\n bias,\n sm_scale,\n L,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n stride_bz,\n stride_bh,\n stride_bm,\n stride_bn,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n HQ: tl.constexpr,\n HK: tl.constexpr,\n ACTUAL_BLOCK_DMODEL: tl.constexpr,\n MAX_SEQLENS_Q: tl.constexpr,\n MAX_SEQLENS_K: tl.constexpr,\n VARLEN: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_h_q = tl.program_id(1)\n off_z = tl.program_id(2)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n if VARLEN:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M > seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n else:\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = MAX_SEQLENS_Q\n seqlen_k = MAX_SEQLENS_K\n\n n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n if IS_CAUSAL:\n n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)\n n_blocks = min(n_blocks, n_blocks_seqlen)\n if n_blocks <= 0:\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n return\n\n GROUP_SIZE: tl.constexpr = HQ // HK\n off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q\n\n n_extra_tokens = 0\n if seqlen_k < BLOCK_N:\n n_extra_tokens = BLOCK_N - seqlen_k\n elif seqlen_k % BLOCK_N:\n n_extra_tokens = seqlen_k % BLOCK_N\n padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL\n\n q_offset = (off_z * stride_qz + off_h_q * stride_qh +\n cu_seqlens_q_start * stride_qm)\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n k_offset = (off_z * stride_kz + off_h_k * stride_kh +\n cu_seqlens_k_start * stride_kn)\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n v_offset = (off_z * stride_vz + off_h_k * stride_vh +\n cu_seqlens_k_start * stride_vk)\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n if BIAS_TYPE != 0:\n bias_ptr = tl.make_block_ptr(\n base=bias + off_h_q * stride_bh,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_bm, stride_bn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n bias_ptr = None\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base \\\n + (off_z * HQ + off_h_q) \\\n * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n else:\n encoded_softmax_block_ptr = 0\n\n m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504089\n q = load_fn(Q_block_ptr, True, padded_head, \"zero\")\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n padded_block_k = n_extra_tokens != 0\n is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n if IS_CAUSAL:\n masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n else:\n masked_blocks = padded_block_k\n masked_blocks = min(masked_blocks, n_blocks)\n n_full_blocks = n_blocks - masked_blocks\n block_min = 0\n block_max = n_blocks * BLOCK_N\n\n if n_full_blocks > 0:\n block_max = (n_blocks - masked_blocks) * BLOCK_N\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n 0,\n 0,\n 0,\n bias_ptr,\n False,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n False,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n block_min = block_max\n block_max = n_blocks * BLOCK_N\n\n tl.debug_barrier()\n if masked_blocks > 0:\n offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0\n K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))\n if bias_ptr is not None:\n bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks))\n acc, l_i, m_i = _attn_fwd_inner(\n acc,\n l_i,\n m_i,\n q,\n K_block_ptr,\n V_block_ptr,\n start_m,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n block_min,\n block_max,\n offs_n_causal,\n masked_blocks,\n n_extra_tokens,\n bias_ptr,\n IS_CAUSAL,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n offs_m,\n offs_n,\n PRE_LOAD_V,\n True,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n padded_head,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n end_m_idx = (start_m + 1) * BLOCK_M\n start_m_idx = start_m * BLOCK_M\n causal_start_idx = seqlen_q - seqlen_k\n acc = acc.to(Out.type.element_ty)\n if IS_CAUSAL:\n if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n out_mask_boundary = tl.full((BLOCK_DMODEL, ),\n causal_start_idx,\n dtype=tl.int32)\n mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n out_ptrs_mask = (mask_m_offsets[:, None] >=\n out_mask_boundary[None, :])\n z = 0.0\n acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +\n off_h_q * stride_oh)\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n tl.store(O_block_ptr, acc, boundary_check=(0, 1))\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n q,\n k,\n v,\n o,\n cu_seqlens_q,\n cu_seqlens_k,\n max_seqlens_q,\n max_seqlens_k,\n causal=False,\n sm_scale=1.0,\n bias=None,\n ):\n if o is None:\n o = torch.empty_like(q, dtype=v.dtype)\n\n check_args(\n q,\n k,\n v,\n o,\n varlen=True,\n cu_seqlens_q=cu_seqlens_q,\n cu_seqlens_k=cu_seqlens_k,\n )\n if True: # varlen\n total_q, nheads_q, head_size = q.shape\n total_k, nheads_k, _ = k.shape\n batch = len(cu_seqlens_q) - 1\n q_strides = (0, q.stride(1), q.stride(0), q.stride(2))\n k_strides = (0, k.stride(1), k.stride(0), k.stride(2))\n v_strides = (0, v.stride(1), v.stride(0), v.stride(2))\n o_strides = (0, o.stride(1), o.stride(0), o.stride(2))\n else:\n batch, seqlen_q, nheads_q, head_size = q.shape\n _, seqlen_k, nheads_k, _ = k.shape\n q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))\n k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))\n v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))\n o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))\n\n unpadded_head_dims = {32, 64, 128, 256}\n if head_size not in unpadded_head_dims:\n padded_d_model = None\n for i in unpadded_head_dims:\n if i > head_size:\n padded_d_model = i\n break\n assert padded_d_model is not None\n else:\n padded_d_model = head_size\n\n grid = lambda META: (\n triton.cdiv(max_seqlens_q, META[\"BLOCK_M\"]),\n nheads_q,\n batch,\n )\n\n encoded_softmax = None\n\n philox_seed = 0x1BF52\n philox_offset = 0x1D4B42\n\n if bias is not None:\n bias_strides = (\n bias.stride(0),\n bias.stride(1),\n bias.stride(2),\n bias.stride(3),\n )\n else:\n bias_strides = (0, 0, 0, 0)\n\n attn_fwd[grid](\n q,\n k,\n v,\n bias,\n sm_scale,\n None,\n o,\n *q_strides,\n *k_strides,\n *v_strides,\n *o_strides,\n *bias_strides,\n cu_seqlens_q,\n cu_seqlens_k,\n dropout_p=0.0,\n philox_seed=philox_seed,\n philox_offset_base=philox_offset,\n encoded_softmax=encoded_softmax,\n HQ=nheads_q,\n HK=nheads_k,\n ACTUAL_BLOCK_DMODEL=head_size,\n MAX_SEQLENS_Q=max_seqlens_q,\n MAX_SEQLENS_K=max_seqlens_k,\n IS_CAUSAL=causal,\n VARLEN=True,\n BLOCK_DMODEL=padded_d_model,\n BIAS_TYPE=0 if bias is None else 1,\n ENABLE_DROPOUT=False,\n RETURN_ENCODED_SOFTMAX=False,\n )\n\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = head_size\n ctx.causal = causal\n ctx.dropout_p = 0.0\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n ctx.encoded_softmax = encoded_softmax\n ctx.return_encoded_softmax = False\n return o, encoded_softmax\n\ntriton_attention = _attention.apply\n", - "description_1": "Use triton language to implement fused attention kernels and associated functions for calculating dropout, loading data, and masked matrix operations in Flash Attention algorithm.", - "description_2": "Use triton language to create attention kernels with dropout and causal masking functionalities.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _bgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_N: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n pid_sn = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n offset_k = tl.arange(0, BLOCK_K)\n offset_n = tl.arange(0, BLOCK_N)\n if EVEN_K:\n tiled_a = tl.load(input_ptr + cur_batch * xm_stride + offset_k * xk_stride)\n else:\n tiled_a = tl.load(\n input_ptr + cur_batch * xm_stride + offset_k * xk_stride,\n mask=offset_k < K,\n other=0,\n )\n split_n_length = tl.cdiv(N, SPLIT_N)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n b_ptr = (lora_ptr + l0_stride * lora_index + pid_sn * split_n_length * lora_k_stride)\n c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + slice_offset * cn_stride)\n\n for n in range(0, split_n_length, BLOCK_N):\n current_n = n + offset_n\n b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] < K)\n c_mask = current_n < split_n_length\n tiled_b = tl.load(\n b_ptr + current_n[:, None] * lora_k_stride + offset_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n )\n\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)\n accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out\n else:\n accumulator = tl.sum(tiled_a * tiled_b, 1)\n\n tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _bgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = True,\n) -> None:\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [torch.float16, torch.bfloat16]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert slice_size == lora_b_weights.size(-2)\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4:\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3\n\n assert lora_b_weights.is_contiguous()\n\n N, K = lora_b_weights.shape[-2:]\n BLOCK_K = triton.next_power_of_2(K)\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [torch.float16, torch.bfloat16]:\n CAST_TYPE = True\n\n batches = lora_indices_tensor.size(0)\n\n config = get_lora_op_configs(\"expand\", batches, N)\n\n grid = lambda META: (META[\"SPLIT_N\"], batches)\n _bgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_K=BLOCK_K,\n EVEN_K=EVEN_K,\n ADD_INPUTS=ADD_INPUTS,\n CAST_TYPE=CAST_TYPE,\n **config,\n )\n return\n", - "description_1": "Use triton language to create a kernel function _bgmv_expand_slice_kernel with 19 parameters for optimized matrix-vector multiplications using LoRA weights and a callable function _bgmv_expand_slice with 7 parameters to manage the data and computation process in a PyTorch environment.", - "description_2": "Use triton language to implement a Grouped GEMV kernel with optimizations for handling LoRA indices and slicing across batches, and provide a corresponding Python interface function for setup and execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom .utils import get_lora_op_configs\n\n@triton.jit\ndef _bgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n scaling,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's\n performance\n \"\"\"\n pid_sk = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n\n offset_n = tl.arange(0, BLOCK_N)\n offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K\n a_ptr = input_ptr + cur_batch * xm_stride\n b_ptr = lora_ptr + l0_stride * lora_index\n accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)\n for k in range(0, K, BLOCK_K * SPLIT_K):\n current_k = k + offset_k\n current_k_c = tl.max_contiguous(current_k, BLOCK_K)\n tiled_a = tl.load(\n a_ptr + current_k_c,\n mask=current_k < K,\n other=0.0,\n ) # [BLOCK_K]\n b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)\n\n tiled_b = tl.load(\n b_ptr + offset_n[:, None] * lora_k_stride +\n current_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n ) # [BLOCK_N,BLOCK_K]\n\n accumulator += tl.sum(tiled_a * tiled_b, 1)\n accumulator *= scaling\n offset_cn = tl.arange(0, BLOCK_N)\n c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride\n c_mask = offset_cn < N\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _bgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n scaling: float = 1.0,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n batches = lora_indices_tensor.size(0)\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_N = triton.next_power_of_2(N)\n # First try to load optimal config from the file\n config = get_lora_op_configs(\"bgmv_shrink\", batches, K)\n\n grid = lambda META: (\n META[\"SPLIT_K\"],\n batches,\n )\n _bgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_N=BLOCK_N,\n **config,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function '_bgmv_shrink_kernel' with 15 parameters for performing a batched generalized matrix-vector multiplication (GroupGEMV) with optional LoRA (Low-Rank Adaptation) weights. The kernel uses block-wise operations and supports split-K optimization for large hidden sizes. The function '_bgmv_shrink' is a wrapper that prepares the input tensors and launches the Triton kernel with appropriate configurations.", - "description_2": "Use triton language to create a GroupGEMV kernel with LoRA support and a wrapper function to handle input preparation and kernel execution.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _sgmv_expand_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n The sgmv's expand triton kernel is based on GroupGEMM.\n \"\"\"\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n add_inputs: bool = False,\n) -> None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4, 10].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n add_inputs (bool, optional): Defaults to False. adds the final lora \n results to the output.\n \"\"\"\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n\n\nsgmv_expand = torch.library.custom_op(\"lora::sgmv_expand\",\n _sgmv_expand,\n mutates_args=[\"output_tensor\"])\n", - "description_1": "Use triton language to implement a kernel function, _sgmv_expand_kernel, which performs an operation based on GroupGEMM with 23 parameters including pointers for input, lora, and output, dimensions N, K, block sizes BLOCK_M, BLOCK_N, BLOCK_K, configuration flags EVEN_K, ADD_INPUTS, CAST_TYPE, and various stride and index tensors. This kernel is invoked by the wrapper function _sgmv_expand, which takes in 9 parameters: inputs, lora_b_weights, output_tensor, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, batches, max_seq_length, and add_inputs, to set up the grid and launch the Triton kernel appropriately.", - "description_2": "Use triton language to implement a GroupGEMM-based kernel with 23 parameters, invoked by a wrapper with 9 parameters to set up and launch the Triton kernel.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n scaling,\n xm_stride, # hidden_size\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n \"\"\"\n The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.\n The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,\n introducing SPLIT-K can improve performance\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sk = tl.program_id(axis=1)\n cur_batch = tl.program_id(axis=2)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)\n\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride)\n b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +\n offset_k[:, None] * lora_n_stride)\n\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < k_remaining,\n other=0.0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < k_remaining,\n other=0.0)\n accumulator += tl.dot(tiled_a, tiled_b)\n\n a_ptr += BLOCK_K * SPLIT_K * xk_stride\n b_ptr += BLOCK_K * SPLIT_K * lora_n_stride\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n c_mask = (offset_cm[:, None] <\n (cur_seq_start + M)) & (offset_cn[None, :] < N)\n accumulator *= scaling\n # handles write-back with reduction-splitting\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n scaling: float,\n) -> None:\n \"\"\"\n\n Args:\n inputs (torch.Tensor): input tensor\n lora_a_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative\n sequence lengths of the sequences in the batch, used to index\n into sequence. E.g.,if the sequence length is [4, 6], it is\n [0, 4].\n seq_len_tensor (torch.Tensor): (batch_size,). record the sequence\n length of the sequences in the batch\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch. An index of -1 means no lora should be\n applied.\n batches (int): batch size\n max_seq_length (int): The max sequence lengths of the sequences\n in the batch\n scaling (float): Scaling factor.\n \"\"\"\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n # TODO tuning this config\n N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank\n BLOCK_M = 32\n BLOCK_N = 16\n BLOCK_K = 32\n SPLIT_K = 8\n EVEN_K = K % (BLOCK_K * SPLIT_K) == 0\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n SPLIT_K,\n batches,\n )\n\n _sgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n SPLIT_K,\n )\n return\n", - "description_1": "Use triton language to implement a kernel function called _sgmv_shrink_kernel, which performs a specialized matrix-vector multiplication with a focus on handling LoRA (low-rank adaptation) indices. The kernel takes 20 parameters, including pointers to input and output tensors, as well as constants for block sizes and reduction settings. This kernel is then called by a Python function, _sgmv_shrink, that prepares the grid and block configuration, validates input types and shapes, and manages tensor strides.", - "description_2": "Use triton language to create a kernel that performs matrix-vector multiplication with LoRA indices handling, and call it within a Python function for tensor preparation and validation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr,\n sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr,\n N, K, EM, num_valid_tokens, stride_am, stride_ak, stride_be,\n stride_bk, stride_bn, stride_cm, stride_cn, stride_bse, stride_bsn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,\n compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr,\n use_int8_w8a16: tl.constexpr):\n \"\"\"\n Implements the fused computation for a Mixture of Experts (MOE) using\n token and expert matrices.\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n if use_int8_w8a16:\n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[\n None, :] * stride_bsn\n b_scale = tl.load(b_scale_ptrs)\n\n if use_fp8_w8a8:\n a_scale = tl.load(a_scale_ptr)\n b_scale = tl.load(b_scale_ptr + off_experts)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n if use_int8_w8a16:\n accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)\n elif use_fp8_w8a8:\n accumulator = tl.dot(a, b, acc=accumulator)\n else:\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n if use_int8_w8a16:\n accumulator = (accumulator * b_scale).to(compute_type)\n elif use_fp8_w8a8:\n accumulator = (accumulator * a_scale * b_scale).to(compute_type)\n else:\n accumulator = accumulator.to(compute_type)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n A_scale: Optional[torch.Tensor],\n B_scale: Optional[torch.Tensor],\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any], compute_type: tl.dtype,\n use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n A_scale,\n B_scale,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,\n B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=compute_type,\n use_fp8_w8a8=use_fp8_w8a8,\n use_int8_w8a16=use_int8_w8a16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused MoE kernel. The kernel takes 22 parameters, including pointers to matrices, matrix dimensions, stride variables, and meta-parameters. It computes a mixture of experts using token and expert matrices, performs block matrix multiplication, and writes back the results. The kernel is invoked through the 'invoke_fused_moe_kernel' function, which prepares grid and meta-parameters before calling the kernel.", - "description_2": "Use triton language to implement and invoke a fused MoE kernel for block matrix multiplication with mixed precision support.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\nif TRITON3:\n\n @triton.jit\n def softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)\n return dt\nelse:\n\n @triton.jit\n def softplus(dt):\n dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)\n return dt\n\n\n@triton.jit\ndef _selective_scan_update_kernel(\n state_ptr,\n x_ptr,\n dt_ptr,\n dt_bias_ptr,\n A_ptr,\n B_ptr,\n C_ptr,\n D_ptr,\n z_ptr,\n out_ptr,\n batch,\n nheads,\n dim,\n dstate,\n nheads_ngroups_ratio,\n stride_state_batch,\n stride_state_head,\n stride_state_dim,\n stride_state_dstate,\n stride_x_batch,\n stride_x_head,\n stride_x_dim,\n stride_dt_batch,\n stride_dt_head,\n stride_dt_dim,\n stride_dt_bias_head,\n stride_dt_bias_dim,\n stride_A_head,\n stride_A_dim,\n stride_A_dstate,\n stride_B_batch,\n stride_B_group,\n stride_B_dstate,\n stride_C_batch,\n stride_C_group,\n stride_C_dstate,\n stride_D_head,\n stride_D_dim,\n stride_z_batch,\n stride_z_head,\n stride_z_dim,\n stride_out_batch,\n stride_out_head,\n stride_out_dim,\n DT_SOFTPLUS: tl.constexpr,\n TIE_HDIM: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n HAS_D: tl.constexpr,\n HAS_Z: tl.constexpr,\n BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_b = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n if HAS_DT_BIAS:\n dt_bias_ptr += pid_h * stride_dt_bias_head\n A_ptr += pid_h * stride_A_head\n B_ptr += pid_b * stride_B_batch + (pid_h //\n nheads_ngroups_ratio) * stride_B_group\n C_ptr += pid_b * stride_C_batch + (pid_h //\n nheads_ngroups_ratio) * stride_C_group\n if HAS_Z:\n z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +\n offs_n[None, :] * stride_state_dstate)\n x_ptrs = x_ptr + offs_m * stride_x_dim\n dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n if HAS_DT_BIAS:\n dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n if HAS_D:\n D_ptr += pid_h * stride_D_head\n A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +\n offs_n[None, :] * stride_A_dstate)\n B_ptrs = B_ptr + offs_n * stride_B_dstate\n C_ptrs = C_ptr + offs_n * stride_C_dstate\n if HAS_D:\n D_ptrs = D_ptr + offs_m * stride_D_dim\n if HAS_Z:\n z_ptrs = z_ptr + offs_m * stride_z_dim\n out_ptrs = out_ptr + offs_m * stride_out_dim\n\n state = tl.load(state_ptrs,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n other=0.0)\n x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if not TIE_HDIM:\n dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,\n other=0.0).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptrs,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),\n other=0.0).to(tl.float32)\n dA = tl.exp(A * dt[:, None])\n else:\n dt = tl.load(dt_ptr).to(tl.float32)\n if HAS_DT_BIAS:\n dt += tl.load(dt_bias_ptr).to(tl.float32)\n if DT_SOFTPLUS:\n dt = softplus(dt)\n A = tl.load(A_ptr).to(tl.float32)\n dA = tl.exp(A * dt) # scalar, not a matrix\n\n B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n if HAS_D:\n D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n if HAS_Z:\n z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt\n state = state * dA + dB * x[:, None]\n tl.store(state_ptrs,\n state,\n mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n out = tl.sum(state * C[None, :], axis=1)\n if HAS_D:\n out += x * D\n if HAS_Z:\n out *= z * tl.sigmoid(z)\n tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state,\n x,\n dt,\n A,\n B,\n C,\n D=None,\n z=None,\n dt_bias=None,\n dt_softplus=False):\n has_heads = state.dim() > 3\n if state.dim() == 3:\n state = state.unsqueeze(1)\n if x.dim() == 2:\n x = x.unsqueeze(1)\n if dt.dim() == 2:\n dt = dt.unsqueeze(1)\n if A.dim() == 2:\n A = A.unsqueeze(0)\n if B.dim() == 2:\n B = B.unsqueeze(1)\n if C.dim() == 2:\n C = C.unsqueeze(1)\n if D is not None and D.dim() == 1:\n D = D.unsqueeze(0)\n if z is not None and z.dim() == 2:\n z = z.unsqueeze(1)\n if dt_bias is not None and dt_bias.dim() == 1:\n dt_bias = dt_bias.unsqueeze(0)\n batch, nheads, dim, dstate = state.shape\n assert x.shape == (batch, nheads, dim)\n assert dt.shape == x.shape\n assert A.shape == (nheads, dim, dstate)\n ngroups = B.shape[1]\n assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n assert B.shape == (batch, ngroups, dstate)\n assert C.shape == B.shape\n if D is not None:\n assert D.shape == (nheads, dim)\n if z is not None:\n assert z.shape == x.shape\n if dt_bias is not None:\n assert dt_bias.shape == (nheads, dim)\n out = torch.empty_like(x)\n grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else\n (0, 0, 0))\n BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else\n ((16, 4) if dstate <= 32 else\n ((8, 4) if dstate <= 64 else\n ((4, 4) if dstate <= 128 else ((4, 8))))))\n tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(\n -1) == 0 and dt_bias.stride(-1) == 0\n with torch.cuda.device(x.device.index):\n _selective_scan_update_kernel[grid](\n state,\n x,\n dt,\n dt_bias,\n A,\n B,\n C,\n D,\n z,\n out,\n batch,\n nheads,\n dim,\n dstate,\n nheads // ngroups,\n state.stride(0),\n state.stride(1),\n state.stride(2),\n state.stride(3),\n x.stride(0),\n x.stride(1),\n x.stride(2),\n dt.stride(0),\n dt.stride(1),\n dt.stride(2),\n *(dt_bias.stride(0),\n dt_bias.stride(1)) if dt_bias is not None else 0,\n A.stride(0),\n A.stride(1),\n A.stride(2),\n B.stride(0),\n B.stride(1),\n B.stride(2),\n C.stride(0),\n C.stride(1),\n C.stride(2),\n *(D.stride(0), D.stride(1)) if D is not None else 0,\n z_strides[0],\n z_strides[1],\n z_strides[2],\n out.stride(0),\n out.stride(1),\n out.stride(2),\n dt_softplus,\n tie_hdim,\n BLOCK_SIZE_M,\n num_warps=num_warps,\n )\n if not has_heads:\n out = out.squeeze(1)\n return out\n", - "description_1": "Use triton language to implement a selective scan update kernel with 48 parameters, including pointers to matrices, matrix dimensions, strides, and meta-parameters. The kernel performs operations on input matrices and stores the result in an output matrix. The selective_state_update function, with 10 parameters, prepares the input data, sets up the grid for kernel execution, and calls the kernel with appropriate arguments.", - "description_2": "Use triton language to implement a softplus function with 1 parameter, which applies the softplus operation on the input tensor. The function is conditionally defined based on the Triton version.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a seeded uniform random number generator that takes in a tensor size and seed values to generate a corresponding output tensor with random numbers. The function _seeded_uniform_triton is a kernel that generates random numbers for each element in a given output tensor using per-row seeds and storing results in a parallelized manner.", - "description_2": "Use triton language to generate a tensor of random float numbers using given seed values, and store results in an output tensor efficiently utilizing parallelism.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n_EPS = 1e-6\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n col_offsets = tl.arange(0, block_size)\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n\n if modify_greedy_probs:\n if not uses_random_sampling:\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to implement a sampling kernel that converts uniform noise to exponential noise and samples tokens from a probability distribution. The kernel takes 18 parameters: sample_indices_ptr, output_ptr, output_logprobs_ptr, output_modified_probs_ptr, probs_ptr, logprobs_ptr, seeds_ptr, uniform_noise_ptr, output_row_stride, probs_row_stride, uniform_noise_row_stride, uniform_noise_best_stride, n_samples, n_cols, n_best, block_size, modify_greedy_probs, save_logprobs, and save_modified_probs. It processes each row independently, applies noise if needed, and stores the sampled tokens and their log probabilities.", - "description_2": "Use triton language to create a kernel for sampling tokens from a probability distribution with optional noise application and log probability storage.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import Config, autotune, heuristics\n\n@autotune(\n configs=[\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N'],\n prune_configs_by={\n 'early_config_prune': lambda args: True,\n 'perf_model': lambda args, config: 0,\n 'top_k': 10,\n },\n)\n@heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef matmul_kernelint8(x, w, A, B, C, M, N, K, \n stride_am, stride_ak, \n stride_bk, stride_bn,\n stride_cm, stride_cn,\n Afp, Bfp, Cfp, Kfp,\n stride_amfp, stride_akfp, \n stride_bkfp, stride_bnfp, \n stride_cmfp, stride_cnfp,\n acc_dtype: tl.constexpr, \n allow_tf32: tl.constexpr, \n fp8_fast_accum: tl.constexpr, \n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, \n BLOCK_Kfp: tl.constexpr, \n GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // group_size\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n # pointers\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=False)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef matmulint8_fused_dequant(x, w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n allow_tf32 = True\n fp8_fast_accum = True\n matmul_kernelint8[grid](\n x, w,\n a, b, c,\n M, N, K,\n K, 1,\n 1, K,\n N, 1,\n afp, bfp, cfp16, Kfp[0],\n Kfp, 1,\n 1, Kfp,\n N, 1,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, acc_dtype=tl.int32, AB_DTYPE=None\n )\n return c, cfp16\n", - "description_1": "Use triton language to implement a kernel `matmul_kernelint8` that performs matrix multiplication for INT8 data types with optional fused dequantization. The kernel has 32 parameters, including pointers to input matrices `A`, `B`, and output matrix `C`, matrix dimensions `M`, `N`, `K`, strides for memory layout, and configuration constants for block sizes and types. A helper function `matmulint8_fused_dequant` sets up the grid and calls the kernel with the appropriate parameters, performing matrix multiplication with optional FP8 fast accumulation and allowing TF32 computation.", - "description_2": "Use triton language to create a matrix multiplication kernel with dequantization for INT8 data, using parameters for matrix pointers, dimensions, strides, and block configurations.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import Config, autotune, cdiv\n\n@autotune(\n configs=[\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 16, 'SPLIT_K': 1}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N'],\n)\n@triton.jit\ndef matmul_kernelfp16(A, B, C, M, N, K,\n stride_amfp, stride_akfp,\n stride_bkfp, stride_bnfp,\n stride_cmfp, stride_cnfp,\n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr, \n SPLIT_K: tl.constexpr, \n GROUP_M: tl.constexpr):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n \n # pointers\n rkfp = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_amfp + rkfp[None, :] * stride_akfp)\n B = B + (rkfp[:, None] * stride_bkfp + rbn[None, :] * stride_bnfp)\n\n accfp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n rmfp = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rnfp = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n afp = tl.zeros((BLOCK_M, BLOCK_K), dtype=C.dtype.element_ty)\n bfp = tl.zeros((BLOCK_K, BLOCK_N), dtype=C.dtype.element_ty)\n C = C + (rmfp[:, None] * stride_cmfp + rnfp[None, :] * stride_cnfp) \n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n K_ = tl.load(K + 0)\n if K_ == 0:\n return \n\n maxK = tl.cdiv(K_, BLOCK_K)\n for k in range(0, maxK - 1):\n afp = tl.load(A)\n bfp = tl.load(B)\n A += BLOCK_K * stride_akfp\n B += BLOCK_K * stride_bkfp \n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n\n k = maxK - 1\n if K_ % BLOCK_K == 0:\n afp = tl.load(A)\n bfp = tl.load(B)\n else:\n k_remainingfp = K_ - k * BLOCK_K \n afp = tl.load(A, mask=rkfp[None, :] < k_remainingfp, other=0.0)\n bfp = tl.load(B, mask=rkfp[:, None] < k_remainingfp, other=0.0)\n\n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n accfp = accfp.to(tl.float16)\n\n # rematerialize rm and rn to save registers\n tl.store(C, accfp, mask=mask)\n\ndef matmulfp16(afp, bfp, cfp16, M, N, K):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n matmul_kernelfp16[grid](\n afp, bfp, cfp16, M, N, K,\n 1, M,\n N, 1,\n N, 1,\n GROUP_M=8\n )\n return\n", - "description_1": "Use triton language to implement a matrix multiplication kernel for half-precision floating-point (fp16) matrices. The kernel 'matmul_kernelfp16' takes 15 parameters: three matrices A, B, C, and their dimensions M, N, K, along with stride values for A, B, and C, and several block and group constants. The function 'matmulfp16' is a wrapper that sets up the grid and calls the kernel with appropriate parameters.", - "description_2": "Use triton language to perform matrix multiplication on fp16 matrices with configurable block sizes and grid dimensions.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom .triton_utils.kernels import silu\n\n@triton.jit\ndef quant_fused_matmul_248_kernel(\n a_ptr, c_ptr, b1_ptr,\n scales1_ptr, zeros1_ptr,\n g1_ptr, b2_ptr,\n scales2_ptr, zeros2_ptr,\n g2_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = (zeros1 + 1)\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = (zeros2 + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\nclass FusedLlamaMLPForQuantizedModel:\n def __init__(\n self,\n gate_proj,\n down_proj,\n up_proj,\n ):\n self.infeatures = gate_proj.infeatures\n self.intermediate_size = gate_proj.outfeatures\n self.outfeatures = down_proj.outfeatures\n self.bits = gate_proj.bits\n self.maxq = gate_proj.maxq\n\n self.gate_proj = gate_proj\n self.up_proj = up_proj\n self.down_proj = down_proj\n\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size, )\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n quant_fused_matmul_248_kernel[grid](\n x, c, self.gate_proj.qweight,\n self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,\n self.up_proj.qweight,\n self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,\n M, N, K,\n self.bits, self.maxq,\n x.stride(0), x.stride(1),\n self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),\n c.stride(0), c.stride(1),\n self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a fused matrix multiplication kernel that computes C = silu(A * B1) * (A * B2) with quantization. The kernel takes 28 parameters: pointers to input matrices, scales, zeros, group indices, dimensions M, N, K, bit width, max quantization value, strides for input and output matrices, and block sizes for tiling. The kernel performs matrix multiplication with quantization and stores the result in the output matrix.", - "description_2": "Use triton language to create a fused quantized matrix multiplication kernel with silu activation, handling input matrices, scales, zeros, and group indices, and outputting the result after applying quantization and activation.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),\n )\n quant_matmul_248_kernel[grid](\n input, qweight, output,\n scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1],\n bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0)\n )\n return output\n\n\ndef transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)\n transpose_quant_matmul_248_kernel[grid](\n input, qweight, output,\n scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_dim,\n bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0)\n )\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: 'quant_matmul_248_kernel' and 'transpose_quant_matmul_248_kernel'. The first kernel computes C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel computes C = A x B where A is a float16 matrix of shape (M, N), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, K). Both kernels use additional parameters for quantization, including scales, zeros, and g_ptr, and are optimized for specific block sizes and group sizes.", - "description_2": "Use triton language to create optimized matrix multiplication kernels with quantization support, handling different input and output shapes and using specific block and group sizes for performance.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_attention_kernel(\n Out, L, M, # outputs\n Q, K, V,\n sm_scale,\n batch_size, num_heads, seq_len,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n stride_h = BLOCK_DMODEL * seq_len\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_d[None, :]\n off_k = off_hz * stride_h + offs_n[None, :] * BLOCK_DMODEL + offs_d[:, None]\n off_v = off_hz * stride_h + offs_n[:, None] * BLOCK_DMODEL + offs_d[None, :]\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(Q.dtype.element_ty)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_ptrs += BLOCK_N * BLOCK_DMODEL\n v_ptrs += BLOCK_N * BLOCK_DMODEL\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * seq_len + offs_m\n m_ptrs = M + off_hz * seq_len + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_n[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\ndef fused_attention(q, k, v, sm_scale, o_buf=None, l_buf=None, m_buf=None):\n BLOCK = 128 if q.dtype == torch.float16 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q) if o_buf is None else o_buf\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n shape = (q.shape[0] * q.shape[1], q.shape[2])\n L = torch.empty(shape, device=q.device, dtype=torch.float32) if l_buf is None else l_buf\n m = torch.empty(shape, device=q.device, dtype=torch.float32) if m_buf is None else m_buf\n\n num_warps = 4 if Lk <= 64 else 8\n # Adjust num_stages for limited resource cases.\n num_stages = 2 if torch.cuda.get_device_capability() >= (8, 0) else 1\n\n fused_attention_kernel[grid](\n o, L, m,\n q, k, v,\n sm_scale,\n q.shape[0], q.shape[1], q.shape[2],\n # tl.constexpr\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n return o\n", - "description_1": "Use triton language to implement a fused attention kernel for the Flash Attention algorithm. The kernel takes 12 parameters: Out, L, M (output tensors), Q, K, V (input tensors), sm_scale (a scaling factor), batch_size, num_heads, seq_len (dimensions of the input), and BLOCK_M, BLOCK_DMODEL, BLOCK_N (block sizes for computation). The kernel computes the attention scores and updates the output tensor. The fused_attention function wraps this kernel, taking 7 parameters: q, k, v (input tensors), sm_scale (scaling factor), and optional o_buf, l_buf, m_buf (buffers for outputs). It sets up the grid and block sizes, and calls the kernel.", - "description_2": "Use triton language to create a fused attention operator for Flash Attention, involving a kernel with 12 parameters for computation and a wrapper function with 7 parameters for setup and execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_attention_kernel(\n Out, L, M, # outputs\n Q, K, V,\n sm_scale,\n batch_size, num_heads, seq_len,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n stride_h = BLOCK_DMODEL * seq_len\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_d[None, :]\n off_k = off_hz * stride_h + offs_n[None, :] * BLOCK_DMODEL + offs_d[:, None]\n off_v = off_hz * stride_h + offs_n[:, None] * BLOCK_DMODEL + offs_d[None, :]\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(Q.dtype.element_ty)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_ptrs += BLOCK_N * BLOCK_DMODEL\n v_ptrs += BLOCK_N * BLOCK_DMODEL\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * seq_len + offs_m\n m_ptrs = M + off_hz * seq_len + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_h + offs_m[:, None] * BLOCK_DMODEL + offs_n[None, :]\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\ndef fused_attention(q, k, v, sm_scale, o_buf=None, l_buf=None, m_buf=None):\n BLOCK = 128 if q.dtype == torch.float16 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q) if o_buf is None else o_buf\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n shape = (q.shape[0] * q.shape[1], q.shape[2])\n L = torch.empty(shape, device=q.device, dtype=torch.float32) if l_buf is None else l_buf\n m = torch.empty(shape, device=q.device, dtype=torch.float32) if m_buf is None else m_buf\n\n num_warps = 4 if Lk <= 64 else 8\n # Adjust num_stages for limited resource cases.\n num_stages = 2 if torch.cuda.get_device_capability() >= (8, 0) else 1\n\n fused_attention_kernel[grid](\n o, L, m,\n q, k, v,\n sm_scale,\n q.shape[0], q.shape[1], q.shape[2],\n # tl.constexpr\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n return o\n", - "description_1": "Use triton language to implement a fused attention kernel that computes the attention mechanism for given query (Q), key (K), and value (V) matrices. The kernel takes into account the scaling factor (sm_scale) and processes the data in blocks defined by BLOCK_M, BLOCK_N, and BLOCK_DMODEL. The function fused_attention serves as a wrapper to set up the necessary parameters and launch the kernel.", - "description_2": "Use triton language to create a fused attention mechanism with block processing for Q, K, V matrices, considering a scaling factor.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import Config, autotune, heuristics\n\n@autotune(\n configs=[\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef matmul_kernelint8(x, w, A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, Afp, Bfp, Cfp, Kfp, stride_amfp, stride_akfp, stride_bkfp, stride_bnfp, stride_cmfp, stride_cnfp, acc_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_Kfp: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=False)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef matmulint8_fused_dequant(x, w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n allow_tf32 = True\n fp8_fast_accum = True\n matmul_kernelint8[grid](\n x, w,\n a, b, c,\n M, N, K,\n K, 1,\n 1, K,\n N, 1,\n afp, bfp, cfp16, Kfp[0],\n Kfp, 1,\n 1, Kfp,\n N, 1,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, acc_dtype=tl.int32, AB_DTYPE=None\n )\n return c, cfp16\n", - "description_1": "Use triton language to implement a matrix multiplication kernel for int8 data types with support for fused dequantization. The kernel function 'matmul_kernelint8' takes 30 parameters: x, w, A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, Afp, Bfp, Cfp, Kfp, stride_amfp, stride_akfp, stride_bkfp, stride_bnfp, stride_cmfp, stride_cnfp, acc_dtype, allow_tf32, fp8_fast_accum, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_Kfp, GROUP_M, SPLIT_K, EVEN_K, AB_DTYPE. The function 'matmulint8_fused_dequant' is a wrapper that sets up the grid and calls the kernel with 11 parameters: x, w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp.", - "description_2": "Use triton language to create a matrix multiplication kernel optimized for int8 data with fused dequantization, utilizing autotuning and heuristics for performance optimization.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernelfp16(A, B, C, M, N, K,\n stride_amfp, stride_akfp, #\n stride_bkfp, stride_bnfp, #\n stride_cmfp, stride_cnfp,\n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr, \n SPLIT_K: tl.constexpr, \n GROUP_M: tl.constexpr):\n # Matrix multiplication kernel\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n \n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n \n rkfp = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_amfp + rkfp[None, :] * stride_akfp)\n B = B + (rkfp[:, None] * stride_bkfp + rbn[None, :] * stride_bnfp)\n accfp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n \n rmfp = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rnfp = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n afp = tl.zeros((BLOCK_M, BLOCK_K), dtype=C.dtype.element_ty)\n bfp = tl.zeros((BLOCK_K, BLOCK_N), dtype=C.dtype.element_ty)\n C = C + (rmfp[:, None] * stride_cmfp + rnfp[None, :] * stride_cnfp) \n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n K_ = tl.load(K + 0)\n if K_ == 0:\n return \n\n maxK = tl.cdiv(K_, BLOCK_K)\n for k in range(0, maxK - 1):\n afp = tl.load(A)\n bfp = tl.load(B)\n A += BLOCK_K * stride_akfp\n B += BLOCK_K * stride_bkfp \n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n\n k = maxK - 1\n if K_ % BLOCK_K == 0:\n afp = tl.load(A)\n bfp = tl.load(B)\n else:\n k_remainingfp = K_ - k * BLOCK_K \n afp = tl.load(A, mask=rkfp[None, :] < k_remainingfp, other=0.0)\n bfp = tl.load(B, mask=rkfp[:, None] < k_remainingfp, other=0.0)\n\n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n accfp = accfp.to(tl.float16)\n tl.store(C, accfp, mask=mask)\n\ndef matmulfp16(afp, bfp, cfp16, M, N, K):\n # Function to call the Triton kernel for FP16 matrix multiplication\n grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n \n matmul_kernelfp16[grid](\n afp, bfp, cfp16, M, N, K,\n 1, M, # Strides for A\n N, 1, # Strides for B\n N, 1, # Strides for C\n GROUP_M=8\n )\n return\n", - "description_1": "Use triton language to define a matrix multiplication kernel (matmul_kernelfp16) with 14 parameters where A, B, and C are the input and output matrices, M, N, K define the dimensions, stride_amfp, stride_akfp, stride_bkfp, stride_bnfp, stride_cmfp, stride_cnfp are stride values, BLOCK_M, BLOCK_N, BLOCK_K are block sizes, SPLIT_K, and GROUP_M are optimization constants. Additionally, define a function (matmulfp16) to call this kernel with matrices A, B, C, dimensions M, N, K, and predefined block and group sizes for FP16 matrix multiplication.", - "description_2": "Use triton language to implement a matrix multiplication kernel for FP16 data types with specific block size and stride parameters and provide a function to execute this kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom .triton_utils.kernels import silu\n\n@triton.jit\ndef quant_fused_matmul_248_kernel(\n a_ptr, c_ptr, b1_ptr,\n scales1_ptr, zeros1_ptr,\n g1_ptr, b2_ptr,\n scales2_ptr, zeros2_ptr,\n g2_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = (zeros1 + 1)\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = (zeros2 + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\nclass FusedLlamaMLPForQuantizedModel:\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size, )\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n quant_fused_matmul_248_kernel[grid](\n x, c, self.gate_proj.qweight,\n self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,\n self.up_proj.qweight,\n self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,\n M, N, K,\n self.bits, self.maxq,\n x.stride(0), x.stride(1),\n self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),\n c.stride(0), c.stride(1),\n self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a quantized fused matrix multiplication kernel for a specific computation C = silu(A * B1) * (A * B2). The kernel handles data with specific bit manipulation and scaling.", - "description_2": "Use triton language to implement a class that utilizes the quantized fused matrix multiplication kernel to perform computations on tensors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),\n )\n quant_matmul_248_kernel[grid](\n input, qweight, output,\n scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1],\n bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0)\n )\n return output\n\n\ndef transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)\n transpose_quant_matmul_248_kernel[grid](\n input, qweight, output,\n scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_dim,\n bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0)\n )\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: quant_matmul_248_kernel and transpose_quant_matmul_248_kernel. The first kernel performs matrix multiplication C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel performs a similar operation but with transposed dimensions. Both kernels use quantization parameters scales, zeros, and g_ptr to adjust the computation. The kernels are called by quant_matmul_248 and transpose_quant_matmul_248 functions, respectively, which set up the output tensor and grid configuration for execution.", - "description_2": "Use triton language to create quantized matrix multiplication kernels with support for custom block sizes and quantization parameters, and provide Python functions to execute these kernels on input tensors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom triton import Config, autotune, heuristics\n\n@autotune(\n configs=[\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef matmul_kernelint8(x, w, A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,\n stride_cm, stride_cn, Afp, Bfp, Cfp, Kfp, stride_amfp, stride_akfp,\n stride_bkfp, stride_bnfp, stride_cmfp, stride_cnfp, acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n BLOCK_Kfp: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=False)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\n\ndef matmulint8_fused_dequant(x, w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n allow_tf32 = True\n fp8_fast_accum = True\n matmul_kernelint8[grid](\n x, w, a, b, c,\n M, N, K, K, 1,\n 1, K, N, 1,\n afp, bfp, cfp16, Kfp[0],\n Kfp, 1, 1, Kfp, N, 1,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, acc_dtype=tl.int32, AB_DTYPE=None\n )\n return c, cfp16\n", - "description_1": "Use triton language to define a matrix multiplication kernel `matmul_kernelint8` that performs computation on matrices using integer 8-bit values. The kernel requires parameters such as input matrices `x` and `w`, matrices `A`, `B`, `C`, their dimensions `M`, `N`, `K`, various strides, accumulators and block size constants. The kernel is highly configurable with multiple autotune configurations for optimizing performance on different matrix sizes and shapes. The function `matmulint8_fused_dequant` serves as a wrapper to define grid size and call the kernel with specific parameters, using optional float16 accumulation fast path and tf32 operations.", - "description_2": "Use triton language to perform integer 8-bit matrix multiplication with configurable autotuning and optional fast paths.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import cdiv\n\n@triton.jit\ndef matmul_kernelfp16(A, B, C, M, N, K,\n stride_amfp, stride_akfp, #\n stride_bkfp, stride_bnfp, #\n stride_cmfp, stride_cnfp,\n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr, \n SPLIT_K: tl.constexpr, \n GROUP_M: tl.constexpr):\n # matrix multiplication\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // group_size\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n\n # pointers\n rkfp = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_amfp + rkfp[None, :] * stride_akfp)\n B = B + (rkfp[:, None] * stride_bkfp + rbn[None, :] * stride_bnfp)\n\n accfp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n rmfp = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rnfp = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n afp = tl.zeros((BLOCK_M, BLOCK_K), dtype=C.dtype.element_ty)\n bfp = tl.zeros((BLOCK_K, BLOCK_N), dtype=C.dtype.element_ty)\n C = C + (rmfp[:, None] * stride_cmfp + rnfp[None, :] * stride_cnfp)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n K_ = tl.load(K + 0)\n if K_ == 0:\n return\n\n maxK = tl.cdiv(K_, BLOCK_K)\n for k in range(0, maxK - 1):\n afp = tl.load(A)\n bfp = tl.load(B)\n\n A += BLOCK_K * stride_akfp\n B += BLOCK_K * stride_bkfp\n\n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n\n k = maxK - 1\n if K_ % BLOCK_K == 0:\n afp = tl.load(A)\n bfp = tl.load(B)\n else:\n k_remainingfp = K_ - k * BLOCK_K\n afp = tl.load(A, mask=rkfp[None, :] < k_remainingfp, other=0.0)\n bfp = tl.load(B, mask=rkfp[:, None] < k_remainingfp, other=0.0)\n\n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n\n accfp = accfp.to(tl.float16)\n\n tl.store(C, accfp, mask=mask)\n\n\ndef matmulfp16(afp, bfp, cfp16, M, N, K):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n\n matmul_kernelfp16[grid](\n afp, bfp, cfp16, M, N, K,\n 1, M, #\n N, 1, #\n N, 1, #\n GROUP_M=8\n )\n return\n", - "description_1": "Use triton language to implement a matrix multiplication kernel 'matmul_kernelfp16' with inputs A, B, C and integer parameters M, N, K, and strides for each dimension. The kernel performs matrix multiplication on blocks of size BLOCK_M x BLOCK_N, while handling edge cases when the matrix dimensions are not multiples of block sizes. The function 'matmulfp16' is a wrapper that sets up grid dimensions for the kernel execution.", - "description_2": "Use triton language to perform block-wise matrix multiplication with custom tiling sizes and striding.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\nfrom .triton_utils.kernels import silu\n\n@triton.jit\ndef quant_fused_matmul_248_kernel(\n a_ptr, c_ptr, b1_ptr,\n scales1_ptr, zeros1_ptr,\n g1_ptr, b2_ptr,\n scales2_ptr, zeros2_ptr,\n g2_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = (zeros1 + 1)\n\n zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = (zeros2 + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n b2 = tl.load(b2_ptrs)\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values\n b1 = (b1 - zeros1) * scales1 # Scale and shift\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\nclass FusedLlamaMLPForQuantizedModel:\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size, )\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n quant_fused_matmul_248_kernel[grid](\n x, c, self.gate_proj.qweight,\n self.gate_proj.scales, self.gate_proj.qzeros, self.gate_proj.g_idx,\n self.up_proj.qweight,\n self.up_proj.scales, self.up_proj.qzeros, self.up_proj.g_idx,\n M, N, K,\n self.bits, self.maxq,\n x.stride(0), x.stride(1),\n self.gate_proj.qweight.stride(0), self.gate_proj.qweight.stride(1),\n c.stride(0), c.stride(1),\n self.gate_proj.scales.stride(0), self.gate_proj.qzeros.stride(0)\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a kernel function 'quant_fused_matmul_248_kernel' that performs a fused matrix multiplication and element-wise operations. The kernel takes 28 parameters: pointers to input matrices, scales, zeros, group indices, dimensions M, N, K, bit width, max quantization value, and strides for accessing memory. It computes the output matrix C by applying the silu activation function to the product of input matrices A and B1, and then multiplies it with the product of A and B2. The function 'triton_llama_mlp' calls this kernel with appropriate grid configuration and reshapes the output.", - "description_2": "Use triton language to create a fused matrix multiplication kernel with silu activation and quantization support, and implement a function to call this kernel with specific input configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton import Config, autotune, heuristics\n\n@autotune(\n configs=[\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=3, num_warps=8),\n Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=4, num_warps=4),\n Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1, 'BLOCK_Kfp': 16}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N'],\n prune_configs_by={\n 'early_config_prune': early_config_prune,\n 'perf_model': estimate_matmul_time,\n 'top_k': 10,\n },\n)\n@heuristics({\n 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,\n})\n@triton.jit\ndef matmul_kernelint8(x, w, A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, Afp, Bfp, Cfp, Kfp, stride_amfp, stride_akfp, stride_bkfp, stride_bnfp, stride_cmfp, stride_cnfp, acc_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_Kfp: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)\n acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=False)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef matmulint8_fused_dequant(x, w, a, b, afp, bfp, c, cfp16, M, N, K, Kfp):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n allow_tf32 = True\n fp8_fast_accum = True\n matmul_kernelint8[grid](\n x, w,\n a, b, c,\n M, N, K,\n K, 1,\n 1, K,\n N, 1,\n afp, bfp, cfp16, Kfp[0],\n Kfp, 1,\n 1, Kfp,\n N, 1,\n allow_tf32=allow_tf32,\n fp8_fast_accum=fp8_fast_accum,\n GROUP_M=8, acc_dtype=tl.int32, AB_DTYPE=None\n )\n return c, cfp16\n", - "description_1": "Use triton language to implement a matrix multiplication kernel for int8 data types with support for fused dequantization. The kernel 'matmul_kernelint8' takes 30 parameters including input matrices, dimensions, strides, and configuration constants. The function 'matmulint8_fused_dequant' sets up the grid and calls the kernel with 12 parameters including input matrices, dimensions, and configuration flags.", - "description_2": "Use triton language to create a matrix multiplication kernel for int8 with fused dequantization, and a function to configure and call this kernel.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n# Function to generate configurations for FP IO-bound cases\ndef get_configs_fp_io_bound():\n configs = []\n for num_stages in [2, 3, 4, 5, 6]:\n for block_m in [16, 32]:\n for block_kfp in [32, 64]:\n for block_n in [32, 64, 128, 256]:\n num_warps = 2 if block_n <= 64 else 4\n configs.append(\n Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_kfp, 'SPLIT_K': 1},\n num_stages=num_stages, num_warps=num_warps))\n return configs\n\n# Triton kernel for matrix multiplication\n@triton.jit\ndef matmul_kernelfp16(A, B, C, M, N, K,\n stride_amfp, stride_akfp, \n stride_bkfp, stride_bnfp, \n stride_cmfp, stride_cnfp,\n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr, \n SPLIT_K: tl.constexpr, \n GROUP_M: tl.constexpr):\n # Matrix multiplication logic\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n \n # Re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n \n # Matrix row and column index calculations\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n\n # Pointer arithmetic\n rkfp = tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_amfp + rkfp[None, :] * stride_akfp)\n B = B + (rkfp[:, None] * stride_bkfp + rbn[None, :] * stride_bnfp)\n\n # Initialize accumulator\n accfp = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n # Load A and B matrices\n afp = tl.zeros((BLOCK_M, BLOCK_K), dtype=C.dtype.element_ty)\n bfp = tl.zeros((BLOCK_K, BLOCK_N), dtype=C.dtype.element_ty)\n C = C + (ram[:, None] * stride_cmfp + rbn[None, :] * stride_cnfp) \n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n # Loop over K dimension\n K_ = tl.load(K + 0)\n if K_ == 0:\n return \n\n maxK = tl.cdiv(K_, BLOCK_K)\n for k in range(0, maxK - 1):\n afp = tl.load(A)\n bfp = tl.load(B)\n A += BLOCK_K * stride_akfp\n B += BLOCK_K * stride_bkfp \n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n\n # Process the last K block\n k = maxK - 1\n if K_ % BLOCK_K == 0:\n afp = tl.load(A)\n bfp = tl.load(B)\n else:\n k_remainingfp = K_ - k * BLOCK_K\n afp = tl.load(A, mask=rkfp[None, :] < k_remainingfp, other=0.0)\n bfp = tl.load(B, mask=rkfp[:, None] < k_remainingfp, other=0.0)\n\n # Final accumulation step\n accfp = tl.dot(afp, bfp, accfp, out_dtype=tl.float32, allow_tf32=False)\n accfp = accfp.to(tl.float16)\n\n # Store the result in C matrix\n tl.store(C, accfp, mask=mask)\n\n# Wrapper function to launch the kernel\ndef matmulfp16(afp, bfp, cfp16, M, N, K):\n grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])\n\n matmul_kernelfp16[grid](\n afp, bfp, cfp16, M, N, K,\n 1, M, # stride_amfp, stride_akfp\n N, 1, # stride_bkfp, stride_bnfp\n N, 1, # stride_cmfp, stride_cnfp\n GROUP_M=8\n )\n return\n", - "description_1": "Use triton language to perform matrix multiplication with fp16 precision, supporting tiling and parallelization across the M, N, and K dimensions. The kernel computes the dot product of two matrices A and B, storing the result in matrix C, and handles different block sizes and stride configurations for efficient memory access.", - "description_2": "Use triton language to perform fp16 matrix multiplication with tiling, parallelization, and custom block sizes and strides for efficient GPU execution.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\n@triton.jit\ndef transpose_quant_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr, g_ptr,\n M, N, K,\n bits, maxq,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales, stride_zeros,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for k in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),\n )\n quant_matmul_248_kernel[grid](\n input, qweight, output,\n scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], input.shape[1],\n bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0)\n )\n return output\n\n\ndef transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)\n grid = lambda META: (\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)\n transpose_quant_matmul_248_kernel[grid](\n input, qweight, output,\n scales.to(input.dtype), qzeros, g_idx,\n input.shape[0], qweight.shape[1], output_dim,\n bits, maxq,\n input.stride(0), input.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), qzeros.stride(0)\n )\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: quant_matmul_248_kernel and transpose_quant_matmul_248_kernel. The first kernel computes C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel computes C = A x B where A is a float16 matrix of shape (M, N), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, K). Both kernels use quantization parameters scales and zeros, and a group index g_ptr. The kernels are called by quant_matmul_248 and transpose_quant_matmul_248 functions respectively, which handle the setup of the output tensor and grid configuration.", - "description_2": "Use triton language to create two kernels for quantized matrix multiplication with support for different input and output shapes, utilizing quantization parameters and group indexing.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rescale_kernel(\n peer_m,\n m,\n peer_l,\n l,\n peer_o,\n o,\n L,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n LAST_STEP: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n o_offset = off_hz * stride_oh\n peer_o_block_ptr = tl.make_block_ptr(\n base=peer_o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n o_block_ptr = tl.make_block_ptr(\n base=o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m\n m_ptrs = m + off_hz * N_CTX + offs_m\n peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m\n l_ptrs = l + off_hz * N_CTX + offs_m\n peer_m_i = tl.load(peer_m_ptrs) \n peer_m_i = peer_m_i.to(tl.float32)\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n peer_l_i = tl.load(peer_l_ptrs) \n peer_l_i = peer_l_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n peer_acc = tl.load(peer_o_block_ptr)\n peer_acc = peer_acc.to(tl.float32)\n acc = tl.load(o_block_ptr) \n acc = acc.to(tl.float32)\n lo = 0\n hi = N_CTX\n m_i_sync = tl.maximum(m_i, peer_m_i)\n alpha = tl.math.exp2(m_i - m_i_sync)\n peer_alpha = tl.math.exp2(peer_m_i - m_i_sync)\n acc_scale = l_i * 0 + alpha\n peer_acc_scale = peer_l_i * 0 + peer_alpha\n acc *= acc_scale[:, None]\n peer_acc *= peer_acc_scale[:, None]\n acc += peer_acc\n l_i = l_i * acc_scale + peer_l_i * peer_acc_scale\n tl.store(m_ptrs, m_i_sync)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i))\n tl.store(o_block_ptr, acc.to(tl.bfloat16))\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n m,\n l,\n O,\n L,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n LAST_STEP: tl.constexpr\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n O_block_ptr = tl.make_block_ptr(\n base=O + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_ptrs = m + off_hz * N_CTX + offs_m\n l_ptrs = l + off_hz * N_CTX + offs_m\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n acc = tl.load(O_block_ptr) \n acc = acc.to(tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.bfloat16)\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.bfloat16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n tl.store(m_ptrs, m_i)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i))\n tl.store(O_block_ptr, acc.to(tl.bfloat16))\n\ndef _lightseq_forward(q, k, v, causal, sm_scale, comm_mode):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n BLOCK_M = 32\n BLOCK_N = 32\n bsz, nh, seq_len, hdim = q.shape\n m = torch.full((bsz * nh, seq_len), fill_value=-float(\"inf\"), device=q.device, dtype=torch.float32)\n l = torch.zeros_like(m)\n L = torch.zeros_like(m)\n o = torch.zeros_like(q)\n grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1)\n num_warps = 4 if Lk <= 64 else 8\n seq_rank = get_sequence_parallel_rank()\n seq_world_size = get_sequence_parallel_size()\n peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o)\n fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid](\n q, k, v, sm_scale,\n m,\n l,\n o,\n L,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n IS_CAUSAL=IS_CAUSAL,\n LAST_STEP=LAST_STEP,\n num_warps=num_warps,\n num_stages=4)\n for time_step in range(seq_world_size // 2 + 1):\n torch.cuda.synchronize()\n buffer_idx_1 = time_step % 2\n buffer_idx_2 = (time_step - 1) % 2\n reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], \n [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode)\n if comm_mode == \"sync\":\n wait_async_handles(reqs)\n if is_compute_for_local_query(time_step):\n if time_step == 0:\n fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step))\n else:\n fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step))\n elif is_idle(time_step):\n pass\n else:\n peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float(\"inf\"))\n peer_l[buffer_idx_2] = torch.zeros_like(l)\n peer_o[buffer_idx_2] = torch.zeros_like(o)\n fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False)\n if comm_mode == \"lightseq\":\n wait_async_handles(reqs)\n if is_sync_from_remote(time_step):\n _rescale_kernel[grid](\n peer_m[buffer_idx_1],\n m,\n peer_l[buffer_idx_1],\n l,\n peer_o[buffer_idx_1],\n o,\n L,\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n o.shape[0], o.shape[1], o.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,\n LAST_STEP=is_last_time(time_step),\n num_warps=num_warps,\n num_stages=4)\n return q, k, v, o, L\n", - "description_1": "Use triton language to implement two kernels: _rescale_kernel and _fwd_kernel. The _rescale_kernel function takes 18 parameters: peer_m, m, peer_l, l, peer_o, o, L, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, BLOCK_M, BLOCK_DMODEL, BLOCK_N, and LAST_STEP. The function initializes offsets, loads various tensors into memory, computes scaling and updates the accumulator, and writes back results to memory. The _fwd_kernel function has 26 parameters: Q, K, V, sm_scale, m, l, O, L, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, BLOCK_M, BLOCK_DMODEL, BLOCK_N, IS_CAUSAL, and LAST_STEP. This function is responsible for initializing offsets, loading qkv blocks, computing qk and scaling constants, updating accumulators, and storing results. The _lightseq_forward function is a wrapper around these kernels and orchestrates the computation by dividing the task into a grid and handles synchronization using auxiliary functions.", - "description_2": "Use triton language to implement a function for rescaling operations with memory offsets, and another for forward kernel computations involving loading of blocks, accumulation, and results storing. Both require careful handling of memory and synchronization.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nimport math\nfrom .async_communication import (\n is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, \n maybe_send_recv_fwd_qkvo, maybe_get_set_global_memory_buffer,\n get_sequence_parallel_rank, get_sequence_parallel_size\n)\n\n@triton.jit\ndef _rescale_kernel(\n peer_m, m, peer_l, l, peer_o, o, L,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX, seqlen_q_rounded, seqlen_peer_q_rounded,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, \n LAST_STEP: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n o_offset = off_hz * stride_oh\n peer_o_block_ptr = tl.make_block_ptr(\n base=peer_o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n o_block_ptr = tl.make_block_ptr(\n base=o + o_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n\n peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m\n m_ptrs = m + off_hz * seqlen_q_rounded + offs_m\n peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m\n l_ptrs = l + off_hz * seqlen_q_rounded + offs_m\n \n peer_m_i = tl.load(peer_m_ptrs) \n peer_m_i = peer_m_i.to(tl.float32)\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n peer_l_i = tl.load(peer_l_ptrs) \n peer_l_i = peer_l_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n\n peer_acc = tl.load(peer_o_block_ptr)\n peer_acc = peer_acc.to(tl.float32)\n acc = tl.load(o_block_ptr)\n acc = acc.to(tl.float32)\n lo = 0\n hi = N_CTX\n m_i_sync = tl.maximum(m_i, peer_m_i)\n alpha = tl.math.exp2(m_i - m_i_sync)\n peer_alpha = tl.math.exp2(peer_m_i - m_i_sync)\n acc_scale = l_i * 0 + alpha\n peer_acc_scale = peer_l_i * 0 + peer_alpha\n \n acc *= acc_scale[:, None]\n peer_acc *= peer_acc_scale[:, None]\n acc += peer_acc\n l_i = l_i * acc_scale + peer_l_i * peer_acc_scale\n tl.store(m_ptrs, m_i_sync)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * N_CTX + offs_m\n tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i))\n tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1))\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale, m, l, O, L,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX, seqlen_q_rounded,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr, LAST_STEP: tl.constexpr\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(BLOCK_DMODEL, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n O_block_ptr = tl.make_block_ptr(\n base=O + qvk_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_ptrs = m + off_hz * seqlen_q_rounded + offs_m\n l_ptrs = l + off_hz * seqlen_q_rounded + offs_m\n m_i = tl.load(m_ptrs) \n m_i = m_i.to(tl.float32)\n l_i = tl.load(l_ptrs) \n l_i = l_i.to(tl.float32)\n acc = tl.load(O_block_ptr) \n acc = acc.to(tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option='zero')\n q = (q * qk_scale).to(tl.bfloat16)\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr, boundary_check=(1,), padding_option='zero')\n v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(tl.bfloat16), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n tl.store(m_ptrs, m_i)\n tl.store(l_ptrs, l_i)\n if LAST_STEP:\n acc = acc / l_i[:, None]\n L_ptrs = L + off_hz * seqlen_q_rounded + offs_m\n tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i))\n tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1))\n\ndef _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode):\n BLOCK_M = 128\n BLOCK_N = 64\n\n bsz, nh, unpadded_seq_len, hdim = q.shape\n cu_seq_lens = torch.arange(0, (bsz+1) * unpadded_seq_len, unpadded_seq_len, dtype=torch.int32, device=q.device)\n max_seqlen = unpadded_seq_len\n seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M\n\n m = torch.full((bsz * nh, seqlen_q_rounded), fill_value=-float(\"inf\"), device=q.device, dtype=torch.float32)\n l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n o = torch.zeros_like(q)\n \n grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1)\n num_warps = 4 if q.shape[-1] <= 64 else 8\n \n seq_rank = get_sequence_parallel_rank()\n seq_world_size = get_sequence_parallel_size()\n\n peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o)\n \n fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid](\n q, k, v, sm_scale,\n m, l, o, L,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n seqlen_q_rounded,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=q.shape[-1],\n IS_CAUSAL=IS_CAUSAL,\n LAST_STEP=LAST_STEP,\n num_warps=num_warps,\n num_stages=4\n )\n \n for time_step in range(seq_world_size // 2 + 1):\n torch.cuda.synchronize()\n buffer_idx_1 = time_step % 2\n buffer_idx_2 = (time_step - 1) % 2\n\n reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], \n [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode)\n if comm_mode == \"sync\":\n wait_async_handles(reqs)\n if is_compute_for_local_query(time_step):\n if time_step == 0:\n fwd_launch_helper(q, k, v, m, l, o, L, True, is_last_time(time_step))\n else:\n fwd_launch_helper(q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step))\n elif is_idle(time_step):\n pass\n else:\n peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float(\"inf\"))\n peer_l[buffer_idx_2] = torch.zeros_like(l)\n peer_o[buffer_idx_2] = torch.zeros_like(o)\n fwd_launch_helper(peer_q[buffer_idx_2], k, v, peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False)\n\n if comm_mode == \"lightseq\":\n wait_async_handles(reqs)\n if is_sync_from_remote(time_step):\n seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1]\n _rescale_kernel[grid](\n peer_m[buffer_idx_1], m, peer_l[buffer_idx_1], l, peer_o[buffer_idx_1], o, L,\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n o.shape[0], o.shape[1], o.shape[2],\n seqlen_q_rounded, seqlen_peer_q_rounded,\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=q.shape[-1],\n LAST_STEP=is_last_time(time_step),\n num_warps=num_warps,\n num_stages=4\n )\n return q, k, v, o, L, cu_seq_lens, max_seqlen\n\nclass _attention_varlen(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n comm_mode = 'lightseq'\n q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode)\n\n ctx.save_for_backward(q, k, v, o, L, cu_seq_lens)\n ctx.max_seqlen = max_seqlen\n ctx.sm_scale = sm_scale\n ctx.comm_mode = comm_mode\n return o\n\ndist_attn_varlen = _attention_varlen.apply\n", - "description_1": "Use triton language to implement forward attention mechanism with optional causal mask and scaling in a distributed environment.", - "description_2": "Use triton language to implement kernel functions for scaling and computing forward pass in attention layers.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef flatten_kernel(\n OUT,\n LSE,\n CU_SEQLENS,\n stride_out_nheads,\n stride_out_seqlen,\n stride_lse_batch,\n stride_lse_nheads,\n stride_lse_seqlen,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads\n OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n\n LSE = LSE + rm[:, None] * stride_lse_seqlen\n x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)\n\n OUT = OUT + rm[:, None] * stride_out_seqlen\n tl.store(OUT, x, mask=rm[:, None] < seqlen)\n\n\ndef flatten_varlen_lse(lse, cu_seqlens):\n total_seqlen = cu_seqlens[-1]\n batch_size, nheads, max_seqlen = lse.shape\n output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device)\n\n grid = lambda META: (triton.cdiv(max_seqlen, META[\"BLOCK_M\"]), batch_size, nheads)\n BLOCK_M = 4\n\n with torch.cuda.device(lse.device.index):\n flatten_kernel[grid](\n output,\n lse,\n cu_seqlens,\n output.stride(0),\n output.stride(1),\n lse.stride(0),\n lse.stride(1),\n lse.stride(2),\n BLOCK_M,\n )\n return output\n\n\n@triton.jit\ndef unflatten_kernel(\n OUT,\n LSE,\n CU_SEQLENS,\n stride_out_batch,\n stride_out_nheads,\n stride_out_seqlen,\n stride_lse_seqlen,\n stride_lse_nheads,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n\n LSE = LSE + rm[:, None] * stride_lse_seqlen\n x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)\n\n OUT = OUT + rm[:, None] * stride_out_seqlen\n tl.store(OUT, x, mask=rm[:, None] < seqlen)\n\n\ndef unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):\n lse = lse.unsqueeze(dim=-1)\n batch_size = len(cu_seqlens) - 1\n nheads = lse.shape[1]\n output = torch.empty(\n (batch_size, nheads, max_seqlen),\n dtype=lse.dtype,\n device=lse.device,\n )\n\n grid = lambda META: (triton.cdiv(max_seqlen, META[\"BLOCK_M\"]), batch_size, nheads)\n BLOCK_M = 4\n\n with torch.cuda.device(lse.device.index):\n unflatten_kernel[grid](\n output,\n lse,\n cu_seqlens,\n output.stride(0),\n output.stride(1),\n output.stride(2),\n lse.stride(0),\n lse.stride(1),\n BLOCK_M,\n )\n return output\n", - "description_1": "Use triton language to create two kernels and their wrapper functions. The first kernel 'flatten_kernel' has 8 parameters, which processes a 2D matrix to create a flattened sequence using stride and block parameters for batch and head indices. The 'flatten_varlen_lse' function wraps this kernel and prepares the inputs. The second kernel 'unflatten_kernel' also has 8 parameters, which converts the flattened sequence back to a 3D matrix format using similar parameters, and 'unflatten_varlen_lse' wraps this kernel for input preparation.", - "description_2": "Use triton language to create a kernel that flattens a variable length sequence by adjusting memory strides and block sizes. Another kernel should then unflatten the sequence to its original multi-dimensional structure.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n\n @triton.jit\n def _fwd_kernel(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @triton.jit\n def _fwd_kernel_alibi(\n Q,\n K,\n V,\n K_cache,\n V_cache,\n B_Loc,\n sm_scale,\n B_Start_Loc,\n B_Seqlen,\n B_Ctxlen,\n Alibi_slopes,\n block_size,\n x,\n Out,\n stride_b_loc_b,\n stride_b_loc_s,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_k_cache_bs,\n stride_k_cache_h,\n stride_k_cache_d,\n stride_k_cache_bl,\n stride_k_cache_x,\n stride_v_cache_bs,\n stride_v_cache_h,\n stride_v_cache_d,\n stride_v_cache_bl,\n num_queries_per_kv: int,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n # Kernel implementation\n pass\n\n @torch.inference_mode()\n def context_attention_fwd(q,\n k,\n v,\n o,\n k_cache,\n v_cache,\n b_loc,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n max_input_len,\n alibi_slopes=None):\n\n cap = torch.cuda.get_device_capability()\n BLOCK = 128 if cap[0] >= 8 else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5)\n batch, head = b_seq_len.shape[0], q.shape[1]\n num_queries_per_kv = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 8 if Lk <= 64 else 8\n if alibi_slopes is not None:\n _fwd_kernel_alibi[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n alibi_slopes,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4\n ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3), #[num_blocks, num_kv_heads, head_size, block_size]\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n k_cache,\n v_cache,\n b_loc,\n sm_scale,\n b_start_loc,\n b_seq_len,\n b_ctx_len,\n v_cache.shape[3],\n 8,\n o,\n b_loc.stride(0),\n b_loc.stride(1),\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n k_cache.stride(\n 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(\n 3), #[num_blocks, num_kv_heads, head_size, block_size]\n num_queries_per_kv=num_queries_per_kv,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n", - "description_1": "Use triton language to implement forward attention kernels and their invocations. The kernels '_fwd_kernel', '_fwd_kernel_alibi' take 38 and 39 parameters respectively including inputs for queries, keys, values, caches, masks, strides, and constants. The function 'context_attention_fwd' orchestrates these kernels with 11 inputs including tensors for q, k, v, outputs, caches, and additional configurations like context lengths and alibi slopes.", - "description_2": "Use triton language to create and invoke forward attention kernels for transformer models, managing inputs, caches, and configuration parameters efficiently.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _uniform_to_exponential_kernel(input, output, n: tl.constexpr):\n idx = tl.arange(0, n)\n x = tl.load(input + idx)\n y = _uniform_to_exponential(x)\n tl.store(output + idx, y)\n\ndef test_uniform_to_exponential():\n \"\"\"Test that we can convert uniform to exponential without div by 0.\"\"\"\n input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],\n dtype=torch.float32,\n device=\"cuda\")\n output = torch.zeros(input.shape, dtype=torch.float32, device=\"cuda\")\n _uniform_to_exponential_kernel[(1, )](input, output, 2)\n assert torch.all(torch.isfinite(output))\n assert torch.all(output > 0)\n assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))\n", - "description_1": "Use triton language to implement a kernel that converts uniform random numbers to exponential random numbers. The kernel function '_uniform_to_exponential_kernel' takes three parameters: 'input' (a pointer to the input tensor), 'output' (a pointer to the output tensor), and 'n' (a compile-time constant representing the number of elements to process). The function uses Triton's parallel programming model to load elements from the input tensor, apply the '_uniform_to_exponential' transformation, and store the results in the output tensor. The 'test_uniform_to_exponential' function tests this kernel by creating a tensor of uniform random numbers, invoking the kernel, and verifying that the output tensor contains valid exponential random numbers.", - "description_2": "Use triton language to create a kernel that transforms uniform random numbers into exponential random numbers, and test its correctness by ensuring the output is finite and positive.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fused_moe_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n topk_weights_ptr,\n sorted_token_ids_ptr,\n expert_ids_ptr,\n num_tokens_post_padded_ptr,\n N,\n K,\n EM,\n num_valid_tokens,\n stride_am,\n stride_ak,\n stride_be,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n MUL_ROUTED_WEIGHT: tl.constexpr,\n top_k: tl.constexpr,\n compute_type: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)\n if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:\n return\n offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)\n token_mask = offs_token < num_valid_tokens\n\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +\n offs_k[None, :] * stride_ak)\n\n off_experts = tl.load(expert_ids_ptr + pid_m)\n b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +\n offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs,\n mask=token_mask[:, None] &\n (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n other=0.0)\n b = tl.load(b_ptrs,\n mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,\n other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if MUL_ROUTED_WEIGHT:\n moe_weight = tl.load(topk_weights_ptr + offs_token,\n mask=token_mask,\n other=0)\n accumulator = accumulator * moe_weight[:, None]\n\n accumulator = accumulator.to(compute_type)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[\n None, :]\n c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,\n topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n sorted_token_ids: torch.Tensor,\n expert_ids: torch.Tensor,\n num_tokens_post_padded: torch.Tensor,\n mul_routed_weight: bool, top_k: int,\n config: Dict[str, Any]) -> None:\n assert topk_weights.stride(1) == 1\n assert sorted_token_ids.stride(0) == 1\n\n grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[\n 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )\n\n fused_moe_kernel[grid](\n A,\n B,\n C,\n topk_weights,\n sorted_token_ids,\n expert_ids,\n num_tokens_post_padded,\n B.shape[1],\n B.shape[2],\n sorted_token_ids.shape[0],\n topk_ids.numel(),\n A.stride(0),\n A.stride(1),\n B.stride(0),\n B.stride(2),\n B.stride(1),\n C.stride(1),\n C.stride(2),\n MUL_ROUTED_WEIGHT=mul_routed_weight,\n top_k=top_k,\n compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,\n **config,\n )\n", - "description_1": "Use triton language to implement a fused Mixture of Experts (MoE) kernel. The kernel function 'fused_moe_kernel' takes 23 parameters: pointers to input matrices, matrix dimensions, stride variables, and meta-parameters for block sizes and computation type. It performs block matrix multiplication using token and expert matrices, with optional weighting. The 'invoke_fused_moe_kernel' function calls this kernel with 11 parameters: input tensors, configuration settings, and meta-parameters, setting up the grid for execution.", - "description_2": "Use triton language to create a kernel for block matrix multiplication in a Mixture of Experts model, and a function to invoke this kernel with specific configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\ndef seeded_uniform(\n *size,\n seeds: torch.Tensor,\n out: Optional[torch.Tensor] = None,\n dtype: Optional[torch.dtype] = None,\n device: Optional[Union[torch.device, str]] = None,\n pin_memory: Optional[bool] = False,\n) -> torch.Tensor:\n n_dims = len(size)\n\n if n_dims > 3:\n raise ValueError(\"seeded_uniform only supports up to 3D tensors\")\n\n if out is None:\n out = torch.empty(*size,\n dtype=dtype,\n device=device,\n pin_memory=pin_memory)\n elif out.shape != size:\n raise ValueError(\"shape of out and size must be the same\")\n\n if n_dims == 3:\n n_rows, n_3d, n_cols = out.shape\n stride_row = out.stride(0)\n stride_3d = out.stride(1)\n elif n_dims == 2:\n n_rows, n_cols = out.shape\n n_3d = 1\n stride_row = out.stride(0)\n stride_3d = 1\n else:\n n_cols = out.shape[0]\n n_rows = 1\n n_3d = 1\n stride_row = 1\n stride_3d = 1\n\n if seeds.ndim != 1:\n raise ValueError(\"seeds must be a 1D tensor\")\n\n if seeds.numel() != n_rows:\n raise ValueError(\n \"seeds must have the same number of elements as out has rows\")\n\n full_block_size = triton.next_power_of_2(n_cols)\n philox_block_size = max(full_block_size // 4, 1)\n n_slices = full_block_size // philox_block_size\n num_warps = 4\n if philox_block_size >= 8192:\n num_warps = 32\n elif philox_block_size >= 4096:\n num_warps = 16\n elif philox_block_size >= 2048:\n num_warps = 8\n\n _seeded_uniform_triton[(n_rows, n_3d)](\n out,\n seeds,\n stride_row,\n stride_3d,\n seeds.stride(0),\n n_rows,\n n_3d,\n n_cols,\n n_slices=n_slices,\n num_warps=num_warps,\n block_size=philox_block_size,\n )\n return out\n\n\n@triton.jit\ndef _seeded_uniform_triton(\n out_ptr: torch.Tensor,\n seed_ptr: torch.Tensor,\n out_row_stride: int,\n out_3d_stride: int,\n seed_row_stride: int,\n n_rows: int,\n n_3d: int,\n n_cols: int,\n n_slices: tl.constexpr,\n block_size: tl.constexpr,\n):\n \"\"\"\n Generate a random float32 number in [0, 1) for each element in the output\n tensor. The random numbers in a row generated using the seed for that row.\n\n Args:\n out_ptr: The output tensor.\n seed_ptr: The per-row seeds to use for random number generation.\n out_row_stride: The stride between rows of the output tensor.\n out_3d_stride: The stride between 3D slices of the output tensor.\n seed_row_stride: The stride between rows of the seed tensor.\n n_rows: The number of rows in the output tensor.\n n_3d: The size of second dimension of the output tensor,\n if output tensor is 3D.\n n_cols: The number of columns in the output tensor.\n n_slices: The number of philox outputs to use.\n \"\"\"\n tl.static_assert(n_slices > 0 and n_slices <= 4, \"0 < n_slices <= 4\")\n\n # Get the row index.\n row_idx = tl.program_id(axis=0)\n three_d_idx = tl.program_id(axis=1)\n\n philox_offsets = tl.arange(0, block_size)\n # Get the seed for the current element.\n seed = tl.load(seed_ptr + row_idx * seed_row_stride)\n if three_d_idx > 0:\n seed ^= three_d_idx\n # Generate random numbers in [0, 1).\n out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)\n\n output_row_start_ptr = (out_ptr + row_idx * out_row_stride +\n three_d_idx * out_3d_stride)\n out1_offsets = philox_offsets\n tl.store(output_row_start_ptr + out1_offsets,\n out1,\n mask=out1_offsets < n_cols)\n if n_slices > 1:\n out2_offsets = tl.arange(block_size, block_size * 2)\n tl.store(output_row_start_ptr + out2_offsets,\n out2,\n mask=out2_offsets < n_cols)\n if n_slices > 2:\n out3_offsets = tl.arange(block_size * 2, block_size * 3)\n tl.store(output_row_start_ptr + out3_offsets,\n out3,\n mask=out3_offsets < n_cols)\n if n_slices > 3:\n out4_offsets = tl.arange(block_size * 3, block_size * 4)\n tl.store(output_row_start_ptr + out4_offsets,\n out4,\n mask=out4_offsets < n_cols)\n", - "description_1": "Use triton language to implement a seeded uniform random number generator. The main kernel '_seeded_uniform_triton' accepts 9 parameters: 'out_ptr' (output tensor), 'seed_ptr' (seed tensor), 'out_row_stride' (stride between output rows), 'out_3d_stride' (stride between 3D slices), 'seed_row_stride' (stride between seed rows), 'n_rows' (number of rows), 'n_3d' (3D dimension size), 'n_cols' (number of columns), 'n_slices' (number of philox outputs), and 'block_size' (block size for philox). The random numbers for each row are generated based on the seed for that row and stored in the output tensor. The 'seeded_uniform' function prepares the parameters and invokes the kernel. It calculates the number of rows, columns, and slices based on the tensor dimensions, and sets the block size and warps for optimal performance. The function finally calls the Triton kernel with the prepared arguments.", - "description_2": "Use triton language to implement a random number generator kernel '_seeded_uniform_triton' with 9 parameters to produce float32 numbers for a specified tensor size. Use a wrapper function 'seeded_uniform' to set up tensor dimensions, strides, and kernel parameters, and then launch the Triton kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _uniform_to_exponential(uniform_noise):\n \"\"\"Convert uniform samples to exponential samples.\"\"\"\n lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)\n uniform_noise = tl.maximum(uniform_noise, lb)\n exponential_noise = -tl.log(uniform_noise)\n return exponential_noise\n\n@triton.jit\ndef _sample_triton(\n sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,\n output_logprobs_ptr: torch.Tensor,\n output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,\n logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,\n uniform_noise_ptr: torch.Tensor, output_row_stride: int,\n probs_row_stride: int, uniform_noise_row_stride: int,\n uniform_noise_best_stride: int, n_samples: int, n_cols: int,\n n_best: int, block_size: tl.constexpr,\n modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,\n save_modified_probs: tl.constexpr):\n # The rows are independent, so we parallelize across those\n sample_idx = tl.program_id(0)\n best_idx = tl.program_id(1)\n\n # Load the row index from DRAM\n row_idx = tl.load(sample_indices_ptr + sample_idx)\n seed = tl.load(seeds_ptr + sample_idx)\n uses_random_sampling = seed != 0\n\n # The stride represents how much we need to increase the\n # pointer to advance 1 row\n row_start_ptr = probs_ptr + row_idx * probs_row_stride\n\n # The block size is the next power of two greater than n_cols,\n # so we can fit each row in a single block\n col_offsets = tl.arange(0, block_size)\n\n # Load the row into SRAM, using a mask since block_size may be > than n_cols\n row = tl.load(row_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=float(\"-inf\"))\n\n if uses_random_sampling:\n uniform_noise_start_ptr = (uniform_noise_ptr +\n sample_idx * uniform_noise_row_stride +\n best_idx * uniform_noise_best_stride)\n uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,\n mask=col_offsets < n_cols,\n other=0.5)\n exponential_noise = _uniform_to_exponential(uniform_noise)\n row /= exponential_noise\n\n sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)\n # clamp sampled token to n_cols - 1\n if sampled_token >= n_cols:\n sampled_token = n_cols - 1\n # Write back output to DRAM\n output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +\n best_idx)\n tl.store(output_row_start_ptr, sampled_token)\n\n if modify_greedy_probs:\n if not uses_random_sampling:\n row = tl.where(col_offsets == sampled_token, 1.0, 0.0)\n tl.store(row_start_ptr + col_offsets,\n row,\n mask=col_offsets < n_cols)\n\n if save_modified_probs:\n output_row_start_ptr = (output_modified_probs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_value)\n\n if save_logprobs:\n sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +\n sampled_token)\n output_row_start_ptr = (output_logprobs_ptr +\n sample_idx * output_row_stride + best_idx)\n tl.store(output_row_start_ptr, sampled_logprob)\n", - "description_1": "Use triton language to implement two kernels: _uniform_to_exponential and _sample_triton. The first kernel converts uniform noise into exponential noise using an inversion method to avoid log(0) error. It takes one parameter: uniform_noise. The second kernel, _sample_triton, samples tokens from a distribution with optional logprob and noise saving. It takes 18 parameters: sample_indices_ptr (the indices for sampling), output_ptr (tensor for output samples), output_logprobs_ptr (tensor for log probabilities of samples), output_modified_probs_ptr (tensor for modified probabilities), probs_ptr (probability distribution), logprobs_ptr (log probabilities of distribution), seeds_ptr (seed for random sampling), uniform_noise_ptr (uniform noise for sampling), output_row_stride, probs_row_stride, uniform_noise_row_stride, uniform_noise_best_stride, n_samples (number of samples), n_cols (number of columns), n_best (number of best samples), block_size, modify_greedy_probs (flag to modify greedy probabilities), save_logprobs (flag to save log probabilities), save_modified_probs (flag to save modified probabilities).", - "description_2": "Use triton language to create a kernel that converts uniform distribution to exponential and another kernel that performs token sampling from a probability matrix with options to save log and modified probabilities.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom flash import (\n bare_attn_fwd,\n bare_attn_bwd,\n)\n\nTRITON_CONFIG_LIST_FWD = [\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4),\n]\n\n@triton.autotune(\n configs=TRITON_CONFIG_LIST_FWD,\n key=['max_seqlen_q', 'max_seqlen_k', 'CAUSAL'],\n)\n@triton.jit\ndef tuned_attn_fwd(\n Q, K, V, B, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_on,\n num_head_q,\n num_head_k,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed_ptr,\n philox_offset1,\n philox_offset2,\n philox_seed_output,\n philox_offset_output,\n encoded_softmax,\n CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n bare_attn_fwd(\n Q, K, V, B, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_on,\n num_head_q,\n num_head_k,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed_ptr,\n philox_offset1,\n philox_offset2,\n philox_seed_output,\n philox_offset_output,\n encoded_softmax,\n CAUSAL,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n\nTRITON_CONFIG_LIST_BWD_FUSED = []\nfor BLOCK_M1 in [16, 32, 64]:\n for BLOCK_N1 in [16, 32, 64, 128, 256]:\n if BLOCK_N1 % BLOCK_M1 != 0:\n continue\n for BLOCK_M2 in [16, 32]:\n for BLOCK_N2 in [16, 32]:\n if BLOCK_M2 % BLOCK_N2 != 0:\n continue\n dic = {'BLOCK_M1': BLOCK_M1, 'BLOCK_N1': BLOCK_N1}\n dic['BLOCK_M2'] = BLOCK_M2\n dic['BLOCK_N2'] = BLOCK_N2\n dic['BLK_SLICE_FACTOR'] = 2\n for waves_per_eu in range(0, 4+1):\n dic['waves_per_eu'] = waves_per_eu\n for num_stages in [0, 1]:\n for num_warps in [1,2,4,8]:\n cfg = triton.Config(dic, num_stages=num_stages, num_warps=num_warps)\n TRITON_CONFIG_LIST_BWD_FUSED.append(cfg)\n\n@triton.autotune(\n configs=TRITON_CONFIG_LIST_BWD_FUSED,\n key=['max_seqlen_q', 'max_seqlen_k', 'head_dim'],\n)\n@triton.jit\ndef tuned_attn_bwd(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV, DQ, DB,\n L, D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n num_head_q,\n num_head_k,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed_ptr,\n philox_offset1,\n philox_offset2,\n BLOCK_DMODEL: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n BLOCK_M1: tl.constexpr,\n BLOCK_N1: tl.constexpr,\n BLOCK_M2: tl.constexpr,\n BLOCK_N2: tl.constexpr,\n BLK_SLICE_FACTOR: tl.constexpr,\n):\n bare_attn_bwd(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV, DQ, DB,\n L, D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n num_head_q,\n num_head_k,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed_ptr,\n philox_offset1,\n BLOCK_DMODEL,\n CAUSAL,\n ENABLE_DROPOUT,\n PADDED_HEAD,\n BIAS_TYPE,\n BLOCK_M1,\n BLOCK_N1,\n BLOCK_M2,\n BLOCK_N2,\n BLK_SLICE_FACTOR,\n )\n", - "description_1": "Use triton language to implement two kernels: tuned_attn_fwd and tuned_attn_bwd. The tuned_attn_fwd kernel takes 39 parameters including input tensors Q, K, V, B, and output tensor Out, along with various strides, dimensions, and configuration constants. It performs forward attention computation using the bare_attn_fwd function. The tuned_attn_bwd kernel takes 50 parameters including input tensors Q, K, V, B, and output tensors DK, DV, DQ, DB, along with various strides, dimensions, and configuration constants. It performs backward attention computation using the bare_attn_bwd function.", - "description_2": "Use triton language to create forward and backward attention kernels with autotuning capabilities, leveraging triton's configuration system to optimize performance for different input sizes and configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom masked_load_store import mload2d, mstore2d\nfrom bwd_inner_dkdv import bwd_kernel_dk_dv\nfrom bwd_inner_dq import bwd_kernel_dq\n\n@triton.jit\ndef attn_bwd(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV, DQ, DB,\n L, D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n num_head_q : 'i32',\n num_head_k : 'i32',\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens : 'i32', # set num_seqlens to zero to ignore cu_seqlens_q/k\n max_seqlen_q, # and use max_seqlen_q/k for all seqlen_q/k\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_DMODEL: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n BLOCK_M1: tl.constexpr,\n BLOCK_N1: tl.constexpr,\n BLOCK_M2: tl.constexpr,\n BLOCK_N2: tl.constexpr,\n BLK_SLICE_FACTOR: tl.constexpr,\n):\n LN2: tl.constexpr = 0.6931471824645996 # = ln(2)\n qk_scale = sm_scale * 1.44269504089\n\n off_h = tl.program_id(1) # head index\n off_z = tl.program_id(2) # batch index, for varlen it indicates index in cu_seqlens_q/k\n pid = tl.program_id(0)\n\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = max_seqlen_q\n seqlen_k = max_seqlen_k\n batch_index = off_z\n\n if num_seqlens > 0:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n batch_index = 0\n\n if num_seqlens < 0: # for padded seqlen\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n batch_index = off_z\n\n off_zh = batch_index * num_head_q + off_h * 1\n\n q_offset = off_h * stride_qh + batch_index * stride_qz + cu_seqlens_q_start * stride_qm\n Q += q_offset\n k_offset = off_h * stride_kh + batch_index * stride_kz + cu_seqlens_k_start * stride_kn\n K += k_offset\n v_offset = off_h * stride_vh + batch_index * stride_vz + cu_seqlens_k_start * stride_vk\n V += v_offset\n do_offset = off_h * stride_oh + batch_index * stride_oz + cu_seqlens_q_start * stride_om\n DO += do_offset\n dk_offset = off_h * stride_dkh + batch_index * stride_dkz + cu_seqlens_k_start * stride_dkn\n DK += dk_offset\n dv_offset = off_h * stride_dvh + batch_index * stride_dvz + cu_seqlens_k_start * stride_dvk\n DV += dv_offset\n dq_offset = off_h * stride_dqh + batch_index * stride_dqz + cu_seqlens_q_start * stride_dqm\n DQ += dq_offset\n\n L += off_zh * max_seqlen_q\n D += off_zh * max_seqlen_q\n\n alibi_slope = None\n\n start_n = pid * BLOCK_N1\n start_m = start_n if CAUSAL else 0\n\n if start_n < seqlen_k:\n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR\n\n dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)\n\n k = mload2d(BLOCK_N1, BLOCK_DMODEL,\n i_base=K,\n i_start_row=start_n,\n i_start_col=0,\n i_rows=seqlen_k,\n i_cols=head_dim,\n stride_row=stride_kn,\n stride_col=stride_kk,\n )\n k = (k * qk_scale).to(K.dtype.element_ty)\n v = mload2d(BLOCK_N1, BLOCK_DMODEL,\n i_base=V,\n i_start_row=start_n,\n i_start_col=0,\n i_rows=seqlen_k,\n i_cols=head_dim,\n stride_row=stride_vk,\n stride_col=stride_vn,\n )\n\n if CAUSAL:\n num_steps = BLOCK_N1 // MASK_BLOCK_M1\n dk, dv = bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope,\n DO, L, D,\n stride_qm, stride_qk,\n stride_om, stride_ok,\n seqlen_q,\n seqlen_k,\n head_dim,\n MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,\n start_n, start_m, num_steps,\n MASK=True, PADDED_HEAD=PADDED_HEAD)\n start_m += num_steps * MASK_BLOCK_M1\n\n num_steps = (seqlen_q - start_m) // BLOCK_M1\n\n dk, dv = bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope,\n DO, L, D,\n stride_qm, stride_qk,\n stride_om, stride_ok,\n seqlen_q,\n seqlen_k,\n head_dim,\n BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,\n start_n, start_m, num_steps,\n MASK=False, PADDED_HEAD=PADDED_HEAD)\n\n mstore2d(dv.to(v.dtype),\n BLOCK_N1,\n BLOCK_DMODEL,\n o_base=DV,\n o_start_row=start_n,\n o_start_col=0,\n o_rows=seqlen_k,\n o_cols=head_dim,\n stride_row=stride_dvk,\n stride_col=stride_dvn)\n\n mstore2d((dk * sm_scale).to(k.dtype),\n BLOCK_N1,\n BLOCK_DMODEL,\n o_base=DK,\n o_start_row=start_n,\n o_start_col=0,\n o_rows=seqlen_k,\n o_cols=head_dim,\n stride_row=stride_dkn,\n stride_col=stride_dkk)\n\n start_m = pid * BLOCK_M2\n end_n = start_m + BLOCK_M2 if CAUSAL else seqlen_k\n\n if start_m < seqlen_q:\n MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR\n offs_m = start_m + tl.arange(0, BLOCK_M2)\n\n Q_block_ptr = tl.make_block_ptr(base=Q, shape=(seqlen_q, head_dim), strides=(stride_qm, stride_qk),\n offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0))\n\n DO_block_ptr = tl.make_block_ptr(base=DO, shape=(seqlen_q, head_dim), strides=(stride_om, stride_ok),\n offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0))\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n do = tl.load(DO_block_ptr)\n dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)\n\n m = tl.load(L + offs_m)\n m = m[:, None]\n\n num_steps = BLOCK_M2 // MASK_BLOCK_N2\n if CAUSAL:\n dq = bwd_kernel_dq(dq, q, K, V, alibi_slope,\n do, m, D,\n stride_kn, stride_kk,\n stride_vk, stride_vn,\n seqlen_q,\n seqlen_k,\n head_dim,\n BLOCK_M2,\n MASK_BLOCK_N2,\n BLOCK_DMODEL,\n start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,\n MASK=True, PADDED_HEAD=PADDED_HEAD)\n end_n -= num_steps * MASK_BLOCK_N2\n\n num_steps = end_n // BLOCK_N2\n dq = bwd_kernel_dq(dq, q, K, V, alibi_slope,\n do, m, D,\n stride_kn, stride_kk,\n stride_vk, stride_vn,\n seqlen_q,\n seqlen_k,\n head_dim,\n BLOCK_M2,\n BLOCK_N2,\n BLOCK_DMODEL,\n start_m, end_n - num_steps * BLOCK_N2, num_steps,\n MASK=False, PADDED_HEAD=PADDED_HEAD)\n\n mstore2d((dq * sm_scale).to(q.dtype),\n BLOCK_M2,\n BLOCK_DMODEL,\n o_base=DQ,\n o_start_row=start_m,\n o_start_col=0,\n o_rows=seqlen_q,\n o_cols=head_dim,\n stride_row=stride_dqm,\n stride_col=stride_dqk)\n", - "description_1": "Use triton language to implement an attention backward pass kernel named attn_bwd. The kernel takes multiple parameters: Q, K, V, B, sm_scale, Out, DO, DK, DV, DQ, DB, L, D, various strides for input tensors, num_head_q, num_head_k, cu_seqlens_q, cu_seqlens_k, num_seqlens, max_seqlen_q, max_seqlen_k, head_dim, dropout_p, philox_seed, philox_offset_base, and several constants: BLOCK_DMODEL, CAUSAL, ENABLE_DROPOUT, PADDED_HEAD, BIAS_TYPE, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2, and BLK_SLICE_FACTOR. The function performs backward computations for keys and values (DK and DV) and queries (DQ) using the specified attention mechanism, handling various cases such as causality, padding, and variable sequence lengths.", - "description_2": "Implement a triton attention backward kernel for handling DK, DV, and DQ with support for causal and padded sequences.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom masked_load_store import mload1d, mload2d\n\n@triton.jit\ndef bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope,\n DO, M, D,\n stride_qm, stride_qk,\n stride_om, stride_ok,\n seqlen_q,\n seqlen_k,\n head_dim,\n BLOCK_M1: tl.constexpr,\n BLOCK_N1: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n start_n, start_m, num_steps,\n MASK: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n ):\n offs_n = start_n + tl.arange(0, BLOCK_N1)\n QT_block_ptr = tl.make_block_ptr(base=Q, shape=(head_dim, seqlen_q), strides=(stride_qk, stride_qm),\n offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1))\n DO_block_ptr = tl.make_block_ptr(base=DO, shape=(seqlen_q, head_dim), strides=(stride_om, stride_ok),\n offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0))\n tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)\n curr_m = start_m\n step_m = BLOCK_M1\n for blk_idx in range(num_steps):\n qT = mload2d(BLOCK_DMODEL, BLOCK_M1,\n i_base=Q,\n i_start_row=0,\n i_start_col=curr_m,\n i_rows=head_dim,\n i_cols=seqlen_q,\n stride_row=stride_qk,\n stride_col=stride_qm,\n )\n offs_m = curr_m + tl.arange(0, BLOCK_M1)\n if curr_m + BLOCK_M1 <= seqlen_q:\n m = tl.load(M + offs_m)\n else:\n m = mload1d(BLOCK_M1, i_base=M, i_start=curr_m, i_nums=seqlen_q)\n kqT = tl.dot(k, qT)\n pT = tl.math.exp2(kqT - m[None, :])\n if MASK:\n mask = (offs_m[None, :] >= offs_n[:, None])\n pT = tl.where(mask, pT, 0.0)\n do = tl.load(DO_block_ptr)\n ppT = pT\n ppT = ppT.to(DO_block_ptr.dtype.element_ty)\n dv += tl.dot(ppT, do)\n Di = tl.load(D + offs_m)\n dpT = tl.dot(v, tl.trans(do))\n dsT = (dpT - Di[None, :]) * pT\n dk += tl.dot(dsT.to(QT_block_ptr.dtype.element_ty), tl.trans(qT))\n curr_m += step_m\n QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))\n DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))\n return dk, dv\n", - "description_1": "Use triton language to implement a backward kernel function 'bwd_kernel_dk_dv' with 24 parameters. The function computes gradients for dk and dv using inputs Q, k, v, DO, M, D, and other parameters like strides, sequence lengths, head dimension, block sizes, and constants for masking and padding. The kernel performs matrix operations and uses masked loading to handle data efficiently.", - "description_2": "Use triton language to create a backward kernel for computing gradients of dk and dv with matrix operations and masked loading.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef bwd_kernel_dq(dq, q, K, V, alibi_slope,\n do, m, D,\n stride_kn, stride_kk,\n stride_vk, stride_vn,\n seqlen_q,\n seqlen_k,\n head_dim,\n BLOCK_M2: tl.constexpr,\n BLOCK_N2: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n start_m, start_n, num_steps,\n MASK: tl.constexpr,\n PADDED_HEAD: tl.constexpr):\n offs_m = start_m + tl.arange(0, BLOCK_M2)\n offs_n = start_n + tl.arange(0, BLOCK_N2)\n KT_block_ptr = tl.make_block_ptr(base=K, shape=(head_dim, seqlen_k), strides=(stride_kk, stride_kn),\n offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1))\n VT_block_ptr = tl.make_block_ptr(base=V, shape=(head_dim, seqlen_k), strides=(stride_vn, stride_vk),\n offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1))\n Di = tl.load(D + offs_m)\n tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)\n curr_n = start_n\n step_n = BLOCK_N2\n for blk_idx in range(num_steps):\n if PADDED_HEAD:\n kT = tl.load(KT_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n else:\n kT = tl.load(KT_block_ptr)\n qk = tl.dot(q, kT)\n p = tl.math.exp2(qk - m)\n if MASK:\n offs_n = curr_n + tl.arange(0, BLOCK_N2)\n mask = (offs_m[:, None] >= offs_n[None, :])\n p = tl.where(mask, p, 0.0)\n if PADDED_HEAD:\n vT = tl.load(VT_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n else:\n vT = tl.load(VT_block_ptr)\n dp = tl.dot(do, vT).to(tl.float32)\n ds = p * (dp - Di[:, None])\n ds = ds.to(KT_block_ptr.type.element_ty)\n dq += tl.dot(ds, tl.trans(kT))\n curr_n += step_n\n KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))\n VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))\n return dq\n", - "description_1": "Use triton language to implement a backward kernel for computing gradients for Q (dq) in an attention mechanism. The kernel takes 27 arguments: 9 tensor arguments including dq, q, K, V, alibi_slope, do, m, and D for inputs, and 4 strides stride_kn, stride_kk, stride_vk, stride_vn; 4 scalar parameters: seqlen_q, seqlen_k, head_dim, and 8 constants: BLOCK_M2, BLOCK_N2, BLOCK_DMODEL for block sizes, start_m, start_n for starting indices, num_steps for loop iterations, MASK, and PADDED_HEAD for conditional behavior within the kernel. The main computations include making block pointers for tensors, performing matrix multiplications, masking, and updating pointers for next blocks.", - "description_2": "Use triton language to implement a backward kernel for computing gradients for Q in attention mechanism using block-wise matrix operations and masking.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom masked_load_store import load_fn\nfrom dropout import dropout_mask\n\n# Kernel to compute dot products with support for small blocks\n@triton.jit\ndef dot(BLOCK_M: tl.constexpr, QDIM: tl.constexpr, KDIM: tl.constexpr, q, k):\n if BLOCK_M == 1:\n return tl.sum(tl.view(q, [QDIM]) * tl.view(k, [KDIM]))\n else:\n return tl.dot(q, k)\n\n# Backward kernel for computing dk and dv in attention mechanism\n@triton.jit\ndef bwd_kernel_dk_dv_common(\n q_ptrs, q_stride, kt, vt, B_block_ptr,\n sm_scale, do_ptrs, do_stride,\n l_ptrs,\n D_ptrs,\n seqlen_q,\n seqlen_k,\n start_m,\n head_dim,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n max_seqlen_k,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n # Kernel computation logic...\n return (dk * sm_scale).to(kt.type.element_ty), dv.to(vt.type.element_ty)\n\n# Backward kernel for computing dq and db in attention mechanism\n@triton.jit\ndef bwd_kernel_dq_db_common(\n q, kt_ptrs, k_stride, vt_ptrs, v_stride, B_block_ptr,\n sm_scale, do,\n dq, DB_block_ptr, store_db,\n l_ptrs,\n D_ptrs,\n seqlen_q,\n seqlen_k,\n start_m,\n head_dim,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n max_seqlen_k,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n # Kernel computation logic...\n return (dq * sm_scale).to(dq.type.element_ty)\n", - "description_1": "Use triton language to implement two backward kernel functions for attention mechanisms. The first kernel, `bwd_kernel_dk_dv_common`, computes gradients with respect to keys and values (dk, dv) and supports dropout and bias adjustments. The second kernel, `bwd_kernel_dq_db_common`, computes gradients with respect to queries and bias (dq, db) and handles causal attention and dropout.", - "description_2": "Use triton language to implement backward kernels for computing gradients in attention mechanisms, including support for dropout and bias.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef bwd_preprocess(\n Out, DO,\n Delta,\n stride_oz, stride_oh, stride_om, stride_on,\n stride_doz, stride_doh, stride_dom, stride_don,\n seqlen_q,\n BLOCK_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n # Calculate offsets\n off_m = tl.program_id(0) * BLOCK_M\n off_h = tl.program_id(1) # head index\n off_z = tl.program_id(2) # batch index\n num_h = tl.num_programs(1)\n o_offset = off_h * stride_oh + off_z * stride_oz\n # Create block pointers\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, D_HEAD),\n strides=(stride_om, stride_on),\n offsets=(off_m, 0),\n block_shape=(BLOCK_M, D_HEAD),\n order=(1, 0)\n )\n do_offset = off_h * stride_doh + off_z * stride_doz\n DO_block_ptr = tl.make_block_ptr(\n base=DO + do_offset,\n shape=(seqlen_q, D_HEAD),\n strides=(stride_dom, stride_don),\n offsets=(off_m, 0),\n block_shape=(BLOCK_M, D_HEAD),\n order=(1, 0)\n )\n # Load tensors\n o = tl.load(O_block_ptr).to(tl.float32)\n do = tl.load(DO_block_ptr).to(tl.float32)\n # Compute delta\n delta = tl.sum(o * do, axis=1)\n # Write-back result\n off_zh = off_z * num_h + off_h * 1\n tl.store(Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M), delta)\n", - "description_1": "Use triton language to implement a backward preprocessing kernel for a fused attention mechanism. The kernel takes in output tensors, a delta tensor, and various stride parameters to calculate a delta for gradient updates. It has 11 arguments: Out (output tensor), DO (gradient output tensor), Delta (result delta tensor), stride_oz, stride_oh, stride_om, stride_on (strides for the output tensor), stride_doz, stride_doh, stride_dom, stride_don (strides for the gradient output tensor), seqlen_q (sequence length), BLOCK_M and D_HEAD (constant block and head dimensions).", - "description_2": "Use triton language to create a kernel that computes delta for gradient updates in a fused attention model, using output and gradient tensors with specified strides and dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Helper function, but not always usable due to compiler bugs (esp. used with tl.trans)\n@triton.jit\ndef dot(BLOCK_M : tl.constexpr, QDIM : tl.constexpr, KDIM : tl.constexpr, q, k):\n if BLOCK_M == 1:\n return tl.sum(tl.view(q, [QDIM]) * tl.view(k, [KDIM]))\n else:\n return tl.dot(q, k)\n\n# TODO: Remove Unused 'Out' Argument from kernels below\n@triton.jit\ndef bwd_kernel_dk_dv(\n Q, K, V, sm_scale, Out, DO,\n DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n seqlen_q, seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n):\n start_m = tl.program_id(0) * BLOCK_N\n off_h = tl.program_id(1) # head index\n off_z = tl.program_id(2) # batch index\n num_h = tl.num_programs(1)\n num_z = tl.num_programs(2)\n # initialize offsets\n offs_m = start_m + tl.arange(0, BLOCK_N)\n offs_n = tl.arange(0, BLOCK_M)\n # Initialize pointers to Q, K, V\n # Q is consumed depending on block ID. Every block uses\n # previous block offset by BLOCK_M x D_HEAD.\n q_offset = off_h * stride_qh + off_z * stride_qz\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_h * stride_kh + off_z * stride_kz\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, start_m),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_h * stride_vh + off_z * stride_vz\n VT_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_vn, stride_vk),\n offsets=(0, start_m),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n do_offset = q_offset\n DO_block_ptr = tl.make_block_ptr(\n base=DO + do_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n off_zh = off_z * num_h + off_h * 1\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_zh * seqlen_q\n l_ptrs = L + off_zh * seqlen_q\n qk_scale = sm_scale * 1.44269504\n # load k and v: they will stay in SRAM throughout\n k = tl.load(K_block_ptr) # (BLOCK_DMODEL, BLOCK_N)\n k = (k * qk_scale).to(K_block_ptr.type.element_ty)\n vt = tl.load(VT_block_ptr) # (BLOCK_DMODEL, BLOCK_N)\n dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)\n # This lower loop bound is because of the causal mask. We create a lower triangular\n # result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can\n # be ignored in the GEMM.\n lo = (start_m // BLOCK_M) * BLOCK_M if CAUSAL else 0\n hi = seqlen_q\n Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0))\n DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0))\n batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k\n '''\n K1 K2 (d)V dO\n Q1 qk11 qk12 (d)v1 dO1\n Q2 qk21 qk22 (d)v2 dO2\n\n QK: (seqlen_q, seqlen_k)\n dO: (seqlen_q, hdim)\n dV: (seqlen_k, hdim)\n\n dV = (QK)^T dO\n\n dV1 = qk11 dO1 + qk21 dO2 = q1 k1 dO1 + q2 k1 dO2\n dV2 = qk12 dO1 + qk22 dO2 = q1 k2 dO1 + q2 k2 dO2\n ~~~~~ = 0\n start_m: select k and dV\n start_n: select q and dO\n '''\n # loop over q (seqlen_q, dhead), do (seqlen_q, d_head)\n for start_n in range(lo, hi, BLOCK_M):\n offs_m_curr = offs_n[:, None] + start_n # (BLOCK_M, 1)\n # -- load q, do --\n q = tl.load(Q_block_ptr) # (BLOCK_M, BLOCK_DMODEL), offs = (BLOCK_M * iter, 0) = (start_n, 0)\n do = tl.load(DO_block_ptr) # (BLOCK_M, BLOCK_DMODEL)\n # -- compute qk ----\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n # q.offs = (start_n, 0), k.offs = (0, start_m)\n qk += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, k) # (BLOCK_M, BLOCK_N)\n if CAUSAL:\n qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float(\"-inf\"))\n l_i = tl.load(l_ptrs + offs_m_curr) # (BLOCK_M, 1)\n p = tl.math.exp2(qk - l_i) # (BLOCK_M, BLOCK_N)\n # -- compute dv ----\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_n * seqlen_k + start_m\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k)\n # CAVEAT: do NOT update p, ds needs the original p\n if BLOCK_M == 1:\n dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(Q.dtype.element_ty) * do\n else:\n dv += tl.dot(tl.where(tl.trans(keep), tl.trans(p) / (1 - dropout_p), 0.0).to(Q.dtype.element_ty), do)\n else:\n if BLOCK_M == 1:\n dv += p.to(Q.dtype.element_ty) * do\n else:\n # dv += tl.dot(tl.trans(p.to(do.dtype)), do)\n dv += tl.dot(tl.trans(p).to(do.dtype), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr) # (BLOCK_M, 1)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n # dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt)\n # do.shape = (BLOCK_M, BLOCK_DMODEL) vt.shape = (BLOCK_DMODEL, BLOCK_N)\n dp += tl.dot(do, vt)\n if ENABLE_DROPOUT:\n dp = tl.where(keep, dp / (1 - dropout_p), 0)\n # compute ds = p * (dp - delta[:, None])\n ds = p * (dp - Di) # (BLOCK_M, BLOCK_N)\n # compute dk\n if BLOCK_M == 1:\n dk += ds.to(Q.dtype.element_ty) * q\n else:\n # ds.shape = (BLOCK_M, BLOCK_N), q.shape = (BLOCK_M, BLOCK_DMODEL)\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) # (BLOCK_N, BLOCK_DMODEL)\n # update pointers\n Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))\n DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems\n # initialize pointers to output\n DK_block_ptr = tl.make_block_ptr(\n base=DK + k_offset,\n shape=(seqlen_k, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n DV_block_ptr = tl.make_block_ptr(\n base=DV + v_offset,\n shape=(seqlen_k, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(start_m, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(DK_block_ptr, (dk * sm_scale).to(DK.type.element_ty))\n tl.store(DV_block_ptr, dv.to(DV.type.element_ty))\n\n@triton.jit\ndef bwd_kernel_dq(\n Q, K, V, sm_scale, Out, DO,\n DQ,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n seqlen_q, seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n):\n start_m = tl.program_id(0) * BLOCK_M\n off_h = tl.program_id(1) # head index\n off_z = tl.program_id(2) # batch index\n num_h = tl.num_programs(1)\n num_z = tl.num_programs(2)\n # initialize offsets\n offs_m = start_m + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # Initialize pointers to Q, K, V\n q_offset = off_h * stride_qh + off_z * stride_qz\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_h * stride_kh + off_z * stride_kz\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_h * stride_vh + off_z * stride_vz\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_vn, stride_vk),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n DO_block_ptr = tl.make_block_ptr(\n base=DO + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n off_zh = off_z * num_h + off_h * 1\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_zh * seqlen_q\n l_ptrs = L + off_zh * seqlen_q\n qk_scale = sm_scale * 1.44269504\n # load q and do: they will stay in SRAM throughout\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n do = tl.load(DO_block_ptr)\n Di = tl.load(D_ptrs + offs_m)\n l_i = tl.load(l_ptrs + offs_m)\n dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # loop over k, v\n lo = 0\n hi = min(start_m + BLOCK_M, seqlen_k) if CAUSAL else seqlen_k\n batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k\n '''\n K1 K2 (d)V dO\n Q1 qk11 qk12 (d)v1 dO1\n Q2 qk21 qk22 (d)v2 dO2\n\n QK: (seqlen_q, seqlen_k)\n dO: (seqlen_q, hdim)\n dV: (seqlen_k, hdim)\n '''\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n kt = tl.load(K_block_ptr) # shape = (BLOCK_DMODEL, BLOCK_N), offs = (0, BLOCK_N * iter) = (0, start_n)\n vt = tl.load(V_block_ptr)\n # -- compute qk ----\n # q.offs = (start_m, 0), k.offs = (0, start_n)\n qk = dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, kt)\n if CAUSAL:\n qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float(\"-inf\"))\n p = tl.math.exp2(qk - l_i[:, None])\n # compute dp = dot(v, do)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt)\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_m * seqlen_k + start_n\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k)\n dp = tl.where(keep, dp / (1 - dropout_p), 0)\n # compute ds = p * (dp - delta[:, None])\n ds = p * (dp - Di[:, None])\n # compute dq. Unfortunately we cannot avoid transpose here as this loop\n # uses k both normal and transpose.\n if BLOCK_M == 1:\n dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(Q.type.element_ty)\n else:\n # ds.shape = (BLOCK_M, BLOCK_N), kt.shape = (BLOCK_DMODEL, BLOCK_N)\n dq += tl.dot(ds.to(Q.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL)\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))\n # initialize pointers to output\n DQ_block_ptr = tl.make_block_ptr(\n base=DQ + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty))\n", - "description_1": "Use triton language to implement three kernels: 'dot', 'bwd_kernel_dk_dv', and 'bwd_kernel_dq'. The 'dot' function takes BLOCK_M, QDIM, KDIM, q, and k as inputs to perform matrix multiplication or dot product based on BLOCK_M. The 'bwd_kernel_dk_dv' function takes 28 arguments including matrices Q, K, V, dropout parameters, scale factors, and block dimensions to compute backward gradients dk and dv. The 'bwd_kernel_dq' function takes similar parameters to compute the gradient dq with respect to the query matrix Q.", - "description_2": "Use triton language to create kernels for computing backward gradients in a flash attention mechanism, allowing dropout and causal attention masks.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, m, n, stride):\n ms = tl.arange(0, m)\n ns = tl.arange(0, n)\n return philox_offset + ms[:, None] * stride + ns[None, :]\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, m, n, stride):\n rng_offsets = dropout_offsets(philox_seed, philox_offset, m, n, stride).to(tl.uint32)\n # TODO: use tl.randint for better performance\n return tl.rand(philox_seed, rng_offsets)\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_output = dropout_rng(philox_seed, philox_offset, m, n, stride)\n rng_keep = rng_output > dropout_p\n return rng_keep\n", - "description_1": "Use triton language to implement a series of functions for dropout operations. The first function, dropout_offsets, takes 5 parameters: philox_seed, philox_offset, m, n, and stride. It calculates offsets for random number generation. The second function, dropout_rng, also takes 5 parameters: philox_seed, philox_offset, m, n, and stride. It generates random numbers using the offsets calculated by dropout_offsets. The third function, dropout_mask, takes 6 parameters: philox_seed, philox_offset, dropout_p, m, n, and stride. It generates a mask for dropout by comparing random numbers to a dropout probability.", - "description_2": "Use triton language to create functions for generating random offsets, random numbers, and dropout masks using given seeds, offsets, dimensions, and stride.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom dropout import dropout_rng\n\n# Kernel to initialize a random number generator for dropout\n@triton.jit\ndef debug_fill_dropout_rng(R,\n stride_rz, stride_rh, stride_rm, stride_rn,\n seqlen_q, seqlen_k,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n start_m = tl.program_id(0)\n off_h = tl.program_id(1) # head index\n off_z = tl.program_id(2) # batch index\n d_offset = off_h * stride_rh + off_z * stride_rz\n num_h = tl.num_programs(1)\n off_zh = off_z * num_h + off_h * 1\n batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k\n R_block_ptr = tl.make_block_ptr(\n base=R + d_offset,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_rm, stride_rn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0)\n )\n for start_n in range(0, seqlen_k, BLOCK_N):\n philox_offset = batch_philox_offset + start_m * BLOCK_M * seqlen_k + start_n\n rng = dropout_rng(philox_seed, philox_offset, BLOCK_M, BLOCK_N, seqlen_k)\n tl.store(R_block_ptr, rng.to(R_block_ptr.type.element_ty), boundary_check=(0, 1))\n R_block_ptr = tl.advance(R_block_ptr, (0, BLOCK_N))\n\n# Wrapper kernel to initialize a random number generator for dropout using tensor inputs\n@triton.jit\ndef debug_fill_dropout_rng_tensor(R,\n stride_rz, stride_rh, stride_rm, stride_rn,\n seqlen_q, seqlen_k,\n philox_seed_ptr,\n philox_offset_base_ptr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n philox_seed = tl.load(philox_seed_ptr)\n philox_offset_base = tl.load(philox_offset_base_ptr)\n debug_fill_dropout_rng(R,\n stride_rz, stride_rh, stride_rm, stride_rn,\n seqlen_q, seqlen_k,\n philox_seed,\n philox_offset_base,\n BLOCK_M,\n BLOCK_N,\n )\n", - "description_1": "Use triton language to implement a random number generator initialization for dropout, with two kernels. The first kernel, debug_fill_dropout_rng, has 10 parameters: R (output tensor), stride_rz, stride_rh, stride_rm, stride_rn (stride sizes for tensor R), seqlen_q, seqlen_k (sequence lengths), philox_seed (random seed), philox_offset_base (offset base for RNG), BLOCK_M, BLOCK_N (block dimensions). It uses these parameters to calculate offsets and store generated random numbers in tensor R. The second kernel, debug_fill_dropout_rng_tensor, acts as a wrapper to load seed and offset values from pointers and calls the first kernel.", - "description_2": "Use triton language to create two kernels: one initializes random numbers for dropout directly, and the other loads seed and offset from pointers to invoke the first kernel.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):\n ms = tl.arange(0, m)\n ns = tl.arange(0, n)\n return philox_offset + ms[:, None] * stride + ns[None, :]\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)\n return tl.rand(philox_seed, rng_offsets)\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)\n rng_keep = rng_output > dropout_p\n return rng_keep\n\n@triton.jit\ndef attn_fwd_inner(\n acc, l_i, m_i, q,\n K_block_ptr, V_block_ptr,\n start_m,\n seqlen_q,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n STAGE: tl.constexpr,\n offs_m: tl.constexpr,\n offs_n: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n if STAGE == 1:\n lo, hi = 0, min(seqlen_k, start_m * BLOCK_M)\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, min(seqlen_k, start_m * BLOCK_M + BLOCK_M)\n lo = tl.multiple_of(lo, BLOCK_M)\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo))\n else:\n lo, hi = 0, seqlen_k\n\n for start_n in range(lo, hi, BLOCK_N):\n if STAGE == 1 or STAGE == 3:\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n if pre_load_v:\n v = tl.load(V_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk, float(\"-inf\"))\n if BLOCK_M == 1:\n qk += tl.sum(tl.view(q, [BLOCK_DMODEL]) * tl.view(k, [BLOCK_DMODEL]))\n else:\n qk += tl.dot(q, k)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_m * BLOCK_M * seqlen_k + start_n\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k)\n if RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty))\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty))\n\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not pre_load_v:\n v = tl.load(V_block_ptr)\n\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n if BLOCK_M == 1:\n acc += tl.view(p.to(V_block_ptr.type.element_ty), [1]) * tl.view(v, [BLOCK_DMODEL])\n else:\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n\n@triton.jit\ndef attn_fwd(\n Q, K, V, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n seqlen_q,\n seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n STAGE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_h = tl.program_id(1)\n off_z = tl.program_id(2)\n num_h = tl.num_programs(1)\n num_z = tl.num_programs(2)\n q_offset = off_h * stride_qh + off_z * stride_qz\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_h * stride_kh + off_z * stride_kz\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_h * stride_vh + off_z * stride_vz\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n off_zh = off_z * num_h + off_h * 1\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base + off_zh * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_zh * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0)\n )\n else:\n encoded_softmax_block_ptr = 0\n if STAGE & 1:\n acc, l_i, m_i = attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr,\n start_m, seqlen_q, seqlen_k,\n dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N,\n 4 - STAGE, offs_m, offs_n,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX)\n if STAGE & 2:\n tl.debug_barrier()\n acc, l_i, m_i = attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr,\n start_m, seqlen_q, seqlen_k,\n dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N,\n 2, offs_m, offs_n,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n m_ptrs = M + off_zh * seqlen_q + offs_m\n tl.store(m_ptrs, m_i + tl.math.log2(l_i))\n o_offset = off_h * stride_oh + off_z * stride_oz\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n", - "description_1": "Use triton language to define several kernels and functions for fused attention mechanism including dropout offset and random number generation for dropout, attention forward inner function with support for dropout and softmax calculation, and attention forward function setting up pointers and managing stages with dropout option. The functions have various parameters for dimensions, strides, constants, and flags controlling behavior.", - "description_2": "Use triton language to implement fused attention with dropout and stage management using triton kernels and functions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom fwd_kernel_inner import attn_fwd_inner\n\n@triton.jit\ndef store_a(O_block_ptr, acc, q_padded):\n if not q_padded:\n tl.store(O_block_ptr, acc)\n else:\n tl.store(O_block_ptr, acc, boundary_check=(0,))\n\n@triton.jit\ndef store_b(O_block_ptr, acc, q_padded):\n if not q_padded:\n tl.store(O_block_ptr, acc, boundary_check=(1,))\n else:\n tl.store(O_block_ptr, acc, boundary_check=(1,0,))\n\n@triton.jit\ndef attn_fwd_common(\n Q_block_ptr,\n K_block_ptr,\n V_block_ptr,\n B_block_ptr,\n O_block_ptr,\n M_ptr_base,\n sm_scale,\n start_m,\n seqlen_q,\n seqlen_k,\n seqlen_k_faligned,\n q_padded,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n max_seqlen_k,\n encoded_softmax_block_ptr,\n CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n ):\n k_padded = seqlen_k != seqlen_k_faligned\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504089\n q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n if CAUSAL:\n seqlen_k_low = 0\n seqlen_k_high = min(seqlen_k_faligned, start_m * BLOCK_M)\n else:\n seqlen_k_low = 0\n seqlen_k_high = seqlen_k_faligned\n\n acc, l_i, m_i = attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr, B_block_ptr,\n start_m, seqlen_q, q_padded, seqlen_k_low, seqlen_k_high, False,\n dropout_p, max_seqlen_k, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N,\n False, offs_m, offs_n,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n MARGINAL_BLOCK=False,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n\n if CAUSAL or k_padded:\n seqlen_k_low = seqlen_k_high\n if CAUSAL:\n seqlen_k_high = min(seqlen_k, start_m * BLOCK_M + BLOCK_M)\n else:\n seqlen_k_high = seqlen_k\n\n tl.debug_barrier()\n acc, l_i, m_i = attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr, B_block_ptr,\n start_m, seqlen_q, q_padded, seqlen_k_low, seqlen_k_high, k_padded,\n dropout_p, max_seqlen_k, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N,\n CAUSAL, offs_m, offs_n,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n MARGINAL_BLOCK=True,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n\n m_ptrs = M_ptr_base + offs_m\n if q_padded:\n overflow_size = (start_m * BLOCK_M + BLOCK_M) - seqlen_q\n boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)\n m_ptrs_mask = boundary > tl.arange(0, BLOCK_M)\n tl.store(m_ptrs, m_i + tl.math.log2(l_i), mask=m_ptrs_mask)\n else:\n tl.store(m_ptrs, m_i + tl.math.log2(l_i))\n\n acc = acc.to(O_block_ptr.type.element_ty)\n tl.store(O_block_ptr, acc, boundary_check=(1,0,))\n", - "description_1": "Use triton language to implement two kernels: 'store_a' and 'store_b', each takes 3 arguments: O_block_ptr, acc, q_padded, where O_block_ptr is the output pointer, acc is the accumulated value to store, and q_padded is a boolean indicating whether q is padded. 'store_a' stores acc with optional boundary check, while 'store_b' stores acc with different boundary checks. A third kernel 'attn_fwd_common' is implemented to perform an attention forward pass. It accepts 26 arguments including pointers, scalars, and constants, operating on blocks of data with configurable dimensions and parameters for attention calculations.", - "description_2": "Use triton language to define kernels for storing values with boundary checks and to compute an attention forward operation on data blocks with specified dimensions and parameters.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom dropout import dropout_mask\nfrom masked_load_store import load_fn, mstore2d\n\n@triton.jit\ndef attn_fwd_inner(\n acc, l_i, m_i,\n q, k_ptrs, v_ptrs, bias_ptrs,\n stride_kn, stride_vk, stride_bn,\n seqlen_q, seqlen_k, head_dim,\n start_m, block_min, block_max,\n dropout_p, philox_seed, batch_philox_offset, max_seqlen_k,\n encoded_sm_base,\n offs_n_causal, masked_blocks, n_extra_tokens,\n alibi_slope,\n CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n PRE_LOAD_V: tl.constexpr,\n MASK_STEPS: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n):\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n for start_n in range(block_min, block_max, BLOCK_N):\n if MASK_STEPS:\n k_offs_n = start_n + tl.arange(0, BLOCK_N)\n else:\n k_offs_n = None\n k_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)\n k = load_fn(k_ptrs, k_offs_d, k_offs_n, head_dim, seqlen_k)\n if PRE_LOAD_V:\n v = load_fn(v_ptrs, k_offs_n, k_offs_d, seqlen_k, head_dim)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MASK_STEPS:\n if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32)\n size_n = start_n + offs_n[None, :]\n mask = size_n < boundary_m[:, None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if CAUSAL:\n causal_boundary = start_n + offs_n_causal\n causal_mask = offs_m[:, None] >= causal_boundary[None, :]\n qk = tl.where(causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n if bias_ptrs is not None:\n bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None\n bias = load_fn(bias_ptrs, offs_m, bias_offs_n, seqlen_q, seqlen_k)\n qk += (bias * 1.44269504089)\n\n if alibi_slope is not None:\n global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n global_n_positions = start_n + tl.arange(0, BLOCK_N)\n alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, global_m_positions,\n global_n_positions)\n qk += (alibi_block * 1.44269504089)\n\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_m * BLOCK_M * max_seqlen_k + start_n\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, max_seqlen_k)\n if RETURN_ENCODED_SOFTMAX:\n mstore2d(tl.where(keep, p, -p).to(q.type.element_ty),\n BLOCK_M,\n BLOCK_N,\n o_base=encoded_sm_base,\n o_start_row=start_m * BLOCK_M,\n o_start_col=start_n,\n o_rows=seqlen_q,\n o_cols=seqlen_k,\n stride_row=max_seqlen_k,\n stride_col=1)\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n mstore2d(p.to(q.type.element_ty),\n BLOCK_M,\n BLOCK_N,\n o_base=encoded_sm_base,\n o_start_row=start_m * BLOCK_M,\n o_start_col=start_n,\n o_rows=seqlen_q,\n o_cols=seqlen_k,\n stride_row=max_seqlen_k,\n stride_col=1)\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not PRE_LOAD_V:\n v = load_fn(v_ptrs, k_offs_n, k_offs_d, seqlen_k, head_dim)\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(v.type.element_ty), v)\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vk\n if bias_ptrs is not None:\n bias_ptrs += BLOCK_N * stride_bn\n return acc, l_i, m_i\n", - "description_1": "Use triton language to implement a forward attention kernel with parameters for accumulation, sequence lengths, head dimensions, dropout, and optional bias and alibi slope. The kernel processes blocks of data with configurable block sizes and supports dropout and causal masking.", - "description_2": "Use triton language to create a forward attention kernel that handles dropout, bias, and causal masking with configurable block sizes and sequence lengths.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Triton kernel for conditional loading with boundary checks.\n@triton.jit\ndef load_fn(ptrs, offset_first, offset_second, _in_boundary_first, _in_boundary_second):\n boundary_first = _in_boundary_first\n boundary_second = _in_boundary_second\n if offset_first is not None and offset_second is not None:\n mask = (offset_first[:, None] < boundary_first) & \\\n (offset_second[None, :] < boundary_second)\n tensor = tl.load(ptrs, mask=mask, other=0.0)\n elif offset_first is not None:\n mask = offset_first[:, None] < boundary_first\n tensor = tl.load(ptrs, mask=mask, other=0.0)\n elif offset_second is not None:\n mask = offset_second[None, :] < boundary_second\n tensor = tl.load(ptrs, mask=mask, other=0.0)\n else:\n tensor = tl.load(ptrs)\n return tensor\n\n# Triton kernel for 1D memory loading.\n@triton.jit\ndef mload1d(\n REGS: tl.constexpr, # Number of registers to load\n i_base, # Base pointer\n i_start, # Start index\n i_nums, # Number of elements\n):\n offs = tl.arange(0, REGS) + i_start\n i_ptrs = i_base + offs\n overflow = i_start + REGS - i_nums\n i_ptrs_mask = tl.full([REGS], 1, dtype=tl.int1)\n i_ptrs_mask = i_ptrs_mask & (offs < i_nums)\n return tl.load(i_ptrs, mask=i_ptrs_mask, other=0.0)\n\n# Triton kernel for 2D memory loading with boundary checks.\n@triton.jit\ndef mload2d(\n REG_ROWS: tl.constexpr, # Number of register rows to load\n REG_COLS: tl.constexpr, # Number of register cols to load\n i_base, # Base pointer\n i_start_row, # Start row index\n i_start_col, # Start col index\n i_rows, # Number of rows\n i_cols, # Number of cols\n stride_row, # Row stride\n stride_col, # Column stride\n):\n off_rows = tl.arange(0, REG_ROWS) + i_start_row\n off_cols = tl.arange(0, REG_COLS) + i_start_col\n i_ptrs = i_base + off_rows[:, None] * stride_row + off_cols[None, :] * stride_col\n row_overflow = i_start_row + REG_ROWS - i_rows\n col_overflow = i_start_col + REG_COLS - i_cols\n i_ptrs_mask = tl.full([REG_ROWS, REG_COLS], 1, dtype=tl.int1)\n if row_overflow > 0:\n i_ptrs_mask = i_ptrs_mask & (off_rows[:, None] < i_rows)\n if col_overflow > 0:\n i_ptrs_mask = i_ptrs_mask & (off_cols[None, :] < i_cols)\n return tl.load(i_ptrs, mask=i_ptrs_mask, other=0.0)\n\n# Triton kernel for 2D memory storing with boundary checks.\n@triton.jit\ndef mstore2d(\n registers, # Data to store\n REG_ROWS: tl.constexpr, # Number of register rows\n REG_COLS: tl.constexpr, # Number of register cols\n o_base, # Base pointer\n o_start_row, # Start row index\n o_start_col, # Start col index\n o_rows, # Number of rows\n o_cols, # Number of cols\n stride_row, # Row stride\n stride_col, # Column stride\n):\n off_rows = tl.arange(0, REG_ROWS) + o_start_row\n off_cols = tl.arange(0, REG_COLS) + o_start_col\n o_ptrs = o_base + off_rows[:, None] * stride_row + off_cols[None, :] * stride_col\n o_ptrs_mask = tl.full([REG_ROWS, REG_COLS], 1, dtype=tl.int1)\n row_overflow = o_start_row + REG_ROWS - o_rows\n if row_overflow > 0:\n o_ptrs_mask = o_ptrs_mask & (off_rows[:, None] < o_rows)\n col_overflow = o_start_col + REG_COLS - o_cols\n if col_overflow > 0:\n o_ptrs_mask = o_ptrs_mask & (off_cols[None, :] < o_cols)\n tl.store(o_ptrs, registers, mask=o_ptrs_mask)\n return o_ptrs, o_ptrs_mask\n", - "description_1": "Use triton language to implement kernels for: 1) Conditional memory loading with boundary checks (4 parameters): loads data from a pointer with conditions on first and second offsets compared against respective boundaries. 2) 1D memory loading (3 parameters): loads a 1D array from a base pointer starting at a given index up to the number of elements specified. 3) 2D memory loading with boundary checks (9 parameters): loads data from a base pointer with conditions on the starting row and column indices and their respective strides. 4) 2D memory storing with boundary checks (9 parameters): stores data to a memory location from a base pointer with conditions on the starting row and column indices and their respective strides.", - "description_2": "Use triton language to implement memory operations with boundary checks: 1) Load 1D and 2D tensors from memory with specified start indices and conditions. 2) Store 2D tensors to memory with specified start indices and conditions.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef attn_fwd(\n Q, K, V, B, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_on,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens, # set num_seqlens to zero to ignore cu_seqlens_q/k\n max_seqlen_q, # and use max_seqlen_q/k for all seqlen_q/k \n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n CAUSAL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_h = tl.program_id(1) # head index\n off_z = tl.program_id(2) # batch index\n num_h = tl.num_programs(1)\n num_z = tl.num_programs(2)\n off_zh = off_z * num_h + off_h * 1\n # FIXME: Better pattern for this branch? It's copied into three kernels\n if num_seqlens > 0:\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M >= seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n batch_index = 0\n elif num_seqlens == 0:\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n seqlen_q = max_seqlen_q\n seqlen_k = max_seqlen_k\n batch_index = off_z\n else: # < 0 for padded seqlen\n cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n if start_m * BLOCK_M >= seqlen_q:\n return\n cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n # Varlen, but padded to Rank 4 tensor\n cu_seqlens_q_start = 0\n cu_seqlens_k_start = 0\n batch_index = off_z\n\n if start_m * BLOCK_M + BLOCK_M > seqlen_q:\n q_padded = True\n else:\n q_padded = False\n if seqlen_k < BLOCK_N:\n seqlen_k_faligned = 0 # floor aligned\n elif seqlen_k % BLOCK_N:\n extra_tokens_n = seqlen_k % BLOCK_N\n seqlen_k_faligned = seqlen_k - extra_tokens_n\n else:\n seqlen_k_faligned = seqlen_k\n\n q_offset = off_h * stride_qh + batch_index * stride_qz + cu_seqlens_q_start * stride_qm\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, head_dim),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_h * stride_kh + batch_index * stride_kz + cu_seqlens_k_start * stride_kn\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(head_dim, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_h * stride_vh + batch_index * stride_vz + cu_seqlens_k_start * stride_vk\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, head_dim),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n if BIAS_TYPE == 0:\n B_block_ptr = 0\n elif BIAS_TYPE == 1:\n B_block_ptr = tl.make_block_ptr(\n base=B + off_h * stride_bh + batch_index * stride_bz,\n shape=(seqlen_q, seqlen_k),\n strides=(stride_bm, stride_bn),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0)\n )\n else:\n tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}')\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_zh * max_seqlen_q * max_seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(max_seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0)\n )\n else:\n encoded_softmax_block_ptr = 0\n # write back O\n o_offset = off_h * stride_oh + batch_index * stride_oz + cu_seqlens_q_start * stride_om\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, head_dim),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n\n M_ptr_base = M + off_zh * max_seqlen_q\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base + off_zh * max_seqlen_q * max_seqlen_k\n else:\n batch_philox_offset = 0\n\n attn_fwd_common(Q_block_ptr,\n K_block_ptr,\n V_block_ptr,\n B_block_ptr,\n O_block_ptr,\n M_ptr_base,\n sm_scale,\n start_m,\n seqlen_q,\n seqlen_k,\n seqlen_k_faligned,\n q_padded,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n max_seqlen_k,\n encoded_softmax_block_ptr,\n CAUSAL=CAUSAL,\n BLOCK_M=BLOCK_M,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK_N,\n pre_load_v=pre_load_v,\n ENABLE_DROPOUT=ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX=RETURN_ENCODED_SOFTMAX,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE)\n", - "description_1": "Use triton language to implement the attn_fwd kernel function for fused attention, handling multiple parameters including Q, K, V, B, scales, offsets, seqlens, and additional constants. The kernel computes attention outputs with potential dropout and encoded softmax support, and handles variable sequence lengths.", - "description_2": "Use triton language to implement a kernel function for attention computation with support for dropout, variable sequence lengths, and optional encoded softmax.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom dropout import dropout_mask\n\n@triton.jit\ndef attn_fwd_inner(\n acc, l_i, m_i, q,\n K_block_ptr, V_block_ptr, B_block_ptr,\n start_m,\n seqlen_q,\n q_padded,\n seqlen_k_low,\n seqlen_k_high,\n k_padded,\n dropout_p,\n dropout_seqlen_k,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n offs_m: tl.constexpr,\n offs_n: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n MARGINAL_BLOCK: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n lo, hi = seqlen_k_low, seqlen_k_high\n if MARGINAL_BLOCK:\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo))\n if BIAS_TYPE == 1:\n B_block_ptr = tl.advance(B_block_ptr, (0, lo))\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr, boundary_check=(1,0), padding_option=\"zero\")\n if pre_load_v:\n v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if MARGINAL_BLOCK:\n if CAUSAL:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk, float(\"-inf\"))\n if k_padded:\n boundary_m = tl.full([BLOCK_M], seqlen_k_high, dtype=tl.int32)\n size_n = start_n + offs_n[None,:]\n mask = size_n < boundary_m[:,None]\n qk = tl.where(mask, qk, float(\"-inf\"))\n if BIAS_TYPE == 1:\n bias = tl.load(B_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n qk += bias * 1.44269504089\n qk += tl.dot(q, k)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_m * BLOCK_M * dropout_seqlen_k + start_n\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, dropout_seqlen_k)\n if RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), boundary_check=(0,1))\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr,\n p.to(encoded_softmax_block_ptr.type.element_ty),\n boundary_check=(0,1))\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not pre_load_v:\n v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N))\n if BIAS_TYPE == 1:\n B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n", - "description_1": "Use triton language to implement a forward attention kernel with dropout and optional bias. The kernel takes 24 parameters: acc, l_i, m_i, q, K_block_ptr, V_block_ptr, B_block_ptr, start_m, seqlen_q, q_padded, seqlen_k_low, seqlen_k_high, k_padded, dropout_p, dropout_seqlen_k, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, and several compile-time constants. It computes the attention scores, applies dropout if enabled, and updates the accumulator, l_i, and m_i.", - "description_2": "Use triton language to create a kernel for forward attention computation with support for dropout and bias, processing blocks of queries, keys, and values.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom flash import bwd_kernel_dk_dv as bare_bwd_kernel_dk_dv\nfrom flash import bwd_kernel_dq as bare_bwd_kernel_dq\n\nTRITON_CONFIG_LIST_BWD_SIZED = [\n triton.Config({'waves_per_eu': 0}, num_stages=1, num_warps=4),\n triton.Config({'waves_per_eu': 1}, num_stages=1, num_warps=4),\n triton.Config({'waves_per_eu': 2}, num_stages=1, num_warps=4),\n triton.Config({'waves_per_eu': 3}, num_stages=1, num_warps=4),\n triton.Config({'waves_per_eu': 4}, num_stages=1, num_warps=4),\n]\n\n@triton.autotune(\n configs=TRITON_CONFIG_LIST_BWD_SIZED,\n key=['BLOCK_DMODEL', 'max_seqlen_q', 'max_seqlen_k'],\n)\n@triton.jit\ndef sized_tuned_bwd_kernel_dk_dv(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n bare_bwd_kernel_dk_dv(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n CAUSAL,\n ENABLE_DROPOUT,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n\n@triton.autotune(\n configs=TRITON_CONFIG_LIST_BWD_SIZED,\n key=['BLOCK_DMODEL', 'max_seqlen_q', 'max_seqlen_k'],\n)\n@triton.jit\ndef sized_tuned_bwd_kernel_dq(\n Q, K, V, B, sm_scale, Out, DO,\n DQ, DB,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n bare_bwd_kernel_dq(Q, K, V, B, sm_scale, Out, DO,\n DQ, DB,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M, BLOCK_DMODEL,\n BLOCK_N,\n CAUSAL,\n ENABLE_DROPOUT,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n", - "description_1": "Use triton language to define two backward kernels for a neural network. The first kernel, sized_tuned_bwd_kernel_dk_dv, computes gradients with respect to keys and values. It takes 54 parameters including input tensors Q, K, V, B, output tensors DK, DV, and various strides and constants. The second kernel, sized_tuned_bwd_kernel_dq, computes gradients with respect to queries. It also takes 54 parameters including input tensors Q, K, V, B, output tensors DQ, DB, and various strides and constants. Both kernels are optimized using triton's autotune feature with a list of configurations.", - "description_2": "Use triton language to define and autotune two backward kernels for computing gradients in a neural network, one for keys and values, and another for queries, each with 54 parameters.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom fwd_kernel import attn_fwd as bare_attn_fwd\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),\n triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4),\n ],\n key=['seqlen_q', 'seqlen_k', 'STAGE'],\n)\n@triton.jit\ndef tuned_attn_fwd(\n Q, K, V, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n seqlen_q,\n seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n STAGE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n bare_attn_fwd(\n Q, K, V, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n seqlen_q,\n seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n STAGE,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n )\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_encoded_softmax,\n autotune=False, return_autotune=False):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n seqlen_q = q.shape[2]\n seqlen_k = k.shape[2]\n o = torch.empty_like(q)\n if torch.version.hip is None:\n BLOCK_M = 128\n BLOCK_N = 64 if Lk <= 64 else 32\n num_stages = 4 if Lk <= 64 else 3\n num_warps = 4 if Lk <= 64 else 8\n\n stage = 3 if causal else 1\n grid = lambda META: (\n triton.cdiv(q.shape[2], META['BLOCK_M']),\n q.shape[1],\n q.shape[0],\n )\n M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n if return_encoded_softmax:\n encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, dtype=_attention.DEBUG_MASK_DTYPE)\n else:\n encoded_softmax = None\n\n philox_seed = 114514\n philox_offset = 1919810\n MAX_BLOCK_M = 128 if dropout_p == 0 else 64\n MAX_BLOCK_N = 32 if dropout_p == 0 else 32\n MAX_BLOCK_M = MAX_BLOCK_M if is_supported_by_tl_dot(seqlen_q) else 1\n MAX_BLOCK_N = MAX_BLOCK_N if is_supported_by_tl_dot(seqlen_k) else 1\n BLOCK_M=min(MAX_BLOCK_M, q.shape[2], k.shape[2])\n BLOCK_N=min(MAX_BLOCK_N, q.shape[2], k.shape[2])\n\n bare_attn_fwd[grid](\n q, k, v, sm_scale, M, o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n seqlen_q=q.shape[2],\n seqlen_k=k.shape[2],\n dropout_p=dropout_p,\n philox_seed=philox_seed,\n philox_offset_base=philox_offset,\n encoded_softmax=encoded_softmax,\n STAGE=stage,\n BLOCK_M=BLOCK_M,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK_N,\n pre_load_v=False,\n num_stages=1,\n num_warps=4,\n ENABLE_DROPOUT=dropout_p > 0.0,\n RETURN_ENCODED_SOFTMAX=encoded_softmax is not None,\n )\n\n tuning_result = None\n block_m = min(128, q.shape[2], k.shape[2])\n grid = (triton.cdiv(q.shape[2], block_m), q.shape[1], q.shape[0])\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n ctx.causal = causal\n ctx.dropout_p = dropout_p\n ctx.philox_seed = philox_seed\n ctx.philox_offset = philox_offset\n ctx.encoded_softmax = encoded_softmax\n return o, encoded_softmax, tuning_result\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a forward attention kernel 'tuned_attn_fwd' with 24 parameters including input tensors Q, K, V, scaling factor sm_scale, output tensor Out, and various strides and constants. The kernel is autotuned with different configurations. The '_attention' class is a PyTorch autograd function that uses this kernel in its forward method, taking 8 parameters including input tensors q, k, v, and additional parameters for scaling, dropout, and tuning.", - "description_2": "Use triton language to create an autotuned forward attention kernel and integrate it into a PyTorch autograd function for efficient computation of attention mechanisms.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom flash import bwd_kernel_dk_dv as bare_bwd_kernel_dk_dv, bwd_kernel_dq as bare_bwd_kernel_dq\n\nTRITON_CONFIG_LIST_BWD = []\nfor BLOCK_M, BLOCK_N in [(32, 64), (64, 16)]:\n dic = {}\n dic['BLOCK_M'] = BLOCK_M\n dic['BLOCK_N'] = BLOCK_N\n for waves_per_eu in [0, 3]:\n dic['waves_per_eu'] = waves_per_eu\n for num_stages in [1]:\n for num_warps in [1, 2]:\n cfg = triton.Config(dict(dic), num_stages=num_stages, num_warps=num_warps)\n TRITON_CONFIG_LIST_BWD.append(cfg)\n\n@triton.autotune(\n configs=TRITON_CONFIG_LIST_BWD,\n key=['BLOCK_DMODEL', 'max_seqlen_q', 'max_seqlen_k'],\n)\n@triton.jit\ndef tuned_bwd_kernel_dk_dv(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n bare_bwd_kernel_dk_dv(\n Q, K, V, B, sm_scale, Out, DO,\n DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dkz, stride_dkh, stride_dkn, stride_dkk,\n stride_dvz, stride_dvh, stride_dvk, stride_dvn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M,\n BLOCK_DMODEL,\n BLOCK_N,\n CAUSAL,\n ENABLE_DROPOUT,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n\n@triton.autotune(\n configs=TRITON_CONFIG_LIST_BWD,\n key=['BLOCK_DMODEL', 'max_seqlen_q', 'max_seqlen_k'],\n)\n@triton.jit\ndef tuned_bwd_kernel_dq(\n Q, K, V, B, sm_scale, Out, DO,\n DQ, DB,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n PADDED_HEAD: tl.constexpr,\n BIAS_TYPE: tl.constexpr,\n):\n bare_bwd_kernel_dq(Q, K, V, B, sm_scale, Out, DO,\n DQ, DB,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_bz, stride_bh, stride_bm, stride_bn,\n stride_oz, stride_oh, stride_om, stride_ok,\n stride_dqz, stride_dqh, stride_dqm, stride_dqk,\n stride_dbz, stride_dbh, stride_dbm, stride_dbn,\n cu_seqlens_q,\n cu_seqlens_k,\n num_seqlens,\n max_seqlen_q,\n max_seqlen_k,\n head_dim,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M, BLOCK_DMODEL,\n BLOCK_N,\n CAUSAL,\n ENABLE_DROPOUT,\n PADDED_HEAD=PADDED_HEAD,\n BIAS_TYPE=BIAS_TYPE,\n )\n", - "description_1": "Use triton language to define two backward kernels for computing gradients with respect to keys/values and queries in a transformer model. The kernels are optimized using autotuning with different configurations of block sizes, number of warps, and other parameters. The kernels take multiple tensor inputs and parameters related to sequence lengths, dropout, and other model-specific constants.", - "description_2": "Use triton language to define and autotune backward kernels for gradient computation in transformers.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Backward preprocessing kernel\n@triton.jit\ndef bwd_preprocess(\n Out, DO,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # Load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n # Compute\n delta = tl.sum(o * do, axis=1)\n # Write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n", - "description_1": "Use triton language to implement a backward preprocessing kernel for fused attention. The kernel, 'bwd_preprocess', takes 4 tensor arguments (Out, DO, NewDO, Delta) and 2 constant expression arguments (BLOCK_M, D_HEAD). It loads data from 'Out' and 'DO', computes a delta by summing the product of 'o' and 'do', and writes back to 'NewDO' and 'Delta'.", - "description_2": "Use triton language to implement a kernel that preprocesses data for backward pass in fused attention by loading input tensors, computing a product-sum, and writing results.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Dot product kernel function\n@triton.jit\ndef dot(BLOCK_M : tl.constexpr, QDIM : tl.constexpr, KDIM : tl.constexpr, q, k):\n if BLOCK_M == 1:\n return tl.sum(tl.view(q, [QDIM]) * tl.view(k, [KDIM]))\n else:\n return tl.dot(q, k)\n\n# Backward kernel function for computing dK and dV\n@triton.jit\ndef bwd_kernel_dk_dv(\n Q, K, V, sm_scale, Out, DO,\n DK, DV,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, seqlen_q, seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n):\n start_m = tl.program_id(0) * BLOCK_N\n off_hz = tl.program_id(1)\n # Q is consumed depending on block ID. Every block uses\n # previous block offset by BLOCK_M x D_HEAD.\n qvk_offset = off_hz * stride_qh\n # initialize offsets\n offs_m = start_m + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # Initialize pointers to Q, K, V\n q_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_hz * stride_kh\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, start_m),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_hz * stride_vh\n VT_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_vn, stride_vk),\n offsets=(0, start_m),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n do_offset = q_offset\n DO_block_ptr = tl.make_block_ptr(\n base=DO + do_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * seqlen_q\n l_ptrs = L + off_hz * seqlen_q\n qk_scale = sm_scale * 1.44269504\n # load k and v: they will stay in SRAM throughout\n k = tl.load(K_block_ptr)\n k = (k * qk_scale).to(K_block_ptr.type.element_ty)\n vt = tl.load(VT_block_ptr)\n dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)\n # This lower loop bound is because of the causal mask. We create a lower triangular\n # result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can\n # be ignored in the GEMM.\n lo = start_m if CAUSAL else 0\n hi = seqlen_q\n Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0))\n DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0))\n batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k\n\n for start_n in range(lo, hi, BLOCK_M):\n offs_m_curr = offs_n[:, None] + start_n\n # -- load q, do --\n q = tl.load(Q_block_ptr)\n do = tl.load(DO_block_ptr)\n # -- compute qk ----\n qk = dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, k) # BLOCK_M x BLOCK_N\n if CAUSAL:\n qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float(\"-inf\"))\n l_i = tl.load(l_ptrs + offs_m_curr)\n p = tl.math.exp2(qk - l_i)\n # -- compute dv ----\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_n * seqlen_k + start_m\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k)\n # CAVEAT: do NOT update p, ds needs the original p\n if BLOCK_M == 1:\n dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(Q.dtype.element_ty) * do\n else:\n dv += tl.dot(tl.where(tl.trans(keep), tl.trans(p) / (1 - dropout_p), 0.0).to(Q.dtype.element_ty), do)\n else:\n if BLOCK_M == 1:\n dv += p.to(Q.dtype.element_ty) * do\n else:\n dv += dot(BLOCK_M, BLOCK_M, BLOCK_DMODEL, tl.trans(p.to(Q.dtype.element_ty)), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_M], dtype=tl.float32)\n dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt)\n if ENABLE_DROPOUT:\n dp = tl.where(keep, dp / (1 - dropout_p), 0)\n # compute ds = p * (dp - delta[:, None])\n ds = p * (dp - Di)\n # compute dk\n if BLOCK_M == 1:\n dk += ds.to(Q.dtype.element_ty) * q\n else:\n dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)\n # update pointers\n Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))\n DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))\n # initialize pointers to output\n DK_block_ptr = tl.make_block_ptr(\n base=DK + k_offset,\n shape=(seqlen_k, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n DV_block_ptr = tl.make_block_ptr(\n base=DV + v_offset,\n shape=(seqlen_k, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(start_m, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(DK_block_ptr, (dk * sm_scale).to(DK.type.element_ty))\n tl.store(DV_block_ptr, dv.to(DK.type.element_ty))\n\n# Backward kernel function for computing dQ\n@triton.jit\ndef bwd_kernel_dq(\n Q, K, V, sm_scale, Out, DO,\n DQ,\n L,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, seqlen_q, seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n CAUSAL: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n):\n start_m = tl.program_id(0) * BLOCK_N\n off_hz = tl.program_id(1)\n qvk_offset = off_hz * stride_qh\n # initialize offsets\n offs_m = start_m + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # Initialize pointers to Q, K, V\n q_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_hz * stride_kh\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_hz * stride_vh\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_vn, stride_vk),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n DO_block_ptr = tl.make_block_ptr(\n base=DO + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * seqlen_q\n l_ptrs = L + off_hz * seqlen_q\n qk_scale = sm_scale * 1.44269504\n # load q and do: they will stay in SRAM throughout\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n do = tl.load(DO_block_ptr)\n Di = tl.load(D_ptrs + offs_m)\n l_i = tl.load(l_ptrs + offs_m)\n dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # loop over k, v\n lo = 0\n hi = start_m + BLOCK_M if CAUSAL else seqlen_k\n batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k\n\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n kt = tl.load(K_block_ptr)\n vt = tl.load(V_block_ptr)\n # -- compute qk ----\n qk = dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, kt)\n if CAUSAL:\n qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float(\"-inf\"))\n p = tl.math.exp2(qk - l_i[:, None])\n # compute dp = dot(v, do)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt)\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_m * seqlen_k + start_n\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k)\n dp = tl.where(keep, dp / (1 - dropout_p), 0)\n # compute ds = p * (dp - delta[:, None])\n ds = p * (dp - Di[:, None])\n # compute dq. Unfortunately we cannot avoid transpose here as this loop\n # uses k both normal and transpose.\n if BLOCK_M == 1:\n dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(Q.type.element_ty)\n else:\n dq += tl.dot(ds.to(Q.type.element_ty), tl.trans(kt))\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))\n # initialize pointers to output\n DQ_block_ptr = tl.make_block_ptr(\n base=DQ + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty))\n", - "description_1": "Use triton language to define three kernels: 'dot', 'bwd_kernel_dk_dv', and 'bwd_kernel_dq'. The 'dot' kernel takes four parameters including BLOCK_M, QDIM, KDIM, and q, k, performing a dot operation and returning the result based on BLOCK_M condition. The 'bwd_kernel_dk_dv' kernel has 27 parameters including multiple tensors (Q, K, V, etc.), strides, constants, dropout settings, and causal flags, designed to compute backward pass for dK and dV. It involves setting up pointers, computing products and performing dropout operations if enabled. The 'bwd_kernel_dq' kernel, similarly, has 26 parameters. It computes the backward pass for dQ by loading tensors, performing dot products and applying dropout settings if needed.", - "description_2": "Use triton language to define the 'dot' kernel, performing conditional vector dot products based on block size, and two backward kernels, 'bwd_kernel_dk_dv' and 'bwd_kernel_dq', to calculate gradient updates for dK, dV, and dQ with optional dropout.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef max_fn(x, y):\n return tl.math.max(x, y)\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):\n ms = tl.arange(0, m)\n ns = tl.arange(0, n)\n return philox_offset + ms[:, None] * stride + ns[None, :]\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)\n # TODO: use tl.randint for better performance\n return tl.rand(philox_seed, rng_offsets)\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)\n rng_keep = rng_output > dropout_p\n return rng_keep\n\n@triton.jit\ndef attn_fwd_inner(\n acc, l_i, m_i, q,\n K_block_ptr, V_block_ptr,\n start_m,\n seqlen_q,\n seqlen_k,\n dropout_p,\n philox_seed,\n batch_philox_offset,\n encoded_softmax_block_ptr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n STAGE: tl.constexpr,\n offs_m: tl.constexpr,\n offs_n: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n if STAGE == 1:\n lo, hi = 0, min(seqlen_k, start_m * BLOCK_M)\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, min(seqlen_k, start_m * BLOCK_M + BLOCK_M)\n lo = tl.multiple_of(lo, BLOCK_M)\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo))\n else:\n lo, hi = 0, seqlen_k\n for start_n in range(lo, hi, BLOCK_N):\n if STAGE == 1 or STAGE == 3:\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n if pre_load_v:\n v = tl.load(V_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk, float(\"-inf\"))\n if BLOCK_M == 1:\n qk += tl.sum(tl.view(q, [BLOCK_DMODEL]) * tl.view(k, [BLOCK_DMODEL]))\n else:\n qk += tl.dot(q, k)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n if ENABLE_DROPOUT:\n philox_offset = batch_philox_offset + start_m * BLOCK_M * seqlen_k + start_n\n keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, seqlen_k)\n if RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty))\n p = tl.where(keep, p, 0.0)\n elif RETURN_ENCODED_SOFTMAX:\n tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty))\n alpha = tl.math.exp2(m_i - m_ij)\n acc = acc * alpha[:, None]\n if not pre_load_v:\n v = tl.load(V_block_ptr)\n l_i = l_i * alpha + l_ij\n m_i = m_ij\n if BLOCK_M == 1:\n acc += tl.view(p.to(V_block_ptr.type.element_ty), [1]) * tl.view(v, [BLOCK_DMODEL])\n else:\n acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.jit\ndef attn_fwd(\n Q, K, V, sm_scale, M, Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H,\n seqlen_q,\n seqlen_k,\n dropout_p,\n philox_seed,\n philox_offset_base,\n encoded_softmax,\n STAGE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n pre_load_v: tl.constexpr,\n ENABLE_DROPOUT: tl.constexpr,\n RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n q_offset = off_hz * stride_qh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n k_offset = off_hz * stride_kh\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, seqlen_k),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n v_offset = off_hz * stride_vh\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(seqlen_k, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n if ENABLE_DROPOUT:\n batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k\n else:\n batch_philox_offset = 0\n if RETURN_ENCODED_SOFTMAX:\n encoded_softmax_block_ptr = tl.make_block_ptr(\n base=encoded_softmax + off_hz * seqlen_q * seqlen_k,\n shape=(seqlen_q, seqlen_k),\n strides=(seqlen_k, 1),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0)\n )\n else:\n encoded_softmax_block_ptr = 0\n if STAGE & 1:\n acc, l_i, m_i = attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr,\n start_m, seqlen_q, seqlen_k,\n dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N,\n 4 - STAGE, offs_m, offs_n,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX)\n if STAGE & 2:\n tl.debug_barrier()\n acc, l_i, m_i = attn_fwd_inner(\n acc, l_i, m_i, q, K_block_ptr, V_block_ptr,\n start_m, seqlen_q, seqlen_k,\n dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N,\n 2, offs_m, offs_n,\n pre_load_v,\n ENABLE_DROPOUT,\n RETURN_ENCODED_SOFTMAX,\n )\n acc = acc / l_i[:, None]\n if ENABLE_DROPOUT:\n acc = acc / (1 - dropout_p)\n m_ptrs = M + off_hz * seqlen_q + offs_m\n tl.store(m_ptrs, m_i + tl.math.log2(l_i))\n o_offset = off_hz * stride_oh\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(seqlen_q, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n", - "description_1": "Use triton language to implement a fused attention mechanism with optional dropout and encoded softmax. The kernel includes five main functions: 'max_fn' computes the element-wise maximum of two inputs; 'dropout_offsets' calculates offsets for dropout random number generation; 'dropout_rng' generates random numbers for dropout; 'dropout_mask' creates a mask for applying dropout; and 'attn_fwd_inner' along with 'attn_fwd' perform the forward pass of attention computation using these utilities. Key features include handling causal masking and optional dropout with encoded softmax, using various constants and strides for input manipulation.", - "description_2": "Use triton language to implement fused attention with encoded softmax and dropout options. Key components include random number generation for dropout and a forward attention pass with optional causal masking.", - "difficulty": 4 - }, - { - "code": "import jax\nimport jax.numpy as jnp\nimport triton\nimport triton.language as tl\nimport jax_triton as jt\n\n@triton.jit\ndef add_kernel(\n x_ptr,\n y_ptr,\n output_ptr,\n block_size: tl.constexpr,\n):\n \"\"\"Adds two vectors.\"\"\"\n pid = tl.program_id(axis=0)\n block_start = pid * block_size\n offsets = block_start + tl.arange(0, block_size)\n mask = offsets < 8\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)\n block_size = 8\n grid = (triton.cdiv(x.size, block_size),)\n return jt.triton_call(\n x,\n y,\n kernel=add_kernel,\n out_shape=out_shape,\n grid=grid,\n block_size=block_size)\n", - "description_1": "Use triton language to define a kernel function 'add_kernel' that adds two vectors. The kernel takes four parameters: x_ptr, y_ptr, output_ptr, and block_size. It calculates the program ID, determines the block start, and computes offsets. It uses these offsets to load elements from x_ptr and y_ptr, adds them, and stores the result in output_ptr. The 'add' function wraps this kernel call, setting up the output shape, block size, and grid configuration, and then calls the kernel using jt.triton_call.", - "description_2": "Use triton language to create a kernel for vector addition and a wrapper function to execute it with specified grid and block size.", - "difficulty": 1 - }, - { - "code": "import jax.numpy as jnp\nimport jax_triton as jt\nimport triton\nimport triton.language as tl\n\ndef _dummy_fn(x):\n assert x.size % 4 == 0\n\n @triton.jit\n def dummy_kernel(x_ptr, o_ptr):\n offs = tl.program_id(axis=0) * 4 + tl.arange(0, 4)\n tl.store(o_ptr + offs, tl.load(x_ptr + offs))\n\n return jt.triton_call(x, kernel=dummy_kernel, out_shape=x, grid=(x.size // 4))\n", - "description_1": "Use triton language to define a kernel `dummy_kernel` with two parameters: x_ptr (input pointer) and o_ptr (output pointer). This kernel uses the program ID and range to calculate offsets for loading from x_ptr and storing to o_ptr. The kernel is invoked by `_dummy_fn`, which accepts a tensor x, ensures its size is a multiple of 4, and calls `jt.triton_call` with the kernel, output shape, and grid size.", - "description_2": "Use triton language to define and invoke a kernel that transfers blocks of data from an input pointer to an output pointer with specified offsets.", - "difficulty": 1 - }, - { - "code": "import functools\nimport jax\nimport jax.numpy as jnp\nimport jax_triton as jt\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef fused_attention_kernel(\n Q, K, V,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n L, M,\n Out,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_ptrs)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(tl.float16)\n v = tl.load(v_ptrs)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vk\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n@functools.partial(jax.jit, static_argnames=[\"sm_scale\"])\ndef fused_attention(q: jnp.ndarray, k: jnp.ndarray,\n v: jnp.ndarray) -> jnp.ndarray:\n \"\"\"Flash attention.\"\"\"\n block_size = 128\n grid = (jt.cdiv(q.shape[2], block_size), q.shape[0] * q.shape[1])\n out_shape = [\n jax.ShapeDtypeStruct(\n shape=(q.shape[0] * q.shape[1], q.shape[2]), dtype=jnp.float32),\n jax.ShapeDtypeStruct(\n shape=(q.shape[0] * q.shape[1], q.shape[2]), dtype=jnp.float32),\n jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)\n ]\n\n metaparams = dict(\n BLOCK_M=block_size,\n BLOCK_N=block_size,\n BLOCK_DMODEL=q.shape[-1],\n num_warps=4,\n num_stages=2)\n _, _, output = jt.triton_call(\n q, k, v,\n *jt.strides_from_shape(q.shape),\n *jt.strides_from_shape(k.shape),\n *jt.strides_from_shape(v.shape),\n *jt.strides_from_shape(q.shape),\n q.shape[0], q.shape[1], q.shape[2],\n kernel=fused_attention_kernel,\n out_shape=out_shape,\n grid=grid,\n **metaparams)\n return output\n", - "description_1": "Use triton language to implement a fused attention kernel that computes attention scores and outputs weighted values. The kernel accepts tensors Q, K, V, and their strides, along with output tensor L, M, and Out, and uses BLOCK_M, BLOCK_DMODEL, and BLOCK_N as tile sizes. It calculates dot products in a loop over block tiles, applies softmax for attention weights, and accumulates results into the output tensor.", - "description_2": "Use triton language to perform matrix multiplications and apply the softmax operation in a block-wise manner for attention score calculation in neural networks.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport jax_triton as jt\n\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n m: tl.constexpr,\n n: tl.constexpr,\n k: tl.constexpr,\n stride_am: tl.constexpr,\n stride_ak: tl.constexpr,\n stride_bk: tl.constexpr,\n stride_bn: tl.constexpr,\n stride_cm: tl.constexpr,\n stride_cn: tl.constexpr,\n block_size_m: tl.constexpr,\n block_size_n: tl.constexpr,\n block_size_k: tl.constexpr,\n group_size_m: tl.constexpr,\n activation: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(m, block_size_m)\n num_pid_n = tl.cdiv(n, block_size_n)\n num_pid_in_group = group_size_m * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * group_size_m\n group_size_m = min(num_pid_m - first_pid_m, group_size_m)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * block_size_m + tl.arange(0, block_size_m)\n offs_bn = pid_n * block_size_n + tl.arange(0, block_size_n)\n offs_k = tl.arange(0, block_size_k)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((block_size_m, block_size_n), dtype=tl.float32)\n for k in range(0, k, block_size_k):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += block_size_k * stride_ak\n b_ptrs += block_size_k * stride_bk\n if activation:\n accumulator = activation(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m)\n offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < m) & (offs_cn[None, :] < n)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef relu(x):\n return tl.where(x >= 0, x, 0)\n\ndef matmul(a, b, activation=None):\n block_size_m = 128\n block_size_n = 256\n block_size_k = 32\n group_size_m = 8\n m, k = a.shape\n n, _ = b.shape\n out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=a.dtype)\n grid = (m // block_size_m * n // block_size_n,)\n return jt.triton_call(\n a,\n b,\n kernel=matmul_kernel,\n out_shape=out_shape,\n grid=grid,\n num_warps=8,\n num_stages=3,\n m=m,\n n=n,\n k=k,\n stride_am=k,\n stride_ak=1,\n stride_bk=n,\n stride_bn=1,\n stride_cm=n,\n stride_cn=1,\n block_size_m=block_size_m,\n block_size_n=block_size_n,\n block_size_k=block_size_k,\n group_size_m=group_size_m,\n activation=activation)\n", - "description_1": "Use triton language to implement a matrix multiplication (matmul) kernel that computes C = A x B, where A is a matrix of shape (M, K), B is a matrix of shape (K, N), and C is the resulting matrix of shape (M, N). The kernel is parameterized by various block sizes and strides for efficient computation. Additionally, an optional activation function can be applied to the results. A separate ReLU function is implemented and can be passed as an activation to the matmul function.", - "description_2": "Use triton language to define a matrix multiplication kernel for multiplying matrices A and B, and an optional activation function such as ReLU on the output matrix.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport jax\nimport jax.numpy as jnp\nimport jax_triton as jt\nimport math\n\nnext_pow2 = lambda x: int(math.pow(2, math.ceil(math.log(x, 2))))\n\n@triton.jit\ndef softmax_kernel(\n input_ptr, output_ptr,\n input_row_stride: tl.constexpr, output_row_stride: tl.constexpr, n_cols:\n tl.constexpr, block_size: tl.constexpr\n):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, block_size)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Substract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(x: jnp.ndarray) -> jnp.ndarray:\n out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)\n block_size = next_pow2(x.shape[1])\n strides = jt.strides_from_shape(x.shape)\n return jt.triton_call(\n x,\n kernel=softmax_kernel,\n out_shape=out_shape,\n input_row_stride=strides[0],\n output_row_stride=strides[0],\n n_cols=x.shape[1],\n grid=x.shape[0],\n block_size=block_size)\n", - "description_1": "Use triton language to implement a softmax operation on a 2D input tensor. The kernel function 'softmax_kernel' takes 6 parameters: input_ptr (pointer to input data), output_ptr (pointer to output data), input_row_stride (stride for input rows), output_row_stride (stride for output rows), n_cols (number of columns in the input), and block_size (size of the block for parallel processing). The function computes the softmax for each row independently by loading the row, subtracting the maximum for numerical stability, computing the exponentials, and normalizing by the sum of exponentials. The result is stored back to the output pointer. The 'softmax' function is a wrapper that prepares the input parameters and calls the Triton kernel.", - "description_2": "Use triton language to implement a softmax operation on a 2D tensor with independent row processing and numerical stability.", - "difficulty": 2 - }, - { - "code": "import triton\nimport jax_triton as jt\nimport triton.language as tl\nimport jax\nimport jax.numpy as jnp\n\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, n_elements, output_ptr, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef add(x, y, *, kernel=add_kernel, **kwargs):\n if kernel is add_kernel:\n kwargs.setdefault(\"BLOCK_SIZE\", 8)\n\n default_grid = lambda meta: triton.cdiv(x.size, meta[\"BLOCK_SIZE\"])\n return jt.triton_call(\n x,\n y,\n x.size,\n kernel=kernel,\n out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),\n grid=kwargs.pop(\"grid\", default_grid),\n **kwargs,\n )\n\n\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n M,\n N,\n K,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n c_ptr,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n K_EXACTLY_DIVISIBLE_BY_BLOCK: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k_remaining in range(K, 0, -BLOCK_SIZE_K):\n if K_EXACTLY_DIVISIBLE_BY_BLOCK:\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n else:\n mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining\n a = tl.load(a_ptrs, mask=mask[None, :], other=0.0)\n b = tl.load(b_ptrs, mask=mask[:, None], other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n c = accumulator.to(tl.float16)\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef matmul(x, y, *, kernel=matmul_kernel, **kwargs):\n m, k = x.shape\n _, n = y.shape\n\n def grid(meta):\n cdiv = triton.cdiv\n return cdiv(m, meta[\"BLOCK_SIZE_M\"]) * cdiv(n, meta[\"BLOCK_SIZE_N\"])\n\n return jt.triton_call(\n x,\n y,\n m,\n n,\n k,\n k, # stride_am\n 1, # stride_ak\n n, # stride_bk\n 1, # stride_bn\n n, # stride_cm\n 1, # stride_cn\n kernel=kernel,\n out_shape=jax.ShapeDtypeStruct((m, n), dtype=x.dtype),\n grid=grid,\n GROUP_SIZE_M=8,\n **kwargs,\n )\n", - "description_1": "Use triton language to implement two kernels. The first kernel, 'add_kernel', adds two arrays element-wise. It requires pointers to input arrays, a pointer for output, number of elements, and a block size constant. The calling function 'add' calculates a default grid size and invokes the kernel. The second kernel, 'matmul_kernel', performs matrix multiplication. It takes pointers to matrices, dimensions, strides, a pointer for the result matrix, block size constants, group size, and a flag indicating exact block divisibility. The function 'matmul' prepares grid dimensions and calls the kernel.", - "description_2": "Use triton language to create an element-wise addition kernel and a matrix multiplication kernel, with respective calling functions handling grid computations and kernel invocations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nfrom triton.language.extra.cuda import libdevice\nimport jax\nimport jax.numpy as jnp\nimport jax_triton as jt\nimport numpy as np\n\n@triton.jit\ndef add_kernel(\n x_ptr, # *Pointer* to first input vector\n y_ptr, # *Pointer* to second input vector\n length, # Length of input and output vectors.\n output_ptr, # *Pointer* to output vector\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < length\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef tanh_kernel(\n x_ptr, # *Pointer* to first input vector\n length, # Length of input and output vectors.\n output_ptr, # *Pointer* to output vector\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < length\n x = tl.load(x_ptr + offsets, mask=mask)\n output = libdevice.tanh(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)\n grid = lambda meta: (triton.cdiv(x.size, meta['BLOCK_SIZE']),)\n return jt.triton_call(\n x,\n y,\n x.size,\n kernel=add_kernel,\n out_shape=out_shape,\n grid=grid,\n BLOCK_SIZE=8,\n )\n\ndef tanh(x: jnp.ndarray) -> jnp.ndarray:\n out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)\n grid = lambda meta: (triton.cdiv(x.size, meta['BLOCK_SIZE']),)\n return jt.triton_call(\n x,\n x.size,\n kernel=tanh_kernel,\n out_shape=out_shape,\n grid=grid,\n BLOCK_SIZE=8,\n )\n", - "description_1": "Use triton language to implement two kernels: 'add_kernel' and 'tanh_kernel'. 'add_kernel' takes four arguments: two pointers to input vectors, the length of the vectors, and a pointer to the output vector. It adds the input vectors element-wise and stores the result in the output vector. 'tanh_kernel' takes three arguments: a pointer to the input vector, the length of the vector, and a pointer to the output vector. It applies the hyperbolic tangent function to each element of the input vector and stores the result in the output vector. Both kernels use a BLOCK_SIZE parameter to determine the size of data each program processes, and they use masks to handle out-of-bounds memory accesses. The kernels are called using the 'jt.triton_call' function, which requires specifying the output shape and grid configuration.", - "description_2": "Use triton language to create an 'add_kernel' that performs element-wise addition of two input vectors and a 'tanh_kernel' that applies the tanh function to an input vector. Both kernels should handle out-of-bounds accesses using masks and be called with 'jt.triton_call' specifying output shape and grid.", - "difficulty": 2 - }, - { - "code": "import triton\nimport math\n\n@triton.jit\ndef kernel(x_ptr, x_size, **META):\n BLOCK_SIZE = META['BLOCK_SIZE']\n # Kernel implementation here\n\ndef matmul248_kernel_config_pruner(configs, nargs):\n \"\"\"\n The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.\n \"\"\"\n m = max(2 ** int(math.ceil(math.log2(nargs[\"M\"]))), 16)\n n = max(2 ** int(math.ceil(math.log2(nargs[\"N\"]))), 16)\n k = max(2 ** int(math.ceil(math.log2(nargs[\"K\"]))), 16)\n\n used = set()\n for config in configs:\n block_size_m = min(m, config.kwargs[\"BLOCK_SIZE_M\"])\n block_size_n = min(n, config.kwargs[\"BLOCK_SIZE_N\"])\n block_size_k = min(k, config.kwargs[\"BLOCK_SIZE_K\"])\n group_size_m = config.kwargs[\"GROUP_SIZE_M\"]\n\n if (\n block_size_m,\n block_size_n,\n block_size_k,\n group_size_m,\n config.num_stages,\n config.num_warps,\n ) in used:\n continue\n\n used.add(\n (\n block_size_m,\n block_size_n,\n block_size_k,\n group_size_m,\n config.num_stages,\n config.num_warps,\n )\n )\n yield triton.Config(\n {\n \"BLOCK_SIZE_M\": block_size_m,\n \"BLOCK_SIZE_N\": block_size_n,\n \"BLOCK_SIZE_K\": block_size_k,\n \"GROUP_SIZE_M\": group_size_m,\n },\n num_stages=config.num_stages,\n num_warps=config.num_warps,\n )\n", - "description_1": "Use triton language to define a kernel function 'kernel' with two parameters: x_ptr (pointer to data) and x_size (size of data). The kernel uses a meta-parameter 'BLOCK_SIZE' for its operation. Additionally, implement a function 'matmul248_kernel_config_pruner' that prunes kernel configurations based on the dimensions M, N, and K, ensuring BLOCK_SIZE_* values are appropriate for the given dimensions.", - "description_2": "Use triton language to create a kernel with parameters for data pointer and size, utilizing a BLOCK_SIZE meta-parameter. Implement a configuration pruner function to adjust kernel configurations based on input dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport math\n\n\n@triton.jit\ndef rotate_half_kernel(\n qk_seq_ptr,\n position_ids_ptr,\n qk_seq_stride,\n position_ids_batch_stride,\n seq_len,\n HEAD_DIM: tl.constexpr,\n BLOCK_HEIGHT: tl.constexpr,\n BLOCK_WIDTH: tl.constexpr,\n INV_BASE: tl.constexpr,\n):\n # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.\n # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.\n\n HALF_HEAD: tl.constexpr = HEAD_DIM // 2\n STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH\n\n batch_seq = tl.program_id(axis=0)\n row_blk_x_col_blk = tl.program_id(axis=1)\n\n row_blk = row_blk_x_col_blk // STEPS_PER_ROW\n row = row_blk * BLOCK_HEIGHT\n if BLOCK_WIDTH < HALF_HEAD:\n col_blk = row_blk_x_col_blk % STEPS_PER_ROW\n col = col_blk * BLOCK_WIDTH\n else:\n col: tl.constexpr = 0\n\n # A block will never cross a sequence boundary, which simplifies things a lot.\n batch = batch_seq // seq_len\n seq = batch_seq % seq_len\n position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)\n # As sometimes happens, just calculating this on the fly is faster than loading it from memory.\n # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.\n freq = (\n tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE)\n * position_id\n )\n cos = tl.cos(freq).to(tl.float32)\n sin = tl.sin(freq).to(tl.float32)\n\n col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)\n embed_offsets = (row * HEAD_DIM + col) + col_offsets\n x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets\n\n for k in range(0, BLOCK_HEIGHT):\n x = tl.load(x_ptrs).to(tl.float32)\n y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)\n out_x = x * cos - y * sin\n tl.store(x_ptrs, out_x)\n out_y = x * sin + y * cos\n tl.store(x_ptrs + HALF_HEAD, out_y)\n x_ptrs += HEAD_DIM\n\n\ndef triton_rotate_half_(qk, position_ids, config=None):\n batch_size, seq_len, qandk, num_heads, head_dim = qk.shape\n\n config = config or {\n \"BLOCK_HEIGHT\": 1,\n \"BLOCK_WIDTH\": min(128, head_dim // 2),\n \"num_warps\": 1,\n }\n config[\"BLOCK_HEIGHT\"] = min(config[\"BLOCK_HEIGHT\"], 2 * num_heads)\n\n assert qk.stride(3) == head_dim\n assert qk.stride(4) == 1\n assert position_ids.shape == (batch_size, seq_len)\n assert (\n position_ids.stride(1) == 1\n ), \"position_ids must be contiguous in the last dimension\"\n assert (2 * num_heads) % config[\n \"BLOCK_HEIGHT\"\n ] == 0, f'number of rows not evenly divisible by {config[\"BLOCK_HEIGHT\"]}'\n assert (head_dim // 2) % config[\n \"BLOCK_WIDTH\"\n ] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config[\"BLOCK_WIDTH\"]}'\n\n qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)\n grid = (\n qk_by_seq.shape[0],\n (2 * num_heads // config[\"BLOCK_HEIGHT\"])\n * (head_dim // 2 // config[\"BLOCK_WIDTH\"]),\n )\n\n BASE = 10000.0\n\n rotate_half_kernel[grid](\n qk_by_seq,\n position_ids,\n qk_by_seq.stride(0),\n position_ids.stride(0),\n seq_len,\n HEAD_DIM=head_dim,\n BLOCK_HEIGHT=config[\"BLOCK_HEIGHT\"],\n BLOCK_WIDTH=config[\"BLOCK_WIDTH\"],\n INV_BASE=-2.0 * math.log(BASE) / head_dim,\n num_warps=config[\"num_warps\"],\n )\n", - "description_1": "Use triton language to implement a kernel function rotate_half_kernel with 10 parameters to perform a rotation operation on input sequences based on positional IDs. The function is invoked by triton_rotate_half_ which takes 3 parameters: qk (5D tensor), position_ids (2D tensor), and an optional config dictionary. The function reshapes and prepares the input data, sets up a computation grid, and calls the Triton kernel to apply the rotation in-place.", - "description_2": "Use triton language to create a kernel that rotates input sequences with given position IDs and invoke it using a helper function with tensor inputs and configuration.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef fusedmatmul_248_kernel(\n a_ptr,\n c_ptr,\n b1_ptr,\n scales1_ptr,\n zeros1_ptr,\n g1_ptr,\n b2_ptr,\n scales2_ptr,\n zeros2_ptr,\n g2_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Computes: C = silu(A * B1) * (A * B2)\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (1, N) float16\n zeros is of shape (1, N//8) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n )\n a_mask = offs_am[:, None] < M\n b1_ptrs = b1_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n b2_ptrs = b2_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n )\n g1_ptrs = g1_ptr + offs_k\n g2_ptrs = g2_ptr + offs_k\n scales1_ptrs = scales1_ptr + offs_bn[None, :]\n scales2_ptrs = scales2_ptr + offs_bn[None, :]\n zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)\n zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n g1_idx = tl.load(g1_ptrs)\n g2_idx = tl.load(g2_ptrs)\n\n scales1 = tl.load(\n scales1_ptrs + g1_idx[:, None] * stride_scales\n )\n scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)\n\n zeros1 = tl.load(\n zeros1_ptrs + g1_idx[:, None] * stride_zeros\n )\n zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq\n zeros1 = zeros1 + 1\n\n zeros2 = tl.load(\n zeros2_ptrs + g2_idx[:, None] * stride_zeros\n )\n zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq\n zeros2 = zeros2 + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0)\n b1 = tl.load(b1_ptrs)\n b2 = tl.load(b2_ptrs)\n\n b1 = (b1 >> shifter[:, None]) & maxq\n b1 = (b1 - zeros1) * scales1\n accumulator1 += tl.dot(a, b1)\n\n b2 = (b2 >> shifter[:, None]) & maxq\n b2 = (b2 - zeros2) * scales2\n accumulator2 += tl.dot(a, b2)\n\n a_ptrs += BLOCK_SIZE_K\n b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g1_ptrs += BLOCK_SIZE_K\n g2_ptrs += BLOCK_SIZE_K\n\n accumulator1 = silu(accumulator1)\n c = accumulator1 * accumulator2\n c = c.to(tl.float16)\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\n\nclass QuantLlamaMLP(nn.Module):\n def triton_llama_mlp(self, x):\n with torch.cuda.device(x.device):\n out_shape = x.shape[:-1] + (self.intermediate_size,)\n x = x.reshape(-1, x.shape[-1])\n M, K = x.shape\n N = self.intermediate_size\n c = torch.empty((M, N), device=x.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n fusedmatmul_248_kernel[grid](\n x,\n c,\n self.gate_proj_qweight,\n self.gate_proj_scales,\n self.gate_proj_qzeros,\n self.gate_proj_g_idx,\n self.up_proj_qweight,\n self.up_proj_scales,\n self.up_proj_qzeros,\n self.up_proj_g_idx,\n M,\n N,\n K,\n self.bits,\n self.maxq,\n x.stride(0),\n x.stride(1),\n self.gate_proj_qweight.stride(0),\n self.gate_proj_qweight.stride(1),\n c.stride(0),\n c.stride(1),\n self.gate_proj_scales.stride(0),\n self.gate_proj_qzeros.stride(0),\n )\n c = c.reshape(out_shape)\n return c\n", - "description_1": "Use triton language to implement a fused matrix multiplication kernel called 'fusedmatmul_248_kernel', which computes the product of an input matrix A with two other matrices B1 and B2, applies the SiLU activation function, and returns the result. The kernel is optimized using various block sizes and operates on input matrices of specific data types and dimensions. There is also a helper function 'silu' to apply the SiLU activation function on a tensor.", - "description_2": "Use triton language to create a fused matrix multiplication operation with SiLU activation for quantized weights, optimizing kernel execution with block size configurations.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk\n + offs_bn[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(\n scales_ptrs + g_idx[:, None] * stride_scales\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs + g_idx[:, None] * stride_zeros\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n scales_ptr,\n zeros_ptr,\n g_ptr,\n M,\n N,\n K,\n bits,\n maxq,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_scales,\n stride_zeros,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = offs_am[:, None] < M\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk\n + offs_n[None, :] * stride_bn\n ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit word from B\n scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = (\n zeros_ptr\n + (offs_n[None, :] // infearure_per_bits)\n + g_idx[:, None] * stride_zeros\n )\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = zeros + 1\n\n a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output = torch.empty(\n (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\n )\n matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n input.shape[1],\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty(\n (input.shape[0], output_dim), device=input.device, dtype=torch.float16\n )\n grid = lambda META: (\n triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\n * triton.cdiv(output_dim, META[\"BLOCK_SIZE_K\"]),\n )\n transpose_matmul_248_kernel[grid](\n input,\n qweight,\n output,\n scales,\n qzeros,\n g_idx,\n input.shape[0],\n qweight.shape[1],\n output_dim,\n bits,\n maxq,\n input.stride(0),\n input.stride(1),\n qweight.stride(0),\n qweight.stride(1),\n output.stride(0),\n output.stride(1),\n scales.stride(0),\n qzeros.stride(0),\n )\n return output\n", - "description_1": "Use triton language to implement two matrix multiplication kernels: 'matmul_248_kernel' and 'transpose_matmul_248_kernel'. The first kernel computes C = A x B where A is a float16 matrix of shape (M, K), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, N). The second kernel computes C = A x B where A is a float16 matrix of shape (M, N), B is an int32 matrix of shape (K//8, N), and C is a float16 matrix of shape (M, K). Both kernels use additional parameters for scaling and zero-point adjustments, and they are optimized for specific block sizes and group sizes.", - "description_2": "Use triton language to create optimized matrix multiplication kernels for quantized matrices, handling scaling and zero-point adjustments, with specific block and group size configurations.", - "difficulty": 4 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef kernel_fn1(X, Y, stride_x, stride_y, BLOCK: tl.constexpr):\n pid = tl.program_id(0)\n offset = pid * BLOCK\n x = tl.load(X + offset * stride_x)\n y = tl.load(Y + offset * stride_y)\n result = x + y\n tl.store(X + offset * stride_x, result)\n\n@triton.jit\ndef kernel_fn2(Z, stride_z, BLOCK: tl.constexpr):\n pid = tl.program_id(0)\n offset = pid * BLOCK\n z = tl.load(Z + offset * stride_z)\n result = z * 2\n tl.store(Z + offset * stride_z, result)\n\ndef call_kernel1(X, Y, stride_x, stride_y, grid_size, BLOCK):\n grid = (grid_size,)\n kernel_fn1[(grid,)](X, Y, stride_x, stride_y, BLOCK)\n\ndef call_kernel2(Z, stride_z, grid_size, BLOCK):\n grid = (grid_size,)\n kernel_fn2[(grid,)](Z, stride_z, BLOCK)\n", - "description_1": "Use triton language to create two kernels: kernel_fn1 adds elements from two input arrays X and Y with respective strides and stores the result back into X. It requires BLOCK as a compile-time constant, and is parallelized over a single grid dimension. kernel_fn2 doubles the elements from an input array Z with stride and stores the result back into Z. It also requires BLOCK as a compile-time constant, and is parallelized over a single grid dimension. There are two corresponding functions, call_kernel1 and call_kernel2, that set up the grid and call these kernels.", - "description_2": "Use triton language to create a kernel that adds elements from two arrays with strides and stores the result, and another kernel that doubles elements from an array with a stride and stores the result. Each kernel uses a BLOCK size as a compile-time constant and operates over a single grid dimension.", - "difficulty": 2 - }, - { - "code": "import triton\nimport torch\n\n# Example of a Triton kernel\n@triton.jit\ndef example_kernel(input_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < input_ptr.shape[0]\n x = tl.load(input_ptr + offsets, mask=mask)\n tl.store(output_ptr + offsets, x * x, mask=mask)\n\n# Function to call the Triton kernel\ndef call_example_kernel(input_tensor, output_tensor, block_size):\n grid = lambda meta: (triton.cdiv(input_tensor.shape[0], block_size),)\n example_kernel[grid](input_tensor, output_tensor, BLOCK_SIZE=block_size)\n", - "description_1": "Use triton language to create a kernel that squares the elements of an input tensor and stores the result in an output tensor. The kernel should be executed in blocks defined by a block size.", - "description_2": "Use triton language to define a kernel that processes an input tensor by squaring its values, then stores the squared values in an output tensor, with execution managed in specified block sizes.", - "difficulty": 2 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef my_kernel(X, Y, Z, N):\n pid = tl.program_id(axis=0)\n offset = pid * N\n X = X + offset\n Y = Y + offset\n Z = Z + offset\n # ... Kernel computations ...\n\ndef my_kernel_wrapper(X, Y, Z, N):\n grid = (N,)\n my_kernel[grid](X, Y, Z, N)\n\n# Kernel call\nx = torch.tensor([1.0, 2.0, 3.0], device='cuda')\ny = torch.tensor([4.0, 5.0, 6.0], device='cuda')\nz = torch.tensor([0.0, 0.0, 0.0], device='cuda')\nmy_kernel_wrapper(x, y, z, 3)\n", - "description_1": "Use triton language to define a kernel my_kernel with 4 parameters: X, Y, Z (pointers to data), and N (size). The kernel computes offsets based on the program_id. Additionally, define a wrapper my_kernel_wrapper that sets the grid size and calls my_kernel with torch tensors x, y, z as inputs.", - "description_2": "Use triton language to define and call a kernel for vector operations on GPU using torch tensors.", - "difficulty": 2 - }, - { - "code": "import triton\n\n@triton.jit\ndef get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):\n cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE\n cur_block = tl.load(col_indices + cur_block_idx, eviction_policy=\"evict_last\")\n next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy=\"evict_last\", mask=cur_block_idx + 1 < total_blocks)\n needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0\n jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK\n\n offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK\n return offset\n\n@triton.jit\ndef forward_inner(\n q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,\n acc, l_i, m_i,\n off_z, off_h, offs_m, offs_n,\n kv_indices, kv_num_blocks,\n block_n_start, block_n_end,\n MATMUL_PRECISION,\n IS_FULL_BLOCKS,\n):\n SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)\n RCP_LN2: tl.constexpr = 1.44269504\n\n if PRESCALE_QK:\n q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)\n\n for start_n in range(block_n_start, block_n_end):\n if IS_DIVISIBLE:\n acc, l_i, m_i = forward_block_mn(\n q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,\n acc, l_i, m_i,\n off_z, off_h, offs_m, offs_n,\n MATMUL_PRECISION, RCP_LN2,\n IS_FULL_BLOCKS,\n )\n else:\n acc, l_i, m_i = forward_block_mn(\n q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,\n acc, l_i, m_i,\n off_z, off_h, offs_m, offs_n,\n MATMUL_PRECISION, RCP_LN2,\n IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,\n )\n\n offset = get_offset_for_next_block(\n start_n, kv_indices, kv_num_blocks,\n SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N\n )\n\n V_block_ptr = tl.advance(V_block_ptr, (offset, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, offset))\n\n offs_n = offs_n + offset\n\n return acc, l_i, m_i\n\n@triton.jit\ndef forward_block_mn(\n q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,\n acc, l_i, m_i,\n off_z, off_h, offs_m, offs_n,\n MATMUL_PRECISION, RCP_LN2,\n IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,\n):\n if IS_DIVISIBLE:\n k = tl.load(K_block_ptr)\n else:\n k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = \"zero\")\n qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION)\n if not PRESCALE_QK:\n qk *= SM_SCALE\n\n if CHECK_BLOCK_BOUNDARY:\n m = offs_m % Q_LEN\n n = offs_n % KV_LEN\n else:\n m = offs_m\n n = offs_n\n\n if CHECK_BLOCK_BOUNDARY:\n post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float(\"-inf\"))\n\n if not IS_FULL_BLOCKS:\n if CHECK_BLOCK_BOUNDARY:\n mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float(\"-inf\"))\n post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float(\"-inf\"))\n\n if not PRESCALE_QK:\n post_mod_scores *= RCP_LN2\n\n m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))\n if not ROWS_GUARANTEED_SAFE:\n masked_out_rows = (m_ij == float(\"-inf\"))\n m_ij_masked = tl.where(masked_out_rows, 0, m_ij)\n else:\n m_ij_masked = m_ij\n\n alpha = tl.math.exp2(m_i - m_ij_masked)\n p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])\n\n l_i = l_i * alpha + tl.sum(p, 1)\n acc = acc * alpha[:, None]\n\n if IS_DIVISIBLE:\n v = tl.load(V_block_ptr)\n else:\n v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = \"zero\")\n acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)\n\n m_i = m_ij\n\n return acc, l_i, m_i\n", - "description_1": "Use triton language to implement a kernel for computing the next block offset in a loop, and kernels for forward pass in a neural network layer. The forward pass involves loading data, computing dot products, applying modifications, and updating accumulators.", - "description_2": "Use triton language to implement kernels for computing block offsets and performing forward pass operations in neural networks, including data loading, dot product computation, and accumulator updates.", - "difficulty": 4 - }, - { - "code": "import triton\nimport torch\n\n# Example Triton kernel\n@triton.jit\ndef example_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_example_kernel(x, y, z, block_size):\n # Call the Triton kernel\n example_kernel[(1,)](x, y, z, BLOCK_SIZE=block_size)\n\n# Example usage\nx = torch.tensor([1, 2, 3])\ny = torch.tensor([4, 5, 6])\nz = torch.empty_like(x)\ncall_example_kernel(x, y, z, block_size=1024)\n", - "description_1": "Use triton language to define a kernel 'example_kernel' with 4 parameters: X, Y, Z, and BLOCK_SIZE. The kernel performs operations on input tensors X, Y, and Z with a specified block size. A function 'call_example_kernel' is used to invoke this kernel with PyTorch tensors and a block size.", - "description_2": "Use triton language to create a kernel for tensor operations with a block size parameter, and provide a function to call this kernel with PyTorch tensors.", - "difficulty": 1 - }, - { - "code": "import triton\nimport triton.language as tl\n\n# Promote to tensor\n@triton.jit\ndef promote_to_tensor(x):\n return x + tl.zeros((1,), tl.int1)\n\n# Floor division\n@triton.jit\ndef div_floor_integer(a, b):\n quot = a // b\n remainder = a % b\n fixed = tl.where(remainder != 0, quot - 1, quot)\n return tl.where((a < 0) != (b < 0), fixed, quot)\n\n# Remainder calculation\n@triton.jit\ndef remainder_integer(a, b):\n remainder = a % b\n return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder)\n\n# Check if floating\n@triton.jit\ndef is_floating(x):\n return promote_to_tensor(x).dtype.is_floating()\n\n# Element-wise product accumulate\n@triton.jit\ndef _prod_accumulate(a, b):\n return a * b\n\n# Reduce product over axis\n@triton.jit\ndef prod(input, axis):\n return tl.reduce(input, axis, _prod_accumulate)\n\n# Minimum calculation\n@triton.jit\ndef minimum(a, b):\n mask = a < b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Maximum calculation\n@triton.jit\ndef maximum(a, b):\n mask = a > b\n if is_floating(a):\n mask |= a != a\n return tl.where(mask, a, b)\n\n# Reduce minimum over dimension\n@triton.jit\ndef min2(a, dim):\n return tl.reduce(a, dim, minimum)\n\n# Reduce maximum over dimension\n@triton.jit\ndef max2(a, dim):\n return tl.reduce(a, dim, maximum)\n\n# Minimum with index\n@triton.jit\ndef minimum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value < b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n equal |= a_isnan and b_isnan\n\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n# Maximum with index\n@triton.jit\ndef maximum_with_index(a_value, a_index, b_value, b_index):\n mask = a_value > b_value\n equal = a_value == b_value\n if is_floating(a_value):\n a_isnan = a_value != a_value\n b_isnan = b_value != b_value\n mask |= a_isnan and not b_isnan\n equal |= a_isnan and b_isnan\n\n mask |= equal & (a_index < b_index)\n return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)\n\n# Minimum with index reduction\n@triton.jit\ndef min_with_index(value, index, dim):\n return tl.reduce((value, index), dim, minimum_with_index)\n\n# Maximum with index reduction\n@triton.jit\ndef max_with_index(value, index, dim):\n return tl.reduce((value, index), dim, maximum_with_index)\n\n# Welford reduction step\n@triton.jit\ndef welford_reduce(value, mean, m2, weight, first_iteration):\n if first_iteration:\n new_weight = tl.full(weight.shape, 1, weight.dtype)\n new_mean = value\n new_m2 = tl.zeros_like(m2)\n else:\n delta = value - mean\n new_weight = weight + 1\n new_mean = mean + delta / new_weight\n new_m2 = m2 + delta * (value - new_mean)\n return new_mean, new_m2, new_weight\n\n# Combine two Welford results\n@triton.jit\ndef welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):\n delta = mean_2 - mean_1\n new_weight = weight_1 + weight_2\n w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)\n return (\n mean_1 + delta * w2_over_w,\n m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,\n new_weight,\n )\n\n# Welford reduction\n@triton.jit\ndef welford(mean, m2, weight, dim):\n return tl.reduce((mean, m2, weight), dim, welford_combine)\n\n# Device assert and return\n@triton.jit\ndef device_assert_then(cond, msg, r):\n tl.device_assert(cond, msg)\n return r\n\n# Random integer in range\n@triton.jit\ndef randint64(seed, offset, low, high):\n r0, r1, r2, r3 = tl.randint4x(seed, offset)\n r0 = r0.to(tl.uint64)\n r1 = r1.to(tl.uint64)\n result = r0 | (r1 << 32)\n size = high - low\n result = result % size.to(tl.uint64)\n result = result.to(tl.int64) + low\n return result\n\n# Bitwise any operation\n@triton.jit\ndef _any_combine(a, b):\n return a | b\n\n# Reduce any over dimension\n@triton.jit\ndef any(a, dim):\n return tl.reduce(a, dim, _any_combine)\n\n# Binary search for bucketize\n@triton.jit\ndef bucketize_binary_search(\n values,\n offsets_ptr,\n indexing_dtype,\n right,\n OFFSETS_SIZE: int,\n BLOCK_SHAPE,\n):\n low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)\n high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)\n\n full_range = OFFSETS_SIZE + 1\n while full_range > 1:\n mid = (high + low) // 2\n mask = mid < OFFSETS_SIZE\n bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)\n if right:\n is_above = values >= bucket_upper_bound\n else:\n is_above = values > bucket_upper_bound\n\n low = tl.where(is_above & mask, mid + 1, low)\n high = tl.where(is_above, high, mid)\n\n full_range = (full_range + 1) // 2\n\n return low\n\n# Pack value and flag\n@triton.jit\ndef pack_value_flag(\n value,\n flag,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)\n return flag.to(DTYPE_PACK) | (uv << bitwidth)\n\n# Unpack value\n@triton.jit\ndef unpack_value(\n pack,\n DTYPE_VALUE,\n DTYPE_VALUE_AS_UINT,\n):\n DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)\n DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)\n bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth\n value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)\n return value_uint.to(DTYPE_VALUE, bitcast=True)\n\n# Unpack flag\n@triton.jit\ndef unpack_flag(pack, DTYPE_FLAG):\n return pack.to(DTYPE_FLAG)\n\n# Exclusive scan with decoupled lookback\n@triton.jit\ndef exclusive_scan_decoupled_lookback(\n scratch_base,\n block_value,\n index,\n combine_fn,\n DTYPE_VALUE_AS_UINT: tl.constexpr,\n DTYPE_PACK: tl.constexpr,\n):\n DTYPE_VALUE = block_value.dtype\n pack = pack_value_flag(\n block_value,\n tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n if index > 0:\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n\n exclusive_prefix = tl.zeros([], DTYPE_VALUE)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)\n while flag == 0:\n pack = tl.atomic_add(scratch_base + test_target, 0, sem=\"relaxed\")\n flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)\n\n value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n pack = pack_value_flag(\n inclusive_prefix,\n tl.full([], 2, DTYPE_VALUE_AS_UINT),\n DTYPE_VALUE_AS_UINT,\n DTYPE_PACK,\n )\n tl.atomic_xchg(scratch_base + index, pack, sem=\"relaxed\")\n return exclusive_prefix\n\n# Exclusive scan for 64-bit block\n@triton.jit\ndef exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):\n if index > 0:\n block_value_u64 = block_value.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 1, block_value_u64)\n tl.debug_barrier()\n flag_one = tl.full([], 1, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem=\"release\")\n\n exclusive_prefix = tl.zeros([], block_value.dtype)\n prefix_valid = False\n test_target = index - 1\n while test_target >= 0:\n flag = tl.full([], 0, tl.uint64)\n while flag == 0:\n flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem=\"acquire\")\n\n value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))\n value = value_u64.to(block_value.dtype, bitcast=True)\n if prefix_valid:\n exclusive_prefix = combine_fn(value, exclusive_prefix)\n else:\n exclusive_prefix = value\n prefix_valid = True\n\n if flag == 2:\n test_target = -1\n else:\n test_target = test_target - 1\n\n if prefix_valid:\n inclusive_prefix = combine_fn(exclusive_prefix, block_value)\n else:\n inclusive_prefix = block_value\n inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)\n tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)\n tl.debug_barrier()\n flag_two = tl.full([], 2, tl.uint64)\n tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem=\"release\")\n\n return exclusive_prefix\n\n# Frexp decomposition\n@triton.jit\ndef frexp(x):\n y = libdevice.ilogb(x) + 1\n exponent = tl.where(x == 0, 0, y)\n mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))\n return mantissa, exponent\n\n# Compare and swap with index\n@triton.jit\ndef _compare_and_swap_with_index(\n x,\n idxs,\n rnumel,\n flip,\n i: tl.constexpr,\n n_dims: tl.constexpr,\n stable: tl.constexpr,\n descending: tl.constexpr,\n):\n n_outer: tl.constexpr = x.numel >> n_dims\n shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]\n\n idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)\n\n y = tl.reshape(x, shape)\n iy = y.to(idtype, bitcast=True)\n right_mask = tl.arange(0, 2)[None, :, None].to(idtype)\n left_mask = (1 - right_mask).to(idtype)\n ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape)\n iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape)\n ileft = tl.reshape(ileft, x.shape)\n iright = tl.reshape(iright, x.shape)\n left = ileft.to(x.dtype, bitcast=True)\n right = iright.to(x.dtype, bitcast=True)\n\n y_idx = tl.reshape(idxs, shape)\n left_idx = tl.broadcast_to(\n tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape\n )\n right_idx = tl.broadcast_to(\n tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape\n )\n left_idx = tl.reshape(left_idx, x.shape)\n right_idx = tl.reshape(right_idx, x.shape)\n\n if rnumel is None:\n left_valid_mask = tl.full(x.shape, True, tl.int1)\n right_valid_mask = tl.full(x.shape, True, tl.int1)\n else:\n left_valid_mask = left_idx < rnumel\n right_valid_mask = right_idx < rnumel\n\n ix = x.to(idtype, bitcast=True)\n\n if descending:\n cond = left < right\n else:\n cond = left > right\n\n if stable:\n cond = cond | ((left == right) & (left_idx > right_idx))\n\n cond = (right_valid_mask > left_valid_mask) | (\n (right_valid_mask == left_valid_mask) & cond\n )\n cond = cond ^ flip\n ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))\n new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs))\n\n return ret.to(x.dtype, bitcast=True), new_idxs\n\n# Bitonic merge with index\n@triton.jit\ndef _bitonic_merge_with_index(\n x,\n idxs,\n rnumel,\n stage: tl.constexpr,\n alternating: tl.constexpr,\n n_dims: tl.constexpr,\n stable: tl.constexpr,\n descending: tl.constexpr,\n):\n n_outer: tl.constexpr = x.numel >> n_dims\n tl.static_assert(stage <= n_dims)\n if alternating:\n shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]\n flip = tl.reshape(\n tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape\n )\n else:\n flip = False\n for i in tl.static_range(stage):\n x, idxs = _compare_and_swap_with_index(\n x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending\n )\n return x, idxs\n\n# Sort with index\n@triton.jit\ndef sort_with_index(\n x,\n idxs,\n rnumel,\n dim: tl.constexpr = None,\n stable: tl.constexpr = tl.constexpr(False),\n descending: tl.constexpr = tl.constexpr(False),\n):\n x, idxs = tl.broadcast(x, idxs)\n _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim\n tl.static_assert(\n _dim == len(x.shape) - 1, \"only minor dimension is currently supported\"\n )\n n_dims: tl.constexpr = _log2(x.shape[_dim])\n\n for i in tl.static_range(1, n_dims + 1):\n x, idxs = _bitonic_merge_with_index(\n x,\n idxs,\n rnumel,\n i,\n alternating=i < n_dims,\n n_dims=n_dims,\n stable=stable,\n descending=descending,\n )\n return x, idxs\n\n# Select one element based on mask\n@triton.jit\ndef select_one(x, mask, dim, keep_dims=False):\n idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False)\n ix = x.to(idtype, bitcast=True)\n iy = tl.sum(ix * mask, dim, keep_dims=keep_dims)\n return iy.to(x.dtype, bitcast=True)\n", - "description_1": "Use triton language to implement a set of utilities including tensor promotion, floor division, minimum and maximum calculations, statistical reduction (like Welford), sorting, and more. These functions operate on tensors and indices and cover element-wise operations, reduction operations, random generation, and scan operations.", - "description_2": "Use triton language to build tensor manipulation functions including tensor promotion, minimum/maximum calculations, sorting, and exclusive scan operations.", - "difficulty": 5 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for element-wise addition\n@triton.jit\ndef add_kernel(X, Y, Z, N):\n pid = triton.program_id(0)\n block_size = 1024\n offset = pid * block_size + triton.arange(0, block_size)\n mask = offset < N\n x = triton.load(X + offset, mask=mask)\n y = triton.load(Y + offset, mask=mask)\n z = x + y\n triton.store(Z + offset, z, mask=mask)\n\n# Function to call the Triton kernel\ndef add_tensors(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n assert x.is_cuda and y.is_cuda\n assert x.shape == y.shape\n z = torch.empty_like(x)\n N = x.numel()\n grid = lambda meta: (triton.cdiv(N, meta['block_size']),)\n add_kernel[grid](x, y, z, N)\n return z\n", - "description_1": "Use triton language to implement an element-wise addition kernel for two CUDA tensors. The kernel is decorated with @triton.jit and takes four arguments: two input tensors X and Y, an output tensor Z, and the number of elements N. The kernel computes the sum of X and Y and stores the result in Z. The function add_tensors calls this kernel, ensuring the input tensors are on CUDA and have the same shape, and returns the result tensor.", - "description_2": "Use triton language to create a CUDA kernel for element-wise addition of two tensors, and implement a function to execute this kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom torch._library import capture_triton\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n output = torch.empty_like(x)\n n_elements = output.numel()\n\n def grid(meta):\n return (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n\n capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)\n return output\n", - "description_1": "Use triton language to define a kernel function 'add_kernel' that performs element-wise addition of two input tensors. The kernel takes five parameters: two input pointers, one output pointer, the number of elements, and a block size. It calculates the program ID, determines the block start, and computes offsets. It uses these offsets to load elements from the input tensors, adds them, and stores the result in the output tensor. The 'add' function wraps this kernel, prepares the output tensor, calculates the number of elements, defines a grid function for execution, and calls the kernel using 'capture_triton'.", - "description_2": "Use triton language to create a kernel for element-wise addition of two tensors, and wrap it in a function that prepares inputs and executes the kernel.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom math import prod\nfrom torch.utils.flop_counter import register_flop_formula\nfrom torch.utils._triton import has_triton\nimport torch._functorch.config as config\n\n\nif has_triton():\n \n @triton.jit\n def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n block = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE\n msk = block < sz\n inp = tl.load(inp_ptr + block, mask=msk)\n relu = tl.where(inp < 0, 0, inp)\n tl.store(out_ptr + block, relu, mask=msk)\n\n @torch._library.triton_op(\"testac::triton_relu\", mutates_args=())\n def triton_relu(x: torch.Tensor) -> torch.Tensor:\n y = torch.empty_like(x)\n sz = y.numel()\n BLOCK_SIZE = 256\n grid = (triton.cdiv(sz, BLOCK_SIZE),)\n torch._library.capture_triton(relu_kernel_)[grid](x, y, sz, BLOCK_SIZE)\n return y\n\n @torch._library.triton_op(\"testac::triton_relu_backward\", mutates_args=())\n def triton_relu_backward(grad_out: torch.Tensor) -> torch.Tensor:\n grad_x = torch.empty_like(grad_out)\n sz = grad_out.numel()\n BLOCK_SIZE = 256\n grid = (triton.cdiv(sz, BLOCK_SIZE),)\n torch._library.capture_triton(relu_kernel_)[grid](grad_out, grad_x, sz, BLOCK_SIZE)\n return grad_x\n\n def f(x, ws):\n x = torch.ops.testac.triton_relu(x)\n for w in ws:\n x = torch.ops.testac.triton_relu(torch.mm(x, w))\n return x.sum()\n\n x = torch.randn(512, 512, requires_grad=True, device=\"cuda\")\n ws = [torch.randn(512, 512, requires_grad=True, device=\"cuda\") for _ in range(5)]\n\n def call():\n return f(x, ws)\n\n @register_flop_formula(\n [torch.ops.testac.triton_relu, torch.ops.testac.triton_relu_backward]\n )\n def triton_relu_flops(inp_shape, *args, **kwargs):\n return prod(inp_shape)\n", - "description_1": "Use triton language to implement a ReLU activation kernel and its backward operation. The kernel has four parameters: inp_ptr (input pointer), out_ptr (output pointer), sz (size of the tensor), and BLOCK_SIZE (constant expression defining block size). The function computes the ReLU activation element-wise on blocks of the input tensor and writes the result to the output tensor. The triton_relu function initializes an empty tensor y and configures the grid size for execution, then calls the ReLU kernel to perform activation on input tensor x, returning the result as y. Similarly, the triton_relu_backward function computes the backward pass using the same kernel.", - "description_2": "Use triton language to create a ReLU activation function and its gradient calculation. Configure grid and block sizes for efficient parallel execution on the input tensor.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\n\n@triton.jit\ndef add_kernel(x, y, output, n_elements, BLOCK_SIZE: tl.constexpr):\n pass\n\n@triton.jit\ndef add_kernel_2d_autotuned(x, y, output, x_elements, y_elements):\n pass\n\n@triton.jit\ndef add_kernel_autotuned(x, y, output, n_elements):\n pass\n\n@triton.jit\ndef add_kernel_autotuned_weird_param_order(in_ptr0, in_ptr1, n_elements, out_ptr):\n pass\n\n@triton.jit\ndef add_kernel_with_optional_param(x, y, output, n_elements, ARGS_PASSED, BLOCK_SIZE: tl.constexpr):\n pass\n\n@triton.jit\ndef add_kernel_with_scaling(x, y, output, n_elements, scaling_factor, BLOCK_SIZE: tl.constexpr):\n pass\n\n@triton.jit\ndef mul2_inplace_kernel(x, n_elements, BLOCK_SIZE: tl.constexpr):\n pass\n\nclass Model(torch.nn.Module):\n def forward(self, x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n add_kernel[(n_elements,)](x, y, output, n_elements, BLOCK_SIZE=16)\n return output\n\nx = torch.randn(10, device='cuda')\ny = torch.randn(10, device='cuda')\nmodel = Model()\noutput = model(x, y)\n", - "description_1": "Use triton language to define several kernels for element-wise addition and multiplication with optional parameters and scaling. Implement a PyTorch model that uses these kernels to perform operations on input tensors.", - "description_2": "Use triton language to define kernels for element-wise operations and integrate them into a PyTorch model for tensor computations.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nfrom triton import language as tl\nfrom torch._C import _cuda_getCurrentRawStream as get_cuda_stream\nfrom torch._dynamo.utils import same\n\ndef autotune(configs, meta):\n def decorator(fn):\n return CachingAutotuner(\n fn,\n triton_meta=meta,\n configs=configs,\n save_cache_hook=False,\n mutated_arg_names=[\"in_out_ptr0\"],\n heuristic_type=HeuristicType.POINTWISE,\n )\n\n return decorator\n\n@autotune(\n configs=[\n triton.Config({\"XBLOCK\": 1}),\n triton.Config({\"XBLOCK\": 2}),\n ],\n meta={\n \"signature\": {0: \"*fp32\", 1: \"*fp32\", 2: \"i32\"},\n \"device\": DeviceProperties.create(torch.device(\"cuda\")),\n \"configs\": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],\n \"constants\": {},\n },\n)\n@triton.jit\ndef kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):\n pid = tl.program_id(0)\n block_start = pid * XBLOCK\n offsets = block_start + tl.arange(0, XBLOCK)\n mask = offsets < xnumel\n x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0)\n y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0)\n output = x + y\n tl.store(in_out_ptr0 + offsets, output, mask=mask)\n\ndef test_kernel():\n xnumel = 384\n in0 = torch.rand(xnumel, device=\"cuda\", dtype=torch.float32)\n inout1 = torch.rand(xnumel, device=\"cuda\", dtype=torch.float32)\n inout2 = inout1.clone()\n\n stream0 = get_cuda_stream(0)\n kernel.run(inout1, in0, xnumel, grid=(xnumel//XBLOCK, ), stream=stream0)\n kernel.run(inout2, in0, xnumel, grid=(xnumel//XBLOCK, ), stream=stream0)\n\n assert same(inout1, inout2, tol=0.001, equal_nan=True), \"failed autotune with inplace kernel\"\n\n", - "description_1": "Use triton language to define an element-wise addition kernel, autotuned for multiple configurations to improve performance. The kernel takes two input pointers and a size, computes their sum, and writes the result to the first input pointer, with masking to handle out-of-bounds memory accesses.", - "description_2": "Use triton language to perform element-wise addition with autotuning, handling memory alignment and using CUDA streams for execution on GPU.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\n\n# Triton kernel for matrix multiplication\n@triton.jit\ndef matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n pid = tl.program_id(axis=0)\n # Compute the block row and column indices\n block_row = pid // (N // BLOCK_SIZE_N)\n block_col = pid % (N // BLOCK_SIZE_N)\n # Compute the start of the block in the output matrix\n c_start = block_row * BLOCK_SIZE_M * stride_cm + block_col * BLOCK_SIZE_N * stride_cn\n # Initialize the accumulator\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n # Loop over the K dimension\n for k in range(0, K, BLOCK_SIZE_K):\n # Compute the start of the block in the input matrices\n a_start = block_row * BLOCK_SIZE_M * stride_am + k * stride_ak\n b_start = k * stride_bk + block_col * BLOCK_SIZE_N * stride_bn\n # Load the blocks from the input matrices\n a = tl.load(a_ptr + a_start, shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))\n b = tl.load(b_ptr + b_start, shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))\n # Compute the matrix multiplication for the block\n acc += tl.dot(a, b)\n # Store the result in the output matrix\n tl.store(c_ptr + c_start, acc)\n\n# Function to call the Triton kernel\ndef matmul(a, b):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions for matrix multiplication\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = (M // 32, N // 32)\n matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32)\n return c\n", - "description_1": "Use triton language to implement a matrix multiplication kernel. The kernel takes two input matrices 'a' and 'b', and computes their product 'c'. The kernel is parameterized by the dimensions M, N, K, and the block sizes BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K. The function 'matmul' calls this kernel with appropriate grid and block sizes.", - "description_2": "Use triton language to create a matrix multiplication kernel and a function to execute it, handling input matrices of compatible dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\n\n@triton.jit\ndef triton_red_fused_add_sum_2(in_out_ptr0, in_ptr0, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):\n \"\"\"\n Kernel function for performing a fused addition and reduction sum operation.\n\n Parameters:\n - in_out_ptr0: Pointer to the input/output buffer where results are stored.\n - in_ptr0: Pointer to the input buffer containing data to be processed.\n - xnumel: The number of elements in the x dimension.\n - rnumel: The number of elements in the reduction dimension.\n - XBLOCK: The size of blocks in the x dimension (compile-time constant).\n - RBLOCK: The size of blocks in the reduction dimension (compile-time constant).\n \"\"\"\n xnumel = 1024\n rnumel = 2048\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n _tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(in_ptr0 + (r1 + (2048*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)\n tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])\n tmp3 = _tmp2 + tmp1\n _tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)\n tmp2 = tl.sum(_tmp2, 1)[:, None]\n tmp4 = tl.load(in_out_ptr0 + (x0), xmask, eviction_policy='evict_last')\n tmp5 = tmp4 + tmp2\n tl.debug_barrier()\n tl.store(in_out_ptr0 + (x0), tmp5, xmask)\n", - "description_1": "Use triton language to define a kernel that performs fused addition and reduction sum operations. It processes elements with x and reduction dimensions specified, using compile-time constants to determine block sizes.", - "description_2": "Use triton language to write a kernel for fused addition and reduction sums, with defined input pointers and block sizes as compile-time constants.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing._internal.triton_utils import add_kernel, HAS_CUDA\nfrom torch.testing._internal.triton_utils import requires_cuda\n\n@requires_cuda\ndef test_inplace_triton_kernel_training():\n @triton.jit\n def sin_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = tl.sin(x)\n tl.store(out_ptr + offsets, output, mask=mask)\n\n def sin_triton(x, out):\n n_elements = x.numel()\n sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)\n\n factory_op = torch.empty_like\n\n class MySin(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x):\n out = factory_op(x)\n sin_triton(x, out)\n ctx.save_for_backward(out)\n return out\n\n @staticmethod\n def backward(ctx, grad):\n (saved,) = ctx.saved_tensors\n out = factory_op(grad)\n sin_triton(saved, out)\n return out\n\n def f(x):\n return MySin.apply(x)\n\n x = torch.randn(3, device=\"cuda\", requires_grad=True)\n count_numel_train(f, x)\n\n@requires_cuda\ndef test_inplace_triton_kernel_v1():\n def f(x: torch.Tensor, y: torch.Tensor):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = (n_elements,)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n return output\n\n inp = (torch.randn(10, device=\"cuda\"), torch.randn(10, device=\"cuda\"))\n count_numel(f, *inp)\n\n@requires_cuda\ndef test_inplace_triton_kernel_v2():\n def f(x: torch.Tensor, y: torch.Tensor):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = (n_elements,)\n tmp = torch.add(x, 1)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n return output, tmp\n\n inp = (torch.randn(10, device=\"cuda\"), torch.randn(10, device=\"cuda\"))\n count_numel(f, *inp)\n\n@requires_cuda\ndef test_inplace_triton_kernel_v3():\n def f(x: torch.Tensor, y: torch.Tensor):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = (n_elements,)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n x.add_(1)\n return output\n\n inp = (torch.randn(10, device=\"cuda\"), torch.randn(10, device=\"cuda\"))\n count_numel(f, *inp)\n\n@requires_cuda\ndef test_inplace_triton_kernel_v4():\n def f(x: torch.Tensor, y: torch.Tensor):\n x_view = x.view(-1)\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = (n_elements,)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n output2 = x_view.mul(2)\n return output, output2\n\n inp = (torch.randn(10, device=\"cuda\"), torch.randn(10, device=\"cuda\"))\n count_numel(f, *inp)\n\n@requires_cuda\ndef test_inplace_triton_kernel_v5():\n def f(x: torch.Tensor, y: torch.Tensor):\n x_view = x.view(-1)\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = (n_elements,)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n x_view.mul_(2)\n return output\n\n inp = (torch.randn(10, device=\"cuda\"), torch.randn(10, device=\"cuda\"))\n count_numel(f, *inp)\n\n@requires_cuda\ndef test_inplace_triton_kernel_v6():\n def f(x: torch.Tensor, y: torch.Tensor):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = (n_elements,)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n return output\n\n t = torch.randn(10, device=\"cuda\")\n inp = (t, t.view(-1))\n count_numel(f, *inp)\n", - "description_1": "Use triton language to implement kernels for performing element-wise sine computations and addition of tensors. The sine computation uses a triton kernel that loads input elements, applies the sine function, and stores the result. The addition uses a triton kernel that adds elements from two input tensors and stores the result in an output tensor. Various variations of these operations are tested, including in-place operations and tensor view manipulations.", - "description_2": "Use triton language to create and test kernels for element-wise operations such as sine and addition, handling variations like in-place operations.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # Kernel to add two vectors\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n result = x + y\n\n tl.store(output_ptr + offsets, result, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n assert x.is_cuda and y.is_cuda\n assert x.shape == y.shape\n output = torch.empty_like(x)\n \n n_elements = x.numel()\n BLOCK_SIZE = 1024\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n \n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n \n return output\n\n# Example call to the Triton add function\nx = torch.randn(1024, device='cuda')\ny = torch.randn(1024, device='cuda')\noutput = add(x, y)\n", - "description_1": "Use triton language to create a kernel `add_kernel` that takes pointers to two input tensors `x_ptr` and `y_ptr`, an output tensor `output_ptr`, and the number of elements `n_elements`. The kernel adds these tensors element-wise using a block size of `BLOCK_SIZE`. Another function `add` is defined to set up and launch this kernel with specified grid size based on input tensors `x` and `y` on the GPU.", - "description_2": "Use triton language to define a vector addition kernel with specific grid configuration, callable from a PyTorch function that prepares and launches this kernel on CUDA tensors.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Kernel function decorated with @triton.jit\n@triton.jit\ndef kernel_function(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # Triton kernel code\n pass\n\n# Function to call the Triton kernel\ndef call_kernel(x):\n # Assuming x is a torch tensor\n n_elements = x.numel()\n y = torch.empty_like(x)\n # Launch the Triton kernel\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n kernel_function[grid](x, y, n_elements, BLOCK_SIZE=1024)\n return y\n", - "description_1": "Use triton language to define a kernel function 'kernel_function' with 4 parameters: x_ptr (pointer to input tensor), y_ptr (pointer to output tensor), n_elements (number of elements in the tensor), and BLOCK_SIZE (block size for execution). The function 'call_kernel' is used to launch this kernel, taking a torch tensor 'x' as input, creating an output tensor 'y', and executing the kernel with a grid size calculated based on the number of elements and block size.", - "description_2": "Use triton language to define a kernel for element-wise operations on tensors, and a function to execute this kernel on a given input tensor using a specified block size.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom torch._inductor.runtime.triton_helpers import math as tl_math\nfrom torch._inductor.runtime.triton_heuristics import triton_config\nfrom torch._inductor.runtime.hints import DeviceProperties, HeuristicType\n\ndef _get_cos_kernel_caching_autotuner_args():\n @triton.jit\n def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):\n # Kernel to compute the cosine of input elements\n xnumel = 16\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:]\n xmask = xindex < xnumel\n x0 = xindex\n tmp0 = tl.load(in_ptr0 + (x0), xmask)\n tmp1 = tl_math.cos(tmp0)\n tl.store(out_ptr0 + (x0), tmp1, xmask)\n\n triton_meta = {\n \"signature\": {0: \"*fp32\", 1: \"*fp32\", 2: \"i32\"},\n \"device\": DeviceProperties.create(torch.device(\"cuda\")),\n \"constants\": {},\n \"configs\": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())],\n }\n\n configs = [\n triton_config([16], 64),\n triton_config([256], 64),\n ]\n\n inductor_meta = {}\n\n return {\n \"fn\": triton_,\n \"triton_meta\": triton_meta,\n \"configs\": configs,\n \"save_cache_hook\": False,\n \"mutated_arg_names\": [],\n \"heuristic_type\": HeuristicType.POINTWISE,\n \"inductor_meta\": inductor_meta,\n }\n", - "description_1": "Use triton language to define a kernel function 'triton_' that computes the cosine of input elements. The kernel takes four parameters: 'in_ptr0' (input pointer), 'out_ptr0' (output pointer), 'xnumel' (number of elements), and 'XBLOCK' (block size, a compile-time constant). The kernel uses Triton's parallel programming model to load input data, compute the cosine using 'tl_math.cos', and store the result.", - "description_2": "Use triton language to create a kernel that calculates the cosine of elements from an input pointer and stores the results in an output pointer, using a specified block size for parallel execution.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef pass_kernel(kernel):\n pass\n\n@torch.compile(backend=\"eager\")\ndef f(x):\n grid = (x.numel(),)\n pass_kernel[grid](kernel=x)\n\n@triton.jit\ndef add_one_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = x + 1\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef add_one(x, out):\n n_elements = x.numel()\n add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)\n\nclass AddOne(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x):\n out = torch.empty_like(x)\n add_one(x, out)\n ctx.save_for_backward(out)\n return out\n\n @staticmethod\n def backward(ctx, grad):\n (saved,) = ctx.saved_tensors\n out = torch.empty_like(grad)\n add_one(saved, out)\n return out\n\n@torch.compile\ndef f(x):\n return AddOne.apply(x)\n\n@triton.jit\ndef pow2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = x * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef f(x: torch.Tensor):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16)\n return output\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef call_triton_add(\n x: torch.Tensor,\n y: torch.Tensor,\n):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)\n return output\n\nt1 = torch.rand(5, device=\"cuda\")\nt2 = torch.rand(5, device=\"cuda\")\n\ncompiled_func = torch.compile(call_triton_add)\ncompiled_func(t1, t2)\n", - "description_1": "Use triton language to implement a kernel 'pass_kernel' that accepts a kernel object as input. Another kernel 'add_one_kernel' increments each element of the input tensor by 1. Implement a function 'add_one' that utilizes 'add_one_kernel'. Another kernel 'pow2_kernel' calculates the square of each element of the input tensor. Implement a function that utilizes 'pow2_kernel' to return the squared output. Implement a Triton kernel 'add_kernel' that adds two input tensors element-wise. Implement a function 'call_triton_add' that utilizes 'add_kernel'.", - "description_2": "Use triton language to create kernels for element-wise operations such as adding one to each element and squaring each element. Implement a Triton kernel to perform element-wise addition of two tensors.", - "difficulty": 2 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sampled_addmm_kernel(\n alpha,\n beta,\n IS_BETA_ZERO: tl.constexpr,\n BLOCKSIZE_ROW: tl.constexpr,\n BLOCKSIZE_COL: tl.constexpr,\n k,\n TILE_K: tl.constexpr,\n values_ptr,\n values_batch_stride,\n values_nnz_stride,\n values_row_block_stride,\n values_col_block_stride,\n crow_indices_ptr,\n crow_indices_batch_stride,\n crow_indices_stride,\n col_indices_ptr,\n col_indices_batch_stride,\n col_indices_stride,\n mat1_ptr,\n mat1_batch_stride,\n mat1_tiled_row_stride,\n mat1_tiled_col_stride,\n mat1_row_block_stride,\n mat1_col_block_stride,\n mat2_ptr,\n mat2_batch_stride,\n mat2_tiled_row_stride,\n mat2_tiled_col_stride,\n mat2_row_block_stride,\n mat2_col_block_stride,\n acc_dtype: tl.constexpr,\n allow_tf32: tl.constexpr,\n):\n batch_pid = tl.program_id(axis=1)\n row_block_pid = tl.program_id(axis=0)\n\n crow_indices_offset_ptr = (\n crow_indices_ptr\n + crow_indices_batch_stride * batch_pid\n + crow_indices_stride * row_block_pid\n )\n nnz_offset = tl.load(crow_indices_offset_ptr)\n nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)\n\n # Compute nnz for the row with number row_block_pid.\n # If it is zero, skip the row.\n row_nnz = nnz_offset_next - nnz_offset\n if row_nnz == 0:\n return\n\n row_block_arange = tl.arange(0, BLOCKSIZE_ROW)\n col_block_arange = tl.arange(0, BLOCKSIZE_COL)\n\n # Pointers are set to the first block of the current row.\n values_block_ptrs = (\n values_ptr\n + values_batch_stride * batch_pid\n + values_nnz_stride * nnz_offset\n + values_row_block_stride * row_block_arange[:, None]\n + values_col_block_stride * col_block_arange[None, :]\n )\n\n col_index_nnz_ptr = (\n col_indices_ptr\n + col_indices_batch_stride * batch_pid\n + col_indices_stride * nnz_offset\n )\n\n # Advance mat1 to the current tiled row, ignore columns.\n mat1_block_ptrs = (\n mat1_ptr\n + mat1_batch_stride * batch_pid\n + mat1_tiled_row_stride * row_block_pid\n + mat1_row_block_stride * row_block_arange[:, None]\n )\n\n # Advance mat2 in batch and block col dimension.\n mat2_block_ptrs = (\n mat2_ptr\n + mat2_batch_stride * batch_pid\n + mat2_col_block_stride * col_block_arange[None, :]\n )\n\n k_tile_arange = tl.arange(0, TILE_K)\n for _ in range(row_nnz):\n acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)\n\n # find column block index\n col_block = tl.load(col_index_nnz_ptr)\n\n for k_tile in range(0, k, TILE_K):\n k_offsets = k_tile + k_tile_arange\n mask_k = k_offsets < k\n\n mat1_block = tl.load(\n mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],\n mask=mask_k[None, :],\n other=0.0,\n )\n\n mat2_block = tl.load(\n mat2_block_ptrs\n + mat2_tiled_col_stride * col_block\n + mat2_row_block_stride * k_offsets[:, None],\n mask=mask_k[:, None],\n other=0.0,\n )\n\n acc_block += tl.dot(\n mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype\n )\n\n if IS_BETA_ZERO:\n acc_block *= alpha\n else:\n acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)\n\n # write result\n tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))\n\n # advance val/col_index ptrs to the next block in the row.\n values_block_ptrs += values_nnz_stride\n col_index_nnz_ptr += col_indices_stride\n\n\ndef sampled_addmm(\n input: torch.Tensor,\n mat1: torch.Tensor,\n mat2: torch.Tensor,\n *,\n beta=1.0,\n alpha=1.0,\n out: Optional[torch.Tensor] = None,\n skip_checks: bool = False,\n max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,\n):\n f_name = \"sampled_addmm\"\n\n check_bsr_layout(f_name, input)\n input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)\n\n if not skip_checks:\n check_device(f_name, mat1, input.device)\n check_device(f_name, mat2, input.device)\n if beta != 0.0 and input.dtype is torch.bool:\n check(\n False,\n f\"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.\",\n )\n if input.dtype is not torch.bool:\n check_dtype(f_name, mat1, input.dtype)\n check_dtype(f_name, mat2, input.dtype)\n else:\n check_dtype(f_name, mat1, mat2.dtype)\n check_mm_compatible_shapes(f_name, mat1, mat2)\n if out is not None:\n check_bsr_layout(f_name, out)\n check_device(f_name, out, mat1.device)\n check_dtype(f_name, out, input.dtype)\n check(\n out.shape == input_broadcasted.shape and out._nnz() == input._nnz(),\n f\"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} \"\n f\"and with nnz equal to {input_broadcasted._nnz()} \"\n f\"but got out.shape = {out.shape} and out.nnz = {out._nnz()}\",\n )\n\n if out is None:\n out = input_broadcasted.to(mat1.dtype, copy=True)\n else:\n out.copy_(input_broadcasted)\n\n if out.numel() == 0 or out._nnz() == 0:\n return out\n\n blocksize = out.values().shape[-2:]\n m = mat1.size(-2)\n n = mat2.size(-1)\n k = mat1.size(-1)\n\n # NOTE: (m, 0) @ (0, n) == zeros(m, n)\n if alpha == 0.0 or k == 0:\n out.values().mul_(beta)\n return out\n\n # prepare inputs by reshaping them to be kernel-compatible\n out_backup = out\n crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)\n\n mat1 = tile_to_blocksize(mat1, (blocksize[0], k))\n mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))\n tile_k = max(*blocksize)\n\n _run_sampled_addmm_kernel(\n alpha,\n beta,\n beta == 0.0,\n blocksize,\n k,\n tile_k,\n values,\n crow_indices,\n col_indices,\n mat1,\n mat2,\n max_grid,\n )\n\n # If nnz x block strides are not the same in out_backup.values and values,\n # it means that out_backup.values and values are not the views of each other,\n # so we have to copy.\n if out_backup.values().stride()[-3:] != values.stride()[-3:]:\n out_backup.values().copy_(values.reshape(out_backup.values().shape))\n return out_backup\n\ndef _scaled_dot_product_attention(\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask: Optional[torch.Tensor],\n dropout_p: float = 0.0,\n is_causal: bool = False,\n scale: Optional[float] = None,\n):\n f_name = \"_scaled_dot_product_attention\"\n check(not is_causal, f\"{f_name}(): is_causal == True is not supported.\")\n check(attn_mask is not None, f\"{f_name}(): attn_mask == None is not supported.\")\n assert attn_mask is not None\n\n check(\n attn_mask.layout == torch.sparse_bsr,\n f\"{f_name}(): \"\n f\"attn_mask.layout must be {torch.sparse_bsr}, but got \"\n f\"attn_mask.layout == {attn_mask.layout}.\",\n )\n\n check_device(f_name, key, query.device)\n check_device(f_name, value, query.device)\n check_device(f_name, attn_mask, query.device)\n\n check_dtype(f_name, key, query.dtype)\n check_dtype(f_name, value, query.dtype)\n if attn_mask.dtype is not torch.bool:\n check_dtype(f_name, attn_mask, query.dtype)\n\n sdpa = sampled_addmm(\n attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False\n )\n if scale is None and query.size(-1) == 0 or scale == 0.0:\n check(\n False,\n f\"{f_name}(): current value of scale == {scale} \"\n \"results in division by zero.\",\n )\n scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n sdpa.values().mul_(scale_factor)\n sdpa = bsr_softmax(sdpa)\n torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)\n sdpa = bsr_dense_mm(sdpa, value)\n return sdpa\n", - "description_1": "Use triton language to implement a sparse matrix addition and multiplication kernel (_sampled_addmm_kernel) that computes on block-sparse row-major matrices, and use it for a scaled dot-product attention computation (_scaled_dot_product_attention). The kernel takes pointers to matrices, block sizes, and parameters such as alpha and beta for scaling the output, calculates the dot product in a tiled manner, and returns the accumulated results.", - "description_2": "Use triton language to perform block-sparse matrix operations including addition, multiplication, and scaled dot-product attention with support for dropout, using block-sparse row-major matrix representation.", - "difficulty": 4 - }, - { - "code": "import triton\nfrom triton import language as tl\nfrom triton.language import load, store\n\n# Kernel to add two input arrays element-wise\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Kernel with conditional optional parameter handling\n@triton.jit\ndef add_kernel_with_optional_param(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n ARGS_PASSED: \"tl.constexpr\",\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n if ARGS_PASSED == \"two\":\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n else:\n output = x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel to add two input arrays\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 128}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 64}, num_stages=4, num_warps=4),\n ],\n key=[],\n)\n@triton.jit\ndef add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n# Autotuned kernel with block pointer handling\n@triton.jit\ndef add_kernel_with_block_ptr(\n x_ptr,\n y_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n x = tl.load(\n tl.make_block_ptr(\n base=x_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n y = tl.load(\n tl.make_block_ptr(\n base=y_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n boundary_check=[0],\n )\n output = x + y\n tl.store(\n tl.make_block_ptr(\n base=output_ptr,\n shape=[n_elements],\n strides=[1],\n offsets=[block_start],\n block_shape=[BLOCK_SIZE],\n order=[0],\n ),\n output,\n boundary_check=[0],\n )\n\n# Kernel to perform element-wise addition with imported load/store\n@triton.jit\ndef add_kernel_with_import(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = load(in_ptr0 + offsets, mask=mask)\n y = load(in_ptr1 + offsets, mask=mask)\n output = x + y\n store(out_ptr + offsets, output, mask=mask)\n", - "description_1": "Use triton language to define multiple kernels that perform element-wise operations, such as addition, on input arrays using a given block size for parallel execution. These kernels take pointers to input/output data, number of elements, and block size as inputs, with some kernels supporting additional compile-time parameters to adjust their behavior or tuning.", - "description_2": "Use triton language to create parallelized kernels for element-wise operations with flexible parameterization and autotuning.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel definition\n@triton.jit\ndef kernel_example(A, B, C, D):\n # Kernel logic here\n pass\n\n# Function that calls the Triton kernel\ndef call_kernel_example(A, B, C, D):\n # Define the grid and block sizes for the Triton kernel launch\n grid = (A.size(0), )\n kernel_example[grid](A, B, C, D)\n", - "description_1": "Use triton language to define a kernel named `kernel_example` with four parameters (A, B, C, D) for executing custom logic. Utilize a function `call_kernel_example` to launch this kernel with specified grid configuration.", - "description_2": "Define a Triton kernel with four parameters and call it with a specified grid using a wrapper function.", - "difficulty": 3 - }, - { - "code": "import triton\nimport torch\n\n# Triton kernel for matrix multiplication\n@triton.jit\ndef matmul_kernel(A, B, C, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n # Triton kernel code for matrix multiplication\n pass\n\n# Function to call the Triton kernel\ndef call_matmul_kernel(A, B, C, M, N, K):\n # Call the Triton kernel with appropriate grid and block sizes\n matmul_kernel[(M, N)](A, B, C, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=32)\n\n# Example usage\nA = torch.randn(128, 128, device='cuda')\nB = torch.randn(128, 128, device='cuda')\nC = torch.empty(128, 128, device='cuda')\ncall_matmul_kernel(A, B, C, 128, 128, 128)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel with parameters A, B, C (input matrices), M, N, K (dimensions), and BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K (block sizes). The kernel performs matrix multiplication and stores the result in C. The function call_matmul_kernel sets up the grid and block sizes and invokes the kernel.", - "description_2": "Use triton language to create a matrix multiplication kernel and a function to execute it with specified input matrices and dimensions.", - "difficulty": 3 - }, - { - "code": "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M,\n N, K, bits, maxq, stride_am, stride_ak, stride_bk,\n stride_bn, stride_cm, stride_cn, stride_scales,\n stride_zeros, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8\n # times\n b_ptrs = b_ptr + (\n (offs_k[:, None] // infearure_per_bits) * stride_bk +\n offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_k\n # shifter is used to extract the N bits of each element in the 32-bit\n # word from B\n scales_ptrs = scales_ptr + offs_bn[None, :]\n zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\n\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, num_pid_k):\n g_idx = tl.load(g_ptrs)\n\n # Fetch scales and zeros; these are per-outfeature and thus reused\n # in the inner loop\n scales = tl.load(scales_ptrs + g_idx[:, None] *\n stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(\n zeros_ptrs +\n g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(\n a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit\n # values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\n g_ptrs += BLOCK_SIZE_K\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[\n None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.jit\ndef transpose_matmul_248_kernel(\n a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits,\n maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,\n stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, N) float16\n B is of shape (K//8, N) int32\n C is of shape (M, K) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N) float16\n g_ptr is of shape (K) int32\n \"\"\"\n infearure_per_bits = 32 // bits\n\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_k\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_k = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n a_ptrs = a_ptr + (\n offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak\n ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8\n # times\n b_ptrs = b_ptr + (\n (offs_bk[:, None] // infearure_per_bits) * stride_bk +\n offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n g_ptrs = g_ptr + offs_bk\n g_idx = tl.load(g_ptrs)\n\n # shifter is used to extract the N bits of each element in the 32-bit\n # word from B\n scales_ptrs = scales_ptr + offs_n[\n None, :] + g_idx[:, None] * stride_scales\n zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits\n ) + g_idx[:, None] * stride_zeros\n\n shifter = (offs_bk % infearure_per_bits) * bits\n zeros_shifter = (offs_n % infearure_per_bits) * bits\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n\n for n in range(0, num_pid_n):\n # Fetch scales and zeros; these are per-outfeature and thus reused\n # in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\n\n zeros = (zeros >> zeros_shifter[None, :]) & maxq\n zeros = (zeros + 1)\n\n a = tl.load(\n a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n\n # Now we need to unpack b (which is N-bit values) into 32-bit\n # values\n b = (b >> shifter[:, None]) & maxq # Extract the N-bit values\n b = (b - zeros) * scales # Scale and shift\n b = tl.trans(b)\n\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_N\n b_ptrs += BLOCK_SIZE_N\n scales_ptrs += BLOCK_SIZE_N\n zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)\n\n c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[\n None, :]\n c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n \"\"\"matmul248 function with matmul_248_kernel.\"\"\"\n with torch.cuda.device(input.device):\n output = torch.empty((input.shape[0], qweight.shape[1]),\n device=input.device,\n dtype=torch.float16)\n grid = lambda META: ( # noqa: E731\n triton.cdiv( # noqa: E731\n input.shape[0], META['BLOCK_SIZE_M']) * triton. # noqa: E731\n cdiv( # noqa: E731\n qweight.shape[1], META['BLOCK_SIZE_N']), ) # noqa: E731\n matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx,\n input.shape[0], qweight.shape[1],\n input.shape[1], bits, maxq, input.stride(0),\n input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0),\n output.stride(1), scales.stride(0),\n qzeros.stride(0))\n return output\n\ndef transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\n \"\"\"transpose_matmul248 function with transpose_matmul_248_kernel.\"\"\"\n with torch.cuda.device(input.device):\n output_dim = (qweight.shape[0] * 32) // bits\n output = torch.empty((input.shape[0], output_dim),\n device=input.device,\n dtype=torch.float16)\n grid = lambda META: ( # noqa: E731\n triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) # noqa: E731\n * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) # noqa: E731\n transpose_matmul_248_kernel[grid](input, qweight, output, scales,\n qzeros, g_idx, input.shape[0],\n qweight.shape[1], output_dim,\n bits, maxq, input.stride(0),\n input.stride(1), qweight.stride(0),\n qweight.stride(1), output.stride(0),\n output.stride(1), scales.stride(0),\n qzeros.stride(0))\n return output\n", - "description_1": "Use triton language to define two kernels for matrix multiplication. The first kernel, matmul_248_kernel, computes C = A x B with A shaped (M, K), B shaped (K//8, N), and C shaped (M, N). It involves multiple parameters like pointers for input and output matrices, dimensions, bit settings, strides, block sizes, and group size. The kernel processes the data in blocks and uses loops to handle different parts of the matrices. The second kernel, transpose_matmul_248_kernel, also computes C = A x B with A shaped (M, N), B shaped (K//8, N), and C shaped (M, K), sharing a similar setup as the first kernel but transposing 'b' in the process. Both kernels are used in their respective functions, matmul248 and transpose_matmul248, which prepare the output tensor, calculate the grid size, and launch the corresponding kernel.", - "description_2": "Use triton language to create kernels for specialized matrix multiplication with additional parameters for quantization, including handling input matrices in different block configurations and processing using loops, tailored for matrix shapes and device specifics.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc,\n stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta):\n TM = meta['TM']\n TN = meta['TN']\n TK = meta['TK']\n TZ = meta['TZ']\n BLOCK = meta['BLOCK']\n # Prologue\n pid0 = tl.program_id(0)\n pid1 = tl.program_id(1)\n pidz = tl.program_id(2)\n if meta['SDD']:\n pid1 = pid1 + SDD_off_width\n blockidm = tl.arange(0, TM) // BLOCK\n blockidn = tl.arange(0, TN) // BLOCK\n offlutm = blockidm * (TN // BLOCK) * 4\n offlutn = blockidn * 4\n header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4\n z = tl.load(header + 0)\n i = tl.load(header + 1 + offlutm)\n j = tl.load(header + 2 + offlutn)\n AS1 = SDD_K // TZ\n lockid = tl.where(TZ > 1, 1, 0)\n offka = pid0 * AS1\n offkb = pid0 * AS1\n offmc = 0\n offnc = 0\n offpa = 0\n offpb = 0\n maxid = TZ\n offhc = 0\n offha = z\n offhb = z\n ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)\n rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)\n else:\n header = lut + pid0 * 6\n offset = tl.load(header + 0)\n AS1 = tl.load(header + 1)\n column = tl.load(header + 2)\n depth = tl.load(header + 3)\n lockid = tl.load(header + 4)\n maxid = tl.load(header + 5)\n pinc = lut + offset\n offhc = depth\n if meta['DSD']:\n # output offset\n offnc = pid1 * TN\n offmc = column * TM\n offpc = 0\n # dense input offset\n offnb = pid1 * TN\n offkb = tl.load(pinc)\n offkb = tl.multiple_of(offkb, 8) # compiler hint\n offpb = 0\n # sparse input offset\n offma = 0\n offka = 0\n offpa = tl.load(pinc + 1)\n offpa = tl.multiple_of(offpa, 8) # compiler hint\n offpa = offpa * BLOCK * BLOCK\n offha = 0\n offhb = depth\n else:\n # output offset\n offmc = pid1 * TM\n offnc = column * TN\n offpc = 0\n # dense input offset\n offma = pid1 * TM\n offka = tl.load(pinc)\n offka = tl.multiple_of(offka, 8) # compiler hint\n offpa = 0\n # sparse input offset\n offnb = 0\n offkb = 0\n offpb = tl.load(pinc + 1)\n offpb = tl.multiple_of(offpb, 8) # compiler hint\n offpb = offpb * BLOCK * BLOCK\n offha = depth\n offhb = 0\n ram = offma + tl.arange(0, TM)\n rbn = offnb + tl.arange(0, TN)\n\n # initialize a, b pointers\n rka = offka + tl.arange(0, TK)\n rkb = offkb + tl.arange(0, TK)\n pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka\n pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb\n if meta['DDS']:\n checkam = ram[:, None] < DS0\n else:\n checkam = AS1 > 0\n if meta['DSD']:\n checkbn = rbn[None, :] < DS0\n else:\n checkbn = AS1 > 0\n a = tl.load(pa, mask=checkam, other=0.)\n b = tl.load(pb, mask=checkbn, other=0.)\n\n # Inner Loop\n acc = tl.zeros((TM, TN), dtype=tl.float32)\n for k in range(AS1, 0, -TK):\n acc += tl.dot(a, b)\n if meta['SDD']:\n inc_a = TK * stride_ka\n inc_b = TK * stride_kb\n else:\n pinc += 2\n if meta['DSD']:\n inc_b = tl.load(pinc)\n inc_a = tl.load(pinc + 1)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = inc_b * stride_kb\n if meta['DDS']:\n inc_a = tl.load(pinc)\n inc_b = tl.load(pinc + 1)\n inc_a = tl.multiple_of(inc_a, 8)\n inc_b = tl.multiple_of(inc_b, 8)\n inc_a = inc_a * stride_ka\n pa += inc_a\n pb += inc_b\n # pre-fetch\n checkak = k > TK\n checkbk = k > TK\n checka = checkam & checkak\n checkb = checkbn & checkbk\n a = tl.load(pa, mask=checka)\n b = tl.load(pb, mask=checkb)\n c = acc.to(C.dtype.element_ty)\n\n if meta['SDD']:\n checkc = True\n rr_blockidm = tl.arange(0, TM) // BLOCK\n rr_blockidn = tl.arange(0, TN) // BLOCK\n rr_offlutm = rr_blockidm * (TN // BLOCK) * 4\n rr_offlutn = rr_blockidn * 4\n off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]\n bkid = tl.load(header + off_bkid)\n offpc = bkid * BLOCK * BLOCK\n rcm = tl.arange(0, TM) % BLOCK\n rcn = tl.arange(0, TN) % BLOCK\n else:\n rcm = offmc + tl.arange(0, TM)\n rcn = offnc + tl.arange(0, TN)\n if meta['DSD']:\n checkc = rcn[None, :] < DS0\n if meta['DDS']:\n checkc = rcm[:, None] < DS0\n\n pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc\n # write-back directly\n if lockid == 0:\n tl.store(pc, c, mask=checkc)\n # accumulate partial results using spin-locks\n else:\n plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1\n pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks\n while tl.atomic_cas(plock, 0, 1) == 1:\n pass\n count = tl.load(pcount)\n if count == 0:\n tl.store(pc, c, mask=checkc)\n else:\n d = tl.load(pc, mask=checkc)\n tl.store(pc, d + c, mask=checkc)\n tl.atomic_xchg(pcount, (count + 1) % maxid)\n tl.atomic_xchg(plock, 0)\n\n\nclass _sparse_matmul(torch.autograd.Function):\n\n sdd_cache = dict()\n dsd_cache = dict()\n dds_cache = dict()\n locks = dict()\n\n @staticmethod\n def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time):\n if trans_c:\n a, b = b, a\n trans_a, trans_b = not trans_b, not trans_a\n AS0 = a.size(0)\n a_dim = -2 if trans_a else -1\n b_dim = -1 if trans_b else -2\n a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]\n if a_inner != b_inner:\n raise ValueError(f\"Size of tensor A along the {a_dim} dim ({a_inner}) must match size \"\n f\"of tensor B along the {b_dim} dim ({b_inner})\")\n if a_inner % 16 != 0:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n\n batch_size = a.size(0)\n a_outer = a.size(3 if trans_a else 2)\n dtype = a.dtype\n is_16_multiple = a_inner % 16 == 0\n is_32_multiple = a_inner % 32 == 0\n is_64_multiple = a_inner % 64 == 0\n if not is_16_multiple:\n raise ValueError('Reduction size for SDD must be a multiple of 16')\n device = a.device\n total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])\n c = torch.empty((batch_size, total_width, block, block), dtype=dtype, device=a.device)\n for lut, width, pack in zip(luts, widths, packs):\n F32TK = [8, 16]\n F16TK = [16]\n F16TK += [32] if is_32_multiple else []\n F16TK += [64] if is_64_multiple else []\n TK = {torch.float32: F32TK, torch.float16: F16TK}[dtype]\n num_lock = 1\n meta = {\n 'TM': block * pack,\n 'TN': block * pack,\n 'BLOCK': block,\n 'TK': TK[0],\n 'TZ': 1,\n 'SDD': True,\n 'DSD': False,\n 'DDS': False\n }\n locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device)\n max_width = 49152\n total = 0 if bench else None\n for off_width in range(0, width, max_width):\n grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]\n _kernel[grid](a,\n b,\n c,\n a.stride(0),\n a.stride(1),\n a.stride(3 if trans_a else 2),\n a.stride(2 if trans_a else 3),\n b.stride(0),\n b.stride(1),\n b.stride(3 if trans_b else 2),\n b.stride(2 if trans_b else 3),\n c.stride(0),\n c.stride(0),\n c.stride(2),\n c.stride(3),\n a_outer,\n a_outer,\n a_inner,\n off_width,\n lut,\n locks,\n num_lock,\n num_warps=4,\n **meta)\n return c\n\n fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _sdd_matmul.__get__(object), 'dds': _sdd_matmul.__get__(object)}\n\n @staticmethod\n def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs,\n c_bench, c_time, da_lut, da_num_locks, da_width, da_packs, da_bench, da_time, db_lut, db_num_locks,\n db_width, db_packs, db_bench, db_time):\n c = _sparse_matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width,\n c_packs, c_bench, c_time)\n ctx.save_for_backward(a, b)\n ctx.da_num_locks = da_num_locks\n ctx.da_lut = da_lut\n ctx.da_width = da_width\n ctx.da_packs = da_packs\n ctx.da_bench = da_bench\n ctx.da_time = da_time\n ctx.db_lut = db_lut\n ctx.db_num_locks = db_num_locks\n ctx.db_width = db_width\n ctx.db_bench = db_bench\n ctx.db_packs = db_packs\n ctx.db_time = db_time\n ctx.mode = mode\n ctx.spdims = spdims\n ctx.block = block\n ctx.trans_a = trans_a\n ctx.trans_b = trans_b\n return c\n\n @staticmethod\n def backward(ctx, dc):\n a, b = ctx.saved_tensors\n mode = ctx.mode\n if ctx.needs_input_grad[0]:\n mode_da = mode[1] + mode[0] + mode[2]\n da = _sparse_matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,\n ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs, ctx.da_bench,\n ctx.da_time)\n if ctx.needs_input_grad[1]:\n mode_db = mode[2] + mode[1] + mode[0]\n db = _sparse_matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,\n ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs, ctx.db_bench,\n ctx.db_time)\n return da, db, None, None, None,\\\n None, None, None, None,\\\n None, None, None, None, None, None,\\\n None, None, None, None, None, None,\\\n None, None, None, None, None, None\n\n\nclass MatMul:\n def __init__(self, layout, block, mode, trans_a=False, trans_b=False, bench=False):\n if mode not in ['sdd', 'dsd', 'dds']:\n raise NotImplementedError('Supported modes are: sdd, dsd, dds')\n self.lut_cache = dict()\n self.trans_a = trans_a\n self.trans_b = trans_b\n self.mode = mode\n self.block = block\n self.layout = layout\n layout_dim = layout.ndim\n assert layout_dim in (2, 3), \"Layout should be a 2 or 3 dimensional tensor of 0s and 1s\"\n if not mode == 'sdd':\n trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b,\n -2)\n self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner\n sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)\n self.dense_inner_size = layout.shape[sparse_inner] * block\n self.sparse_shape = (layout.sum().item(), block, block)\n if layout_dim == 2:\n layout = layout.unsqueeze(0)\n layout = layout.long()\n self.spdims = layout.shape\n self.bench = bench\n self.time_c = None\n self.time_da = None\n self.time_db = None\n\n def make_lut(self, dtype, device):\n key = (dtype, device)\n if key in self.lut_cache:\n return self.lut_cache[key]\n layout, block = self.layout, self.block\n step = 16\n if self.mode == 'sdd':\n c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)\n elif self.mode == 'dsd':\n c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a,\n device)\n elif self.mode == 'dds':\n c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b,\n device)\n if self.mode == 'sdd':\n da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device)\n elif self.mode == 'dsd':\n da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)\n elif self.mode == 'dds':\n da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step,\n not self.trans_b, device)\n if self.mode == 'sdd':\n db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device)\n elif self.mode == 'dsd':\n db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a,\n device)\n elif self.mode == 'dds':\n db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)\n self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\\\n da_lut, da_num_locks, da_width, da_packs,\\\n db_lut, db_num_locks, db_width, db_packs)\n return self.lut_cache[key]\n\n @staticmethod\n def _pad_shape(x, is_sparse):\n max_dim = 3 if is_sparse else 4\n for i in range(max_dim - x.dim()):\n x = x.unsqueeze(0)\n return x\n\n def _validate_inputs(self, a, b):\n if a.device != b.device:\n raise ValueError(f\"Inputs must be on the same device; got {a.device} for tensor A \"\n f\"and {b.device} for tensor B\")\n if not get_accelerator().on_accelerator(a):\n raise ValueError(\"Only GPU devices are supported for now\")\n if torch.is_autocast_enabled():\n a, b = a.half(), b.half()\n elif a.dtype != b.dtype:\n raise ValueError(f\"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B\")\n mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b\n if mode != 'sdd':\n dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')\n dense_inner = dense.shape[self.dense_inner_dim]\n if dense_inner != self.dense_inner_size:\n raise ValueError(f\"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim \"\n f\"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.\")\n if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:\n raise ValueError(f\"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument \"\n f\"{sparse_name}, got {sparse.shape}\")\n\n def add_extra_dims(x):\n dims_needed = 4 - x.ndim\n if dims_needed > 0:\n singletons = [1] * dims_needed\n x = x.view(*singletons, *x.shape)\n elif dims_needed < 0:\n raise ValueError(\"Tensors with more than 4 dimensions are not currently supported\")\n return x\n\n a = add_extra_dims(a)\n b = add_extra_dims(b)\n return a, b\n\n def __call__(self, a, b):\n c_lut, c_num_locks, c_width, c_packs,\\\n da_lut, da_num_locks, da_width, da_packs,\\\n db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)\n time_c = [None]\n time_da = [None]\n time_db = [None]\n original_dims = max(a.ndim, b.ndim)\n a, b = self._validate_inputs(a, b)\n a = MatMul._pad_shape(a, self.mode == 'dsd')\n b = MatMul._pad_shape(b, self.mode == 'dds')\n c = _sparse_matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,\n c_num_locks, c_width, c_packs, self.bench, time_c, da_lut, da_num_locks, da_width,\n da_packs, self.bench, time_da, db_lut, db_num_locks, db_width, db_packs, self.bench,\n time_db)\n dims_to_trim = c.ndim - original_dims\n for _ in range(dims_to_trim):\n c = c.squeeze(0)\n self.time_c = time_c[0]\n self.time_da = time_da[0]\n self.time_db = time_db[0]\n return c\n", - "description_1": "Use triton language to define a kernel for block sparse matrix multiplication with parameters such as tensors A, B, and C, along with their strides and metadata for controlling matrix dimensions and operations. Implement the functionality using a main API class _sparse_matmul with methods for forward and backward passes.", - "description_2": "Use triton language to create a block sparse matrix multiplication operator with kernels for various sparsity configurations, handling operations like sparse-dense-dense multiplication.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n\ndef next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef num_warps(n):\n if n < 512:\n return 4\n if n < 2048:\n return 8\n return 16\n\n\n@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})\n@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})\n@triton.jit\ndef _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm,\n stride_zattnm, **meta):\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from LUT\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # block id and column id\n blockid = tl.load(LUT + offset + rbmn * 4 + 0)\n columnid = tl.load(LUT + offset + rbmn * 4 + 1)\n rowid = tl.load(LUT + offset + rbmn * 4 + 2)\n headid = tl.load(LUT + offset + rbmn * 4 + 3)\n # pointers to X\n px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n x = tl.load(px, mask=check, other=-float('inf'))\n x = x.to(tl.float32)\n # apply scale\n if meta['APPLY_SCALE']:\n x = x * scale\n # apply RPE\n if meta['APPLY_RPE']:\n prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn\n rpe = tl.load(prpe, mask=check, other=0)\n x = x + rpe\n # apply key-padding mask\n if meta['APPLY_KP_MASK']:\n pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn\n kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))\n if meta['KP_MASK_MUL']:\n kp_m = tl.where(kp_m == 0, -float('inf'), 0.)\n x = x + kp_m\n # apply attention mask\n if meta['APPLY_ATTN_MASK']:\n pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn\n attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))\n if meta['ATTN_MASK_MUL']:\n attn_m = tl.where(attn_m == 0, -float('inf'), 0.)\n x = x + attn_m\n # computation\n x = tl.softmax(x)\n tl.store(px, x, mask=check)\n\n\n@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})\n@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})\n@triton.jit\ndef _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):\n pidhm = tl.program_id(0)\n pidz = tl.program_id(1)\n TN = meta['TN']\n BLOCK = meta['BLOCK']\n # create index ranges\n rxm = pidhm % BLOCK\n rbm = pidhm // BLOCK\n rxn = tl.arange(0, TN) % BLOCK\n rbn = tl.arange(0, TN) // BLOCK\n # extract information from look-up table\n header = LUT + rbm * 2\n size = tl.load(header + 0)\n offset = tl.load(header + 1)\n # bounds checking on lut\n check = rbn < size\n rbmn = tl.where(check, rbn, size - 1)\n # initialize pointers to block-sparse input\n blockid = tl.load(LUT + offset + rbmn * 4)\n X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn\n # compute fused softmax backward\n x = tl.load(X, mask=check, other=0)\n dx = tl.load(DX, mask=check, other=0)\n x = x.to(tl.float32)\n dx = dx.to(tl.float32)\n y = x * (dx - tl.sum(x * dx, 0)) * scale\n tl.store(DX, y, mask=check)\n\n\nclass _sparse_softmax(torch.autograd.Function):\n\n bwd_kernels = dict()\n\n @staticmethod\n def make_lut(layout, block, device):\n _empty = torch.tensor([], dtype=torch.int64, device=layout.device)\n sizes = _empty.clone()\n # sizes along rows\n for h in range(layout.shape[0]):\n sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))\n # offsets in block format\n offsets = torch.zeros_like(sizes)\n offsets[1:] = torch.cumsum(sizes[:-1], dim=0)\n # block indices\n idx = torch.arange(layout.sum())\n head = layout.nonzero()[:, 0]\n rows = layout.nonzero()[:, 1]\n columns = layout.nonzero()[:, 2]\n core = torch.stack((idx, columns, rows, head), dim=1).view(-1)\n # construct look-up table\n offsets = offsets * 4 + 2 * sizes.numel()\n header = torch.stack((sizes, offsets), dim=1).view(-1)\n lut = torch.cat((header, core)).type(torch.int32).to(device)\n return lut, int(sizes.max())\n\n @staticmethod\n def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,\n num_blocks, maxlut, bench, time):\n\n apply_scale = False if scale == 1.0 else True\n\n # handle None rpe\n if rpe is None:\n apply_rpe = False\n stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0\n rpe = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_rpe = True\n stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)\n\n # handle None key_padding_mask\n if key_padding_mask is None:\n apply_kp_mask = False\n stride_zkpm = 0\n key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_kp_mask = True\n stride_zkpm = key_padding_mask.stride(0)\n\n # handle None attention_mask\n if attn_mask is None:\n apply_attn_mask = False\n stride_zattnm = 0\n attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)\n else:\n apply_attn_mask = True\n stride_zattnm = attn_mask.stride(0)\n\n # run kernel\n M = x.shape[0]\n meta = {\n 'BLOCK': block,\n 'APPLY_SCALE': apply_scale,\n 'APPLY_RPE': apply_rpe,\n 'APPLY_KP_MASK': apply_kp_mask,\n 'APPLY_ATTN_MASK': apply_attn_mask,\n 'KP_MASK_MUL': kp_mask_mode == 'mul',\n 'ATTN_MASK_MUL': attn_mask_mode == 'mul',\n }\n grid = lambda opt: [spdims[0] * spdims[1] * block, M]\n _forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, maxlut, x.stride(0),\\\n stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta)\n\n # save to context\n ctx.mark_dirty(x)\n ctx.save_for_backward(x, lut)\n ctx.spdims = spdims\n ctx.block = block\n ctx.maxlut = maxlut\n ctx.scale = scale\n ctx.apply_scale = apply_scale\n ctx.apply_rpe = apply_rpe\n ctx.apply_kp_mask = apply_kp_mask\n ctx.apply_attn_mask = apply_attn_mask\n ctx.kp_mask_mode = kp_mask_mode\n ctx.attn_mask_mode = attn_mask_mode\n return x\n\n @staticmethod\n def backward(ctx, dx):\n\n # retrieve from context\n x, lut = ctx.saved_tensors\n # run kernel\n M = x.shape[0]\n grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]\n _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)\n return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n", - "description_1": "Use triton language to implement block-sparse softmax forward and backward operations. The forward operation (_forward) requires 13 parameters: X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm. The backward operation (_backward) requires 7 parameters: X, scale, DX, LUT, sizemax, stride_zx, stride_zdx. Both operations use meta-parameters for controlling behavior related to block sizes, applying scale and masks.", - "description_2": "Use triton language to perform block-sparse softmax computations with configurable parameters for scale, masks, and block sizes, using separate kernels for forward and backward passes.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n TMP,\n Out,\n stride_qz,\n stride_qh,\n stride_qm,\n stride_qk,\n stride_kz,\n stride_kh,\n stride_kn,\n stride_kk,\n stride_vz,\n stride_vh,\n stride_vk,\n stride_vn,\n stride_oz,\n stride_oh,\n stride_om,\n stride_on,\n Z,\n H,\n N_CTX,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n q = tl.load(q_ptrs)\n for start_n in range(0, N_CTX, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(k_ptrs + start_n * stride_kn)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs)\n acc = acc * acc_scale[:, None]\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n l_i = l_i_new\n m_i = m_i_new\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\nclass triton_flash_attn(torch.nn.Module):\n\n def __init__(self, ):\n super(triton_flash_attn, self).__init__()\n\n def forward(self, q, k, v, sm_scale, block_128=True):\n BLOCK = 128 if block_128 else 64\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n tmp,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n q.stride(3),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n o.stride(3),\n k.shape[0],\n k.shape[1],\n k.shape[2],\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=1,\n )\n return o\n", - "description_1": "Use triton language to implement a forward kernel for attention mechanism. This kernel computes the scaled dot-product attention using input tensors Q (queries), K (keys), V (values), and an attention scaling factor (sm_scale). The kernel handles multi-dimensional tensor strides for Q, K, V, and outputs the result into an output tensor 'Out'. The kernel also utilizes temporary storage 'TMP' and employs blocking strategies with parameters BLOCK_M, BLOCK_N, and BLOCK_DMODEL to efficiently perform matrix multiplications and updates. The kernel accommodates variable context sizes (N_CTX) and multi-head attention by processing separate heads and contexts in a parallelized manner.", - "description_2": "Use triton language to perform scaled dot-product attention with blocking optimization in a forward pass.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K', 'dtype_id', 'allow_tf32']\n)\n@triton.jit\ndef cvmm_kernel(\n a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bo, stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_index, stride_sel, stride_out_index,\n out_index_is_none: tl.constexpr,\n dtype_id: tl.constexpr, allow_tf32: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n pid_m = first_pid_m + (pid % group_size_m)\n\n sel_first = tl.load(sel_ptr + pid_m * BLOCK_SIZE_M * stride_sel)\n sel_last = tl.load(sel_ptr + (min((pid_m + 1) * BLOCK_SIZE_M, M) - 1) * stride_sel)\n sel_all = tl.load(sel_ptr + stride_sel * ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M))\n\n for matrix_id in range(sel_first, sel_last + 1):\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n\n remap_offs_am = tl.load(index_ptr + stride_index * offs_am)\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (remap_offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + matrix_id * stride_bo + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n\n if dtype_id == 1:\n a = a.to(tl.float16)\n b = b.to(tl.float16)\n elif dtype_id == 2:\n a = a.to(tl.bfloat16)\n b = b.to(tl.bfloat16)\n\n accumulator += tl.dot(a, b, allow_tf32=allow_tf32)\n\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if dtype_id == 1:\n c = accumulator.to(tl.float16)\n elif dtype_id == 2:\n c = accumulator.to(tl.bfloat16)\n else:\n c = accumulator\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n if out_index_is_none:\n remap_offs_cm = remap_offs_am\n else:\n remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = ((offs_cm[:, None] < M) & (sel_all[:, None] == matrix_id)) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\ndef cvmm_triton(\n x: torch.Tensor,\n sel_index: torch.Tensor,\n sel: torch.Tensor,\n keys: torch.Tensor,\n out_dtype: torch.dtype,\n out_index: torch.Tensor\n):\n x = x.flatten(end_dim=-2)\n assert x.shape[-1] == keys.shape[1]\n\n sel_shape = sel.shape\n sel = sel.flatten()\n\n M = sel.shape[0]\n O, K, N = keys.shape\n out = torch.empty((M, N), device=x.device, dtype=out_dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n\n out_index_is_none = False\n if out_index.numel() == 1 and out_index == -1:\n out_index_is_none = True\n\n cvmm_kernel[grid](\n x, keys, out, sel_index, sel, out_index,\n M, N, K,\n x.stride(0), x.stride(1),\n keys.stride(0), keys.stride(1), keys.stride(2),\n out.stride(0), out.stride(1),\n sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),\n out_index_is_none=out_index_is_none,\n dtype_id = dtype_to_type_id(out.dtype), allow_tf32=False,\n )\n\n return out.view(*sel_shape, N)\n", - "description_1": "Use triton language to implement a matrix multiplication kernel (cvmm_kernel) that computes C = A x B, where A has shape (M, K), B has shape (K, N), and C has shape (M, N). The kernel uses block sizes for M, N, and K dimensions and supports different data types. The cvmm_triton function prepares the input tensors and launches the kernel with appropriate grid dimensions.", - "description_2": "Use triton language to create a matrix multiplication kernel with configurable block sizes and data types, and a function to launch this kernel with given input tensors.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K', 'dtype_id', 'allow_tf32']\n)\n@triton.jit\ndef cvmm_kernel(\n # Pointers to matrices\n a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,\n # Matrix dimensions\n M, N, K,\n stride_am, stride_ak,\n stride_bo, stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_index, stride_sel, stride_out_index,\n out_index_is_none: tl.constexpr,\n dtype_id: tl.constexpr, allow_tf32: tl.constexpr,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n pid_m = first_pid_m + (pid % group_size_m)\n\n sel_first = tl.load(sel_ptr + pid_m * BLOCK_SIZE_M * stride_sel)\n sel_last = tl.load(sel_ptr + (min((pid_m + 1) * BLOCK_SIZE_M, M) - 1) * stride_sel)\n sel_all = tl.load(sel_ptr + stride_sel * ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M))\n\n for matrix_id in range(sel_first, sel_last + 1):\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n\n remap_offs_am = tl.load(index_ptr + stride_index * offs_am)\n\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (remap_offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + matrix_id * stride_bo + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n\n if dtype_id == 1:\n a = a.to(tl.float16)\n b = b.to(tl.float16)\n elif dtype_id == 2:\n a = a.to(tl.bfloat16)\n b = b.to(tl.bfloat16)\n\n accumulator += tl.dot(a, b, allow_tf32=allow_tf32)\n\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if dtype_id == 1:\n c = accumulator.to(tl.float16)\n elif dtype_id == 2:\n c = accumulator.to(tl.bfloat16)\n else:\n c = accumulator\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n if out_index_is_none:\n remap_offs_cm = remap_offs_am\n else:\n remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = ((offs_cm[:, None] < M) & (sel_all[:, None] == matrix_id)) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\nif version.parse(torch.__version__) >= version.parse(\"2.2.0\"):\n torch.library.define(\"mylib::cvmm_triton\", \"(Tensor x, Tensor sel_index, Tensor sel, Tensor keys, ScalarType out_dtype, Tensor out_index) -> Tensor\")\n lib_decorator = torch.library.impl(\"mylib::cvmm_triton\", \"default\")\nelse:\n lib_decorator = lambda x: x\n\n@lib_decorator\ndef cvmm_triton(\n x: torch.Tensor,\n sel_index: torch.Tensor,\n sel: torch.Tensor,\n keys: torch.Tensor,\n out_dtype: torch.dtype,\n out_index: torch.Tensor\n):\n x = x.flatten(end_dim=-2)\n assert x.shape[-1] == keys.shape[1]\n\n sel_shape = sel.shape\n sel = sel.flatten()\n\n M = sel.shape[0]\n O, K, N = keys.shape\n out = torch.empty((M, N), device=x.device, dtype=out_dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n\n out_index_is_none = False\n if out_index.numel() == 1 and out_index == -1:\n out_index_is_none = True\n\n cvmm_kernel[grid](\n x, keys, out, sel_index, sel, out_index,\n M, N, K,\n x.stride(0), x.stride(1),\n keys.stride(0), keys.stride(1), keys.stride(2),\n out.stride(0), out.stride(1),\n sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),\n out_index_is_none=out_index_is_none,\n dtype_id = dtype_to_type_id(out.dtype), allow_tf32=False,\n )\n\n return out.view(*sel_shape, N)\n\ndef dtype_to_type_id(dtype: torch.dtype):\n if dtype == torch.float32:\n return 0\n elif dtype == torch.float16:\n return 1\n elif dtype == torch.bfloat16:\n return 2\n raise ValueError(\"Unknown dtype\")\n", - "description_1": "Use triton language to implement a kernel cvmm_kernel with 26 parameters for matrix multiplication C = A x B, where A has shape (M, K), B has shape (K, N), and C has shape (M, N). Parameters include pointers to matrices, matrix dimensions, strides, and various constexpr meta-parameters for block sizes and grouping. The kernel features L2 cache optimizations and pointer arithmetic for efficient computation, as well as accumulation in fp32 for higher accuracy and conditional logic for dtype conversions.", - "description_2": "Use triton language to implement a callable function cvmm_triton with 6 parameters to invoke the cvmm_kernel for efficient matrix multiplication with mixed precision, capable of handling flattened input matrices and supporting various output data types, leveraging Triton configurations for autotuning.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n ],\n key=['M', 'N', 'K', 'dtype_id', 'allow_tf32']\n)\n@triton.jit\ndef cvmm_kernel(\n a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bo, stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_index, stride_sel, stride_out_index,\n out_index_is_none: tl.constexpr,\n dtype_id: tl.constexpr, allow_tf32: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_n = (pid % num_pid_in_group) // group_size_m\n pid_m = first_pid_m + (pid % group_size_m)\n\n sel_first = tl.load(sel_ptr + pid_m * BLOCK_SIZE_M * stride_sel)\n sel_last = tl.load(sel_ptr + (min((pid_m + 1) * BLOCK_SIZE_M, M) - 1) * stride_sel)\n sel_all = tl.load(sel_ptr + stride_sel * ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M))\n\n for matrix_id in range(sel_first, sel_last + 1):\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n remap_offs_am = tl.load(index_ptr + stride_index * offs_am)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (remap_offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + matrix_id * stride_bo + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n if dtype_id == 1:\n a = a.to(tl.float16)\n b = b.to(tl.float16)\n elif dtype_id == 2:\n a = a.to(tl.bfloat16)\n b = b.to(tl.bfloat16)\n\n accumulator += tl.dot(a, b, allow_tf32=allow_tf32)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if dtype_id == 1:\n c = accumulator.to(tl.float16)\n elif dtype_id == 2:\n c = accumulator.to(tl.bfloat16)\n else:\n c = accumulator\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n if out_index_is_none:\n remap_offs_cm = remap_offs_am\n else:\n remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)\n\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = ((offs_cm[:, None] < M) & (sel_all[:, None] == matrix_id)) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\nif version.parse(torch.__version__) >= version.parse(\"2.2.0\"):\n torch.library.define(\"mylib::cvmm_triton\", \"(Tensor x, Tensor sel_index, Tensor sel, Tensor keys, ScalarType out_dtype, Tensor out_index) -> Tensor\")\n lib_decorator = torch.library.impl(\"mylib::cvmm_triton\", \"default\")\nelse:\n lib_decorator = lambda x: x\n\n@lib_decorator\ndef cvmm_triton(\n x: torch.Tensor,\n sel_index: torch.Tensor,\n sel: torch.Tensor,\n keys: torch.Tensor,\n out_dtype: torch.dtype,\n out_index: torch.Tensor\n):\n x = x.flatten(end_dim=-2)\n assert x.shape[-1] == keys.shape[1]\n\n sel_shape = sel.shape\n sel = sel.flatten()\n\n M = sel.shape[0]\n O, K, N = keys.shape\n out = torch.empty((M, N), device=x.device, dtype=out_dtype)\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n\n out_index_is_none = False\n if out_index.numel() == 1 and out_index == -1:\n out_index_is_none = True\n\n cvmm_kernel[grid](\n x, keys, out, sel_index, sel, out_index,\n M, N, K,\n x.stride(0), x.stride(1),\n keys.stride(0), keys.stride(1), keys.stride(2),\n out.stride(0), out.stride(1),\n sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),\n out_index_is_none=out_index_is_none,\n dtype_id = dtype_to_type_id(out.dtype), allow_tf32=False,\n )\n\n return out.view(*sel_shape, N)\n\n\nif version.parse(torch.__version__) >= version.parse(\"2.2.0\"):\n cvmm_triton_call = torch.ops.mylib.cvmm_triton\nelse:\n cvmm_triton_call = cvmm_triton\n", - "description_1": "Use triton language to implement matrix multiplication kernel `cvmm_kernel` which computes C = A x B, where A has dimensions (M, K), B has dimensions (K, N), and C has dimensions (M, N). The function has 25 parameters: 6 pointers to matrices, 3 matrix dimension parameters, 10 stride parameters for accessing elements in matrices, 3 constexpr parameters indicating index state and data types, and 4 meta-parameters defining block and group sizes for optimization purposes.", - "description_2": "Use triton language to create a callable function `cvmm_triton` that uses the triton kernel `cvmm_kernel` for efficient matrix multiplication on GPU. This function prepares and flattens input matrices, sets up the execution grid for the kernel, and manages output reshaping. It has 6 parameters: 4 tensor inputs including two selection indices and a set of keys, an output data type, and an optional output index for reshaping.", - "difficulty": 3 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _softmax_fwd_kernel(\n output_ptr,\n stride_output_row,\n input_ptr,\n stride_input_row,\n num_cols,\n block_size: tl.constexpr,\n):\n # setup input ptrs\n row_index = tl.program_id(0)\n\n row_start_ptr = input_ptr + (row_index * stride_input_row)\n col_offsets = tl.arange(0, block_size)\n input_pointers = row_start_ptr + col_offsets\n\n row_mask = col_offsets < num_cols\n\n # move to SRAM\n row = tl.load(input_pointers, mask=row_mask, other=float(\"-inf\"))\n\n # softmax itself\n safe_row = row - tl.max(row, axis=0)\n numerator = tl.exp(safe_row)\n denominator = tl.sum(numerator, axis=0)\n sm_out = numerator / denominator\n\n # write back to HBM\n output_row_ptr = output_ptr + (row_index * stride_output_row)\n output_pointers = output_row_ptr + col_offsets\n tl.store(output_pointers, sm_out, mask=row_mask)\n\ndef softmax(x: torch.Tensor) -> torch.Tensor:\n \"\"\" Triton impl of Softmax, fwd pass only \"\"\"\n rows, cols = x.shape\n assert x.dim() == 2, f\"only accepts 2D tensors for now\"\n block_size = triton.next_power_of_2(cols)\n num_warps = 4 # *32 \n if block_size > 2047: # 2048\n num_warps = 8\n if block_size > 4095: # 4096\n num_warps = 16\n \n grid = (rows,)\n\n # allocate our output buffer\n sm_out = torch.empty_like(x)\n\n _softmax_fwd_kernel[grid](\n sm_out,\n sm_out.stride(0),\n x,\n x.stride(0),\n cols,\n block_size=block_size,\n num_warps=num_warps\n )\n\n return sm_out\n", - "description_1": "Use triton language to implement a softmax function. The kernel '_softmax_fwd_kernel' takes 6 parameters: output_ptr (pointer to output tensor), stride_output_row (stride of output tensor rows), input_ptr (pointer to input tensor), stride_input_row (stride of input tensor rows), num_cols (number of columns in the input tensor), and block_size (block size for processing). The 'softmax' function takes a single parameter 'x' (a 2D torch tensor) and computes the softmax using the Triton kernel.", - "description_2": "Use triton language to create a softmax operation for 2D tensors, utilizing a kernel to handle row-wise computation with configurable block size and warps.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, #\n K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #\n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #\n N_CTX: tl.constexpr, fp8_v: tl.constexpr):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n else:\n lo, hi = 0, N_CTX\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k = tl.load(K_block_ptr)\n qk = tl.dot(q, k)\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n qk = qk * qk_scale - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_block_ptr)\n if fp8_v:\n p = p.to(tl.float8e5)\n else:\n p = p.to(tl.float16)\n acc = tl.dot(p, v, acc)\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n return acc, l_i, m_i\n\n@triton.autotune(list(filter(lambda conf: conf.kwargs[\"BLOCK_M\"] * conf.kwargs[\"BLOCK_N\"] >= 128 * 128 or conf.num_warps != 8, [\n triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \\\n for BM in [64, 128]\\\n for BN in [32, 64]\\\n for s in ([1] if triton.runtime.driver.active.get_current_target().backend == \"hip\" else [3, 4, 7])\\\n for w in [4, 8]\\\n])), key=[\"N_CTX\", \"HEAD_DIM\"])\n@triton.jit\ndef _attn_fwd(Q, K, V, sm_scale, M, Out, #\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, N_CTX, #\n HEAD_DIM: tl.constexpr, #\n BLOCK_M: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n STAGE: tl.constexpr #\n ):\n tl.static_assert(BLOCK_N <= HEAD_DIM)\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, HEAD_DIM),\n order=(1, 0),\n )\n v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)\n V_block_ptr = tl.make_block_ptr(\n base=V + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, HEAD_DIM),\n order=v_order,\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qvk_offset,\n shape=(HEAD_DIM, N_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(HEAD_DIM, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + qvk_offset,\n shape=(N_CTX, HEAD_DIM),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, HEAD_DIM),\n order=(1, 0),\n )\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)\n qk_scale = sm_scale\n qk_scale *= 1.44269504 # 1/log(2)\n q = tl.load(Q_block_ptr)\n if STAGE & 1:\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M, HEAD_DIM, BLOCK_N, #\n 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #\n )\n if STAGE & 2:\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, #\n BLOCK_M, HEAD_DIM, BLOCK_N, #\n 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #\n )\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, causal, sm_scale):\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n assert HEAD_DIM_K in {16, 32, 64}\n o = torch.empty_like(q)\n stage = 3 if causal else 1\n extra_kern_args = {}\n if triton.runtime.driver.active.get_current_target().backend == \"hip\":\n waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2\n extra_kern_args = {\"waves_per_eu\": waves_per_eu, \"allow_flush_denorm\": True}\n\n grid = lambda args: (triton.cdiv(q.shape[2], args[\"BLOCK_M\"]), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, sm_scale, M, o, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], #\n N_CTX=q.shape[2], #\n HEAD_DIM=HEAD_DIM_K, #\n STAGE=stage, #\n **extra_kern_args)\n\n ctx.save_for_backward(q, k, v, o, M)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.HEAD_DIM = HEAD_DIM_K\n ctx.causal = causal\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, M = ctx.saved_tensors\n assert do.is_contiguous()\n assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n BATCH, N_HEAD, N_CTX = q.shape[:3]\n PRE_BLOCK = 128\n NUM_WARPS, NUM_STAGES = 4, 5\n BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32\n BLK_SLICE_FACTOR = 2\n RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)\n arg_k = k\n arg_k = arg_k * (ctx.sm_scale * RCP_LN2)\n PRE_BLOCK = 128\n assert N_CTX % PRE_BLOCK == 0\n pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)\n delta = torch.empty_like(M)\n _attn_bwd_preprocess[pre_grid](\n o, do, #\n delta, #\n BATCH, N_HEAD, N_CTX, #\n BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #\n )\n grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)\n _attn_bwd[grid](\n q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #\n M, delta, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n N_HEAD, N_CTX, #\n BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #\n BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #\n BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #\n HEAD_DIM=ctx.HEAD_DIM, #\n num_warps=NUM_WARPS, #\n num_stages=NUM_STAGES #\n )\n\n return dq, dk, dv, None, None\n\nattention = _attention.apply\n", - "description_1": "Use triton language to implement a fused attention mechanism with forward and backward passes. The forward pass (_attn_fwd) computes the attention output given query (Q), key (K), and value (V) tensors, along with scaling and other parameters. The backward pass (_attn_bwd) computes gradients for Q, K, and V given the gradient of the output. The kernels are optimized for different block sizes and stages, and the function is wrapped in a PyTorch autograd function for easy integration.", - "description_2": "Use triton language to create a fused attention operator with both forward and backward computation capabilities, optimized for performance with configurable block sizes and stages.", - "difficulty": 4 - }, - { - "code": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_apply_penalty(\n Logits, presence_penalty, freqency_penalty,\n p_token_ids, p_token_counts, p_cumsum_seq_len, \n stride_logit_b, stride_logit_s,\n BLOCK_P: tl.constexpr\n):\n # Get the current batch index\n cur_batch = tl.program_id(0)\n # Load frequency and presence penalties for the current batch\n cur_freqency = tl.load(freqency_penalty + cur_batch)\n cur_presence = tl.load(presence_penalty + cur_batch)\n\n # Calculate start and end indices for the current batch\n cur_batch_start_index = tl.load(p_cumsum_seq_len + cur_batch)\n cur_batch_end_index = tl.load(p_cumsum_seq_len + cur_batch + 1)\n\n # Calculate offsets for token IDs and counts\n cur_batch_id_offset = cur_batch_start_index + tl.arange(0, BLOCK_P)\n batch_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset