Skip to content

Commit 08cfeb7

Browse files
committed
AD test
1 parent 0d0431c commit 08cfeb7

File tree

11 files changed

+1346
-74
lines changed

11 files changed

+1346
-74
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LatticeMatrices"
22
uuid = "dd6a91e4-736f-4540-ac85-13822ca7b545"
33
authors = ["Yuki Nagai <cometscome@gmail.com>"]
4-
version = "0.2.6"
4+
version = "0.2.7"
55

66
[deps]
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

ext/AD/AD.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ function kernel_Wiltinger!(i, A, dindexer, ::Val{NC1}, ::Val{NC2}, ::Val{nw}) wh
1616
for ic = 1:NC1
1717
X = real(A[ic, jc, indices...])
1818
Y = imag(A[ic, jc, indices...])
19+
#A[jc, ic, indices...] = Complex(0.5 * X, -0.5 * Y) # ∂/∂A
1920
A[ic, jc, indices...] = Complex(0.5 * X, -0.5 * Y) # ∂/∂A
2021
end
2122
end
@@ -191,7 +192,7 @@ function Enzyme.EnzymeRules.reverse(::RevConfig,
191192
dAstruct = A.dval isa Base.RefValue ? A.dval[] : A.dval
192193
dAstruct === nothing && return (nothing, nothing, nothing)
193194

194-
# 便利ハンドル
195+
# Handy handles
195196
dAval = dAstruct.A
196197
dBval = dB.val.A
197198
N1 = Val(A.val.NC1)
@@ -226,7 +227,7 @@ function Enzyme.EnzymeRules.augmented_primal(
226227
return AugmentedReturn(nothing, A.dval, nothing)
227228
end
228229

229-
# 出力Aがconstantな場合
230+
# When output A is constant
230231
function Enzyme.EnzymeRules.augmented_primal(
231232
::RevConfig,
232233
::Const{typeof(traceless_antihermitian!)},
@@ -270,11 +271,11 @@ end
270271

271272
function Enzyme.EnzymeRules.reverse(::RevConfig,
272273
::Const{typeof(traceless_antihermitian!)},
273-
::Type{<:Const}, _tape, # ← 第3引数は戻り値のアクティビティ型(Const{Nothing}
274+
::Type{<:Const}, _tape, # Third arg is the return activity type (Const{Nothing})
274275
A::Annotation{<:LatticeMatrix},
275276
B::Annotation{<:LatticeMatrix})
276277

277-
# 上流は「出力Aのshadow」に溜まって返ってくる(dAoutは渡されない)
278+
# Upstream accumulates in the "shadow of output A" (dAout is not passed)
278279
dA = A.dval
279280
dA = dA isa Base.RefValue ? dA[] : dA
280281
dA === nothing && return (nothing, nothing)
@@ -291,4 +292,4 @@ function Enzyme.EnzymeRules.reverse(::RevConfig,
291292
# dB += Π_ah,0(dA)
292293
JACC.parallel_for(Nsites, kernel_traceless_antihermitian_add!, dB.A, dA.A, NC1, nw, idx)
293294
return (nothing, nothing)
294-
end
295+
end

src/Latticeindices.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ end
133133
@generated function delinearize(
134134
::DIndexer{D,dims,strides}, L::Integer, offset::Integer=0
135135
) where {D,dims,strides}
136-
# 返すブロックを組み立て
136+
# Build the block to return
137137
body = Expr(:block,
138138
:(Base.@_inline_meta true),
139139
:(r = Int(L) - 1),
@@ -142,21 +142,21 @@ end
142142
143143
comps = Vector{Any}(undef, D)
144144
145-
# d = D..2 まで割り算
145+
# Divide for d = D..2
146146
for d = D:-1:2
147147
sd = Int(strides[d])
148148
push!(body.args, :(q, r = Base.divrem(r, $(sd))))
149149
comps[d] = :(q + off + 1)
150150
end
151-
# d = 1 は余り
151+
# Remainder for d = 1
152152
comps[1] = :(r + off + 1)
153153
154154
push!(body.args, Expr(:tuple, comps...))
155155
return body
156156
end
157157
=#
158158

159-
# ラッパ(ここに @inline を付けるのはOK)
159+
# Wrapper (OK to add @inline here)
160160
@inline function delinearize(
161161
idx::DIndexer{D,dims,strides}, L::Integer, ::Val{nw}
162162
) where {D,dims,strides,nw}
@@ -198,4 +198,4 @@ export shiftindices
198198

199199

200200
export delinearize
201-
export linearize
201+
export linearize

0 commit comments

Comments
 (0)