Releases: trnsci/trnblas
v0.5.4 — chunked batched-pair dispatch + PySCF FP32 precision gate
v0.5.4
Chunked batched-pair dispatch (#46). Extends nki_batched_pair_energy to medium/large shapes. One @nki.jit call per i-row (64 calls for nocc=64), each processing all nocc j-pairs. Per-chunk XLA graph: ~1.4 GB (vs 18 GB full-batch). Same NEFF reused across all i-calls.
Medium-shape results (trn1.2xlarge, nbasis=512, nocc=64, nvir=448, naux=1536):
- Warm energy: 1.536 s (5.2× faster than torch, 3.4× faster than v0.5.2 CPU fallback)
- Cold energy: 34 min (77 NEFF compilations, paid once per instance lifetime)
PySCF FP32 precision envelope (#20) — CLOSED. All 8 hardware test cases pass below 1 µHartree. Key gate values: glycine/cc-pVDZ = 3.51×10⁻⁷ Ha, H₂O/cc-pVTZ = 1.99×10⁻⁷ Ha. Double-double emulation (#10, #22) deferred indefinitely — FP32 is sufficient for DF-MP2 at target molecule/basis combinations.
See CHANGELOG.md for full details.
v0.5.3 — medium-shape investigation + TMPDIR fix
v0.5.3
Investigation: medium-shape batched-pair NEFF compile fails (18 GB XLA graph). nl.affine_range traces all loop iterations eagerly at compile time. At nocc=64 / 4096 pairs, the XLA JSON reaches 18 GB; the trn1 root EBS volume had 16 GB free. TMPDIR fix added as hygiene (does not unblock compilation — both /tmp and /var/tmp share the same filesystem).
Fix: chunked dispatch (one @nki.jit per i-row, processing all nocc j-pairs). Tracked in #46; ships in v0.5.4.
See CHANGELOG.md for full details.
v0.5.2 — batched-pair energy kernel
v0.5.2
Batched-pair energy kernel (#43, nki_batched_pair_energy). A single @nki.jit dispatch computes the full DF-MP2 pair energy for all nocc² orbital pairs, eliminating the ~100 ms × nocc² Neuron XLA per-dispatch overhead that made nki_fused_gemm_energy impractical (215× slower than chunk-GEMM at nocc=16).
Hardware-validated on trn1.2xlarge:
- Small shape (nocc=16, 256 pairs): warm energy 0.005 s, 3.6× faster than torch baseline
- Spike B (nocc=4, 16 pairs): 1.9 ms vs 25.4 ms per-pair loop — 13.5× speedup
Medium shape (nocc=64, 4096 pairs): NEFF compile blocked (18 GB XLA graph exceeds disk). Fix in v0.5.4.
See CHANGELOG.md for full details.
v0.5.1 — fused GEMM+energy kernel (per-pair)
v0.5.1
Fused GEMM+energy kernel (#38, nki_fused_gemm_energy). One @nki.jit kernel handles one DF-MP2 orbital pair — both GEMMs (T and T_T) and the Vector Engine energy expression — without writing the (nvir, nvir) T_flat intermediate to HBM.
Note: hardware benchmarks revealed this kernel is 215× slower than the chunk-GEMM baseline at nocc=16 due to ~100 ms Neuron XLA per-dispatch overhead. The fix (single-dispatch batched kernel) ships in v0.5.2.
See CHANGELOG.md for full details.
v0.5.0 — GEMM tile-shape autotuner
v0.5.0
GEMM tile-shape autotuner (#26). _nki_gemm_impl sweeps six tile candidates {64,128} × {128} × {128,256,512} on first call per shape bucket and caches the winner to disk (/var/tmp/trnblas-autotune/cache.json). Subsequent calls hit the in-process cache first, then the NEFF cache.
_make_gemm_kernel(tile_m, tile_k, tile_n)— factory returning a separately-cached@nki.jitclosure per tile config- Opt out with
TRNBLAS_AUTOTUNE=0to restore v0.4.x fixed(128,128,512)behaviour - Backward compatible:
_gemm_kernelalias preserved
See CHANGELOG.md for full details.
v0.4.3 — correct silent NKI fallback across v0.4.x
Correction release. Every "trn1 NKI" perf number published in v0.4.0 / v0.4.1 / v0.4.2 was actually trn1's 8-vCPU Xeon running torch.matmul, not the Trainium Tensor Engine.
What went wrong
Our SSM runners launched the Neuron venv's python directly without prepending its bin/ to $PATH. torch_neuronx.initializer calls subprocess.run(["libneuronpjrt-path"]) to locate the PJRT plugin library — that binary lives in the venv's bin/ and was unresolvable. Every NKI dispatch raised FileNotFoundError, and our _nki_*_impl try/except wrappers swallowed the exception and fell back to torch.matmul.
Correctness tests kept passing because torch.matmul gives the same answer as nki_gemm; only perf attribution was wrong. The v0.4.2 cross-vendor comparison was A10G's Ampere GPU vs trn1's Xeon, not vs trn1's Tensor Engine.
What's fixed
scripts/run_neuron_tests.sh,scripts/run_df_mp2_bench.sh— prepend$NEURON_VENV/binto$PATH+ setTRNBLAS_REQUIRE_NKI=1in the test runner.trnblas.nki.NkiFallbackWarning— emitted once per distinct error when the fallback triggers. Makes future misconfigurations visible.tests/test_nki_really_runs.py— anti-regression test that forcesTRNBLAS_REQUIRE_NKI=1and asserts a GEMM dispatch completes.- Re-measured trn1 numbers on docs/benchmarks.md with retraction banner.
Side finding
trnblas.nki.nki_mp2_energy kernel tests had a partition-limit bug that was masked by the silent fallback (nl.load(eps_vir[0:NVIR]) exceeds 128 partitions for nvir > 128). Tests skipped pending kernel rewrite under #15. Not in the production DF-MP2 path (examples/df_mp2.py uses the torch reduction).
Re-measured (trn1.2xlarge, real NKI, warm cache)
| Op | Shape | v0.4.x (was CPU) | v0.4.3 (real NKI) |
|---|---|---|---|
| GEMM | 1024³ | 4.5 ms | 2.3 ms |
| SYRK | 1024² | 7.91 ms | 5.71 ms |
| TRSM | 2048×512 | 27.75 ms | 35.82 ms |
| DF-MP2 medium warm | — | 9.77 s | 9.91 s |
Relative A10G vs trn1 ratios land in the same 19–45× range; cross-vendor story is unchanged, attribution is correct.
See the CHANGELOG.
v0.4.2 — cuBLAS head-to-head on A10G
Patch release. First cross-vendor DF-MP2 numbers published — closes #4.
Highlights
| Shape | trn1 warm | A10G warm | A10G vs trn1 |
|---|---|---|---|
| small (128/16/384) | 0.008s | 0.001s | 8× |
| medium (512/64/1536) | 9.77s | 0.266s | 37× |
| large (768/96/2304) | 62.8s | 2.018s | 31× |
Energies bit-exact across platforms (fp32 reduction-order noise in the last ULP for medium/large).
What landed
infra/terraform-cuda/— provisions a single-A10Gg5.xlargeCI instance vintage-matched to trn1 (GA102 Ampere 2021 vs Trainium1 2022).scripts/run_cuda_bench.sh— SSM runner mirroring the Trainium one.examples/df_mp2.py --device cuda— moves inputs to GPU HBM; existing CPU path unchanged.- Honest cross-vendor table in docs/benchmarks.md with the vintage-parity rationale and the "close the gap" target for v0.5.0+ NKI kernel work.
What this tells us
Raw cuBLAS on 2021-vintage Ampere is ~30× faster than the current trnblas path on trn1 at medium/large DF-MP2 shapes. Trainium's Tensor Engine is being under-utilized in the current pipeline — closing the gap is exactly what #15, #18 (syrk), and #19 (trsm) are for in v0.5.0.
See the full CHANGELOG.
torch.matmul. Fixed in v0.4.3. Re-measured numbers are on the benchmarks page.
v0.4.1 — documentation stabilization
Patch release. Documentation drift against the v0.4.0 code fixed; no functional changes.
trnblas.__version__was stuck at0.3.0, now tracks the current release.- Installation docs: new
[pyscf]extra section,TRNBLAS_REQUIRE_NKIenv var table, updatedneuronxcc >= 2.24pin. - NKI API docs: GEMM HBM padding + measured per-call timings; new
nki_batched_gemmandnki_mp2_energysections. - Architecture page: current Level 3 coverage state with issue cross-references.
- Benchmarks page: TBD placeholders replaced with measured trn1.2xlarge numbers.
- Index: PySCF real-molecule demo pointer.
See CHANGELOG.
v0.4.0 — DF-MP2 end-to-end validation + NKI energy kernel
Highlights
-
Real-molecule DF-MP2 validation against PySCF (#11) — trnblas matches PySCF's own
mp.dfmp2.DFMP2reference to nanohartree precision on H2O/STO-3G, H2O/cc-pvdz, CH4/cc-pvdz, NH3/cc-pvdz. Newpip install trnblas[pyscf]extra, runnableexamples/df_mp2_pyscf.pydemo. -
Fused MP2 energy-reduction NKI kernel (#15, Phase 1) —
trnblas.nki.nki_mp2_energywith partition-dim sub-tiling. Validated on trn1 acrossnvir ∈ {8, 16, 64, 256, 448}. Scaffold landed; further perf work tracked under #15. -
DF-MP2 step-4 collapse (#14) — energy reduction replaced from
nocc²sequential batched dispatches with one chunked GEMM via the algebraic identityT_full = X @ X.T. Ontrn1.2xlarge:Shape Flops Cold Warm TFLOPS small (128/16/384) 3.4 G 0.025s 0.008s 0.43 medium (512/64/1536) 2757 G 12.9s 9.77s 0.28 large (768/96/2304) 20352 G 65.9s 62.8s 0.32 -
Trainium CI infrastructure — Terraform module for a persistent trn1 test instance, SSM-driven runners (
scripts/run_neuron_tests.sh,scripts/run_df_mp2_bench.sh), docs atdocs/aws_setup.md. -
NKI GEMM kernel wired to real
nisa.nc_matmulwith stationary tile reuse + HBM padding for arbitrary shapes.nki_batched_gemmdispatches per-slice through the cached kernel. 17/17 hardware tests pass. -
Repository transfer — now at
trnsci/trnblas. Docs at https://trnsci.dev/trnblas/. -
neuronxccfloor bumped>=2.15 → >=2.24(NKI 2.24+nc_matmulcalling convention) to unify with the rest of the trnsci suite.
See the full CHANGELOG for details.
v0.3.0
See CHANGELOG.md for full details.
Highlights
NKI on Trainium — validated end-to-end.
- Real NKI GEMM kernel (
trnblas/nki/dispatch.py:_gemm_kernel) with PSUM accumulation across K-tiles and stationary A-tile reuse vianisa.nc_matmul. NKI 2.24 calling convention. - Arbitrary-shape support via HBM padding wrapper (M/K → 128, N → 512), with
TILE_N = min(N, 512)for small-N single-tile case. - Batched GEMM dispatch (
trnblas.batched_gemm) loops over the batch dim through the cached 2D kernel — the natural shape for DF-MP2 contractions over auxiliary basis indices. - Local-flow Trainium CI — Terraform (
infra/terraform/) provisions a stopped trn1 instance;scripts/run_neuron_tests.shstarts → runspytest -m neuronvia SSM → trap-stops.--warmflag runs twice to expose NEFF cache deltas.
Performance (trn1.2xlarge, neuronxcc 2.24.5133)
17/17 pytest -m neuron tests pass.
| Pass | Wall time |
|---|---|
| Cold | 7.01s |
| Warm (NEFF cache) | 2.52s (2.8× faster) |
| Operation | Per-call (warm) |
|---|---|
| GEMM 512³ | 1.6 ms |
| GEMM 1024³ | 4.5 ms |
| Batched 32 × 256×128×256 | 1.23 ms / slice |
NEFF cache survives instance stop/start (EBS-backed) — kernel compile cost is paid exactly once per shape.