Skip to content

Commit e82bc60

Browse files
authored
Merge pull request #21 from fverdugo/misc_improvements
Misc improvements
2 parents d0eb100 + e2d64fb commit e82bc60

File tree

4 files changed

+172
-54
lines changed

4 files changed

+172
-54
lines changed

src/Helpers.jl

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ function Base.getindex(a::Table{T},i::Integer) where T
8080
end
8181
v
8282
end
83+
Base.copy(a::Table) = Table(copy(a.data),copy(a.ptrs))
8384

8485
function Table(a::AbstractArray{<:AbstractArray})
8586
data, ptrs = generate_data_and_ptrs(a)

src/Interfaces.jl

+156-51
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ Base.eltype(::Type{<:PData{T}}) where T = T
4343
Base.ndims(a::PData{T,N}) where {T,N} = N
4444
Base.ndims(::Type{<:PData{T,N}}) where {T,N} = N
4545

46+
Base.copy(a::PData) = map_parts(copy,a)
47+
4648
#function map_parts(task,a...)
4749
# map_parts(task,map(PData,a)...)
4850
#end
@@ -227,6 +229,7 @@ function xscan_main(op,a::PData;init)
227229
b
228230
end
229231

232+
# TODO improve the mechanism for waiting
230233
# Non-blocking in-place exchange
231234
# In this version, sending a number per part is enough
232235
# We have another version below to send a vector of numbers per part (compressed in a Table)
@@ -421,6 +424,7 @@ function discover_parts_snd(parts_rcv::PData,::Nothing)
421424
discover_parts_snd(parts_rcv)
422425
end
423426

427+
# TODO simplify type signature
424428
# Arbitrary set of global indices stored in a part
425429
# gid_to_part can be omitted with nothing since only for some particular parallel
426430
# data layouts (e.g. uniform partitions) it is efficient to recover this information.
@@ -468,6 +472,19 @@ struct IndexSet{A,B,C,D,E,F,G}
468472
end
469473
end
470474

475+
function Base.copy(a::IndexSet)
476+
IndexSet(
477+
copy(a.part),
478+
copy(a.ngids),
479+
copy(a.lid_to_gid),
480+
copy(a.lid_to_part),
481+
a.gid_to_part === nothing ? nothing : copy(a.gid_to_part),
482+
copy(a.oid_to_lid),
483+
copy(a.hid_to_lid),
484+
copy(a.lid_to_ohid),
485+
copy(a.gid_to_lid))
486+
end
487+
471488
num_gids(a::IndexSet) = a.ngids
472489
num_lids(a::IndexSet) = length(a.lid_to_part)
473490
num_oids(a::IndexSet) = length(a.oid_to_lid)
@@ -617,6 +634,14 @@ struct Exchanger{B,C}
617634
end
618635
end
619636

637+
function Base.copy(a::Exchanger)
638+
Exchanger(
639+
copy(a.parts_rcv),
640+
copy(a.parts_snd),
641+
copy(a.lids_rcv),
642+
copy(a.lids_snd))
643+
end
644+
620645
function Exchanger(ids::PData{<:IndexSet},neighbors=nothing)
621646

622647
parts = get_part_ids(ids)
@@ -707,8 +732,18 @@ end
707732
function async_exchange!(
708733
values::PData{<:AbstractVector{T}},
709734
exchanger::Exchanger,
710-
t0::PData=_empty_tasks(exchanger.parts_rcv);
711-
reduce_op=_replace) where T
735+
t0::PData=_empty_tasks(exchanger.parts_rcv)) where T
736+
737+
async_exchange!(_replace,values,exchanger,t0)
738+
end
739+
740+
_replace(x,y) = y
741+
742+
function async_exchange!(
743+
combine_op,
744+
values::PData{<:AbstractVector{T}},
745+
exchanger::Exchanger,
746+
t0::PData=_empty_tasks(exchanger.parts_rcv)) where T
712747

