Skip to content

Support MultiRequest #2749

@vchuravy

Description

@vchuravy

MPI.jl has a useful MultiRequest API that stores raw MPI.API.MPI_Request under the hood.

using MPI
using Enzyme


MPI.Init()
comm = MPI.COMM_WORLD


function ring(token, comm)
    rank = MPI.Comm_rank(comm)
    N = MPI.Comm_size(comm)
    reqs = MPI.MultiRequest(2)
    # reqs = MPI.UnsafeMultiRequest(2)

    buf = Ref(token)

    if rank != 0
        MPI.Irecv!(buf, comm, reqs[2]; source = rank - 1)  
        MPI.Wait(reqs[2])
    end

    MPI.Isend(buf, comm, reqs[1]; dest = mod(rank + 1, N))

    if rank == 0
        MPI.Irecv!(buf, comm, reqs[2]; source = N - 1)
        MPI.Wait(reqs[2])
    end 
    MPI.Wait(reqs[1])

    return buf[]
end

token = MPI.Comm_rank(comm) == 0 ? 1.0 : NaN
@test ring(token, comm) == 1.0

autodiff(Forward, ring, Duplicated(1.0, 1.0), Const(comm))

Currently fails with:

julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:307: void AdjointGenerator::handleMPI(llvm::CallInst&, llvm::Function*, llvm::StringRef): Assertion `!gutils->isConstantValue(call.getOperand(6))' failed.

[14280] signal (6.-6): Aborted
in expression starting at /home/vchuravy/src/Enzyme/test/integration/MPI/multi_request.jl:98
unknown function (ip: 0x7f16e8a9894c)
gsignal at /usr/lib/libc.so.6 (unknown line)
abort at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x7f16e8a254e2)
handleMPI at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:307
handleKnownCallDerivatives at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:2254
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6405
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
recursivelyHandleSubfunction at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:4984
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6608
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
EnzymeCreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:661
EnzymeCreateForwardDiff at /home/vchuravy/src/Enzyme/src/api.jl:342
unknown function (ip: 0x7f16e1732298)
_jl_invoke at /cache/build/builder-amdci5-7/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-7/julialang/julia-release-1-dot-10/src/gf.c:3077
macro expansion at /home/vchuravy/src/Enzyme/src/compiler.jl:2687 [inlined]
macro expansion at /home/vchuravy/.julia/packages/LLVM/iza6e/src/base.jl:97 [inlined]
enzyme! at /home/vchuravy/src/Enzyme/src/compiler.jl:2512

A MultiRequest looks like this:

julia> reqs = MPI.MultiRequest(4)
4-element MPI.MultiRequest:
 null request
 null request
 null request
 null request

julia> dump(reqs)
MPI.MultiRequest
  vals: Array{Int32}((4,)) Int32[738197504, 738197504, 738197504, 738197504]
  buffers: Array{Any}((4,))
    1: Nothing nothing
    2: Nothing nothing
    3: Nothing nothing
    4: Nothing nothing

The crux is that the raw MPI.API.MPI_Request is just an Int32

julia> MPI.API.MPI_Request
Int32

#2747 adds a realistic test

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions