@@ -84,6 +84,35 @@ def _energy_reduction(
8484 return float (e_mp2 )
8585
8686
87+ def _energy_reduction_fused_gemm (
88+ B : torch .Tensor ,
89+ eps_occ : torch .Tensor ,
90+ eps_vir : torch .Tensor ,
91+ ) -> float :
92+ """Energy via the fused GEMM+energy kernel (#38, v0.5.1).
93+
94+ Calls `nki_fused_gemm_energy` once per (i, j) orbital pair. Each
95+ call fuses the GEMM (B[i] @ B[j].T) and the VE energy expression
96+ into one @nki.jit kernel — eliminating the T_flat HBM round-trip
97+ present in `_energy_reduction`.
98+
99+ This is the per-pair loop the RFC (fused_df_mp2_energy_kernel.md)
100+ describes. The NEFF cache amortises the two-GEMM compile cost
101+ across all nocc² pairs since every (i, j) invocation has the same
102+ shape and hits the same NEFF.
103+ """
104+ from trnblas .nki import nki_fused_gemm_energy
105+
106+ nocc , nvir , naux = B .shape
107+ e_mp2 = torch .zeros ((), dtype = B .dtype , device = B .device )
108+ for i in range (nocc ):
109+ eps_occ_i = float (eps_occ [i ])
110+ for j in range (nocc ):
111+ eps_occ_j = float (eps_occ [j ])
112+ e_mp2 = e_mp2 + nki_fused_gemm_energy (B [i ], B [j ], eps_occ_i , eps_occ_j , eps_vir )
113+ return float (e_mp2 )
114+
115+
87116def df_mp2_energy (
88117 C_occ : torch .Tensor , # (nbasis, nocc) — occupied MO coefficients
89118 C_vir : torch .Tensor , # (nbasis, nvir) — virtual MO coefficients
@@ -93,12 +122,17 @@ def df_mp2_energy(
93122 eps_vir : torch .Tensor , # (nvir,) — virtual orbital energies
94123 timings : dict | None = None ,
95124 use_fused : bool = False ,
125+ use_fused_gemm : bool = False ,
96126) -> float :
97127 """Compute DF-MP2 correlation energy.
98128
99129 Returns E_MP2 (scalar). Optionally fills `timings` with per-step seconds.
100- When `use_fused=True`, the energy-reduction step routes through
101- `trnblas.nki.nki_mp2_energy`.
130+
131+ use_fused: Route energy-reduction through `nki_mp2_energy`
132+ (fused chunk-level kernel, #15 M2 — 1.48× on energy step).
133+ use_fused_gemm: Route energy through `nki_fused_gemm_energy` per (i,j)
134+ pair (fused GEMM+energy kernel, #38 v0.5.1 — eliminates
135+ T_flat HBM round-trip).
102136 """
103137 nbasis , nocc = C_occ .shape
104138 naux = J_metric .shape [0 ]
@@ -135,14 +169,18 @@ def df_mp2_energy(
135169 B = trnblas .batched_gemm (1.0 , ia_P , J_b ) # (nocc, nvir, naux)
136170 t_metric = time .perf_counter () - t0
137171
138- # Step 4: Energy via one GEMM (chunked over i if memory-tight) .
172+ # Step 4: Energy.
139173 # T(i,j)_{ab} = Σ_P B[i,a,P] B[j,b,P]
140- # Reshape B → X of shape (nocc·nvir, naux); then T_full = X @ X.T is
141- # one GEMM, and T_full[i·nvir+a, j·nvir+b] = T(i,j)_{ab}. No batching
142- # over (i,j) needed — that was the wrong shape for this contraction.
143- # For shapes where the full T_full doesn't fit HBM, chunk over i.
174+ #
175+ # Three paths in order of increasing fusion:
176+ # default: chunk-GEMM (B_flat @ B_flat.T) + torch reduction
177+ # --fused-energy: chunk-GEMM + fused NKI energy kernel (#15 M2)
178+ # --fused-gemm-energy: per-pair fused GEMM+energy NKI kernel (#38 v0.5.1)
144179 t0 = time .perf_counter ()
145- e_mp2 = _energy_reduction (B , eps_occ , eps_vir , use_fused = use_fused )
180+ if use_fused_gemm :
181+ e_mp2 = _energy_reduction_fused_gemm (B , eps_occ , eps_vir )
182+ else :
183+ e_mp2 = _energy_reduction (B , eps_occ , eps_vir , use_fused = use_fused )
146184 t_energy = time .perf_counter () - t0
147185
148186 if timings is not None :
@@ -195,22 +233,28 @@ def _make_inputs(nbasis: int, nocc: int, naux: int, seed: int = 42, device: str
195233}
196234
197235
198- def bench (shape_name : str , device : str = "cpu" , use_fused : bool = False ):
236+ def bench (
237+ shape_name : str ,
238+ device : str = "cpu" ,
239+ use_fused : bool = False ,
240+ use_fused_gemm : bool = False ,
241+ ):
199242 nbasis , nocc , naux = _BENCH_SHAPES [shape_name ]
200243 nvir = nbasis - nocc
201244 flops = _flops (nbasis , nocc , naux )
202245 inputs = _make_inputs (nbasis , nocc , naux , device = device )
203246
247+ energy_mode = "fused-gemm" if use_fused_gemm else ("fused" if use_fused else "torch" )
204248 print (f"[shape={ shape_name } nbasis={ nbasis } nocc={ nocc } nvir={ nvir } naux={ naux } ]" )
205249 print (
206250 f" approx flops: { flops / 1e9 :.1f} G backend: { trnblas .get_backend ()} "
207- f"device: { device } fused_energy : { use_fused } "
251+ f"device: { device } energy_mode : { energy_mode } "
208252 )
209253
210254 for label in ("cold" , "warm" ):
211255 t = {}
212256 t0 = time .perf_counter ()
213- e = df_mp2_energy (* inputs , timings = t , use_fused = use_fused )
257+ e = df_mp2_energy (* inputs , timings = t , use_fused = use_fused , use_fused_gemm = use_fused_gemm )
214258 # Ensure async GPU work completes before stopping the timer.
215259 if device != "cpu" and torch .cuda .is_available ():
216260 torch .cuda .synchronize ()
@@ -246,14 +290,25 @@ def main():
246290 "--fused-energy" ,
247291 action = "store_true" ,
248292 help = "Route the energy-reduction step through trnblas.nki.nki_mp2_energy "
249- "(fused NKI kernel, #15 M2)." ,
293+ "(fused chunk-level kernel, #15 M2 — 1.48× on energy step)." ,
294+ )
295+ parser .add_argument (
296+ "--fused-gemm-energy" ,
297+ action = "store_true" ,
298+ help = "Route the energy step through nki_fused_gemm_energy (per-pair fused "
299+ "GEMM+energy kernel, #38 v0.5.1 — eliminates T_flat HBM round-trip)." ,
250300 )
251301 args = parser .parse_args ()
252302
253303 if args .bench :
254304 shapes = [args .shape ] if args .shape else list (_BENCH_SHAPES )
255305 for s in shapes :
256- bench (s , device = args .device , use_fused = args .fused_energy )
306+ bench (
307+ s ,
308+ device = args .device ,
309+ use_fused = args .fused_energy ,
310+ use_fused_gemm = args .fused_gemm_energy ,
311+ )
257312 return
258313
259314 if args .demo :
@@ -273,7 +328,12 @@ def main():
273328
274329 timings : dict = {}
275330 t0 = time .perf_counter ()
276- e_mp2 = df_mp2_energy (* inputs , timings = timings , use_fused = args .fused_energy )
331+ e_mp2 = df_mp2_energy (
332+ * inputs ,
333+ timings = timings ,
334+ use_fused = args .fused_energy ,
335+ use_fused_gemm = args .fused_gemm_energy ,
336+ )
277337 total = time .perf_counter () - t0
278338 for k , v in timings .items ():
279339 print (f" { k :15s} : { v :.3f} s" )
0 commit comments