1313
1414using MPI, StaticArrays, JACC
1515
16+ abstract type LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} <: Lattice{D,T,AT,NC1,NC2,nw} end
17+
1618# ---------------------------------------------------------------------------
1719# container (faces / derived datatypes are GONE)
1820# ---------------------------------------------------------------------------
1921# struct LatticeMatrix{D,T,AT,NC1,NC2,nw} <: Lattice{D,T,AT}
20- struct LatticeMatrix {D,T,AT,NC1,NC2,nw,DI} <: Lattice{D,T,AT,NC1,NC2,nw}
22+ struct LatticeMatrix_standard {D,T,AT,NC1,NC2,nw,DI} <: LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} # Lattice{D,T,AT,NC1,NC2,nw}
2123 nw:: Int # ghost width
2224 phases:: SVector{D,T} # phases
2325 NC1:: Int
@@ -38,12 +40,21 @@ struct LatticeMatrix{D,T,AT,NC1,NC2,nw,DI} <: Lattice{D,T,AT,NC1,NC2,nw}
3840 # stride::NTuple{D,Int}
3941end
4042
43+ # ---------------------------------------------------------------------------
44+ # constructor + heavy init (still cheap to call)
45+ # ---------------------------------------------------------------------------
46+ function LatticeMatrix(NC1, NC2, dim, gsize, PEs; nw= 1 , elementtype= ComplexF64, phases= ones(dim), comm0= MPI. COMM_WORLD)
47+ return LatticeMatrix_standard(NC1, NC2, dim, gsize, PEs; nw, elementtype, phases, comm0)
48+ end
4149
50+ function LatticeMatrix(A, dim, PEs; nw= 1 , phases= ones(dim), comm0= MPI. COMM_WORLD)
51+ return LatticeMatrix_standard(A, dim, PEs; nw, phases, comm0)
52+ end
4253
4354# ---------------------------------------------------------------------------
4455# constructor + heavy init (still cheap to call)
4556# ---------------------------------------------------------------------------
46- function LatticeMatrix (NC1, NC2, dim, gsize, PEs; nw= 1 , elementtype= ComplexF64, phases= ones(dim), comm0= MPI. COMM_WORLD)
57+ function LatticeMatrix_standard (NC1, NC2, dim, gsize, PEs; nw= 1 , elementtype= ComplexF64, phases= ones(dim), comm0= MPI. COMM_WORLD)
4758
4859 # Cartesian grid
4960 D = dim
@@ -84,12 +95,12 @@ function LatticeMatrix(NC1, NC2, dim, gsize, PEs; nw=1, elementtype=ComplexF64,
8495 # return LatticeMatrix{D,T,typeof(A),NC1,NC2,nw}(nw, phases, NC1, NC2, gsize,
8596 # cart, Tuple(coords), dims, nbr,
8697 # A, buf, MPI.Comm_rank(cart), PN, comm0)
87- return LatticeMatrix {D,T,typeof(A),NC1,NC2,nw,DI}(nw, phases, NC1, NC2, gsize,
98+ return LatticeMatrix_standard {D,T,typeof(A),NC1,NC2,nw,DI}(nw, phases, NC1, NC2, gsize,
8899 cart, Tuple(coords), dims, nbr,
89100 A, buf, MPI. Comm_rank(cart), PN, comm0, indexer)
90101end
91102
92- function LatticeMatrix (A, dim, PEs; nw= 1 , phases= ones(dim), comm0= MPI. COMM_WORLD)
103+ function LatticeMatrix_standard (A, dim, PEs; nw= 1 , phases= ones(dim), comm0= MPI. COMM_WORLD)
93104
94105 NC1, NC2, NN... = size(A)
95106 # println(NN)
189200
190201export allsum
191202
192- function get_globalrange(ls:: LatticeMatrix{D,T,TA} , dim) where {D,T,TA}
203+ function get_globalrange(ls:: LatticeMatrix , dim)
193204 coords_r = MPI. Cart_coords(ls. cart, ls. myrank)
194205 istart = get_globalindex(ls, 1 , dim, coords_r[dim])
195206 # if dim == 1
@@ -199,14 +210,14 @@ function get_globalrange(ls::LatticeMatrix{D,T,TA}, dim) where {D,T,TA}
199210 return istart: iend
200211end
201212
202- function get_globalindex(ls:: LatticeMatrix{D,T,TA } , i, dim, myrank_dim) where {D,T,TA }
213+ function get_globalindex(ls:: LatticeMatrix{D,T,AT,NC1,NC2,nw,DI } , i, dim, myrank_dim) where {D,T,AT,NC1,NC2,nw,DI }
203214 ix = i + ls. PN[dim] * myrank_dim
204215 return ix
205216end
206217
207218
208219
209- function set_halo!(ls:: LatticeMatrix{D,T,TA } ) where {D,T,TA }
220+ function set_halo!(ls:: LatticeMatrix{D,T,AT,NC1,NC2,nw,DI } ) where {D,T,AT,NC1,NC2,nw,DI }
210221 for id = 1 : D
211222 exchange_dim!(ls, id)
212223 end
0 commit comments