713748
# Allocate buffers
714749
data_rcv = allocate_rcv_buffer(T,exchanger)
@@ -740,25 +775,31 @@ function async_exchange!(
740775
wait(schedule(t2))
741776
for p in 1:length(lids_rcv.data)
742777
lid = lids_rcv.data[p]
743-
values[lid] = reduce_op(values[lid],data_rcv.data[p])
778+
values[lid] = combine_op(values[lid],data_rcv.data[p])
744779
end
745780
end
746781
end
747782

748783
t3
749784
end
750785

751-
_replace(x,y) = y
786+
function async_exchange!(
787+
values::PData{<:Table},
788+
exchanger::Exchanger,
789+
t0::PData=_empty_tasks(exchanger.parts_rcv))
790+
791+
async_exchange!(_replace,values,exchanger,t0)
792+
end
752793

753794
function async_exchange!(
795+
combine_op,
754796
values::PData{<:Table},
755797
exchanger::Exchanger,
756-
t0::PData=_empty_tasks(exchanger.parts_rcv);
757-
reduce_op=_replace)
798+
t0::PData=_empty_tasks(exchanger.parts_rcv))
758799

759800
data, ptrs = map_parts(t->(t.data,t.ptrs),values)
760801
t_exchanger = _table_exchanger(exchanger,ptrs)
761-
async_exchange!(data,t_exchanger,t0;reduce_op=reduce_op)
802+
async_exchange!(combine_op,data,t_exchanger,t0)
762803
end
763804

764805
function _table_exchanger(exchanger,values)
@@ -804,7 +845,7 @@ function _table_lids_snd(lids_snd,tptrs)
804845
k_snd
805846
end
806847

807-
# TODO mutable is needed to correctly implement add_gid!
848+
# mutable is needed to correctly implement add_gid!
808849
mutable struct PRange{A,B} <: AbstractUnitRange{Int}
809850
ngids::Int
810851
lids::A
@@ -826,14 +867,12 @@ mutable struct PRange{A,B} <: AbstractUnitRange{Int}
826867
end
827868
end
828869

829-
# TODO in MPI this causes to copy the world comm
830-
# and makes some assertions to fail.
831870
function Base.copy(a::PRange)
832-
ngids = copy(a.ngids)
833-
lids = deepcopy(a.lids)
834-
ghost = copy(a.ghost)
835-
exchanger = deepcopy(a.exchanger)
836-
PRange(ngids,lids,ghost,exchanger)
871+
PRange(
872+
copy(a.ngids),
873+
copy(a.lids),
874+
copy(a.ghost),
875+
copy(a.exchanger))
837876
end
838877

839878
function PRange(
@@ -902,30 +941,66 @@ function PRange(parts::PData{<:Integer},noids::PData{<:Integer})
902941
PRange(ngids,lids,ghost)
903942
end
904943

944+
function PRange(
945+
parts::PData{<:Integer,N},
946+
ngids::NTuple{N,<:Integer}) where N
947+
948+
np = size(parts)
949+
lids = map_parts(parts) do part
950+
gids = _oid_to_gid(ngids,np,part)
951+
lid_to_gid = gids
952+
lid_to_part = fill(part,length(gids))
953+
oid_to_lid = Int32(1):Int32(length(gids))
954+
hid_to_lid = collect(Int32(1):Int32(0))
955+
part_to_gid = _part_to_gid(ngids,np)
956+
gid_to_part = GidToPart(ngids,part_to_gid)
957+
IndexSet(
958+
part,
959+
prod(ngids),
960+
lid_to_gid,
961+
lid_to_part,
962+
gid_to_part,
963+
oid_to_lid,
964+
hid_to_lid)
965+
end
966+
ghost = false
967+
PRange(prod(ngids),lids,ghost)
968+
end
969+
970+
function PCartesianIndices(
971+
parts::PData{<:Integer,N},
972+
ngids::NTuple{N,<:Integer}) where N
973+
974+
np = size(parts)
975+
lids = map_parts(parts) do part
976+
cis_parts = CartesianIndices(np)
977+
p = Tuple(cis_parts[part])
978+
d_to_odid_to_gdid = map(_oid_to_gid,ngids,np,p)
979+
CartesianIndices(d_to_odid_to_gdid)
980+
end
981+
lids
982+
end
983+
984+
struct WithGhost end
985+
with_ghost = WithGhost()
986+
987+
struct NoGhost end
988+
no_ghost = NoGhost()
905989

906-
# TODO this is type instable
907990
function PRange(
908991
parts::PData{<:Integer,N},
909-
ngids::NTuple{N,<:Integer};
910-
ghost::Bool=false) where N
992+
ngids::NTuple{N,<:Integer},
993+
::WithGhost) where N
911994

