-
Notifications
You must be signed in to change notification settings - Fork 82
Add MPI extension with Allreduce! forward rule #2745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| module EnzymeMPIExt | ||
|
|
||
| using MPI | ||
| using Enzyme | ||
|
|
||
| import Enzyme.EnzymeCore: EnzymeRules | ||
|
|
||
| function EnzymeRules.forward(config, ::Const{typeof(MPI.Allreduce!)}, rt, v, op::Const, comm::Const) | ||
| op = op.val | ||
| comm = comm.val | ||
|
|
||
| if !(op == MPI.SUM || op == +) | ||
| error("Forward mode MPI.Allreduce! is only implemented for MPI.SUM.") | ||
| end | ||
|
|
||
| if EnzymeRules.needs_primal(config) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The needs primal only is relevant to returning the result, if something is mutated it still needs to happen unless the argument is marked noneed |
||
| MPI.Allreduce!(v.val, op, comm) | ||
| end | ||
|
|
||
| if EnzymeRules.width(config) == 1 | ||
| MPI.Allreduce!(v.dval, op, comm) | ||
| else | ||
| # would be nice to use MPI non-blocking collectives | ||
| foreach(v.dval) do dval | ||
| MPI.Allreduce!(dval, op, comm) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes the dvals are all independent memory allocations. The structure here is a I was thinking about derived types as well, but I have not yet managed to convince MPI to understand a |
||
| end | ||
| end | ||
|
|
||
| if EnzymeRules.needs_primal(config) | ||
| return v | ||
| else | ||
| return v.dval | ||
| end | ||
| end | ||
|
|
||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| using MPI | ||
| using Enzyme | ||
| using Test | ||
|
|
||
| MPI.Init() | ||
|
|
||
| @show Base.get_extension(Enzyme, :EnzymeMPIExt) | ||
|
|
||
| buff = Ref(3.0) | ||
| comm = MPI.COMM_WORLD | ||
|
|
||
| MPI.Allreduce!(buff, MPI.SUM, comm) | ||
|
|
||
| @test buff[] == MPI.Comm_size(comm) * 3.0 | ||
|
|
||
| buff[] = 3.0 | ||
| dbuff = Ref(0.0) | ||
|
|
||
| if MPI.Comm_rank(comm) == 0 | ||
| dbuff[] = 1.0 | ||
| end | ||
|
|
||
| autodiff(ForwardWithPrimal, MPI.Allreduce!, Duplicated(buff, dbuff), Const(MPI.SUM), Const(comm)) | ||
|
|
||
| @test buff[] == MPI.Comm_size(comm) * 3.0 | ||
| @test dbuff[] == 1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,7 @@ | ||
| using MPI | ||
| using Enzyme | ||
| using Test | ||
|
|
||
| @testset "collectives" for np in (1, 2, 4) | ||
| run(`$(mpiexec()) -n $np $(Base.julia_cmd()) --project=$(@__DIR__) $(joinpath(@__DIR__, "collectives.jl"))`) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for sake of understanding, is there a reason why this wasnt handled already llvm-side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Julia 1.10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay we should fix that needs EnzymeAD/Enzyme#2530
we may also need callingconv stuff like
https://github.com/EnzymeAD/Enzyme/blob/2e6f771cf4570d2800a67e9da9298c0d612d5020/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp#L5076 and perhaps elsewhere