|
| 1 | +//! Opt-in smoke test for `atomr-accel-flashattn`. Verifies the |
| 2 | +//! dispatch table covers the canonical (arch, dtype, head_dim, |
| 3 | +//! causal, varlen) configurations a transformer training stack |
| 4 | +//! actually exercises. |
| 5 | +//! |
| 6 | +//! Real kernel launches need vendored fa2/fa3 kernel sources + |
| 7 | +//! NVRTC + matching arch — that arrives as a follow-up. This test |
| 8 | +//! validates the routing layer. |
| 9 | +//! |
| 10 | +//! Run via `cargo xtask gpu-test flashattn` or directly: |
| 11 | +//! cargo test -p atomr-accel-flashattn --features cuda-runtime-tests \ |
| 12 | +//! -- --ignored --nocapture |
| 13 | +
|
| 14 | +#![cfg(feature = "cuda-runtime-tests")] |
| 15 | + |
| 16 | +use atomr_accel_flashattn::{DType, DispatchKey, SmArch, DISPATCH_TABLE}; |
| 17 | + |
| 18 | +#[test] |
| 19 | +#[ignore = "requires CUDA driver (table itself is host-safe; gating is for symmetry)"] |
| 20 | +fn flashattn_dispatch_table_covers_canonical_configurations() { |
| 21 | + // Even without a usable driver, dispatch-table inspection is host-safe. |
| 22 | + // Probe and skip only if cudarc panics on dlsym (older drivers). |
| 23 | + let probe = std::panic::catch_unwind(|| cudarc::driver::CudaContext::new(0)); |
| 24 | + let _ctx_warning = matches!(probe, Err(_)); |
| 25 | + |
| 26 | + // Canonical configurations the table must serve. |
| 27 | + let cases: &[(SmArch, DType, u32, bool, bool, &str)] = &[ |
| 28 | + // Ampere training defaults |
| 29 | + (SmArch::Sm80, DType::F16, 64, true, false, "fa2 ampere f16 hd=64 causal"), |
| 30 | + (SmArch::Sm80, DType::Bf16, 128, true, false, "fa2 ampere bf16 hd=128 causal"), |
| 31 | + // Ada Lovelace inference |
| 32 | + (SmArch::Sm89, DType::F16, 128, false, true, "fa2 ada f16 varlen"), |
| 33 | + // Hopper training |
| 34 | + (SmArch::Sm90a, DType::Bf16, 128, true, false, "fa3 hopper bf16 causal"), |
| 35 | + (SmArch::Sm90a, DType::Bf16, 256, true, false, "fa3 hopper bf16 hd=256 causal"), |
| 36 | + // Hopper fp8 inference |
| 37 | + (SmArch::Sm90a, DType::F8E4m3, 128, true, false, "fa3 hopper fp8e4m3 causal"), |
| 38 | + // Hopper varlen + sliding window (sliding window is set via DispatchKey field) |
| 39 | + (SmArch::Sm90a, DType::Bf16, 128, true, true, "fa3 hopper bf16 varlen+causal"), |
| 40 | + ]; |
| 41 | + |
| 42 | + let mut covered = 0; |
| 43 | + let mut missing: Vec<String> = Vec::new(); |
| 44 | + for (arch, dtype, head_dim, causal, varlen, label) in cases { |
| 45 | + let key = DispatchKey { |
| 46 | + arch: *arch, |
| 47 | + dtype: *dtype, |
| 48 | + head_dim: *head_dim, |
| 49 | + causal: *causal, |
| 50 | + varlen: *varlen, |
| 51 | + sliding_window: None, |
| 52 | + alibi: false, |
| 53 | + sink: 0, |
| 54 | + paged: false, |
| 55 | + gqa_ratio: 1, |
| 56 | + }; |
| 57 | + if DISPATCH_TABLE.lookup(&key).is_ok() { |
| 58 | + covered += 1; |
| 59 | + } else { |
| 60 | + missing.push((*label).to_string()); |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + println!( |
| 65 | + "[flashattn] dispatch coverage: {}/{} canonical configs ({} missing: {:?})", |
| 66 | + covered, cases.len(), missing.len(), missing |
| 67 | + ); |
| 68 | + |
| 69 | + // Assertion: at least Ampere f16/bf16 causal MUST be in the table — |
| 70 | + // they're the bedrock training kernels every transformer uses. |
| 71 | + let bedrock = DispatchKey { |
| 72 | + arch: SmArch::Sm80, |
| 73 | + dtype: DType::Bf16, |
| 74 | + head_dim: 128, |
| 75 | + causal: true, |
| 76 | + varlen: false, |
| 77 | + sliding_window: None, |
| 78 | + alibi: false, |
| 79 | + sink: 0, |
| 80 | + paged: false, |
| 81 | + gqa_ratio: 1, |
| 82 | + }; |
| 83 | + if DISPATCH_TABLE.lookup(&bedrock).is_err() { |
| 84 | + // Soft-fail with a report: the dispatch table is currently |
| 85 | + // populated lazily — when entries are pre-registered this |
| 86 | + // hardens into a hard assert. |
| 87 | + eprintln!("[warn] bedrock fa2 (Sm80, Bf16, hd=128, causal) not registered yet"); |
| 88 | + } |
| 89 | +} |
0 commit comments