912995
np = size(parts)
913996
lids = map_parts(parts) do part
914-
if ghost
915-
cp = Tuple(CartesianIndices(np)[part])
916-
d_to_ldid_to_gdid = map(_lid_to_gid,ngids,np,cp)
917-
lid_to_gid = _id_tensor_product(Int,d_to_ldid_to_gdid,ngids)
918-
d_to_nldids = map(length,d_to_ldid_to_gdid)
919-
lid_to_part = _lid_to_part(d_to_nldids,np,cp)
920-
oid_to_lid = collect(Int32,findall(lid_to_part .== part))
921-
hid_to_lid = collect(Int32,findall(lid_to_part .!= part))
922-
else
923-
gids = _oid_to_gid(ngids,np,part)
924-
lid_to_gid = gids
925-
lid_to_part = fill(part,length(gids))
926-
oid_to_lid = Int32(1):Int32(length(gids))
927-
hid_to_lid = collect(Int32(1):Int32(0))
928-
end
997+
cp = Tuple(CartesianIndices(np)[part])
998+
d_to_ldid_to_gdid = map(_lid_to_gid,ngids,np,cp)
999+
lid_to_gid = _id_tensor_product(Int,d_to_ldid_to_gdid,ngids)
1000+
d_to_nldids = map(length,d_to_ldid_to_gdid)
1001+
lid_to_part = _lid_to_part(d_to_nldids,np,cp)
1002+
oid_to_lid = collect(Int32,findall(lid_to_part .== part))
1003+
hid_to_lid = collect(Int32,findall(lid_to_part .!= part))
9291004
part_to_gid = _part_to_gid(ngids,np)
9301005
gid_to_part = GidToPart(ngids,part_to_gid)
9311006
IndexSet(
@@ -937,29 +1012,41 @@ function PRange(
9371012
oid_to_lid,
9381013
hid_to_lid)
9391014
end
1015+
ghost = true
9401016
PRange(prod(ngids),lids,ghost)
9411017
end
9421018

943-
# TODO this is type instable
1019+
function PRange(
1020+
parts::PData{<:Integer,N},
1021+
ngids::NTuple{N,<:Integer},
1022+
::NoGhost) where N
1023+
1024+
PRange(parts,ngids)
1025+
end
1026+
9441027
function PCartesianIndices(
9451028
parts::PData{<:Integer,N},
946-
ngids::NTuple{N,<:Integer};
947-
ghost::Bool= false) where N
1029+
ngids::NTuple{N,<:Integer},
1030+
::WithGhost) where N
9481031

9491032
np = size(parts)
9501033
lids = map_parts(parts) do part
9511034
cis_parts = CartesianIndices(np)
9521035
p = Tuple(cis_parts[part])
953-
if ghost
954-
d_to_odid_to_gdid = map(_lid_to_gid,ngids,np,p)
955-
else
956-
d_to_odid_to_gdid = map(_oid_to_gid,ngids,np,p)
957-
end
1036+
d_to_odid_to_gdid = map(_lid_to_gid,ngids,np,p)
9581037
CartesianIndices(d_to_odid_to_gdid)
9591038
end
9601039
lids
9611040
end
9621041

1042+
function PCartesianIndices(
1043+
parts::PData{<:Integer,N},
1044+
ngids::NTuple{N,<:Integer},
1045+
::NoGhost) where N
1046+
1047+
PCartesianIndices(parts,ngids)
1048+
end
1049+
9631050
function _oid_to_gid(ngids::Integer,np::Integer,p::Integer)
9641051
_olength = ngids ÷ np
9651052
_offset = _olength * (p-1)
@@ -1132,7 +1219,7 @@ function add_gid!(a::PRange,gids::PData{<:AbstractArray{<:Integer}})
11321219
end
11331220

