Skip to content

Commit 1ced809

Browse files
committed
bug fix in partial_tr
1 parent 1c08fe5 commit 1ced809

File tree

3 files changed

+112
-68
lines changed

3 files changed

+112
-68
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LatticeMatrices"
22
uuid = "dd6a91e4-736f-4540-ac85-13822ca7b545"
33
authors = ["Yuki Nagai <cometscome@gmail.com>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

src/LatticeMatrices_core.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313

1414
using MPI, StaticArrays, JACC
1515

16+
abstract type LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} <: Lattice{D,T,AT,NC1,NC2,nw} end
17+
1618
# ---------------------------------------------------------------------------
1719
# container (faces / derived datatypes are GONE)
1820
# ---------------------------------------------------------------------------
1921
#struct LatticeMatrix{D,T,AT,NC1,NC2,nw} <: Lattice{D,T,AT}
20-
struct LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} <: Lattice{D,T,AT,NC1,NC2,nw}
22+
struct LatticeMatrix_standard{D,T,AT,NC1,NC2,nw,DI} <: LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} #Lattice{D,T,AT,NC1,NC2,nw}
2123
nw::Int # ghost width
2224
phases::SVector{D,T} # phases
2325
NC1::Int
@@ -38,12 +40,21 @@ struct LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} <: Lattice{D,T,AT,NC1,NC2,nw}
3840
#stride::NTuple{D,Int}
3941
end
4042

43+
# ---------------------------------------------------------------------------
44+
# constructor + heavy init (still cheap to call)
45+
# ---------------------------------------------------------------------------
46+
function LatticeMatrix(NC1, NC2, dim, gsize, PEs; nw=1, elementtype=ComplexF64, phases=ones(dim), comm0=MPI.COMM_WORLD)
47+
return LatticeMatrix_standard(NC1, NC2, dim, gsize, PEs; nw, elementtype, phases, comm0)
48+
end
4149

50+
function LatticeMatrix(A, dim, PEs; nw=1, phases=ones(dim), comm0=MPI.COMM_WORLD)
51+
return LatticeMatrix_standard(A, dim, PEs; nw, phases, comm0)
52+
end
4253

4354
# ---------------------------------------------------------------------------
4455
# constructor + heavy init (still cheap to call)
4556
# ---------------------------------------------------------------------------
46-
function LatticeMatrix(NC1, NC2, dim, gsize, PEs; nw=1, elementtype=ComplexF64, phases=ones(dim), comm0=MPI.COMM_WORLD)
57+
function LatticeMatrix_standard(NC1, NC2, dim, gsize, PEs; nw=1, elementtype=ComplexF64, phases=ones(dim), comm0=MPI.COMM_WORLD)
4758

4859
# Cartesian grid
4960
D = dim
@@ -84,12 +95,12 @@ function LatticeMatrix(NC1, NC2, dim, gsize, PEs; nw=1, elementtype=ComplexF64,
8495
#return LatticeMatrix{D,T,typeof(A),NC1,NC2,nw}(nw, phases, NC1, NC2, gsize,
8596
# cart, Tuple(coords), dims, nbr,
8697
# A, buf, MPI.Comm_rank(cart), PN, comm0)
87-
return LatticeMatrix{D,T,typeof(A),NC1,NC2,nw,DI}(nw, phases, NC1, NC2, gsize,
98+
return LatticeMatrix_standard{D,T,typeof(A),NC1,NC2,nw,DI}(nw, phases, NC1, NC2, gsize,
8899
cart, Tuple(coords), dims, nbr,
89100
A, buf, MPI.Comm_rank(cart), PN, comm0, indexer)
90101
end
91102

92-
function LatticeMatrix(A, dim, PEs; nw=1, phases=ones(dim), comm0=MPI.COMM_WORLD)
103+
function LatticeMatrix_standard(A, dim, PEs; nw=1, phases=ones(dim), comm0=MPI.COMM_WORLD)
93104

94105
NC1, NC2, NN... = size(A)
95106
#println(NN)
@@ -189,7 +200,7 @@ end
189200

190201
export allsum
191202

192-
function get_globalrange(ls::LatticeMatrix{D,T,TA}, dim) where {D,T,TA}
203+
function get_globalrange(ls::LatticeMatrix, dim)
193204
coords_r = MPI.Cart_coords(ls.cart, ls.myrank)
194205
istart = get_globalindex(ls, 1, dim, coords_r[dim])
195206
#if dim == 1
@@ -199,14 +210,14 @@ function get_globalrange(ls::LatticeMatrix{D,T,TA}, dim) where {D,T,TA}
199210
return istart:iend
200211
end
201212

202-
function get_globalindex(ls::LatticeMatrix{D,T,TA}, i, dim, myrank_dim) where {D,T,TA}
213+
function get_globalindex(ls::LatticeMatrix{D,T,AT,NC1,NC2,nw,DI}, i, dim, myrank_dim) where {D,T,AT,NC1,NC2,nw,DI}
203214
ix = i + ls.PN[dim] * myrank_dim
204215
return ix
205216
end
206217

207218

208219

209-
function set_halo!(ls::LatticeMatrix{D,T,TA}) where {D,T,TA}
220+
function set_halo!(ls::LatticeMatrix{D,T,AT,NC1,NC2,nw,DI}) where {D,T,AT,NC1,NC2,nw,DI}
210221
for id = 1:D
211222
exchange_dim!(ls, id)
212223
end

0 commit comments

Comments
 (0)