Skip to content

Commit 6dd483b

Browse files
authored
Merge pull request #36 from fverdugo/gridap_distributed
Misc devs for Gridapdistributed
2 parents 06f1304 + 13035ef commit 6dd483b

File tree

6 files changed

+262
-24
lines changed

6 files changed

+262
-24
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Francesc Verdugo <[email protected]> and contributors"]
44
version = "0.2.2"
55

66
[deps]
7+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
78
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
@@ -12,6 +13,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1213
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
1314

1415
[compat]
16+
Distances = "0.10"
1517
IterativeSolvers = "0.9"
1618
MPI = "0.16, 0.17, 0.18, 0.19"
1719
SparseMatricesCSR = "0.6"

src/Interfaces.jl

+196-17
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ num_parts(a::AbstractPData) = length(a)
7373
"""
7474
get_backend(a::AbstractPData) -> AbstractBackend
7575
76-
Get the back-end associated with `a`.
76+
Get the back-end associated with `a`.
7777
"""
7878
get_backend(a::AbstractPData) = @abstractmethod
7979

@@ -639,6 +639,16 @@ function lids_are_equal(a::AbstractIndexSet,b::AbstractIndexSet)
639639
a.lid_to_gid == b.lid_to_gid
640640
end
641641

642+
function find_lid_map(a::AbstractIndexSet,b::AbstractIndexSet)
643+
alid_to_blid = fill(Int32(-1),num_lids(a))
644+
for blid in 1:num_lids(b)
645+
gid = b.lid_to_gid[blid]
646+
alid = a.gid_to_lid[gid]
647+
alid_to_blid[alid] = blid
648+
end
649+
alid_to_blid
650+
end
651+
642652
# The given ids are assumed to be a sub-set of the lids
643653
function touched_hids(a::AbstractIndexSet,gids::AbstractVector{<:Integer})
644654
i = 0
@@ -828,7 +838,7 @@ function async_exchange!(
828838
data_snd = allocate_snd_buffer(Tsnd,exchanger)
829839

830840
# Fill snd buffer
831-
t1 = map_parts(t0,values_snd,data_snd,exchanger.lids_snd) do t0,values_snd,data_snd,lids_snd
841+
t1 = map_parts(t0,values_snd,data_snd,exchanger.lids_snd) do t0,values_snd,data_snd,lids_snd
832842
@task begin
833843
wait(schedule(t0))
834844
for p in 1:length(lids_snd.data)
@@ -848,7 +858,7 @@ function async_exchange!(
848858

849859
# Fill values_rcv from rcv buffer
850860
# asynchronously
851-
t3 = map_parts(t2,values_rcv,data_rcv,exchanger.lids_rcv) do t2,values_rcv,data_rcv,lids_rcv
861+
t3 = map_parts(t2,values_rcv,data_rcv,exchanger.lids_rcv) do t2,values_rcv,data_rcv,lids_rcv
852862
@task begin
853863
wait(schedule(t2))
854864
for p in 1:length(lids_rcv.data)
@@ -946,7 +956,7 @@ mutable struct PRange{A,B,C} <: AbstractUnitRange{Int}
946956
exchanger::Exchanger,
947957
gid_to_part::Union{AbstractPData{<:AbstractArray{<:Integer}},Nothing}=nothing,
948958
ghost::Bool=true)
949-
959+
950960
A = typeof(partition)
951961
B = typeof(exchanger)
952962
C = typeof(gid_to_part)
@@ -1455,12 +1465,22 @@ function Base.similar(
14551465
end
14561466

14571467
function Base.copy!(a::PVector,b::PVector)
1458-
map_parts(copy!,a.values,b.values)
1468+
@check oids_are_equal(a.rows,b.rows)
1469+
if a.rows.partition === b.rows.partition
1470+
map_parts(copy!,a.values,b.values)
1471+
else
1472+
map_parts(copy!,a.owned_values,b.owned_values)
1473+
end
14591474
a
14601475
end
14611476

14621477
function Base.copyto!(a::PVector,b::PVector)
1463-
map_parts(copyto!,a.values,b.values)
1478+
@check oids_are_equal(a.rows,b.rows)
1479+
if a.rows.partition === b.rows.partition
1480+
map_parts(copyto!,a.values,b.values)
1481+
else
1482+
map_parts(copyto!,a.owned_values,b.owned_values)
1483+
end
14641484
a
14651485
end
14661486

@@ -1470,6 +1490,19 @@ function Base.copy(b::PVector)
14701490
a
14711491
end
14721492

1493+
function LinearAlgebra.rmul!(a::PVector,v::Number)
1494+
map_parts(a.values) do l
1495+
rmul!(l,v)
1496+
end
1497+
a
1498+
end
1499+
1500+
function Base.:(==)(a::PVector,b::PVector)
1501+
length(a) == length(b) &&
1502+
num_parts(a.values) == num_parts(b.values) &&
1503+
reduce(&,map_parts(==,a.owned_values,b.owned_values),init=true)
1504+
end
1505+
14731506
struct DistributedBroadcasted{A,B,C}
14741507
owned_values::A
14751508
ghost_values::B
@@ -1556,6 +1589,101 @@ function LinearAlgebra.norm(a::PVector,p::Real=2)
15561589
reduce(+,contibs;init=zero(eltype(contibs)))^(1/p)
15571590
end
15581591

1592+
# Distances.jl related (needed eg for non-linear solvers)
1593+
1594+
for M in Distances.metrics
1595+
@eval begin
1596+
function (dist::$M)(a::PVector,b::PVector)
1597+
_eval_dist(dist,a,b,Distances.parameters(dist))
1598+
end
1599+
end
1600+
end
1601+
1602+
function _eval_dist(d,a,b,::Nothing)
1603+
partials = map_parts(a.owned_values,b.owned_values) do a,b
1604+
_eval_dist_local(d,a,b,nothing)
1605+
end
1606+
s = reduce(
1607+
(i,j)->Distances.eval_reduce(d,i,j),
1608+
partials,
1609+
init=Distances.eval_start(d, a, b))
1610+
Distances.eval_end(d,s)
1611+
end
1612+
1613+
Base.@propagate_inbounds function _eval_dist_local(d,a,b,::Nothing)
1614+
@boundscheck if length(a) != length(b)
1615+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
1616+
end
1617+
if length(a) == 0
1618+
return zero(Distances.result_type(d, a, b))
1619+
end
1620+
@inbounds begin
1621+
s = Distances.eval_start(d, a, b)
1622+
if (IndexStyle(a, b) === IndexLinear() && eachindex(a) == eachindex(b)) || axes(a) == axes(b)
1623+
@simd for I in eachindex(a, b)
1624+
ai = a[I]
1625+
bi = b[I]
1626+
s = Distances.eval_reduce(d, s, Distances.eval_op(d, ai, bi))
1627+
end
1628+
else
1629+
for (ai, bi) in zip(a, b)
1630+
s = Distances.eval_reduce(d, s, Distances.eval_op(d, ai, bi))
1631+
end
1632+
end
1633+
return s
1634+
end
1635+
end
1636+
1637+
function _eval_dist(d,a,b,p)
1638+
@notimplemented
1639+
end
1640+
1641+
function _eval_dist_local(d,a,b,p)
1642+
@notimplemented
1643+
end
1644+
1645+
function Base.any(f::Function,x::PVector)
1646+
partials = map_parts(x.owned_values) do o
1647+
any(f,o)
1648+
end
1649+
reduce(|,partials,init=false)
1650+
end
1651+
1652+
function Base.all(f::Function,x::PVector)
1653+
partials = map_parts(x.owned_values) do o
1654+
all(f,o)
1655+
end
1656+
reduce(&,partials,init=true)
1657+
end
1658+
1659+
function Base.maximum(x::PVector)
1660+
partials = map_parts(maximum,x.owned_values)
1661+
reduce(max,partials,init=typemin(eltype(x)))
1662+
end
1663+
1664+
function Base.maximum(f::Function,x::PVector)
1665+
partials = map_parts(x.owned_values) do o
1666+
maximum(f,o)
1667+
end
1668+
reduce(max,partials,init=typemin(eltype(x)))
1669+
end
1670+
1671+
function Base.minimum(x::PVector)
1672+
partials = map_parts(minimum,x.owned_values)
1673+
reduce(min,partials,init=typemax(eltype(x)))
1674+
end
1675+
1676+
function Base.minimum(f::Function,x::PVector)
1677+
partials = map_parts(x.owned_values) do o
1678+
minimum(f,o)
1679+
end
1680+
reduce(min,partials,init=typemax(eltype(x)))
1681+
end
1682+
1683+
function Base.findall(f::Function,x::PVector)
1684+
@notimplemented
1685+
end
1686+
15591687
function PVector{T}(
15601688
::UndefInitializer,
15611689
rows::PRange) where T
@@ -1678,8 +1806,46 @@ function LinearAlgebra.dot(a::PVector,b::PVector)
16781806
end
16791807

16801808
function local_view(a::PVector,rows::PRange)
1681-
@notimplementedif a.rows !== rows
1682-
a.values
1809+
if a.rows === rows
1810+
a.values
1811+
else
1812+
map_parts(a.values,rows.partition,a.rows.partition) do values,rows,arows
1813+
LocalView(values,(find_lid_map(rows,arows),))
1814+
end
1815+
end
1816+
end
1817+
1818+
struct LocalView{T,N,A,B} <:AbstractArray{T,N}
1819+
plids_to_value::A
1820+
d_to_lid_to_plid::B
1821+
local_size::NTuple{N,Int}
1822+
function LocalView(
1823+
plids_to_value::AbstractArray{T,N},d_to_lid_to_plid::NTuple{N}) where {T,N}
1824+
A = typeof(plids_to_value)
1825+
B = typeof(d_to_lid_to_plid)
1826+
local_size = map(length,d_to_lid_to_plid)
1827+
new{T,N,A,B}(plids_to_value,d_to_lid_to_plid,local_size)
1828+
end
1829+
end
1830+
1831+
Base.size(a::LocalView) = a.local_size
1832+
Base.IndexStyle(::Type{<:LocalView}) = IndexCartesian()
1833+
function Base.getindex(a::LocalView{T,N},lids::Vararg{Integer,N}) where {T,N}
1834+
plids = map(_lid_to_plid,lids,a.d_to_lid_to_plid)
1835+
if all(i->i>0,plids)
1836+
a.plids_to_value[plids...]
1837+
else
1838+
zero(T)
1839+
end
1840+
end
1841+
function Base.setindex!(a::LocalView{T,N},v,lids::Vararg{Integer,N}) where {T,N}
1842+
plids = map(_lid_to_plid,lids,a.d_to_lid_to_plid)
1843+
@check all(i->i>0,plids) "You are trying to set a value that is not stored in the local portion"
1844+
a.plids_to_value[plids...] = v
1845+
end
1846+
function _lid_to_plid(lid,lid_to_plid)
1847+
plid = lid_to_plid[lid]
1848+
plid
16831849
end
16841850

16851851
function global_view(a::PVector,rows::PRange)
@@ -1772,6 +1938,13 @@ struct PSparseMatrix{T,A,B,C,D} <: AbstractMatrix{T}
17721938
end
17731939
end
17741940

1941+
function LinearAlgebra.fillstored!(a::PSparseMatrix,v)
1942+
map_parts(a.values) do values
1943+
LinearAlgebra.fillstored!(values,v)
1944+
end
1945+
a
1946+
end
1947+
17751948
function Base.copy(a::PSparseMatrix)
17761949
PSparseMatrix(
17771950
copy(a.values),
@@ -1916,9 +2089,16 @@ function LinearAlgebra.mul!(
19162089
end
19172090

19182091
function local_view(a::PSparseMatrix,rows::PRange,cols::PRange)
1919-
@notimplementedif a.rows !== rows
1920-
@notimplementedif a.cols !== cols
1921-
a.values
2092+
if a.rows === rows && a.cols === cols
2093+
a.values
2094+
else
2095+
map_parts(
2096+
a.values,rows.partition,cols.partition,a.rows.partition,a.cols.partition) do values,rows,cols,arows,acols
2097+
rmap = find_lid_map(rows,arows)
2098+
cmap = (cols === rows && acols === arows) ? rmap : find_lid_map(cols,acols)
2099+
LocalView(values,(rmap,cmap))
2100+
end
2101+
end
19222102
end
19232103

19242104
function global_view(a::PSparseMatrix,rows::PRange,cols::PRange)
@@ -1978,7 +2158,7 @@ function matrix_exchanger(values,row_exchanger,row_lids,col_lids)
19782158
gj_rcv = Table(gj_rcv_data,ptrs)
19792159
k_rcv, gi_rcv, gj_rcv
19802160
end
1981-
2161+
19822162
k_rcv, gi_rcv, gj_rcv = map_parts(setup_rcv,part,parts_rcv,row_lids,col_lids,values)
19832163

19842164
gi_snd = exchange(gi_rcv,parts_snd,parts_rcv)
@@ -2087,9 +2267,9 @@ function async_assemble!(
20872267
gi_rcv = Table(gi_rcv_data,ptrs)
20882268
gj_rcv = Table(gj_rcv_data,ptrs)
20892269
v_rcv = Table(v_rcv_data,ptrs)
2090-
gi_rcv, gj_rcv, v_rcv
2270+
gi_rcv, gj_rcv, v_rcv
20912271
end
2092-
2272+
20932273
gi_rcv, gj_rcv, v_rcv = map_parts(setup_rcv,part,parts_rcv,rows.partition,coo_values)
20942274

20952275
gi_snd, t1 = async_exchange(gi_rcv,parts_snd,parts_rcv)
@@ -2186,9 +2366,9 @@ function async_exchange!(
21862366
gi_snd = Table(gi_snd_data,ptrs)
21872367
gj_snd = Table(gj_snd_data,ptrs)
21882368
v_snd = Table(v_snd_data,ptrs)
2189-
gi_snd, gj_snd, v_snd
2369+
gi_snd, gj_snd, v_snd
21902370
end
2191-
2371+
21922372
gi_snd, gj_snd, v_snd = map_parts(
21932373
setup_snd,part,parts_snd,lids_snd,rows.partition,coo_values)
21942374

@@ -2365,4 +2545,3 @@ function IterativeSolvers.zerox(A::PSparseMatrix,b::PVector)
23652545
fill!(x, zero(T))
23662546
return x
23672547
end
2368-

src/MPIBackend.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ struct MPIBackend <: AbstractBackend end
77
const mpi = MPIBackend()
88

99
function get_part_ids(b::MPIBackend,nparts::Integer)
10-
comm = MPI.COMM_WORLD
10+
comm = MPI.Comm_dup(MPI.COMM_WORLD)
1111
@notimplementedif num_parts(comm) != nparts
1212
MPIData(get_part_id(comm),comm,(nparts,))
1313
end
1414

1515
function get_part_ids(b::MPIBackend,nparts::Tuple)
16-
comm = MPI.COMM_WORLD
16+
comm = MPI.Comm_dup(MPI.COMM_WORLD)
1717
@notimplementedif num_parts(comm) != prod(nparts)
1818
MPIData(get_part_id(comm),comm,nparts)
1919
end
@@ -48,6 +48,10 @@ get_part_id(a::MPIData) = get_part_id(a.comm)
4848
get_backend(a::MPIData) = mpi
4949
i_am_main(a::MPIData) = get_part_id(a.comm) == MAIN
5050

51+
function get_part_ids(a::MPIData)
52+
MPIData(get_part_id(a.comm),a.comm,a.size)
53+
end
54+
5155
function Base.iterate(a::MPIData)
5256
next = iterate(a.part)
5357
if next == nothing

src/PartitionedArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using LinearAlgebra
66
using Printf
77
import MPI
88
import IterativeSolvers
9+
import Distances
910

1011
export AbstractBackend
1112
export prun

test/mpi/driver_mpi_backend.jl

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using PartitionedArrays
22
using Test
3+
using MPI
34

45
function main(parts)
56

@@ -8,6 +9,10 @@ function main(parts)
89
nparts = num_parts(parts)
910
@assert nparts == 4
1011

12+
@test MPI.COMM_WORLD !== parts.comm
13+
_parts = get_part_ids(parts)
14+
@test _parts.comm === parts.comm
15+
1116
#s = size(parts)
1217
#display(map_parts(part->s,parts))
1318

0 commit comments

Comments
 (0)