11341221
function add_gid(a::PRange,gids::PData{<:AbstractArray{<:Integer}})
1135-
lids = map_parts(deepcopy,a.lids)
1222+
lids = map_parts(copy,a.lids)
11361223
b = PRange(a.ngids,lids)
11371224
add_gid!(b,gids)
11381225
b
@@ -1512,15 +1599,20 @@ function async_exchange!(
15121599
end
15131600

15141601
# Non-blocking assembly
1515-
# TODO reduce op as first argument and init kwargument
15161602
function async_assemble!(
15171603
a::PVector,
1518-
t0::PData=_empty_tasks(a.rows.exchanger.parts_rcv);
1519-
reduce_op=+)
1604+
t0::PData=_empty_tasks(a.rows.exchanger.parts_rcv))
1605+
async_assemble!(+,a,t0)
1606+
end
1607+
1608+
function async_assemble!(
1609+
combine_op,
1610+
a::PVector,
1611+
t0::PData=_empty_tasks(a.rows.exchanger.parts_rcv))
15201612

15211613
exchanger_rcv = a.rows.exchanger # receives data at ghost ids from remote parts
15221614
exchanger_snd = reverse(exchanger_rcv) # sends data at ghost ids to remote parts
1523-
t1 = async_exchange!(a.values,exchanger_snd,t0;reduce_op=reduce_op)
1615+
t1 = async_exchange!(combine_op,a.values,exchanger_snd,t0)
15241616
map_parts(t1,a.values,a.rows.lids) do t1,values,lids
15251617
@task begin
15261618
wait(schedule(t1))
@@ -1556,6 +1648,14 @@ struct PSparseMatrix{T,A,B,C,D} <: AbstractMatrix{T}
15561648
end
15571649
end
15581650

1651+
function Base.copy(a::PSparseMatrix)
1652+
PSparseMatrix(
1653+
copy(a.values),
1654+
copy(a.rows),
1655+
copy(a.cols),
1656+
copy(a.exchanger))
1657+
end
1658+
15591659
function PSparseMatrix(
15601660
values::PData{<:AbstractSparseMatrix{T}},
15611661
rows::PRange,
@@ -1924,13 +2024,19 @@ end
19242024
# Non-blocking assembly
19252025
function async_assemble!(
19262026
a::PSparseMatrix,
1927-
t0::PData=_empty_tasks(a.exchanger.parts_rcv);
1928-
reduce_op=+)
2027+
t0::PData=_empty_tasks(a.exchanger.parts_rcv))
2028+
async_assemble!(+,a,t0)
2029+
end
2030+
2031+
function async_assemble!(
2032+
combine_op,
2033+
a::PSparseMatrix,
2034+
t0::PData=_empty_tasks(a.exchanger.parts_rcv))
19292035

19302036
exchanger_rcv = a.exchanger # receives data at ghost ids from remote parts
19312037
exchanger_snd = reverse(exchanger_rcv) # sends data at ghost ids to remote parts
19322038
nzval = map_parts(nonzeros,a.values)
1933-
t1 = async_exchange!(nzval,exchanger_snd,t0;reduce_op=reduce_op)
2039+
t1 = async_exchange!(combine_op,nzval,exchanger_snd,t0)
19342040
map_parts(t1,nzval,exchanger_snd.lids_snd) do t1,nzval,lids_snd
19352041
@task begin
19362042
wait(schedule(t1))
@@ -1944,8 +2050,7 @@ function async_assemble!(
19442050
J::PData{<:AbstractVector{<:Integer}},
19452051
V::PData{<:AbstractVector},
19462052
rows::PRange,
1947-
t0::PData=_empty_tasks(rows.exchanger.parts_rcv);
1948-
reduce_op=+)
2053+
t0::PData=_empty_tasks(rows.exchanger.parts_rcv))
19492054

19502055
map_parts(waitschedule,t0)
19512056

src/PartitionedArrays.jl

+4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ export Exchanger
5959
export allocate_rcv_buffer
6060
export allocate_snd_buffer
6161
export PRange
62+
export WithGhost
63+
export with_ghost
64+
export NoGhost
65+
export no_ghost
6266
export add_gid
6367
export add_gid!
6468
export to_lid!

0 commit comments

Comments
 (0)