Skip to content

Commit db15a33

Browse files
committed
domainwall
1 parent 9eca927 commit db15a33

File tree

3 files changed

+532
-226
lines changed

3 files changed

+532
-226
lines changed

src/Latticeindices.jl

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,151 @@ end
4545
end
4646

4747
# L(1-based) -> idx(1-based)
48+
#=
4849
@inline function delinearize(
49-
::DIndexer{D,dims,strides}, L::Integer, offset=Int32(0)
50+
::DIndexer{D,dims,strides}, L::Integer, offset=0
5051
) where {D,dims,strides}
51-
m = MVector{D,Int32}(undef)
52+
m = MVector{D,Int64}(undef)
5253
r = L - 1
5354
@inbounds for d in D:-1:2
5455
q = r ÷ strides[d]
5556
r = r % strides[d]
5657
m[d] = q + offset + 1
58+
#m[d] = q + 1
5759
end
60+
#m[1] = r + 1
5861
m[1] = r + offset + 1
59-
return NTuple{D,Int32}(m)
62+
#m .+= offset
63+
return NTuple{D,Int64}(m)
64+
end
65+
=#
66+
67+
@inline function delinearize(::DIndexer{1,dims,strides}, L::Integer, offset::Integer=0) where {dims,strides}
68+
r = Int(L) - 1; off = Int(offset)
69+
i1 = r + off + 1
70+
return (i1,)
71+
end
72+
73+
@inline function delinearize(::DIndexer{2,dims,strides}, L::Integer, offset::Integer=0) where {dims,strides}
74+
r = Int(L) - 1; off = Int(offset)
75+
q2, r = Base.divrem(r, Int(strides[2])); i2 = q2 + off + 1
76+
i1 = r + off + 1
77+
return (i1, i2)
78+
end
79+
80+
@inline function delinearize(::DIndexer{3,dims,strides}, L::Integer, offset::Integer=0) where {dims,strides}
81+
r = Int(L) - 1; off = Int(offset)
82+
q3, r = Base.divrem(r, Int(strides[3])); i3 = q3 + off + 1
83+
q2, r = Base.divrem(r, Int(strides[2])); i2 = q2 + off + 1
84+
i1 = r + off + 1
85+
return (i1, i2, i3)
86+
end
87+
88+
@inline function delinearize(::DIndexer{4,dims,strides}, L::Integer, offset::Integer=0) where {dims,strides}
89+
r = Int(L) - 1; off = Int(offset)
90+
q4, r = Base.divrem(r, Int(strides[4])); i4 = q4 + off + 1
91+
q3, r = Base.divrem(r, Int(strides[3])); i3 = q3 + off + 1
92+
q2, r = Base.divrem(r, Int(strides[2])); i2 = q2 + off + 1
93+
i1 = r + off + 1
94+
return (i1, i2, i3, i4)
95+
end
96+
97+
@inline function delinearize(::DIndexer{5,dims,strides}, L::Integer, offset::Integer=0) where {dims,strides}
98+
r = Int(L) - 1; off = Int(offset)
99+
q5, r = Base.divrem(r, Int(strides[5])); i5 = q5 + off + 1
100+
q4, r = Base.divrem(r, Int(strides[4])); i4 = q4 + off + 1
101+
q3, r = Base.divrem(r, Int(strides[3])); i3 = q3 + off + 1
102+
q2, r = Base.divrem(r, Int(strides[2])); i2 = q2 + off + 1
103+
i1 = r + off + 1
104+
return (i1, i2, i3, i4, i5)
105+
end
106+
107+
@inline function delinearize(::DIndexer{6,dims,strides}, L::Integer, offset::Integer=0) where {dims,strides}
108+
r = Int(L) - 1; off = Int(offset)
109+
q6, r = Base.divrem(r, Int(strides[6])); i6 = q6 + off + 1
110+
q5, r = Base.divrem(r, Int(strides[5])); i5 = q5 + off + 1
111+
q4, r = Base.divrem(r, Int(strides[4])); i4 = q4 + off + 1
112+
q3, r = Base.divrem(r, Int(strides[3])); i3 = q3 + off + 1
113+
q2, r = Base.divrem(r, Int(strides[2])); i2 = q2 + off + 1
114+
i1 = r + off + 1
115+
return (i1, i2, i3, i4, i5, i6)
116+
end
117+
118+
119+
@inline _delinearize(::Val{1}, r::Int, off::Int, strides) = (r + off + 1,)
120+
121+
@inline function _delinearize(::Val{N}, r::Int, off::Int, strides) where {N}
122+
@inbounds q, r2 = Base.divrem(r, Int(strides[N]))
123+
head = _delinearize(Val(N-1), r2, off, strides)
124+
return (head..., q + off + 1)
125+
end
126+
127+
@inline function delinearize(::DIndexer{D,dims,strides}, L::Integer, offset::Integer=0) where {D,dims,strides}
128+
_delinearize(Val(D), Int(L) - 1, Int(offset), strides)
129+
end
130+
131+
132+
#=
133+
@generated function delinearize(
134+
::DIndexer{D,dims,strides}, L::Integer, offset::Integer=0
135+
) where {D,dims,strides}
136+
# 返すブロックを組み立て
137+
body = Expr(:block,
138+
:(Base.@_inline_meta true),
139+
:(r = Int(L) - 1),
140+
:(off = Int(offset)),
141+
)
142+
143+
comps = Vector{Any}(undef, D)
144+
145+
# d = D..2 まで割り算
146+
for d = D:-1:2
147+
sd = Int(strides[d])
148+
push!(body.args, :(q, r = Base.divrem(r, $(sd))))
149+
comps[d] = :(q + off + 1)
150+
end
151+
# d = 1 は余り
152+
comps[1] = :(r + off + 1)
153+
154+
push!(body.args, Expr(:tuple, comps...))
155+
return body
60156
end
157+
=#
158+
159+
# ラッパ(ここに @inline を付けるのはOK)
160+
@inline function delinearize(
161+
idx::DIndexer{D,dims,strides}, L::Integer, ::Val{nw}
162+
) where {D,dims,strides,nw}
163+
return delinearize(idx, L, Int(nw))
164+
end
165+
166+
167+
168+
#=
169+
# L(1-based) -> idx(1-based)
170+
@inline @Base.propagate_inbounds function delinearize(
171+
::DIndexer{D,dims,strides}, L::Integer, offset = Int32(0)
172+
) where {D,dims,strides}
173+
offset32 = Int32(offset)
174+
L32 = Int32(L)
175+
one32 = Int32(1)
176+
177+
m = MVector{D,Int32}(undef)
178+
r = L32 - one32
179+
180+
@inbounds for d = D:-1:2
181+
sd = Int32(strides[d])
182+
q = r ÷ sd
183+
r = r % sd
184+
m[d] = q + offset32 + one32
185+
end
186+
m[1] = r + offset32 + one32
187+
188+
return NTuple{D,Int32}(m)
189+
end
190+
=#
191+
192+
61193

