Skip to content

Commit 575f622

Browse files
michel2323vchuravylcwgiordano
authored
Test for MPI.Irecv/MPI.Isend/MPI.Wait (#518)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com> Co-authored-by: Lucas Wilcox <lucas@swirlee.com> Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com>
1 parent e099b56 commit 575f622

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using MPI
2+
using Enzyme
3+
using Test
4+
5+
6+
function halo(x)
7+
np = MPI.Comm_size(MPI.COMM_WORLD)
8+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
9+
requests = Vector{MPI.Request}()
10+
if rank != 0
11+
buf = @view x[1:1]
12+
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest = rank - 1, tag = 0))
13+
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source = rank - 1, tag = 0))
14+
end
15+
if rank != np - 1
16+
buf = @view x[end:end]
17+
push!(requests, MPI.Isend(x[(end - 1):(end - 1)], MPI.COMM_WORLD; dest = rank + 1, tag = 0))
18+
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source = rank + 1, tag = 0))
19+
end
20+
for request in requests
21+
MPI.Wait(request)
22+
end
23+
return nothing
24+
end
25+
26+
MPI.Init()
27+
np = MPI.Comm_size(MPI.COMM_WORLD)
28+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
29+
nl = rank == 0 ? 0 : 2
30+
nr = rank == np - 1 ? 0 : 2
31+
nlocal = nr + nl + 1
32+
33+
x = zeros(nlocal)
34+
fill!(x, Float64(rank))
35+
halo(x)
36+
MPI.Barrier(MPI.COMM_WORLD)
37+
38+
@test x[nl + 1] == Float64(rank) # Local
39+
if rank != 0
40+
@test x[1] == Float64(rank - 1) # Recv
41+
@test x[2] == Float64(rank) # Send
42+
end
43+
if rank != np - 1
44+
@test x[end] == Float64(rank + 1) # Recv
45+
@test x[end - 1] == Float64(rank) # Send
46+
end
47+
48+
dx = zeros(nlocal)
49+
fill!(dx, Float64(rank))
50+
autodiff(Reverse, halo, Duplicated(x, dx))
51+
MPI.Barrier(MPI.COMM_WORLD)
52+
53+
@test dx[nl + 1] == Float64(rank) # Local -> no change
54+
if rank != 0
55+
@test dx[1] == 0.0 # Recv -> Send & zero'd
56+
@test dx[2] == Float64(rank + rank - 1) # Send -> += Recv
57+
end
58+
if rank != np - 1
59+
@test dx[end] == 0.0 # Recv -> Send & zero'd
60+
@test dx[end - 1] == Float64(rank + rank + 1) # Send -> += Recv
61+
end
62+
63+
fill!(dx, Float64(rank))
64+
autodiff(Forward, halo, Duplicated(x, dx))
65+
MPI.Barrier(MPI.COMM_WORLD)
66+
67+
@test dx[nl + 1] == Float64(rank)
68+
if rank != 0
69+
@test dx[1] == Float64(rank - 1)
70+
@test dx[2] == Float64(rank)
71+
end
72+
if rank != np - 1
73+
@test dx[end] == Float64(rank + 1)
74+
@test dx[end - 1] == Float64(rank)
75+
end

test/integration/MPI/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ using MPI
22
using Enzyme
33
using Test
44

5+
# Current MPI support (needs to be tested from Julia)
6+
# - MPI_Ssend
7+
# - MPI_Waitall
8+
# - MPI_Barrier/MPI_Probe
9+
# - MPI_Allreduce
10+
# - MPI_Bcast
11+
# - MPI_Reduce
12+
# - MPI_Gather/MPI_Scatter
13+
# - MPI_Allgather
14+
515
# Query functions MPI_Comm_size/MPI_Comm_rank
616
@testset "queries" for np in (1, 2, 4)
717
run(`$(mpiexec()) -n $np $(Base.julia_cmd()) --project=$(@__DIR__) $(joinpath(@__DIR__, "queries.jl"))`)
@@ -11,3 +21,8 @@ end
1121
@testset "blocking_ring" for np in (1, 2, 4)
1222
run(`$(mpiexec()) -n $np $(Base.julia_cmd()) --project=$(@__DIR__) $(joinpath(@__DIR__, "blocking_ring.jl"))`)
1323
end
24+
25+
# Test MPI_Irecv/MPI_Isend/MPI_Wait with a non-blocking halo exchange pattern
26+
VERSION >= v"1.11.0" && @testset "nonblocking_halo" for np in (1, 2, 4)
27+
run(`$(mpiexec()) -n $np $(Base.julia_cmd()) --project=$(@__DIR__) $(joinpath(@__DIR__, "nonblocking_halo.jl"))`)
28+
end

0 commit comments

Comments
 (0)