Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand All @@ -35,6 +36,7 @@ EnzymeChainRulesCoreExt = "ChainRulesCore"
EnzymeDynamicPPLExt = ["ADTypes", "DynamicPPL"]
EnzymeGPUArraysCoreExt = "GPUArraysCore"
EnzymeLogExpFunctionsExt = "LogExpFunctions"
EnzymeMPIExt = "MPI"
EnzymeSpecialFunctionsExt = "SpecialFunctions"
EnzymeStaticArraysExt = "StaticArrays"

Expand All @@ -50,6 +52,7 @@ GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.6.2"
LLVM = "9.1"
LogExpFunctions = "0.3"
MPI = "0.20"
ObjectFile = "0.4, 0.5"
PrecompileTools = "1"
Preferences = "1.4"
Expand All @@ -65,6 +68,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand Down
37 changes: 37 additions & 0 deletions ext/EnzymeMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module EnzymeMPIExt

using MPI
using Enzyme

import Enzyme.EnzymeCore: EnzymeRules

function EnzymeRules.forward(config, ::Const{typeof(MPI.Allreduce!)}, rt, v, op::Const, comm::Const)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for sake of understanding, is there a reason why this wasnt handled already llvm-side?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Julia 1.10

  %14 = call i32 @PMPI_Allreduce(i64 %6, i64 %bitcast_coercion, i32 noundef 1, i32 %9, i32 %11, i32 %13) #12 [ "jl_roots"({} addrspace(10)* %2, {} addrspace(10)* %1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140676541099488 to {}*) to {} addrspace(10)*), {} addrspace(10)* %0) ], !dbg !25
0x3bb3f008
Unhandled MPI FUNCTION
UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:2211!

[63249] signal (6.-6): Aborted
in expression starting at REPL[15]:1
unknown function (ip: 0x7ff25629894c)
gsignal at /usr/lib/libc.so.6 (unknown line)
abort at /usr/lib/libc.so.6 (unknown line)
_ZN4llvm25llvm_unreachable_internalEPKcS1_j at /home/vchuravy/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/bin/../lib/julia/libLLVM-15jl.so (unknown line)
handleMPI at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:2211
handleKnownCallDerivatives at /workspace/srcdir/Enzyme/enzyme/Enzyme/CallDerivatives.cpp:2249
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6402
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:5062
EnzymeCreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:661
EnzymeCreateForwardDiff at /home/vchuravy/.julia/packages/Enzyme/rsnI8/src/api.jl:342

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op = op.val
comm = comm.val

if !(op == MPI.SUM || op == +)
error("Forward mode MPI.Allreduce! is only implemented for MPI.SUM.")
end

if EnzymeRules.needs_primal(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The needs primal only is relevant to returning the result, if something is mutated it still needs to happen unless the argument is marked noneed

MPI.Allreduce!(v.val, op, comm)
end

if EnzymeRules.width(config) == 1
MPI.Allreduce!(v.dval, op, comm)
else
# would be nice to use MPI non-blocking collectives
foreach(v.dval) do dval
MPI.Allreduce!(dval, op, comm)
Copy link
Collaborator

@michel2323 michel2323 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the dvals really non-contiguous? If so, I'd think about a derived datatype for v.dval to get this into one Allreduce and get the benefits of vector-mode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the dvals are all independent memory allocations.

The structure here is a NTuple{N, Vector{Float64}}, I have been wanting a variant where we use continous memory, but we can only do that from 1.11 and we likely can't gurantuee it.

I was thinking about derived types as well, but I have not yet managed to convince MPI to understand a Vector{Float64}

end
end

if EnzymeRules.needs_primal(config)
return v
else
return v.dval
end
end


end
26 changes: 26 additions & 0 deletions test/integration/MPI/collectives.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using MPI
using Enzyme
using Test

MPI.Init()

@show Base.get_extension(Enzyme, :EnzymeMPIExt)

buff = Ref(3.0)
comm = MPI.COMM_WORLD

MPI.Allreduce!(buff, MPI.SUM, comm)

@test buff[] == MPI.Comm_size(comm) * 3.0

buff[] = 3.0
dbuff = Ref(0.0)

if MPI.Comm_rank(comm) == 0
dbuff[] = 1.0
end

autodiff(ForwardWithPrimal, MPI.Allreduce!, Duplicated(buff, dbuff), Const(MPI.SUM), Const(comm))

@test buff[] == MPI.Comm_size(comm) * 3.0
@test dbuff[] == 1.0
4 changes: 4 additions & 0 deletions test/integration/MPI/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using MPI
using Enzyme
using Test

@testset "collectives" for np in (1, 2, 4)
run(`$(mpiexec()) -n $np $(Base.julia_cmd()) --project=$(@__DIR__) $(joinpath(@__DIR__, "collectives.jl"))`)
end
Loading