62194
@inline function shiftindices(indices, shift)
63195
return ntuple(i -> indices[i] + shift[i], length(indices))

src/LinearAlgebras/linearalgebra_D.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#Overwrite Y with X*a + Y*b, where a and b are scalars. Return Y.
22
function LinearAlgebra.axpby!(
33
a::Number,
4-
X::LatticeMatrix{D,T1,AT1,NC1,NC2,nw,DI},
4+
X::TX,
55
b::Number,
6-
Y::LatticeMatrix{D,T1,AT1,NC1,NC2,nw,DI},
7-
) where {T1,AT1,NC1,NC2,nw,D,DI}
6+
Y::TY,
7+
) where {T1,AT1,NC1,NC2,nw,D,DI,
8+
TX<:LatticeMatrix{D,T1,AT1,NC1,NC2,nw,DI},TY<:LatticeMatrix{D,T1,AT1,NC1,NC2,nw,DI}}
89

910
JACC.parallel_for(
1011
prod(Y.PN), kernel_D_axpby!, a, X.A, b, Y.A, Val(NC1), Val(NC2), Val(nw), Y.indexer
@@ -21,6 +22,25 @@ end
2122
end
2223
end
2324

25+
@inline function kernel_D_axpby!(i, a, X, b, Y, ::Val{3}, ::Val{3}, ::Val{nw}, dindexer) where {nw}
26+
indices = delinearize(dindexer, i, nw)
27+
28+
Y[1, 1, indices...] = a * X[1, 1, indices...] + b * Y[1, 1, indices...]
29+
Y[2, 1, indices...] = a * X[2, 1, indices...] + b * Y[2, 1, indices...]
30+
Y[3, 1, indices...] = a * X[3, 1, indices...] + b * Y[3, 1, indices...]
31+
32+
33+
Y[1, 2, indices...] = a * X[1, 2, indices...] + b * Y[1, 2, indices...]
34+
Y[2, 2, indices...] = a * X[2, 2, indices...] + b * Y[2, 2, indices...]
35+
Y[3, 2, indices...] = a * X[3, 2, indices...] + b * Y[3, 2, indices...]
36+
37+
Y[1, 3, indices...] = a * X[1, 3, indices...] + b * Y[1, 3, indices...]
38+
Y[2, 3, indices...] = a * X[2, 3, indices...] + b * Y[2, 3, indices...]
39+
Y[3, 3, indices...] = a * X[3, 3, indices...] + b * Y[3, 3, indices...]
40+
41+
42+
end
43+
2444
#C = a*x
2545
function LinearAlgebra.mul!(C::LatticeMatrix{D,T1,AT1,NC1,NG,nw,DI},
2646
a::TA, x::LatticeMatrix{D,T1,AT1,NC1,NG,nw,DI}) where {T1,AT1,NC1,nw,NG,TA<:Number,D,DI}

0 commit comments

Comments
 (0)