Skip to content

Releases: trnsci/trnblas

v0.5.4 — chunked batched-pair dispatch + PySCF FP32 precision gate

21 Apr 20:04

Choose a tag to compare

v0.5.4

Chunked batched-pair dispatch (#46). Extends nki_batched_pair_energy to medium/large shapes. One @nki.jit call per i-row (64 calls for nocc=64), each processing all nocc j-pairs. Per-chunk XLA graph: ~1.4 GB (vs 18 GB full-batch). Same NEFF reused across all i-calls.

Medium-shape results (trn1.2xlarge, nbasis=512, nocc=64, nvir=448, naux=1536):

  • Warm energy: 1.536 s (5.2× faster than torch, 3.4× faster than v0.5.2 CPU fallback)
  • Cold energy: 34 min (77 NEFF compilations, paid once per instance lifetime)

PySCF FP32 precision envelope (#20) — CLOSED. All 8 hardware test cases pass below 1 µHartree. Key gate values: glycine/cc-pVDZ = 3.51×10⁻⁷ Ha, H₂O/cc-pVTZ = 1.99×10⁻⁷ Ha. Double-double emulation (#10, #22) deferred indefinitely — FP32 is sufficient for DF-MP2 at target molecule/basis combinations.

See CHANGELOG.md for full details.

v0.5.3 — medium-shape investigation + TMPDIR fix

21 Apr 20:04

Choose a tag to compare

v0.5.3

Investigation: medium-shape batched-pair NEFF compile fails (18 GB XLA graph). nl.affine_range traces all loop iterations eagerly at compile time. At nocc=64 / 4096 pairs, the XLA JSON reaches 18 GB; the trn1 root EBS volume had 16 GB free. TMPDIR fix added as hygiene (does not unblock compilation — both /tmp and /var/tmp share the same filesystem).

Fix: chunked dispatch (one @nki.jit per i-row, processing all nocc j-pairs). Tracked in #46; ships in v0.5.4.

See CHANGELOG.md for full details.

v0.5.2 — batched-pair energy kernel

21 Apr 20:04
fd64f78

Choose a tag to compare

v0.5.2

Batched-pair energy kernel (#43, nki_batched_pair_energy). A single @nki.jit dispatch computes the full DF-MP2 pair energy for all nocc² orbital pairs, eliminating the ~100 ms × nocc² Neuron XLA per-dispatch overhead that made nki_fused_gemm_energy impractical (215× slower than chunk-GEMM at nocc=16).

Hardware-validated on trn1.2xlarge:

  • Small shape (nocc=16, 256 pairs): warm energy 0.005 s, 3.6× faster than torch baseline
  • Spike B (nocc=4, 16 pairs): 1.9 ms vs 25.4 ms per-pair loop — 13.5× speedup

Medium shape (nocc=64, 4096 pairs): NEFF compile blocked (18 GB XLA graph exceeds disk). Fix in v0.5.4.

See CHANGELOG.md for full details.

v0.5.1 — fused GEMM+energy kernel (per-pair)

21 Apr 20:03
d02786a

Choose a tag to compare

v0.5.1

Fused GEMM+energy kernel (#38, nki_fused_gemm_energy). One @nki.jit kernel handles one DF-MP2 orbital pair — both GEMMs (T and T_T) and the Vector Engine energy expression — without writing the (nvir, nvir) T_flat intermediate to HBM.

Note: hardware benchmarks revealed this kernel is 215× slower than the chunk-GEMM baseline at nocc=16 due to ~100 ms Neuron XLA per-dispatch overhead. The fix (single-dispatch batched kernel) ships in v0.5.2.

See CHANGELOG.md for full details.

v0.5.0 — GEMM tile-shape autotuner

21 Apr 20:03

Choose a tag to compare

v0.5.0

GEMM tile-shape autotuner (#26). _nki_gemm_impl sweeps six tile candidates {64,128} × {128} × {128,256,512} on first call per shape bucket and caches the winner to disk (/var/tmp/trnblas-autotune/cache.json). Subsequent calls hit the in-process cache first, then the NEFF cache.

  • _make_gemm_kernel(tile_m, tile_k, tile_n) — factory returning a separately-cached @nki.jit closure per tile config
  • Opt out with TRNBLAS_AUTOTUNE=0 to restore v0.4.x fixed (128,128,512) behaviour
  • Backward compatible: _gemm_kernel alias preserved

See CHANGELOG.md for full details.

v0.4.3 — correct silent NKI fallback across v0.4.x

13 Apr 20:57

Choose a tag to compare

Correction release. Every "trn1 NKI" perf number published in v0.4.0 / v0.4.1 / v0.4.2 was actually trn1's 8-vCPU Xeon running torch.matmul, not the Trainium Tensor Engine.

What went wrong

Our SSM runners launched the Neuron venv's python directly without prepending its bin/ to $PATH. torch_neuronx.initializer calls subprocess.run(["libneuronpjrt-path"]) to locate the PJRT plugin library — that binary lives in the venv's bin/ and was unresolvable. Every NKI dispatch raised FileNotFoundError, and our _nki_*_impl try/except wrappers swallowed the exception and fell back to torch.matmul.

Correctness tests kept passing because torch.matmul gives the same answer as nki_gemm; only perf attribution was wrong. The v0.4.2 cross-vendor comparison was A10G's Ampere GPU vs trn1's Xeon, not vs trn1's Tensor Engine.

What's fixed

  • scripts/run_neuron_tests.sh, scripts/run_df_mp2_bench.sh — prepend $NEURON_VENV/bin to $PATH + set TRNBLAS_REQUIRE_NKI=1 in the test runner.
  • trnblas.nki.NkiFallbackWarning — emitted once per distinct error when the fallback triggers. Makes future misconfigurations visible.
  • tests/test_nki_really_runs.py — anti-regression test that forces TRNBLAS_REQUIRE_NKI=1 and asserts a GEMM dispatch completes.
  • Re-measured trn1 numbers on docs/benchmarks.md with retraction banner.

Side finding

trnblas.nki.nki_mp2_energy kernel tests had a partition-limit bug that was masked by the silent fallback (nl.load(eps_vir[0:NVIR]) exceeds 128 partitions for nvir > 128). Tests skipped pending kernel rewrite under #15. Not in the production DF-MP2 path (examples/df_mp2.py uses the torch reduction).

Re-measured (trn1.2xlarge, real NKI, warm cache)

Op Shape v0.4.x (was CPU) v0.4.3 (real NKI)
GEMM 1024³ 4.5 ms 2.3 ms
SYRK 1024² 7.91 ms 5.71 ms
TRSM 2048×512 27.75 ms 35.82 ms
DF-MP2 medium warm 9.77 s 9.91 s

Relative A10G vs trn1 ratios land in the same 19–45× range; cross-vendor story is unchanged, attribution is correct.

See the CHANGELOG.

v0.4.2 — cuBLAS head-to-head on A10G

13 Apr 03:10

Choose a tag to compare

Patch release. First cross-vendor DF-MP2 numbers published — closes #4.

Highlights

Shape trn1 warm A10G warm A10G vs trn1
small (128/16/384) 0.008s 0.001s
medium (512/64/1536) 9.77s 0.266s 37×
large (768/96/2304) 62.8s 2.018s 31×

Energies bit-exact across platforms (fp32 reduction-order noise in the last ULP for medium/large).

What landed

  • infra/terraform-cuda/ — provisions a single-A10G g5.xlarge CI instance vintage-matched to trn1 (GA102 Ampere 2021 vs Trainium1 2022).
  • scripts/run_cuda_bench.sh — SSM runner mirroring the Trainium one.
  • examples/df_mp2.py --device cuda — moves inputs to GPU HBM; existing CPU path unchanged.
  • Honest cross-vendor table in docs/benchmarks.md with the vintage-parity rationale and the "close the gap" target for v0.5.0+ NKI kernel work.

What this tells us

Raw cuBLAS on 2021-vintage Ampere is ~30× faster than the current trnblas path on trn1 at medium/large DF-MP2 shapes. Trainium's Tensor Engine is being under-utilized in the current pipeline — closing the gap is exactly what #15, #18 (syrk), and #19 (trsm) are for in v0.5.0.

See the full CHANGELOG.


⚠️ Erratum (v0.4.3): The "trn1 vs A10G" table in this release was comparing A10G's Ampere GPU to trn1's Xeon CPU, not its Tensor Engine. A PATH misconfiguration caused silent NKI fallback to torch.matmul. Fixed in v0.4.3. Re-measured numbers are on the benchmarks page.

v0.4.1 — documentation stabilization

13 Apr 02:16

Choose a tag to compare

Patch release. Documentation drift against the v0.4.0 code fixed; no functional changes.

  • trnblas.__version__ was stuck at 0.3.0, now tracks the current release.
  • Installation docs: new [pyscf] extra section, TRNBLAS_REQUIRE_NKI env var table, updated neuronxcc >= 2.24 pin.
  • NKI API docs: GEMM HBM padding + measured per-call timings; new nki_batched_gemm and nki_mp2_energy sections.
  • Architecture page: current Level 3 coverage state with issue cross-references.
  • Benchmarks page: TBD placeholders replaced with measured trn1.2xlarge numbers.
  • Index: PySCF real-molecule demo pointer.

See CHANGELOG.


⚠️ Erratum (v0.4.3): The benchmark tables this release pointed at on the docs site attributed trn1 numbers to NKI; they were actually trn1 Xeon torch.matmul. Fixed in v0.4.3.

v0.4.0 — DF-MP2 end-to-end validation + NKI energy kernel

13 Apr 01:43

Choose a tag to compare

Highlights

  • Real-molecule DF-MP2 validation against PySCF (#11) — trnblas matches PySCF's own mp.dfmp2.DFMP2 reference to nanohartree precision on H2O/STO-3G, H2O/cc-pvdz, CH4/cc-pvdz, NH3/cc-pvdz. New pip install trnblas[pyscf] extra, runnable examples/df_mp2_pyscf.py demo.

  • Fused MP2 energy-reduction NKI kernel (#15, Phase 1)trnblas.nki.nki_mp2_energy with partition-dim sub-tiling. Validated on trn1 across nvir ∈ {8, 16, 64, 256, 448}. Scaffold landed; further perf work tracked under #15.

  • DF-MP2 step-4 collapse (#14) — energy reduction replaced from nocc² sequential batched dispatches with one chunked GEMM via the algebraic identity T_full = X @ X.T. On trn1.2xlarge:

    Shape Flops Cold Warm TFLOPS
    small (128/16/384) 3.4 G 0.025s 0.008s 0.43
    medium (512/64/1536) 2757 G 12.9s 9.77s 0.28
    large (768/96/2304) 20352 G 65.9s 62.8s 0.32
  • Trainium CI infrastructure — Terraform module for a persistent trn1 test instance, SSM-driven runners (scripts/run_neuron_tests.sh, scripts/run_df_mp2_bench.sh), docs at docs/aws_setup.md.

  • NKI GEMM kernel wired to real nisa.nc_matmul with stationary tile reuse + HBM padding for arbitrary shapes. nki_batched_gemm dispatches per-slice through the cached kernel. 17/17 hardware tests pass.

  • Repository transfer — now at trnsci/trnblas. Docs at https://trnsci.dev/trnblas/.

  • neuronxcc floor bumped >=2.15 → >=2.24 (NKI 2.24+ nc_matmul calling convention) to unify with the rest of the trnsci suite.

See the full CHANGELOG for details.


⚠️ Erratum (v0.4.3): The "GEMM per-call kernel timing" and "DF-MP2 end-to-end" tables here reported trn1 numbers that were silently torch.matmul fallback on trn1's Xeon, not NKI on the Tensor Engine. Fixed in v0.4.3.

v0.3.0

12 Apr 03:59

Choose a tag to compare

See CHANGELOG.md for full details.

Highlights

NKI on Trainium — validated end-to-end.

  • Real NKI GEMM kernel (trnblas/nki/dispatch.py:_gemm_kernel) with PSUM accumulation across K-tiles and stationary A-tile reuse via nisa.nc_matmul. NKI 2.24 calling convention.
  • Arbitrary-shape support via HBM padding wrapper (M/K → 128, N → 512), with TILE_N = min(N, 512) for small-N single-tile case.
  • Batched GEMM dispatch (trnblas.batched_gemm) loops over the batch dim through the cached 2D kernel — the natural shape for DF-MP2 contractions over auxiliary basis indices.
  • Local-flow Trainium CI — Terraform (infra/terraform/) provisions a stopped trn1 instance; scripts/run_neuron_tests.sh starts → runs pytest -m neuron via SSM → trap-stops. --warm flag runs twice to expose NEFF cache deltas.

Performance (trn1.2xlarge, neuronxcc 2.24.5133)

17/17 pytest -m neuron tests pass.

Pass Wall time
Cold 7.01s
Warm (NEFF cache) 2.52s (2.8× faster)
Operation Per-call (warm)
GEMM 512³ 1.6 ms
GEMM 1024³ 4.5 ms
Batched 32 × 256×128×256 1.23 ms / slice

NEFF cache survives instance stop/start (EBS-backed) — kernel compile cost is paid exactly once per shape.