Skip to content

Commit d02786a

Browse files
authored
feat(nki): fused GEMM+energy kernel (#38, v0.5.1)
Implements nki_fused_gemm_energy — one @nki.jit per DF-MP2 orbital pair combining T and T_T GEMMs + VE energy expression, eliminating the T_flat HBM round-trip. Correctness validated on trn1 (52/52 tests pass). Also fixes NKI closure variable limitation in autotuner (_make_gemm_kernel closure → 6 static module-level @nki.jit variants), and adds stopping-state handling to run_neuron_tests.sh. Benchmark (trn1 small shape, nocc=16, 256 pairs): per-pair dispatch overhead ~100ms/NEFF on Neuron XLA dominates; warm energy 27.8s vs 0.13s baseline. Production speedup requires batched-pair kernel (all pairs in one @nki.jit), tracked in #43. Closes #38.
2 parents a59dbbc + 01586e2 commit d02786a

10 files changed

Lines changed: 760 additions & 85 deletions

File tree

CHANGELOG.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,83 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.5.1] — 2026-04-15
11+
12+
### Added
13+
14+
- **Fused GEMM+energy kernel (#38, `nki_fused_gemm_energy`).** A single
15+
`@nki.jit` kernel handles one DF-MP2 orbital pair — both GEMMs (T and T_T)
16+
and the VE energy expression — without writing the `(nvir, nvir)` T_flat
17+
intermediate to HBM.
18+
19+
**Two-GEMM T_T strategy:** `T.T[a,b] = T[b,a] = (B_j @ B_i.T)[a,b]`.
20+
Rather than `nl.load_transpose2d` of T from HBM (which re-introduces the
21+
HBM round-trip), T_T is computed as a second GEMM tile in the same kernel
22+
body. Both T and T_T land in SBUF via `tensor_copy` — no HBM write for
23+
either intermediate.
24+
25+
**Kernel design:**
26+
- `TILE = 128` everywhere (`nl.load_transpose2d` constrains both dims to ≤ 128).
27+
- Outer a-loop, inner b-loop; two sequential PSUM allocations per (a, b) tile
28+
(one for T, one for T_T); VE energy expression fully SBUF-resident.
29+
- Cross-b batching: `acc_b (TILE, N_B_TILES)` in SBUF accumulates all b-strip
30+
partials before one `nl.store` per a-strip — same pattern as `_mp2_energy_kernel`.
31+
- NEFF cache amortises the two-GEMM compile across all `nocc²` pairs (same
32+
shape every invocation).
33+
- NKI 0.3.0 broadcast fix applied to `denom` construction (same as `_mp2_energy_kernel`).
34+
35+
**Public API:** `trnblas.nki.nki_fused_gemm_energy(b_i, b_j, eps_occ_i, eps_occ_j, eps_vir)` → scalar.
36+
37+
**Example integration:** `examples/df_mp2.py --fused-gemm-energy` routes the
38+
energy step through the per-pair kernel. Default path remains the chunk-GEMM
39+
path — see benchmark note below.
40+
41+
**On-hardware benchmark (trn1, small shape: nbasis=128, nocc=16, nvir=112,
42+
naux=384; 256 pairs):**
43+
44+
| Step | Baseline warm | Fused warm |
45+
|---|---|---|
46+
| energy | **0.13s** | **27.8s** |
47+
| total | 3.98s | 31.5s |
48+
49+
The fused kernel is correct (energies agree to 6 significant figures) but
50+
the per-pair loop is **215× slower** on the energy step. Root cause:
51+
Neuron XLA imposes ~100ms per-NEFF-dispatch overhead, independent of kernel
52+
compute time. With 256 pairs × 100ms = 25.6s ≈ 27.8s observed. The
53+
chunk-GEMM baseline amortises this with two dispatches total.
54+
55+
Pre-transferring B to the XLA device and accumulating on-device (eliminating
56+
per-pair CPU syncs) produces the same warm timing because Neuron XLA's
57+
per-dispatch overhead is in the dispatch pipeline itself, not in the
58+
CPU→XLA transfer.
59+
60+
**Follow-on:** production speedup requires a batched kernel that processes
61+
all nocc² pairs in one `@nki.jit` invocation — tracked in #43.
62+
63+
**Tests:** `TestFusedGemmEnergy` in `tests/test_nki_gemm.py`:
64+
aligned/unaligned correctness (atol=1e-2), symmetry (`E(i,j) == E(j,i)`),
65+
zero-B_i, NEFF cache reuse (cold vs warm timing).
66+
67+
### Fixed
68+
69+
- **NKI closure variable limitation in autotuner (#26 regression).** The
70+
v0.5.0 `_make_gemm_kernel` factory returned a `@nki.jit` closure that
71+
referenced tile sizes (`tm`, `tk`, `tn`) as Python free variables.
72+
NKI's AST-based compiler reads source from the on-disk file and resolves
73+
names from the local namespace only — it cannot traverse closure cells,
74+
producing `error: unbound variable 'tm'` for every tile config.
75+
76+
**Fix:** replaced the factory with six static `@nki.jit` kernel definitions
77+
at module level (`_gemm_kernel_64_128_128``_gemm_kernel_128_128_512`),
78+
each with literal integer tile constants. All six are registered in
79+
`_gemm_kernel_registry` at import time; `_get_gemm_kernel()` is now a
80+
dict lookup. `_make_gemm_kernel` is removed. Autotuner behaviour
81+
(sweep, cache, escape-hatch) is unchanged.
82+
83+
**Root cause note:** NKI `@nki.jit` functions must have tile constants
84+
visible as literal integers or module-level globals at AST trace time.
85+
Closure variables from an enclosing factory scope are not reachable.
86+
1087
## [0.5.0] — 2026-04-15
1188

1289
### Added

examples/df_mp2.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,35 @@ def _energy_reduction(
8484
return float(e_mp2)
8585

8686

87+
def _energy_reduction_fused_gemm(
88+
B: torch.Tensor,
89+
eps_occ: torch.Tensor,
90+
eps_vir: torch.Tensor,
91+
) -> float:
92+
"""Energy via the fused GEMM+energy kernel (#38, v0.5.1).
93+
94+
Calls `nki_fused_gemm_energy` once per (i, j) orbital pair. Each
95+
call fuses the GEMM (B[i] @ B[j].T) and the VE energy expression
96+
into one @nki.jit kernel — eliminating the T_flat HBM round-trip
97+
present in `_energy_reduction`.
98+
99+
This is the per-pair loop the RFC (fused_df_mp2_energy_kernel.md)
100+
describes. The NEFF cache amortises the two-GEMM compile cost
101+
across all nocc² pairs since every (i, j) invocation has the same
102+
shape and hits the same NEFF.
103+
"""
104+
from trnblas.nki import nki_fused_gemm_energy
105+
106+
nocc, nvir, naux = B.shape
107+
e_mp2 = torch.zeros((), dtype=B.dtype, device=B.device)
108+
for i in range(nocc):
109+
eps_occ_i = float(eps_occ[i])
110+
for j in range(nocc):
111+
eps_occ_j = float(eps_occ[j])
112+
e_mp2 = e_mp2 + nki_fused_gemm_energy(B[i], B[j], eps_occ_i, eps_occ_j, eps_vir)
113+
return float(e_mp2)
114+
115+
87116
def df_mp2_energy(
88117
C_occ: torch.Tensor, # (nbasis, nocc) — occupied MO coefficients
89118
C_vir: torch.Tensor, # (nbasis, nvir) — virtual MO coefficients
@@ -93,12 +122,17 @@ def df_mp2_energy(
93122
eps_vir: torch.Tensor, # (nvir,) — virtual orbital energies
94123
timings: dict | None = None,
95124
use_fused: bool = False,
125+
use_fused_gemm: bool = False,
96126
) -> float:
97127
"""Compute DF-MP2 correlation energy.
98128
99129
Returns E_MP2 (scalar). Optionally fills `timings` with per-step seconds.
100-
When `use_fused=True`, the energy-reduction step routes through
101-
`trnblas.nki.nki_mp2_energy`.
130+
131+
use_fused: Route energy-reduction through `nki_mp2_energy`
132+
(fused chunk-level kernel, #15 M2 — 1.48× on energy step).
133+
use_fused_gemm: Route energy through `nki_fused_gemm_energy` per (i,j)
134+
pair (fused GEMM+energy kernel, #38 v0.5.1 — eliminates
135+
T_flat HBM round-trip).
102136
"""
103137
nbasis, nocc = C_occ.shape
104138
naux = J_metric.shape[0]
@@ -135,14 +169,18 @@ def df_mp2_energy(
135169
B = trnblas.batched_gemm(1.0, ia_P, J_b) # (nocc, nvir, naux)
136170
t_metric = time.perf_counter() - t0
137171

138-
# Step 4: Energy via one GEMM (chunked over i if memory-tight).
172+
# Step 4: Energy.
139173
# T(i,j)_{ab} = Σ_P B[i,a,P] B[j,b,P]
140-
# Reshape B → X of shape (nocc·nvir, naux); then T_full = X @ X.T is
141-
# one GEMM, and T_full[i·nvir+a, j·nvir+b] = T(i,j)_{ab}. No batching
142-
# over (i,j) needed — that was the wrong shape for this contraction.
143-
# For shapes where the full T_full doesn't fit HBM, chunk over i.
174+
#
175+
# Three paths in order of increasing fusion:
176+
# default: chunk-GEMM (B_flat @ B_flat.T) + torch reduction
177+
# --fused-energy: chunk-GEMM + fused NKI energy kernel (#15 M2)
178+
# --fused-gemm-energy: per-pair fused GEMM+energy NKI kernel (#38 v0.5.1)
144179
t0 = time.perf_counter()
145-
e_mp2 = _energy_reduction(B, eps_occ, eps_vir, use_fused=use_fused)
180+
if use_fused_gemm:
181+
e_mp2 = _energy_reduction_fused_gemm(B, eps_occ, eps_vir)
182+
else:
183+
e_mp2 = _energy_reduction(B, eps_occ, eps_vir, use_fused=use_fused)
146184
t_energy = time.perf_counter() - t0
147185

148186
if timings is not None:
@@ -195,22 +233,28 @@ def _make_inputs(nbasis: int, nocc: int, naux: int, seed: int = 42, device: str
195233
}
196234

197235

198-
def bench(shape_name: str, device: str = "cpu", use_fused: bool = False):
236+
def bench(
237+
shape_name: str,
238+
device: str = "cpu",
239+
use_fused: bool = False,
240+
use_fused_gemm: bool = False,
241+
):
199242
nbasis, nocc, naux = _BENCH_SHAPES[shape_name]
200243
nvir = nbasis - nocc
201244
flops = _flops(nbasis, nocc, naux)
202245
inputs = _make_inputs(nbasis, nocc, naux, device=device)
203246

247+
energy_mode = "fused-gemm" if use_fused_gemm else ("fused" if use_fused else "torch")
204248
print(f"[shape={shape_name} nbasis={nbasis} nocc={nocc} nvir={nvir} naux={naux}]")
205249
print(
206250
f" approx flops: {flops / 1e9:.1f} G backend: {trnblas.get_backend()} "
207-
f"device: {device} fused_energy: {use_fused}"
251+
f"device: {device} energy_mode: {energy_mode}"
208252
)
209253

210254
for label in ("cold", "warm"):
211255
t = {}
212256
t0 = time.perf_counter()
213-
e = df_mp2_energy(*inputs, timings=t, use_fused=use_fused)
257+
e = df_mp2_energy(*inputs, timings=t, use_fused=use_fused, use_fused_gemm=use_fused_gemm)
214258
# Ensure async GPU work completes before stopping the timer.
215259
if device != "cpu" and torch.cuda.is_available():
216260
torch.cuda.synchronize()
@@ -246,14 +290,25 @@ def main():
246290
"--fused-energy",
247291
action="store_true",
248292
help="Route the energy-reduction step through trnblas.nki.nki_mp2_energy "
249-
"(fused NKI kernel, #15 M2).",
293+
"(fused chunk-level kernel, #15 M2 — 1.48× on energy step).",
294+
)
295+
parser.add_argument(
296+
"--fused-gemm-energy",
297+
action="store_true",
298+
help="Route the energy step through nki_fused_gemm_energy (per-pair fused "
299+
"GEMM+energy kernel, #38 v0.5.1 — eliminates T_flat HBM round-trip).",
250300
)
251301
args = parser.parse_args()
252302

253303
if args.bench:
254304
shapes = [args.shape] if args.shape else list(_BENCH_SHAPES)
255305
for s in shapes:
256-
bench(s, device=args.device, use_fused=args.fused_energy)
306+
bench(
307+
s,
308+
device=args.device,
309+
use_fused=args.fused_energy,
310+
use_fused_gemm=args.fused_gemm_energy,
311+
)
257312
return
258313

259314
if args.demo:
@@ -273,7 +328,12 @@ def main():
273328

274329
timings: dict = {}
275330
t0 = time.perf_counter()
276-
e_mp2 = df_mp2_energy(*inputs, timings=timings, use_fused=args.fused_energy)
331+
e_mp2 = df_mp2_energy(
332+
*inputs,
333+
timings=timings,
334+
use_fused=args.fused_energy,
335+
use_fused_gemm=args.fused_gemm_energy,
336+
)
277337
total = time.perf_counter() - t0
278338
for k, v in timings.items():
279339
print(f" {k:15s}: {v:.3f}s")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "trnblas"
7-
version = "0.5.0"
7+
version = "0.5.1"
88
description = "BLAS operations for AWS Trainium via NKI"
99
readme = "README.md"
1010
license = "Apache-2.0"

scripts/run_neuron_tests.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ SHA="$(git rev-parse HEAD)"
3434
echo "Looking up instance with Name=$TAG in $REGION..."
3535
INSTANCE_ID=$(aws ec2 describe-instances \
3636
--filters "Name=tag:Name,Values=$TAG" \
37-
"Name=instance-state-name,Values=stopped,running,pending" \
37+
"Name=instance-state-name,Values=stopped,stopping,running,pending" \
3838
--query 'Reservations[0].Instances[0].InstanceId' \
3939
--output text \
4040
--region "$REGION")
@@ -59,6 +59,11 @@ trap cleanup EXIT
5959
STATE=$(aws ec2 describe-instances --instance-ids "$INSTANCE_ID" --region "$REGION" \
6060
--query 'Reservations[0].Instances[0].State.Name' --output text)
6161

62+
if [[ "$STATE" == "stopping" ]]; then
63+
echo "Instance is stopping — waiting for stopped..."
64+
aws ec2 wait instance-stopped --instance-ids "$INSTANCE_ID" --region "$REGION"
65+
STATE=stopped
66+
fi
6267
if [[ "$STATE" == "stopped" ]]; then
6368
echo "Starting instance..."
6469
aws ec2 start-instances --instance-ids "$INSTANCE_ID" --region "$REGION" >/dev/null

scripts/run_phase3_spike.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,8 @@ device = xm.xla_device()
114114
TILE, NPAIRS = 128, 8
115115
B = torch.randn(NPAIRS, TILE, TILE).to(device)
116116
D = torch.ones(TILE, TILE).to(device)
117-
O = torch.zeros(NPAIRS, TILE, 1).to(device)
118117
print("Compiling spike C...", flush=True)
119-
_spike_c_te_ve_overlap(B, D, O)
118+
_spike_c_te_ve_overlap(B, D)
120119
print("Done.", flush=True)
121120
')
122121
printf '%s' "$WARMUP_PY" > /tmp/spike_c_warmup.py

0 commit comments

Comments
 (0)