All notable changes to this project will be documented in this file.
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
- Medium-shape bench numbers (
docs/benchmarks.md). 3-way comparison at nbasis=512, nocc=64, nvir=448, naux=1536 (4096 pairs). Torch warm total 9.795s; fused-gemm 10.877s (+11%); batched-pair 7.111s (CPU fallback — NEFF compile blocked, see below). See benchmarks.md for the full table.
-
Medium-shape batched-pair NEFF compile fails: XLA graph is 18 GB.
nl.affine_rangetraces all loop iterations eagerly into the XLA IR at compile time. At nocc=64, nvir=448, naux=1536 (4096 pairs × ~192 inner tile operations), the NKI source JSON reaches 18 GB. The trn1 root volume had 16 GB free — both/tmpand/var/tmpare on the same 96 GB EBS root volume; there is no filesystem routing fix.For context: small-shape (nocc=16, 256 pairs) produces a ~240 MB JSON and compiles fine. Medium is 75× larger due to more pairs and larger nvir/naux.
Fix (deferred, issue #47): chunked dispatch — call the batched-pair kernel with ~256 pairs per dispatch (16 calls for nocc=64). Each chunk's XLA graph is ~240 MB; total dispatch overhead is ~1.6 s (vs 409 s per-pair).
-
TMPDIR=/var/tmpadded to SSM runners (run_df_mp2_bench.sh,run_neuron_tests.sh,run_pyscf_tests.sh) and Terraform user-data as a defensive measure. Does not unblock medium-shape batched-pair compilation (both paths share the same filesystem), but is correct hygiene for future NKI kernels that produce large intermediate files.
-
Batched-pair energy kernel (#43,
nki_batched_pair_energy). A single@nki.jitdispatch computes the full DF-MP2 pair energy for all NOCC² orbital pairs, eliminating the ~100ms × nocc² Neuron XLA per-dispatch overhead that madenki_fused_gemm_energyimpractical (215× slower than chunk-GEMM at NOCC=16 / 256 pairs).Design: The kernel has five levels of
nl.affine_range(i → j → a-strip → b-strip → k-tile). For each (i, j, a, b) combination, two GEMMs are computed as in_fused_gemm_energy_kernel:- GEMM 1:
T[a,b] = B[i][a_strip,:] @ B[j][b_strip,:].T - GEMM 2:
T_T[a,b] = B[j][a_strip,:] @ B[i][b_strip,:].TBoth land in SBUF viatensor_copy; the VE energy expression and free-dim reductions all run in SBUF/PSUM. Output is(TILE, NOCC²)— host calls.sum()for the scalar energy. No partition-axis (axis=0) reductions inside the kernel.
3D NKI indexing validated on trn1 (2026-04-17):
nl.load_transpose2d( B[i, a_off:a_off+T, k_off:k_off+K])withias anl.affine_rangeloop variable compiles correctly and produces accurate results (bothnl.loadandnl.load_transpose2dconfirmed via Spike A / Spike B scripts).Spike B warm time: 2ms for NOCC=4 (16 pairs), vs ~1.6s for the per-pair loop — 800× reduction in dispatch overhead at the same NOCC.
Public API:
trnblas.nki.nki_batched_pair_energy(B, eps_occ, eps_vir)→ 0-d scalar.B: (nocc, nvir, naux),eps_occ: (nocc,),eps_vir: (nvir,).Example integration:
examples/df_mp2.py --batched-pair-energyroutes the energy step through the new kernel.Tests:
TestBatchedPairEnergyintests/test_nki_gemm.py: aligned and unaligned correctness (atol=1e-2 over the full NOCC² sum),test_matches_fused_gemm_energy(against the per-pair sum), zero-B, andtest_dispatch_overhead(cold/warm/per-pair-loop timing). - GEMM 1:
-
[tool.uv]dev-machine support. Addedexclude-dependenciesforneuronxcc,torch-neuronx, andnkisouv run/uv syncresolves on machines without Trainium hardware.
-
Fused GEMM+energy kernel (#38,
nki_fused_gemm_energy). A single@nki.jitkernel handles one DF-MP2 orbital pair — both GEMMs (T and T_T) and the VE energy expression — without writing the(nvir, nvir)T_flat intermediate to HBM.Two-GEMM T_T strategy:
T.T[a,b] = T[b,a] = (B_j @ B_i.T)[a,b]. Rather thannl.load_transpose2dof T from HBM (which re-introduces the HBM round-trip), T_T is computed as a second GEMM tile in the same kernel body. Both T and T_T land in SBUF viatensor_copy— no HBM write for either intermediate.Kernel design:
TILE = 128everywhere (nl.load_transpose2dconstrains both dims to ≤ 128).- Outer a-loop, inner b-loop; two sequential PSUM allocations per (a, b) tile (one for T, one for T_T); VE energy expression fully SBUF-resident.
- Cross-b batching:
acc_b (TILE, N_B_TILES)in SBUF accumulates all b-strip partials before onenl.storeper a-strip — same pattern as_mp2_energy_kernel. - NEFF cache amortises the two-GEMM compile across all
nocc²pairs (same shape every invocation). - NKI 0.3.0 broadcast fix applied to
denomconstruction (same as_mp2_energy_kernel).
Public API:
trnblas.nki.nki_fused_gemm_energy(b_i, b_j, eps_occ_i, eps_occ_j, eps_vir)→ scalar.Example integration:
examples/df_mp2.py --fused-gemm-energyroutes the energy step through the per-pair kernel. Default path remains the chunk-GEMM path — see benchmark note below.On-hardware benchmark (trn1, small shape: nbasis=128, nocc=16, nvir=112, naux=384; 256 pairs):
Step Baseline warm Fused warm energy 0.13s 27.8s total 3.98s 31.5s The fused kernel is correct (energies agree to 6 significant figures) but the per-pair loop is 215× slower on the energy step. Root cause: Neuron XLA imposes ~100ms per-NEFF-dispatch overhead, independent of kernel compute time. With 256 pairs × 100ms = 25.6s ≈ 27.8s observed. The chunk-GEMM baseline amortises this with two dispatches total.
Pre-transferring B to the XLA device and accumulating on-device (eliminating per-pair CPU syncs) produces the same warm timing because Neuron XLA's per-dispatch overhead is in the dispatch pipeline itself, not in the CPU→XLA transfer.
Follow-on: production speedup requires a batched kernel that processes all nocc² pairs in one
@nki.jitinvocation — tracked in #43.Tests:
TestFusedGemmEnergyintests/test_nki_gemm.py: aligned/unaligned correctness (atol=1e-2), symmetry (E(i,j) == E(j,i)), zero-B_i, NEFF cache reuse (cold vs warm timing).
-
NKI closure variable limitation in autotuner (#26 regression). The v0.5.0
_make_gemm_kernelfactory returned a@nki.jitclosure that referenced tile sizes (tm,tk,tn) as Python free variables. NKI's AST-based compiler reads source from the on-disk file and resolves names from the local namespace only — it cannot traverse closure cells, producingerror: unbound variable 'tm'for every tile config.Fix: replaced the factory with six static
@nki.jitkernel definitions at module level (_gemm_kernel_64_128_128…_gemm_kernel_128_128_512), each with literal integer tile constants. All six are registered in_gemm_kernel_registryat import time;_get_gemm_kernel()is now a dict lookup._make_gemm_kernelis removed. Autotuner behaviour (sweep, cache, escape-hatch) is unchanged.Root cause note: NKI
@nki.jitfunctions must have tile constants visible as literal integers or module-level globals at AST trace time. Closure variables from an enclosing factory scope are not reachable.
- GEMM tile-shape autotuner (#26).
_nki_gemm_implnow sweeps six tile candidates{64,128} × {128} × {128,256,512}on the first call per shape bucket and caches the winner to disk (/var/tmp/trnblas-autotune/cache.json, overrideable viaTRNBLAS_AUTOTUNE_CACHE). Subsequent calls for the same bucket hit the in-process dict first, then the NEFF cache — no re-sweep. Padding is now computed after tile selection so alignment always matches the chosen tile sizes. Opt out withTRNBLAS_AUTOTUNE=0to restore v0.4.x fixed(128,128,512)behaviour. Backward compatible:_gemm_kernelalias preserved;_TILE_M/K/Nfallback constants kept for SYRK/TRSM._make_gemm_kernel(tile_m, tile_k, tile_n)— factory returning a new@nki.jitclosure with tile values baked in at trace time; each config produces a separately-cached NEFF._get_gemm_kernel(...)— registry wrapper (call once per config, reuse)._autotune_bucket(M, K, N)—ceil_pow2coarse bucket; all shapes in a DF-MP2 run land in the same bucket._sweep_tile_configs/_sweep_on_default_pad— hardware timing (3 warm runs each), safe fallback to default on per-config errors.- Persistent JSON cache with atomic directory creation.
TestAutotunerclass intests/test_nki_gemm.py: escape-hatch, cache-hit (no re-sweep), persistent-cache round-trip, per-config correctness, hardware sweep.
-
#33 resolved —
_mp2_energy_kernelprofile via Neuron Profiler 2.0. The April-14 profiling attempt was blocked by the deprecated Neuron 2.29 API (inspect/show-session→ NTFF v130 format mismatch). Retried April-15 using the newneuron-profile capture+view --output-format summary-textAPI (no InfluxDB, no browser). Key findings at medium bench shape (trn1.2xlarge, neuronxcc 2.24.5133):- Vector Engine: 96.45% active — the entire reduction (
T*(2T-Tᵀ)/denom) is element-wise arithmetic; Tensor Engine runs 21 instructions in 0.48 µs. - HBM reads: 6.58 GB — matches the analytical 2-pass prediction exactly (previous napkin of 33 GB was for the unfused torch path).
- Kernel wall time: ~0.21 s vs ~2.83 s for the torch reduction — ~13× speedup on the reduction step alone.
- Amdahl ceiling: 1.48×. The GEMM (T_flat = B_chunk @ B_flat.T) takes
~5.2 s in both paths and is 96% of fused energy step time. The 1.48× overall
result is an exact Amdahl prediction (f=0.35, s=13.5 → 1.48×), not a tuning
failure. Path to 3× requires fusing the GEMM into the energy pass (Phase 3
RFC) or a different algorithm.
See
docs/design/mp2_energy_profile_findings.mdfor full data and next steps.scripts/run_neuron_profile.shupdated to Neuron Profiler 2.0 API with base64-encoded SSM commands (bypasses all shell-quoting issues; supports--probemode for API discovery).
- Vector Engine: 96.45% active — the entire reduction (
-
#35 — cross-pair batching in
_mp2_energy_kernel(store-fence hypothesis falsified). Restructured the kernel to accumulate all NOCC pair partials for eachi-row in SBUF before a single batched HBM store, replacing IC×NOCC per-pair stores with IC stores per chunk (4,096 → 64 stores at medium; 9,216 → 96 at large). Measured result on trn1 (warm NEFF cache):- medium: energy 5.43 s → 5.38 s (within noise; 1.49× vs torch)
- large: energy 30.27 s → 30.13 s (within noise; 1.47× vs torch) The per-pair HBM-store fence hypothesis is now falsified — the compiler appears to tolerate the store traffic without serializing across pairs. The kernel is kept with batched stores (no functional regression, 64× less store traffic); the 1.47× ceiling stands. Ceiling now explained by Amdahl (#33 profile); no remaining unknown hypotheses. Tracked on #31.
-
Migrated to NKI 0.3.0 / Neuron SDK 2.29. Canonical
nki.*namespace; the legacyneuronxcc.nki.*shim is no longer used.[neuron]extra now requiresnki>=0.3.0. Kernels updated for the NKI 0.3.0 breaking-change surface:nc_matmul→nisa.nc_matmul(dst=, stationary=, moving=, accumulate=)(all kwargs; internal accumulate replaces externalpsum[...] += ...).nl.copy(psum, ...)returns a view; usenl.ndarray + nisa.tensor_copyto move PSUM → SBUF beforenl.store.- Tensor-tensor
nl.dividedropped; usemultiply × reciprocal. Kernels migrated:_gemm_kernel,_syrk_kernel, and the kernel factory inscripts/autotune_gemm.py. 32 neuron tests pass.
-
_mp2_energy_kernelre-skipped pending #15 M2 redesign — NKI 0.3.0's stricter tensor-tensor broadcast rules reject the current kernel's(1,1) - (P_TILE,1)partition-dim pattern. M1 work (namespace + free-dim reduction + partition-major HBM output) is preserved in the kernel source; only the scalar-vs-partition subtract needs rewriting. -
#15 M2.1a —
_mp2_energy_kernelbroadcast correctness fix.denomconstruction now lifts all three eps operands to(P_TILE, NVIR)vianl.broadcast_tobefore subtracting, so every op sees matching partition dims. 5TestNkiKernelMP2 cases re-enabled (37/37 neuron tests pass on trn1). Same kernel structure as M1 — same tile geometry, per-strip accumulator,(P_TILE, IC, NOCC)output layout. -
#15 M2.2 — measured DF-MP2 perf on trn1 (warm NEFF cache).
trnblas.nki.nki_mp2_energynow beats the torch reduction but misses the RFC's 3–5× target:- medium (nbasis=512, nocc=64, nvir=448): energy 8.03 s → 5.43 s (1.48×); total 9.79 s → 7.29 s.
- large (nbasis=768, nocc=96, nvir=672): energy 44.57 s → 30.27 s (1.47×); total 50.00 s → 35.76 s.
- Bit-parity with torch reference at
atol=1e-4, rtol=1e-4at both shapes. Speedup ratio is roughly shape-invariant — per-(i,j) launch cost scales withnocc²the same way torch does. RFC is updated to "Shipped (M2.1)" status; perf follow-ups (larger free-dim tiles, atomic-add variant, multi-engine pipeline evidence) tracked on the same #15 issue.examples/df_mp2.py --fused-energyexposes the kernel; default path remains torch until a future milestone hits 3×.
-
examples/df_mp2.py—--fused-energyopt-in flag; threadsuse_fusedthroughdf_mp2_energy/_energy_reduction. -
scripts/run_df_mp2_bench.sh—--compareruns torch + fused back-to-back in one SSM session for A/B perf measurement; 'stopping' instance state no longer blocks back-to-back runs.
- NKI CPU simulator dispatch via
TRNBLAS_USE_SIMULATOR=1. Routes kernels throughnki.simulate(kernel)(numpy_args)on CPU, bypassingtorch_xla+ NEFF compile. Iteration loop drops from ~8–12 min per cycle to seconds._nki_{gemm,syrk,mp2_energy}_implall carry the simulator branch;_nki_trsm_leftplumbs through transitively vianki_gemm. Correctness-only — no perf modelling, no SBUF capacity checks. Seedocs/developing_kernels.md. tests/test_nki_sim.py— curated simulator-backed correctness suite, markernki_simulator. Skips unlessTRNBLAS_USE_SIMULATOR=1+nkiis importable.scripts/run_simulator_tests.sh— SSM runner that runs the simulator suite on the trn1 DLAMI.nki-simulatorCI job onubuntu-latest. Runs thenki_simulator-marked suite againstnki>=0.3.0from the AWS pip index (--extra-index-url https://pip.repos.neuron.amazonaws.com) on every push + PR. Zero AWS cost for the correctness gate; hardware SSM now reserved for perf + MLIR verification. Of the five NKI 0.3.0 breaking-changes trnblas navigated, four would have surfaced on this gate pre-merge. The fifth (partition-broadcast strictness, MLIR-level) still requires hardware — NKI 0.3.0 has no documented device-free NEFF compile API. Maintestmatrix now excludes-m nki_simulatorto avoid running the suite twice.docs/developing_kernels.md— kernel authoring guide: three dispatch modes (pytorch / hardware / simulator), simulator limitations, NKI 0.3.0 migration reference, architecture-exploitation design discipline.
nki_mp2_energyM1 landed: kernel now correctly produces(P_TILE, IC, NOCC)per-partition partials on real NKI underTRNBLAS_REQUIRE_NKI=1. All 5 previously-skippedTestNkiKerneltests pass on trn1 acrossnvir ∈ {8, 16, 64, 256, 448}. Host.sum()reduces to the final scalar energy (partial is ≤ 258 KB, noise cost).docs/design/fused_df_mp2_energy_kernel.md— architectural RFC for the M2 fused pair-energy kernel (Phase 3 follow-up that uses M1's reduction pattern as a building block).
- SBUF persistence across the strip loop (per-partition buffer lives on-chip between all NSTRIP iterations).
- Scalar Engine free-dim reduction via
nl.sum(axis=1)— the only reduction axis NKI permits. - Partition-major HBM output so the
(P_TILE, 1)SBUF tile stores with axis-to-axis alignment (no partition-dim reshape, which the BIR verifier rejects). - Amortised dispatch: IC × NOCC (i, j) pairs per kernel launch.
| Error | Pattern rejected | Pattern that works |
|---|---|---|
partitions … exceed 128 |
nl.load(1D_tensor) of length >128 |
Reshape to (1, N) at caller |
Reduction on partition axes is not supported |
nl.sum(tile, axis=(0,1)) |
Free-dim reduce only; accumulate per-partition |
illegal partition step (BIR) |
Reshape SBUF partition↔free | Plan tensor layout so partition aligns throughout; never reshape |
Unexpected output dependencies, missing indices in dst |
acc[...] = nl.add(acc, x) inside affine_range |
Per-iteration SBUF slots, reduce after the loop |
Each constraint was surfaced by real hardware under
TRNBLAS_REQUIRE_NKI=1 — the silent-fallback era would have masked
all of them.
0.4.3 — 2026-04-13
The SSM runners in v0.4.0–v0.4.2 invoked the Neuron venv's python
directly without prepending its bin/ to $PATH. torch_neuronx's
initializer calls subprocess.run(["libneuronpjrt-path"]) to locate
the PJRT plugin library; that binary lives in the venv's bin/ and
couldn't be resolved. Every NKI dispatch raised FileNotFoundError,
which our _nki_*_impl try/except wrappers swallowed and fell back
to torch.matmul. As a result, every "trn1 NKI" perf number
published in v0.4.0 / v0.4.1 / v0.4.2 was trn1's 8-vCPU Xeon, not the
Tensor Engine.
Correctness tests still passed because torch.matmul gives the same
answer as nki_gemm; only perf attribution was wrong. The v0.4.2
cross-vendor comparison vs A10G was also mislabeled — we were
comparing A10G's GPU to trn1's Xeon.
Real NKI dispatch is now verified (commit d1b481f): cold call
includes NEFF compile (seconds), warm dispatches show real Tensor
Engine execution. docs/benchmarks.md tables are re-measured and
prefaced with a retraction banner.
scripts/run_neuron_tests.sh,scripts/run_df_mp2_bench.sh— prepend$NEURON_VENV/binto$PATHin the SSMenvline sotorch_neuronx's PJRT plugin lookup can resolve. Tests now also run withTRNBLAS_REQUIRE_NKI=1so future silent-fallback regressions fail loudly.trnblas.nki.nki_mp2_energykernel tests skipped (#15) — the kernel has a partition-limit bug (nl.load(eps_vir[0:NVIR])exceeds 128 partitions fornvir > 128) that was masked by the silent fallback. Not in the production DF-MP2 path; kernel rewrite tracked under #15.
trnblas.nki.NkiFallbackWarning— emitted once per distinct error when the NKI path silently falls back to torch. Makes misconfigured environments visible without requiringTRNBLAS_REQUIRE_NKI=1. Emitted viawarnings.warnwith a custom category.tests/test_nki_really_runs.py— anti-regression test that forcesTRNBLAS_REQUIRE_NKI=1and asserts a GEMM dispatch completes. Would have caught the v0.4.0 regression on day one.
Under real NKI dispatch (commit fd56274, trn1.2xlarge, neuronxcc
2.24.5133):
| Op | Shape | v0.4.x "trn1 NKI" (was CPU fallback) | v0.4.3 trn1 NKI (real) |
|---|---|---|---|
| GEMM warm | 1024³ | 4.5 ms | 2.3 ms |
| SYRK warm | 512² | 2.45 ms | 2.14 ms |
| SYRK warm | 1024² | 7.91 ms | 5.71 ms |
| TRSM warm | 512² | 6.05 ms | 5.59 ms |
| TRSM warm | 2048×512 | 27.75 ms | 35.82 ms |
| DF-MP2 medium warm | — | 9.77 s | 9.91 s |
The relative A10G vs trn1 ratios are in a similar 19–45× range; the cross-vendor story's shape is unchanged, only the attribution is fixed.
Added (carried from Unreleased into this release)
trnblas.nki.nki_trsm— blocked panel TRSM (#19). Diagonal panels solve viatorch.linalg.solve_triangular(small, sequential); trailing off-diagonal updates run throughnki_gemm(dominant work for large M). Covers all{lower, upper} × {trans, not} × {unit, nonunit}combinations forside="left";side="right"falls back to torch. 7/7@pytest.mark.neurontests pass on trn1 under real NKI dispatch.trnblas.nki.nki_syrk— NKI SYRK kernel (#18). LoadsAonce from HBM and reuses it for both operand roles via twoload_transpose2dcalls, avoiding the materialisedA.T.contiguous()thatnki_gemm(A, A.T)would otherwise write. 7/7@pytest.mark.neurontests pass on trn1 under real NKI.examples/bench_syrk.py,examples/bench_trsm.py— per-op timing scripts (cpu / cuda / trn1) feeding the cross-vendor table.scripts/autotune_gemm.py— GEMM tile-config study harness (#26). Paused during this correction release; resume in v0.5.0.scripts/probe_nki.py— one-shot NKI health probe (diagnostic for the silent-fallback class of bug).
0.4.2 — 2026-04-13
- cuBLAS head-to-head infrastructure (#4). New
infra/terraform-cuda/module provisions a single-A10Gg5.xlargeCI instance (GA102 Ampere, Apr 2021 — vintage-matched to Trainium1 Oct 2022). Newscripts/run_cuda_bench.shSSM runner mirrorsrun_df_mp2_bench.shwith trap-stop cleanup. examples/df_mp2.py --device {cpu,cuda}flag. Inputs are built on CPU with a fixed seed and then moved to the requested device, so GPU energies match CPU bit-for-bit (within fp32 reduction-order noise). Addedtorch.cuda.synchronize()before stopping the wall-clock so async kernels complete.
df_mp2_energynow respects the input tensor device. Thetorch.eyein the metric inversion step and the scalar energy accumulator previously hardcoded CPU, which broke--device cuda.docs/benchmarks.md— DF-MP2 table replaced with a side-by-side trn1 vs A10G comparison (new headline: A10G is 30–37× faster than the current trn1 torch-matmul path at medium/large, with bit-exact energies — the gap to close via NKI kernels in v0.5.0+).docs/aws_setup.md— new "GPU companion instance" subsection + g5.xlarge cost row.
.gitignoreterraform-state rule extended to cover allinfra/terraform*/dirs (was scoped to the Trainium module only).
0.4.1 — 2026-04-13
trnblas.__version__was stuck at"0.3.0"whilepyproject.tomladvanced through0.3.1/0.4.0. Now tracks the current release ("0.4.1").
- Documentation site stabilised to match v0.4.0 state:
docs/installation.md— new[pyscf]extra section,TRNBLAS_REQUIRE_NKIenv var table, updatedneuronxcc >= 2.24andtorch-neuronx >= 2.9pins.docs/api/nki.md— expanded GEMM section with HBM padding behaviour + measured per-call timings; new sections fornki_batched_gemmandnki_mp2_energywith perf caveats.docs/architecture.md— "Known gaps" refreshed with current Level 3 coverage status and issue cross-references.docs/benchmarks.md— placeholder replaced with measured trn1.2xlarge numbers (GEMM per-call, batched GEMM per-slice, DF-MP2 small/medium/large, NEFF cache warmup).docs/index.md— pointer to the PySCF real-molecule demo.
0.4.0 — 2026-04-12
-
Real-molecule DF-MP2 validation against PySCF (#11). New
examples/_pyscf_bridge.pyruns RHF + builds DF integrals;examples/df_mp2_pyscf.pyis a runnable demo comparing trnblas vs PySCF's ownmp.dfmp2.DFMP2reference. Newtests/test_df_mp2_pyscf.py(marker:pyscf, skipped if PySCF isn't installed) parameterises H2O/STO-3G, H2O/cc-pvdz, CH4/cc-pvdz, NH3/cc-pvdz. Measured agreement on all four: |E_trnblas - E_pyscf| < 10⁻⁷ Hartree (nanohartree precision). New optional extrapip install trnblas[pyscf]. -
trnblas.nki.nki_mp2_energy— fused MP2 energy-reduction NKI kernel (#15). Streams T_flat tiles on-chip via partition-dim sub-tiling (P_TILE picked as the largest divisor of nvir ≤ 128; covers all bench shapes). Loads a (P_TILE, nvir) strip + its within-block transpose (vianl.load_transpose2d), buildsdenomon-chip, reduces into a per-(i,j) SBUF accumulator, single HBM store per (i,j). Five on-hardware correctness tests (tests/test_nki_mp2_energy.py,@pytest.mark.neuron) covernvir ∈ {8, 16, 64, 256, 448}— all pass on trn1. Perf status: bit-exact with the torch reference but matches (not beats) it at medium on trn1 — the per-(i,j) dispatch/load chain swamps compute savings.examples/df_mp2.pykeeps the torch path for now; NKI dispatch re-wire deferred to a kernel restructuring pass (batch multiple (i,j) per dispatch). -
examples/df_mp2.pyrefactored to usetrnblas.batched_gemmfor all per-occupied-orbital loops (steps 2b, 3, and 4-per-i). Energy reduction in step 4 fully vectorised over (j, a, b), eliminating the per-pair.item()round-trip. -
--benchmode in the example: runs cold + warm passes for three synthetic shapes (small/medium/large), reports per-step timings and effective TFLOPS. -
scripts/run_df_mp2_bench.sh— SSM-driven runner for the bench, parallel torun_neuron_tests.sh.
After collapsing the energy step (was 4096 sequential batched dispatches
via a Python loop; now one chunked GEMM via the algebraic identity
T_full = X @ X.T where X = B.reshape(nocc·nvir, naux)), see #14:
| Shape | Flops | Cold | Warm | TFLOPS | Speedup vs prior |
|---|---|---|---|---|---|
| small (128/16/384) | 3.4 G | 0.025s | 0.008s | 0.43 | — |
| medium (512/64/1536) | 2757 G | 12.92s | 9.77s | 0.28 | 2.2× |
| large (768/96/2304) | 20352 G | 65.88s | 62.84s | 0.32 | (newly feasible) |
Energy reproducible bit-for-bit across runs:
- small: -1.619250e-04
- medium: -2.487221
- large: -4.351183e+01
The energy step still dominates large's wall (57s of 63s = 92%) — it's memory-bandwidth bound on the huge T tensor + intermediates. Fusing it into a custom NKI kernel would be the next optimisation; tracked as a future v0.4 follow-up to #14.
- Terraform module (
infra/terraform/) provisioning a Trainium CI instance with SSM access; instance kept stopped between runs (~$10/mo EBS-only). scripts/run_neuron_tests.sh— local SSM-driven runner forpytest -m neuron; starts the instance, runs tests, always stops it via trap.- AWS setup docs (
docs/aws_setup.md) covering provisioning, running tests, cost, and troubleshooting.
.github/workflows/neuron.ymlworkflow_dispatch scaffold. Per the trnfft pattern, GitHub Actions does not touch AWS — all Neuron testing is human-initiated locally withAWS_PROFILE=aws.
-
NKI GEMM kernel (
trnblas/nki/dispatch.py:_gemm_kernel) wired to actualnisa.nc_matmulcalls with PSUM accumulation across K-tiles and stationary A-tile reuse — supersedes the previous stub that overwrote per K-tile. -
Dispatch wrapper now handles arbitrary shapes via HBM padding: M/K rounded to 128, N rounded to 512 (when N > 512); kernel uses
TILE_N = min(N, 512)for single-tile small-N. Result is sliced back to the original (M, N). Removes the alignment-rejection fallback path. -
TRNBLAS_REQUIRE_NKI=1env-var added — re-raises on kernel exceptions instead of silently falling back totorch.matmul. Lets the validation suite surface kernel breakage. -
trnblas.batched_gemmdispatches per-slice through the cached 2D_gemm_kernelvia newnki_batched_gemmwrapper. Every slice after the first hits the NEFF cache (identical signature), so per-slice cost is HBM transfer + Tensor Engine dispatch only. The natural batched dispatch shape for DF-MP2 contractions over auxiliary basis indices. -
Bumped
neuronxccfloor from>=2.15to>=2.24to unify with the rest of the trnsci suite (matches trnfft / trnrand).torch-neuronxfloor bumped to>=2.9to match. -
Repository transferred from
scttfrdmn/trnblasto thetrnsciGitHub organisation (trnsci/trnblas). Documentation now served at https://trnsci.dev/trnblas/. CanonicalCONTRIBUTING.mdandCODE_OF_CONDUCT.mdadopted to match the trnsci suite.
17/17 pytest -m neuron tests pass. Cached-NEFF speedup measured by
running the suite twice on the same instance:
| Pass | Wall time |
|---|---|
| Cold (first run after instance start) | 7.01s |
| Warm (NEFF cache hit + warm XLA graph) | 2.52s (2.8× faster) |
Per-call kernel timing (warm cache, mean of 5):
| Shape (M×K×N) | Per-call |
|---|---|
| 512×512×512 | 1.6 ms |
| 1024×1024×1024 | 4.5 ms |
Batched dispatch (warm, batch=32 of 256×128×256):
| Metric | Value |
|---|---|
| Total | 39.3 ms |
| Per-slice | 1.23 ms |
NEFF cache at /var/tmp/neuron-compile-cache/ persists across instance
stop/start (EBS-backed), so kernel compile cost is paid exactly once per
shape per cache lifetime.
0.2.0 - 2026-04-11
- MkDocs Material documentation site at scttfrdmn.github.io/trnblas with Installation, Quickstart, API reference (Level 1/2/3, NKI backend), and Architecture pages.
- GitHub Actions CI matrix (Python 3.10, 3.11, 3.12).
- Neuron hardware CI workflow scaffold (
workflow_dispatch) — SSM wiring deferred until a persistent CI Trainium instance is available. - PyPI publishing workflow (OIDC trusted publishers, sdist + wheel on release) — matches trnfft pattern.
- Benchmark suite scaffold (
benchmarks/bench_blas.py, pytest-benchmark). - Issue and PR templates under
.github/. - README badges — CI status, PyPI version, Python versions, License, Docs.
- Cross-link to trnblas in trnfft's Related Projects table
(
scttfrdmn/trnfft@7330b3f).
0.1.0 - 2026-04-11
- Level 1 BLAS:
axpy,dot,nrm2,scal,asum,iamax. - Level 2 BLAS:
gemv,symv,trmv,ger. - Level 3 BLAS:
gemm,batched_gemm,symm,syrk,trsm,trmm. - NKI dispatch layer with
auto,pytorch, andnkibackend selection. - NKI GEMM kernel stub with stationary tile reuse strategy (scaffolded for on-hardware validation on trn1/trn2).
- DF-MP2 example (
examples/df_mp2.py) demonstrating the quantum-chemistry use case with half-transform GEMMs, Cholesky, triangular solve, and energy evaluation. - Test suite covering Level 1/2/3 BLAS correctness against PyTorch/NumPy references, with SPD matrix fixtures for symmetric/triangular routines.