Skip to content

Commit 7ab447f

Browse files
committed
AD
1 parent 143e1fc commit 7ab447f

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ mul!(M1, M2s, M3) # M1 = (M2 shifted) * M3
227227
228228
The shift is applied with periodic wrapping across the global lattice size.
229229
230+
---
231+
232+
233+
230234
### 3) Multiplication with conjugate-transposed matrices
231235
232236
```julia
@@ -239,6 +243,27 @@ All combinations of shifted and adjoint operands are supported and tested in `te
239243
240244
---
241245
246+
## Automatic differentiation (Enzyme)
247+
248+
We provide Enzyme-based AD extensions and test cases. See `test/adtest/ad.jl` for a concrete comparison between
249+
automatic differentiation and numerical differentiation using `calc_action_loopfn`. The loop body is factored
250+
into a small helper function (`_calc_action_step!`), which makes Enzyme AD more reliable for loop-heavy code.
251+
252+
Example (runs the AD vs numerical comparison with `calc_action_loopfn`):
253+
254+
```julia
255+
using Enzyme
256+
using LatticeMatrices, MPI, JACC
257+
JACC.@init_backend
258+
MPI.Init()
259+
260+
include("test/adtest/ad.jl") # runs main() in the script
261+
```
262+
263+
Note: the AD result here follows Enzyme's complex differentiation convention. For a complex variable
264+
`U = X + iY`, the gradient reported by Enzyme is
265+
`dS/dUij = dS/dXij + i dS/dYij`.
266+
242267

243268
---
244269

ext/AD/AD.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,17 @@ function ER.augmented_primal(cfg::ER.RevConfig,
4343
mul_AshiftB!(C.val, A.val, B.val, shift.val)
4444

4545
# Always tape A and parent(B) primals to survive workspace reuse.
46-
# tapeA_obj, itA = get_block(A.val.temps)
47-
# tapeA_obj .= A.val.A
48-
# tapeA = (tapeA_obj, itA)
49-
tapeA_obj = deepcopy(A.val.A)
50-
tapeA = (tapeA_obj, nothing)
51-
52-
# tapeB_obj, itB = get_block(B.val.temps)
53-
# tapeB_obj .= B.val.A
54-
# tapeB = (tapeB_obj, itB)
55-
tapeB_obj = deepcopy(B.val.A)
56-
tapeB = (tapeB_obj, nothing)
46+
tapeA_obj, itA = get_block(A.val.temps)
47+
tapeA_obj .= A.val.A
48+
tapeA = (tapeA_obj, itA)
49+
#tapeA_obj = deepcopy(A.val.A)
50+
#tapeA = (tapeA_obj, nothing)
51+
52+
tapeB_obj, itB = get_block(B.val.temps)
53+
tapeB_obj .= B.val.A
54+
tapeB = (tapeB_obj, itB)
55+
#tapeB_obj = deepcopy(B.val.A)
56+
#tapeB = (tapeB_obj, nothing)
5757

5858
tape_shift = shift.val
5959
if get(ENV, "LM_DEBUG_MULASHIFTB", "") == "1"
@@ -188,10 +188,10 @@ function _rev_mul_AshiftB!(
188188

189189
# Release tape blocks(早期 return があっても必ずやりたいなら try/finally 化推奨)
190190
if tapeA !== nothing
191-
# unused!(A.val.temps, tapeA[2])
191+
unused!(A.val.temps, tapeA[2])
192192
end
193193
if tapeB !== nothing
194-
# unused!(B.val.temps, tapeB[2])
194+
unused!(B.val.temps, tapeB[2])
195195
end
196196

197197
if _should_zero_dC(dCout)

0 commit comments

Comments
 (0)