|
| 1 | +--- |
| 2 | +name: atomr-accel-cutlass |
| 3 | +description: Use when wiring or extending CUTLASS kernel templates through `atomr-accel-cutlass` — the `CutlassActor`, `GemmRequest<T>` / `GroupedGemmRequest<T>` / `ConvFwdRequest<T>` / `Dgrad` / `Wgrad`, the EVT (epilogue visitor tree) emitter, the `(template, shape, dtype, arch)` plan cache, and the Strategy A (NVRTC at runtime) vs Strategy B (`cutlass-prebuilt`, nvcc at build time) compilation choice. Triggers on adding a CUTLASS template, picking arch×dtype, hitting a plan-cache miss, choosing fp8 vs fp4, or fitting an EVT chain. |
| 4 | +--- |
| 5 | + |
| 6 | +# CUTLASS templates |
| 7 | + |
| 8 | +This skill covers the Phase 6 sibling crate. Enable the `cutlass` |
| 9 | +feature on `atomr-accel-cuda` and `CutlassActor` becomes available |
| 10 | +alongside the other kernel actors. For the per-library kernel |
| 11 | +actor pattern see [`atomr-accel-kernels`](../atomr-accel-kernels/SKILL.md); |
| 12 | +for portable trait surface considerations see |
| 13 | +[`atomr-accel-backends`](../atomr-accel-backends/SKILL.md). |
| 14 | + |
| 15 | +## Compilation strategies |
| 16 | + |
| 17 | +| Strategy | When | Trade-off | |
| 18 | +|---|---|---| |
| 19 | +| **A — NVRTC at runtime** (default) | First call to a new `(template, shape, dtype, arch)` triggers an NVRTC compile, then the cubin is cached on disk via the Phase 0.6 cache. Subsequent calls are warm. | First-call latency 30–60s per kernel; downstream builds run on no-GPU hosts. | |
| 20 | +| **B — nvcc at build time** (`cutlass-prebuilt` feature) | `build.rs` walks a generator and emits a static archive of pre-instantiated kernels for a fixed `(op × dtype × arch)` matrix. | Fast cold start, no NVRTC at runtime. Requires `nvcc` on the build host — CI on no-GPU runners breaks. | |
| 21 | + |
| 22 | +Default to A. Switch to B for production deployments where every |
| 23 | +serving instance hits the same kernel matrix. |
| 24 | + |
| 25 | +## Cargo features |
| 26 | + |
| 27 | +Add to `atomr-accel-cuda` features: |
| 28 | + |
| 29 | +```toml |
| 30 | +features = ["cutlass", "f16"] # GEMM only |
| 31 | +features = ["cutlass", "cutlass-grouped", "f16"] # + grouped GEMM |
| 32 | +features = ["cutlass", "cutlass-evt", "f16"] # + EVT epilogues |
| 33 | +features = ["cutlass", "cutlass-prebuilt", "f16"] # Strategy B |
| 34 | +``` |
| 35 | + |
| 36 | +## arch × dtype support matrix |
| 37 | + |
| 38 | +| dtype | sm_80 | sm_86 | sm_89 | sm_90a | sm_100 | |
| 39 | +|---|:-:|:-:|:-:|:-:|:-:| |
| 40 | +| f32, f64, f16, bf16 | ✔ | ✔ | ✔ | ✔ | ✔ | |
| 41 | +| fp8 e4m3 / e5m2 | | | ✔ | ✔ | ✔ | |
| 42 | +| fp4 e2m1 | | | | | ✔ | |
| 43 | +| int8 → int32 | ✔ | ✔ | ✔ | ✔ | ✔ | |
| 44 | + |
| 45 | +Use `is_supported_for(dtype, arch)` (or `is_fp8_supported` / |
| 46 | +`is_fp4_supported`) before constructing a request — building a |
| 47 | +`GemmRequest` in an unsupported cell still succeeds, but the |
| 48 | +NVRTC compile will reject the template instantiation. |
| 49 | + |
| 50 | +## Request types |
| 51 | + |
| 52 | +Every request is generic over `T: GemmSupported` (currently `f32`, |
| 53 | +`f64`, `f16`, `bf16`, plus the fp8 / fp4 markers under the matching |
| 54 | +feature) and produces a `PlanKey` for the plan cache. |
| 55 | + |
| 56 | +| Module | Request | Dispatch trait | Gate | |
| 57 | +|---|---|---|---| |
| 58 | +| `gemm` | `GemmRequest<T>` | `CutlassGemmDispatch` | always-on | |
| 59 | +| `grouped_gemm` | `GroupedGemmRequest<T>` | `CutlassGroupedGemmDispatch` | `grouped` | |
| 60 | +| `conv` | `ConvFwdRequest<T>` / `ConvDgradRequest<T>` / `ConvWgradRequest<T>` | `CutlassConvDispatch` | always-on | |
| 61 | +| `evt` | `EpilogueVisitorTree`, `EvtBuilder`, `EpilogueOp` | n/a (composes onto `GemmRequest`) | `evt` | |
| 62 | + |
| 63 | +## A simple GEMM |
| 64 | + |
| 65 | +```rust |
| 66 | +use atomr_accel_cutlass::{ |
| 67 | + CutlassMsg, GemmEpilogue, GemmLayout, GemmRequest, GemmShape, SmArch, |
| 68 | +}; |
| 69 | +use half::f16; |
| 70 | + |
| 71 | +let req = GemmRequest::<f16> { |
| 72 | + arch: SmArch::Sm90a, |
| 73 | + shape: GemmShape::new(4096, 4096, 4096), |
| 74 | + layout_a: GemmLayout::RowMajor, |
| 75 | + layout_b: GemmLayout::ColMajor, |
| 76 | + layout_c: GemmLayout::RowMajor, |
| 77 | + epilogue: GemmEpilogue::LinearReLU { alpha: 1.0, beta: 0.0 }, |
| 78 | + /* a/b/c GpuRefs, reply channel … */ |
| 79 | +}; |
| 80 | + |
| 81 | +cutlass.tell(CutlassMsg::Gemm(Box::new(req))); |
| 82 | +``` |
| 83 | + |
| 84 | +## EVT — fused epilogue chains |
| 85 | + |
| 86 | +`cutlass-evt` unlocks the epilogue visitor tree emitter — the way |
| 87 | +to chain post-GEMM ops (bias-add, activation, dropout, scale, |
| 88 | +quantize, reduce) into a single launch. Build with `EvtBuilder`: |
| 89 | + |
| 90 | +```rust |
| 91 | +#[cfg(feature = "cutlass-evt")] |
| 92 | +use atomr_accel_cutlass::{EpilogueOp, EpilogueVisitorTree, EvtBuilder}; |
| 93 | + |
| 94 | +let tree: EpilogueVisitorTree = EvtBuilder::new() |
| 95 | + .scale(1.0 / 8.0) |
| 96 | + .add_bias(/* bias GpuRef */) |
| 97 | + .activation(EpilogueOp::Gelu) |
| 98 | + .quantize_to_fp8() |
| 99 | + .build()?; |
| 100 | + |
| 101 | +let req = GemmRequest { /* … */, epilogue: tree.into_epilogue() }; |
| 102 | +``` |
| 103 | + |
| 104 | +Each EVT chain produces a unique `PlanKey` — the cache discriminates |
| 105 | +GEMM-with-EVT-A from GEMM-with-EVT-B without collision. |
| 106 | + |
| 107 | +## The plan cache |
| 108 | + |
| 109 | +`PlanCache` (LRU, capacity set at `CutlassActor` construction) |
| 110 | +stores rendered `.cu` source + lowered kernel name keyed by |
| 111 | +`(template_id, shape, dtype, arch, layout, epilogue)`. The cache |
| 112 | +saves the per-call NVRTC compile — under Strategy A a warm cache |
| 113 | +hit is microseconds, a miss is tens of seconds. |
| 114 | + |
| 115 | +```rust |
| 116 | +let props = atomr_accel_cutlass::props(/* plan_cache_capacity */ 256); |
| 117 | +let cutlass: ActorRef<CutlassMsg> = system.actor_of(props, "cutlass"); |
| 118 | +``` |
| 119 | + |
| 120 | +The cache is **per-actor**, not global. If you spawn multiple |
| 121 | +`CutlassActor`s for parallelism, each gets its own cache. The |
| 122 | +underlying NVRTC disk cache is shared (Phase 0.6), so the second |
| 123 | +actor's first call reads from disk — fast, but not as fast as an |
| 124 | +in-process LRU hit. |
| 125 | + |
| 126 | +## Refitting weights without recompile |
| 127 | + |
| 128 | +```rust |
| 129 | +use atomr_accel_cutlass::{CutlassMsg, RefitMsg}; |
| 130 | + |
| 131 | +cutlass.tell(CutlassMsg::Refit { |
| 132 | + msg: RefitMsg { |
| 133 | + plan_key: cached_key, // from a previous Gemm dispatch |
| 134 | + weights: new_bytes, // host-side; the actor stages them |
| 135 | + }, |
| 136 | + reply: Box::new(|res| { /* … */ }), |
| 137 | +}); |
| 138 | +``` |
| 139 | + |
| 140 | +Refit is for already-compiled plans. The plan key carries the |
| 141 | +template + shape + dtype + arch fingerprint; new weight bytes are |
| 142 | +copied into the kernel's bound workspace. No NVRTC pass. |
| 143 | + |
| 144 | +## Wiring into `ContextActor` |
| 145 | + |
| 146 | +```rust |
| 147 | +let cutlass = system.actor_of(atomr_accel_cutlass::props(64), "cutlass"); |
| 148 | +context.tell(ContextMsg::RegisterExtra { |
| 149 | + name: "cutlass", |
| 150 | + actor: cutlass.clone().into_dyn(), |
| 151 | +}); |
| 152 | +``` |
| 153 | + |
| 154 | +`KernelChildren::register_extra` exists exactly for siblings like |
| 155 | +this — the cutlass actor lives next to `BlasActor` / `CudnnActor` |
| 156 | +and dies with them when the context rebuilds. |
| 157 | + |
| 158 | +## Mock vs real |
| 159 | + |
| 160 | +`CutlassInner::compile_sink` is `Option<...>` so the actor records |
| 161 | +rendered `.cu` source + lowered kernel name into the plan cache |
| 162 | +even without an NVRTC actor wired in. This is the host-only test |
| 163 | +path — the smoke test exercises plan-cache discrimination without a |
| 164 | +GPU. In production set `compile_sink` to a closure that forwards |
| 165 | +to `atomr_accel_cuda::kernel::NvrtcActor`. |
| 166 | + |
| 167 | +## Canonical references |
| 168 | + |
| 169 | +- `crates/atomr-accel-cutlass/src/lib.rs` — public surface, |
| 170 | + Strategy A/B explainer, arch×dtype matrix. |
| 171 | +- `crates/atomr-accel-cutlass/src/{gemm,grouped_gemm,conv,evt}.rs` |
| 172 | + — one request type per file. |
| 173 | +- `crates/atomr-accel-cutlass/src/plan_cache.rs` — `PlanCache` |
| 174 | + + `PlanKey` (`(template_id, shape, dtype, arch, layout, |
| 175 | + epilogue)`). |
| 176 | +- `crates/atomr-accel-cutlass/src/dtype.rs` — `CutlassDtype`, |
| 177 | + `is_supported_for`, `GemmSupported`, `SmArch`. |
| 178 | +- `crates/atomr-accel-cutlass/cutlass/include/` — vendored CUTLASS |
| 179 | + headers (BSD-3-Clause). |
| 180 | +- `crates/atomr-accel-cutlass/tests/cutlass_smoke.rs` — arch×dtype |
| 181 | + smoke test (host-only). |
| 182 | +- [`docs/features-matrix.md`](../../../docs/features-matrix.md) § |
| 183 | + `atomr-accel-cutlass` — feature flags + transitive deps. |
| 184 | + |
| 185 | +## Common pitfalls |
| 186 | + |
| 187 | +- **Cold-start latency under Strategy A.** The first call to a new |
| 188 | + shape kicks off a 30–60s NVRTC compile. Pre-warm at startup by |
| 189 | + issuing a no-op `GemmRequest` for each canonical shape, or |
| 190 | + switch to Strategy B if your shape catalogue is fixed. |
| 191 | +- **Forgetting `cutlass-prebuilt` requires nvcc.** CI fails on |
| 192 | + no-GPU runners. Either keep Strategy A in CI and B in production, |
| 193 | + or self-host a CUDA-equipped builder. |
| 194 | +- **Mixing fp8 with sm_80 / sm_86.** `is_fp8_supported(arch)` is |
| 195 | + false there. The smoke test enforces this; production code |
| 196 | + should call `is_supported_for` before submitting. |
| 197 | +- **fp4 outside Blackwell.** Only sm_100 / sm_120 supports |
| 198 | + `F4E2m1`. `is_fp4_supported(arch)` returns false elsewhere. |
| 199 | +- **EVT without the feature.** Building an `EvtBuilder` chain |
| 200 | + errors at compile time when `cutlass-evt` is off — it's not |
| 201 | + plumbed through plain `GemmEpilogue`. Add the feature explicitly. |
| 202 | +- **Plan-cache reuse across GPUs of different arch.** `PlanKey` |
| 203 | + includes `arch`, so swapping a sm_80 cubin into a sm_90a context |
| 204 | + is a cache miss (correctly). Don't try to lift a cached plan to |
| 205 | + a different arch by editing the key. |
| 206 | +- **Holding a `PlanKey` past a context rebuild.** Same `KernelHandle` |
| 207 | + story as NVRTC actor — re-resolve through the actor after |
| 208 | + `ContextReady` cycles. |
0 commit comments