From d1ff71455edcaf3a4b45749cef82333dcd85ce70 Mon Sep 17 00:00:00 2001 From: Agata Skorupka <45850123+askorupka@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:13:37 +0200 Subject: [PATCH] feat: Distributed data parallel training support (#2464) * first experiment distributed * feat: add DistributedUtils (MPI&NCCL working) * feat: add DistributedUtils (MPI&NCCL working) * fix: no need for amdgpu now * chore: cleanup&propose how to use amdgpu * chore: add preferences for CUDA-awareness * feat: fix devices for CUDA-awareness * chore: add tests * chore: get rid of unnecessary deps * chore: update NEWS.md * chore: cleanup env * chore: update docs * chore: update docs & cleanup * chore: update docs & cleanup * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello * Update docs/src/guide/gpu.md * Update docs/src/guide/gpu.md * chore: add PR review suggestions * chore: fix docs * fix: add runtests.jl * chore: small docs update * chore: remove pkgs from deps --------- Co-authored-by: CarloLucibello Co-authored-by: Carlo Lucibello --- NEWS.md | 4 + Project.toml | 8 + docs/src/guide/gpu.md | 117 +++++++++++ ext/FluxMPIExt/FluxMPIExt.jl | 183 +++++++++++++++++ ext/FluxMPINCCLExt/FluxMPINCCLExt.jl | 109 ++++++++++ src/Flux.jl | 5 + src/distributed/backend.jl | 44 +++++ src/distributed/public_api.jl | 284 +++++++++++++++++++++++++++ test/ext_distributed/common.jl | 89 +++++++++ test/ext_distributed/data.jl | 26 +++ test/ext_distributed/optimizer.jl | 28 +++ test/ext_distributed/runtests.jl | 35 ++++ test/ext_distributed/synchronized.jl | 91 +++++++++ test/runtests.jl | 19 ++ 14 files changed, 1042 insertions(+) create mode 100644 ext/FluxMPIExt/FluxMPIExt.jl create mode 100644 ext/FluxMPINCCLExt/FluxMPINCCLExt.jl create mode 100644 src/distributed/backend.jl create mode 100644 src/distributed/public_api.jl create mode 100644 test/ext_distributed/common.jl create mode 100644 test/ext_distributed/data.jl create mode 100644 test/ext_distributed/optimizer.jl create mode 100644 test/ext_distributed/runtests.jl create mode 100644 test/ext_distributed/synchronized.jl diff --git a/NEWS.md b/NEWS.md index a4b0856327..0448b74d77 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,10 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.18 +* Add [support for distributed data parallel training](https://github.com/FluxML/Flux.jl/pull/2446). +* MPI and NCCL backend available with `FluxMPIExt` and `FluxMPINCCLExt` extensions respectively. + ## v0.14.17 * Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`. diff --git a/Project.toml b/Project.toml index c5b76f1800..d5605b7beb 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -26,7 +27,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] @@ -34,6 +37,8 @@ FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxEnzymeExt = "Enzyme" +FluxMPIExt = "MPI" +FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] FluxMetalExt = "Metal" [compat] @@ -45,14 +50,17 @@ Compat = "4.10.0" Enzyme = "0.12" Functors = "0.4" MLUtils = "0.4" +MPI = "0.20.19" MacroTools = "0.5" Metal = "0.5, 1" +NCCL = "0.1.1" NNlib = "0.9.22" OneHotArrays = "0.2.4" Optimisers = "0.3.3" Preferences = "1" ProgressLogging = "0.1" Reexport = "1.0" +Setfield = "1.1" SpecialFunctions = "2.1.2" Statistics = "1" Zygote = "0.6.67" diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index ffce90b055..b9cd6d1f8c 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -385,3 +385,120 @@ Flux.supported_devices Flux.get_device Flux.gpu_backend! ``` + +## Distributed data parallel training + +!!! danger "Experimental" + + Distributed support is experimental and could change in the future. + + +Flux supports now distributed data parallel training with `DistributedUtils` module. +If you want to run your code on multiple GPUs, you have to install `MPI.jl` (see [docs](https://juliaparallel.org/MPI.jl/stable/usage/) for more info). + +```julia-repl +julia> using MPI + +julia> MPI.install_mpiexecjl() +``` + +Now you can run your code with `mpiexecjl --project=. -n julia .jl` from CLI. + +You can use either the `MPIBackend` or `NCCLBackend`, the latter only if also `NCCL.jl` is loaded. First, initialize a backend with `DistributedUtils.initialize`, e.g. + +```julia-repl +julia> using Flux, MPI, NCCL, CUDA + +julia> CUDA.allowscalar(false) + +julia> DistributedUtils.initialize(NCCLBackend) + +julia> backend = DistributedUtils.get_distributed_backend(NCCLBackend) +NCCLBackend{Communicator, MPIBackend{MPI.Comm}}(Communicator(Ptr{NCCL.LibNCCL.ncclComm} @0x000000000607a660), MPIBackend{MPI.Comm}(MPI.Comm(1140850688))) +``` + +Pass your model, as well as any data to GPU device. +```julia-repl +julia> model = Chain(Dense(1 => 256, tanh), Dense(256 => 1)) |> gpu +Chain( + Dense(1 => 256, tanh), # 512 parameters + Dense(256 => 1), # 257 parameters +) # Total: 4 arrays, 769 parameters, 744 bytes. + +julia> x = rand(Float32, 1, 16) |> gpu +1×16 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}: + 0.239324 0.331029 0.924996 0.55593 0.853093 0.874513 0.810269 0.935858 0.477176 0.564591 0.678907 0.729682 0.96809 0.115833 0.66191 0.75822 + +julia> y = x .^ 3 +1×16 CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}: + 0.0137076 0.0362744 0.791443 0.171815 0.620854 0.668804 0.53197 0.819654 0.108651 0.179971 0.312918 0.388508 0.907292 0.00155418 0.29 0.435899 +``` + +In this case, we are training on a total of `16 * number of processes` samples. You can also use `DistributedUtils.DistributedDataContainer` to split the data uniformly across processes (or do it manually). + +```julia-repl +julia> data = DistributedUtils.DistributedDataContainer(backend, x) +Flux.DistributedUtils.DistributedDataContainer(Float32[0.23932439 0.33102947 … 0.66191036 0.75822026], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) +``` + +You have to wrap your model in `DistributedUtils.FluxDistributedModel` and synchronize it (broadcast accross all processes): +```julia-repl +julia> model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0) +Chain( + Dense(1 => 256, tanh), # 512 parameters + + Dense(256 => 1), # 257 parameters +) # Total: 4 arrays, 769 parameters, 744 bytes. +``` + +Time to set up an optimizer by using `DistributedUtils.DistributedOptimizer` and synchronize it as well. +```julia-repl +julia> using Optimisers + +julia> opt = DistributedUtils.DistributedOptimizer(backend, Optimisers.Adam(0.001f0)) +DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)) + +julia> st_opt = Optimisers.setup(opt, model) +(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),) + +julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0) +(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),) +``` + +Now you can define loss and train the model. +```julia-repl +julia> loss(model) = mean((model(x) .- y).^2) +loss (generic function with 1 method) + +julia> for epoch in 1:100 + global model, st_opt + l, grad = Zygote.withgradient(loss, model) + println("Epoch $epoch: Loss $l") + st_opt, model = Optimisers.update(st_opt, model, grad[1]) + end +Epoch 1: Loss 0.011638729 +Epoch 2: Loss 0.0116432225 +Epoch 3: Loss 0.012763695 +... +``` + +Remember that in order to run it on multiple GPUs you have to run from CLI `mpiexecjl --project=. -n julia .jl`, +where `` is the number of processes that you want to use. The number of processes usually corresponds to the number of gpus. + +By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`. +Then test if your MPI is CUDA-aware by +```julia-repl +julia> import Pkg +julia> Pkg.test("MPI"; test_args=["--backend=CUDA"]) +``` + +If it is, set your local preference as below +```julia-repl +julia> using Preferences +julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true) +``` + +!!! warning "Known shortcomings" + + We don't run CUDA-aware tests so you're running it at own risk. + diff --git a/ext/FluxMPIExt/FluxMPIExt.jl b/ext/FluxMPIExt/FluxMPIExt.jl new file mode 100644 index 0000000000..57cb338479 --- /dev/null +++ b/ext/FluxMPIExt/FluxMPIExt.jl @@ -0,0 +1,183 @@ +module FluxMPIExt + +using CUDA +using Flux: MPIBackend, NCCLBackend, DistributedUtils, + AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu, + get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE +using MPI: MPI + +if Base.find_package("AMDGPU") !== nothing + using AMDGPU +end + + +function DistributedUtils.__initialize( + ::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing, + force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg + !MPI.Initialized() && MPI.Init() + DistributedUtils.MPI_Initialized[] = true + + local_rank = MPI.Comm_rank(MPI.COMM_WORLD) + + if cuda_devices !== missing && CUDA.functional() + if cuda_devices === nothing + CUDA.device!((local_rank + 1) % length(CUDA.devices())) + else + CUDA.device!(cuda_devices[local_rank + 1]) + end + elseif force_cuda + error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).") + end + + if Base.find_package("AMDGPU") !== nothing + if amdgpu_devices !== missing && AMDGPU.functional() + if amdgpu_devices === nothing + AMDGPU.device!((local_rank + 1) % length(AMDGPU.devices())) + else + AMDGPU.device!(amdgpu_devices[local_rank + 1]) + end + elseif force_amdgpu + error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).") + end + end + + return +end + +DistributedUtils.__get_distributed_backend(::Type{MPIBackend}) = MPIBackend(MPI.COMM_WORLD) + +DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm) + +DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm) + +# Broadcast +# Union with Function is because of Flux.cpu istypeof Function +# We need CPU in case of non CUDA-aware implementation +function DistributedUtils.__bcast!( + backend::MPIBackend, sendrecvbuf, dev::Union{AbstractDevice, Function}; root=0) + MPI.Bcast!(sendrecvbuf, backend.comm; root) + return sendrecvbuf +end + +function DistributedUtils.__bcast!( + backend::MPIBackend, sendbuf, recvbuf, dev::Union{AbstractDevice, Function}; root=0) + return DistributedUtils.__bcast!( + backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf), + dev; root) +end + +# if MPI implementation is not CUDA-aware +# we have to move data to CPU first +for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) + if !aware + @eval begin + function DistributedUtils.__bcast!( + backend::MPIBackend, sendrecvbuf, dev::$dType; root=0) + sendrecvbuf_ = sendrecvbuf |> cpu + DistributedUtils.__bcast!(backend, sendrecvbuf_, cpu; root) + sendrecvbuf |> gpu + return sendrecvbuf + end + + function DistributedUtils.__bcast!( + backend::MPIBackend, sendbuf, recvbuf, dev::$dType; root=0) + sendbuf_ = sendbuf |> cpu + recvbuf_ = recvbuf |> cpu + DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cpu; root) + recvbuf |> gpu + return recvbuf + end + end + end +end + + +# Allreduce +function DistributedUtils.__allreduce!( + backend::MPIBackend, sendrecvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F} + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm) + if op === DistributedUtils.avg + sendrecvbuf ./= DistributedUtils.total_workers(backend) + end + return sendrecvbuf +end + +function DistributedUtils.__allreduce!( + backend::MPIBackend, sendbuf, recvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F} + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm) + if op === DistributedUtils.avg + recvbuf ./= DistributedUtils.total_workers(backend) + end + return recvbuf +end + +for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) + if !aware + @eval begin + function DistributedUtils.__allreduce!( + backend::MPIBackend, sendrecvbuf, op::F, dev::$dType) where {F} + sendrecvbuf_ = sendrecvbuf |> cpu + DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cpu) + sendrecvbuf |> gpu + return sendrecvbuf + end + + function DistributedUtils.__allreduce!( + backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType) where {F} + sendbuf_ = sendbuf |> cpu + recvbuf_ = recvbuf |> cpu + DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cpu) + recvbuf |> gpu + return recvbuf + end + end + end +end + +# Reduce +function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, + dev::Union{AbstractDevice, Function}; root::Int) where {F} + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root) + if op === DistributedUtils.avg + sendrecvbuf ./= DistributedUtils.total_workers(backend) + end + return sendrecvbuf +end + +function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F, + dev::Union{AbstractDevice, Function}; root::Int) where {F} + mpiop = ifelse(op === DistributedUtils.avg, +, op) + MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root) + if op === DistributedUtils.avg + recvbuf ./= DistributedUtils.total_workers(backend) + end + return recvbuf +end + +for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) + if !aware + @eval begin + function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, + dev::$dType; root::Int) where {F} + sendrecvbuf_ = sendrecvbuf |> cpu + DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cpu; root) + sendrecvbuf |> gpu + return sendrecvbuf + end + + function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, + op::F, dev::$dType; root::Int) where {F} + sendbuf_ = sendbuf |> cpu + recvbuf_ = recvbuf |> cpu + DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cpu; root) + recvbuf |> gpu + return recvbuf + end + end + end +end + +end \ No newline at end of file diff --git a/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl b/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl new file mode 100644 index 0000000000..754a6c74c6 --- /dev/null +++ b/ext/FluxMPINCCLExt/FluxMPINCCLExt.jl @@ -0,0 +1,109 @@ +module FluxMPINCCLExt + +using Flux: MPIBackend, NCCLBackend, DistributedUtils, FluxCUDADevice, FluxAMDGPUDevice, AbstractDevice +using MPI: MPI +using NCCL: NCCL +using Setfield: @set! +using CUDA + +function DistributedUtils.__initialize( + ::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing) + @assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`." + DistributedUtils.__initialize( + MPIBackend; cuda_devices, force_cuda=true, caller="NCCLBackend", amdgpu_devices) + DistributedUtils.NCCL_Initialized[] = true + return +end + +function DistributedUtils.__get_distributed_backend(::Type{NCCLBackend}) + unique_id = NCCL.UniqueID() # Generate on all ranks to know the type + mpi_backend = DistributedUtils.__get_distributed_backend(MPIBackend) + buf = [unique_id.internal...] + DistributedUtils.bcast!(mpi_backend, buf; root=0) + @set! unique_id.internal = Tuple(buf) + + nranks = DistributedUtils.total_workers(mpi_backend) + rank = DistributedUtils.local_rank(mpi_backend) + + return NCCLBackend(NCCL.Communicator(nranks, rank; unique_id), mpi_backend) +end + +DistributedUtils.local_rank(backend::NCCLBackend) = NCCL.rank(backend.comm) + +DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm) + +# For non-CUDA Arrays, fallback to MPI +# Broadcast +function DistributedUtils.__bcast!( + backend::NCCLBackend, sendrecvbuf::CuArray, ::FluxCUDADevice; root=0) + NCCL.Broadcast!(sendrecvbuf, backend.comm; root) + return sendrecvbuf +end + +function DistributedUtils.__bcast!( + backend::NCCLBackend, sendrecvbuf, dev::AbstractDevice; root=0) + return DistributedUtils.__bcast!(backend.mpi_backend, sendrecvbuf, dev; root) +end + +function DistributedUtils.__bcast!( + backend::NCCLBackend, sendbuf, recvbuf, ::FluxCUDADevice; root=0) + NCCL.Broadcast!(sendbuf, recvbuf, backend.comm; root) + return recvbuf +end + +function DistributedUtils.__bcast!( + backend::NCCLBackend, sendbuf, recvbuf, dev::AbstractDevice; root=0) + return DistributedUtils.__bcast!(backend.mpi_backend, sendbuf, recvbuf, dev; root) +end + +# Allreduce +function DistributedUtils.__allreduce!( + backend::NCCLBackend, sendrecvbuf::CuArray, op::F, dev::FluxCUDADevice) where {F} + op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) + NCCL.Allreduce!(sendrecvbuf, op, backend.comm) + return sendrecvbuf +end + +function DistributedUtils.__allreduce!( + backend::NCCLBackend, sendrecvbuf, op::F, dev::AbstractDevice) where {F} + return DistributedUtils.__allreduce!(backend.mpi_backend, sendrecvbuf, op, dev) +end + +function DistributedUtils.__allreduce!( + backend::NCCLBackend, sendbuf, recvbuf, op::F, ::FluxCUDADevice) where {F} + op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) + NCCL.Allreduce!(sendbuf, recvbuf, op, backend.comm) + return recvbuf +end + +function DistributedUtils.__allreduce!( + backend::NCCLBackend, sendbuf, recvbuf, op::F, dev::AbstractDevice) where {F} + return DistributedUtils.__allreduce!(backend.mpi_backend, sendbuf, recvbuf, op, dev) +end + +# Reduce +function DistributedUtils.__reduce!( + backend::NCCLBackend, sendrecvbuf, op::F, ::FluxCUDADevice; root::Int) where {F} + op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) + NCCL.Reduce!(sendrecvbuf, op, backend.comm; root) + return sendrecvbuf +end + +function DistributedUtils.__reduce!(backend::NCCLBackend, sendrecvbuf, op::F, + dev::AbstractDevice; root::Int) where {F} + return DistributedUtils.__reduce!(backend.mpi_backend, sendrecvbuf, op, dev; root) +end + +function DistributedUtils.__reduce!( + backend::NCCLBackend, sendbuf, recvbuf, op::F, ::FluxCUDADevice; root::Int) where {F} + op = ifelse(op === DistributedUtils.avg, NCCL.avg, op) + NCCL.Reduce!(sendbuf, recvbuf, op, backend.comm; root) + return recvbuf +end + +function DistributedUtils.__reduce!(backend::NCCLBackend, sendbuf, recvbuf, op::F, + dev::AbstractDevice; root::Int) where {F} + return DistributedUtils.__reduce!(backend.mpi_backend, sendbuf, recvbuf, op, dev; root) +end + +end \ No newline at end of file diff --git a/src/Flux.jl b/src/Flux.jl index 2681ea6729..7eac8ee7d6 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -92,6 +92,11 @@ include("deprecations.jl") include("losses/Losses.jl") using .Losses +# Distributed Training +include("distributed/backend.jl") +include("distributed/public_api.jl") +export MPIBackend, NCCLBackend, DistributedUtils + @compat(public, ( # init glorot_uniform, diff --git a/src/distributed/backend.jl b/src/distributed/backend.jl new file mode 100644 index 0000000000..56f6e5a993 --- /dev/null +++ b/src/distributed/backend.jl @@ -0,0 +1,44 @@ +# ============================================== +# inspired by avik-pal's Lux.jl +# ============================================== + +abstract type AbstractFluxDistributedBackend end + +""" + MPIBackend(comm = nothing) + +Create an MPI backend for distributed training. Users should not use this function directly. +Instead use [`DistributedUtils.get_distributed_backend(MPIBackend)`](@ref). +""" +struct MPIBackend{C} <: AbstractFluxDistributedBackend + comm::C + + function MPIBackend(comm=nothing) + if Base.get_extension(@__MODULE__, :FluxMPIExt) === nothing + error("`MPIBackend` requires `MPI.jl` to be loaded.") + end + return new{typeof(comm)}(comm) + end +end + +""" + NCCLBackend(comm = nothing, mpi_backend = nothing) + +Create an NCCL backend for distributed training. Users should not use this function +directly. Instead use [`DistributedUtils.get_distributed_backend(NCCLBackend)`](@ref). +""" +struct NCCLBackend{C, M <: Union{Nothing, MPIBackend}} <: AbstractFluxDistributedBackend + comm::C + mpi_backend::M + + function NCCLBackend(comm=nothing, mpi_backend=nothing) + if Base.get_extension(@__MODULE__, :FluxMPINCCLExt) === nothing + error("`NCCLBackend` requires `CUDA.jl`, `MPI.jl` and `NCCL.jl` to be loaded.") + end + return new{typeof(comm), typeof(mpi_backend)}(comm, mpi_backend) + end +end + +# Preferences for GPU-Aware MPI +const MPI_CUDA_AWARE = @load_preference("FluxDistributedMPICUDAAware", false) +const MPI_ROCM_AWARE = @load_preference("FluxDistributedMPIROCMAware", false) \ No newline at end of file diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl new file mode 100644 index 0000000000..38176a2e63 --- /dev/null +++ b/src/distributed/public_api.jl @@ -0,0 +1,284 @@ +# ============================================== +# inspired by avik-pal's Lux.jl +# ============================================== + +module DistributedUtils + +using ChainRulesCore: ChainRulesCore +using ..Flux: AbstractFluxDistributedBackend, MPIBackend, NCCLBackend, AbstractDevice, get_device +using Functors: fmap +using MLUtils: MLUtils, numobs +using Optimisers: Optimisers, AbstractRule, Leaf +using Random: Random +using Setfield: @set! + +const CRC = ChainRulesCore + +const NCCL_Initialized = Ref(false) +const MPI_Initialized = Ref(false) + +""" + initialized(backend::Type{<:AbstractFluxDistributedBackend}) + +Check if the given backend is initialized. +""" +initialized(::Type{<:MPIBackend}) = MPI_Initialized[] +initialized(::Type{<:NCCLBackend}) = NCCL_Initialized[] + +""" + initialize(backend::Type{<:AbstractFluxDistributedBackend}; kwargs...) + +Initialize the given backend. Users can supply `cuda_devices` and `amdgpu_devices` to +initialize the backend with the given devices. These can be set to `missing` to prevent +initialization of the given device type. If set to `nothing`, and the backend is functional +we assign GPUs in a round-robin fashion. Finally, a list of integers can be supplied to +initialize the backend with the given devices. + +Possible values for `backend` are: + + - `MPIBackend`: MPI backend for distributed training. Requires `MPI.jl` to be installed. + - `NCCLBackend`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`, + `MPI.jl`, and `NCCL.jl` to be installed. This also wraps `MPI` backend for non-CUDA + communications. +""" +function initialize(backend::Type{<:AbstractFluxDistributedBackend}; kwargs...) + # initialized(backend) && return + __initialize(backend; kwargs...) + return +end + +function __initialize end + +""" + get_distributed_backend(backend::Type{<:AbstractFluxDistributedBackend}) + +Get the distributed backend for the given backend type. Possible values are: + + - `MPIBackend`: MPI backend for distributed training. Requires `MPI.jl` to be installed. + - `NCCLBackend`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`, + `MPI.jl`, and `NCCL.jl` to be installed. This also wraps `MPI` backend for non-CUDA + communications. + +!!! danger + + `initialize(backend; kwargs...)` must be called before calling this function. +""" +function get_distributed_backend(backend::Type{<:AbstractFluxDistributedBackend}) + initialized(backend) || + error("Backend `$(backend)` is not initialized. Call `DistributedUtils.initialize` first.") + return __get_distributed_backend(backend) +end + +function __get_distributed_backend end + +CRC.@non_differentiable get_distributed_backend(::Any...) + +""" + local_rank(backend::AbstractFluxDistributedBackend) + +Get the local rank for the given backend. +""" +function local_rank end + +CRC.@non_differentiable local_rank(::Any...) + +""" + total_workers(backend::AbstractFluxDistributedBackend) + +Get the total number of workers for the given backend. +""" +function total_workers end + +CRC.@non_differentiable total_workers(::Any...) + +""" + bcast!(backend::AbstractFluxDistributedBackend, sendrecvbuf; root::Int=0) + bcast!(backend::AbstractFluxDistributedBackend, sendbuf, recvbuf; root::Int=0) + +Backend Agnostic API to broadcast the given buffer `sendrecvbuf` or `sendbuf` to all +workers into `recvbuf`. The value at `root` will be broadcasted to all other workers. +""" +function bcast!(backend::AbstractFluxDistributedBackend, sendrecvbuf; root::Int=0) + return __bcast!(backend, sendrecvbuf, get_device(); root) +end + +function bcast!(backend::AbstractFluxDistributedBackend, sendbuf, recvbuf; root::Int=0) + dev = ifelse(get_device() == FluxCPUDevice, cpu, gpu) + return __bcast!(backend, sendbuf, recvbuf, dev; root) +end + +function __bcast! end + +CRC.@non_differentiable bcast!(::Any...) + +function avg end + +""" + allreduce!(backend::AbstractFluxDistributedBackend, sendrecvbuf, op) + allreduce!(backend::AbstractFluxDistributedBackend, sendbuf, recvbuf, op) + +Backend Agnostic API to perform an allreduce operation on the given buffer `sendrecvbuf` or +`sendbuf` and store the result in `recvbuf`. + +`op` allows a special `DistributedUtils.avg` operation that averages the result across all +workers. +""" +function allreduce!(backend::AbstractFluxDistributedBackend, sendrecvbuf, op::F) where {F} + return __allreduce!(backend, sendrecvbuf, op, get_device()) +end + +function allreduce!( + backend::AbstractFluxDistributedBackend, sendbuf, recvbuf, op::F) where {F} + dev = ifelse(get_device() == FluxCPUDevice, cpu, gpu) + return __allreduce!(backend, sendbuf, recvbuf, op, dev) +end + +function __allreduce! end + +CRC.@non_differentiable allreduce!(::Any...) + +""" + reduce!(backend::AbstractFluxDistributedBackend, sendrecvbuf, op; root::Int=0) + reduce!(backend::AbstractFluxDistributedBackend, sendbuf, recvbuf, op; root::Int=0) + +Backend Agnostic API to perform a reduce operation on the given buffer `sendrecvbuf` or +`sendbuf` and store the result in `recvbuf`. + +`op` allows a special `DistributedUtils.avg` operation that averages the result across all +workers. +""" +function reduce!( + backend::AbstractFluxDistributedBackend, sendrecvbuf, op::F; root::Int=0) where {F} + return __reduce!(backend, sendrecvbuf, op, get_device(); root) +end + +function reduce!(backend::AbstractFluxDistributedBackend, + sendbuf, recvbuf, op::F; root::Int=0) where {F} + dev = ifelse(get_device() == FluxCPUDevice, cpu, gpu) + return __reduce!(backend, sendbuf, recvbuf, op, dev; root) +end + +function __reduce! end + +CRC.@non_differentiable reduce!(::Any...) + +## As Flux model is an arbitrary type it's not possible to dispatch `synchronize!!` +## end user needs to wrap Flux model into `FluxDistributedModel` +## e.g. model = DistributedUtils.synchronize!!(backend, FluxDistributedModel(model); root=0) +struct FluxDistributedModel{M} + model::M +end + +# synchronize! +""" + synchronize!!(backend::AbstractFluxDistributedBackend, ps; root::Int=0) + +Synchronize the given structure `ps` using the given backend. The value at `root` will be +broadcasted to all other workers. +""" +function synchronize!!(backend::AbstractFluxDistributedBackend, model::FluxDistributedModel; root::Int=0) + return fmap(x -> synchronize!!(backend, x; root), model.model) +end + +function synchronize!!(backend::AbstractFluxDistributedBackend, ps::Tuple; root::Int=0) + length(ps) == 0 && return ps + return map(x -> synchronize!!(backend, x; root), ps) +end + +function synchronize!!(backend::AbstractFluxDistributedBackend, + ps::NamedTuple{fields}; root::Int=0) where {fields} + length(ps) == 0 && return ps + return NamedTuple{fields}(map(x -> synchronize!!(backend, x; root), values(ps))) +end + +function synchronize!!( + backend::AbstractFluxDistributedBackend, ps::AbstractArray{T}; root::Int=0) where {T} + if isbitstype(T) + bcast!(backend, ps; root) + return ps + end + return map(x -> synchronize!!(backend, x; root), ps) +end + +# if no method for a given type, just return the value +function synchronize!!(backend::AbstractFluxDistributedBackend, ps::T; root::Int=0) where {T} + isbitstype(T) && return bcast!(backend, [ps]; root)[] + return ps +end + +# data container +""" + DistributedDataContainer(backend::AbstractFluxDistributedBackend, data) + +`data` must be compatible with `MLUtils` interface. The returned container is compatible +with `MLUtils` interface and is used to partition the dataset across the available +processes. + +!!! danger + + `MLUtils.jl` must be installed and loaded before using this. +""" +struct DistributedDataContainer + data + idxs +end + +function DistributedDataContainer(backend::AbstractFluxDistributedBackend, data) + return __construct_distributed_data_container(backend, data) +end + +Base.length(ddc::DistributedDataContainer) = length(ddc.idxs) + +Base.getindex(ddc::DistributedDataContainer, i) = getindex(ddc.data, ddc.idxs[i]) + +function MLUtils.getobs(dc::DistributedDataContainer, idx) + return MLUtils.getobs(dc.data, dc.idxs[idx]) +end + +function __construct_distributed_data_container( + backend::AbstractFluxDistributedBackend, data) + total_size = numobs(data) + split_across = total_workers(backend) + size_per_worker = Int(ceil(total_size / split_across)) + + partitions = collect(Iterators.partition(1:total_size, size_per_worker)) + idxs = collect(partitions[local_rank(backend) + 1]) + + return DistributedDataContainer(data, idxs) +end + +# Distributed Optimizer +""" + DistributedOptimizer(backend::AbstractFluxDistributedBacked, optimizer) + +Wrap the `optimizer` in a `DistributedOptimizer`. Before updating the parameters, this +averages the gradients across the processes using Allreduce. + +## Arguments + + - `optimizer`: An Optimizer compatible with the Optimisers.jl package + +""" +struct DistributedOptimizer{B <: AbstractFluxDistributedBackend} <: AbstractRule + backend::B + opt +end + +function Optimisers.apply!(opt::DistributedOptimizer, state, x, y) + y_avg = DistributedUtils.allreduce!(opt.backend, y, DistributedUtils.avg) + return Optimisers.apply!(opt.opt, state, x, y_avg) +end + +Optimisers.init(opt::DistributedOptimizer, x::AbstractArray) = Optimisers.init(opt.opt, x) + +function Optimisers._adjust(opt::DistributedOptimizer, nt::NamedTuple) + return DistributedOptimizer(opt.backend, Optimisers._adjust(opt.opt, nt)) +end + +function DistributedUtils.synchronize!!( + backend::AbstractFluxDistributedBackend, ps::Leaf; root::Int=0) + @set! ps.state = DistributedUtils.synchronize!!(backend, ps.state; root) + return ps +end + +end \ No newline at end of file diff --git a/test/ext_distributed/common.jl b/test/ext_distributed/common.jl new file mode 100644 index 0000000000..5ccad67597 --- /dev/null +++ b/test/ext_distributed/common.jl @@ -0,0 +1,89 @@ +using Flux, MPI, NCCL, Test, CUDA + +const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi") +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? Flux.cpu : Flux.gpu +const aType = input_args[1] == "CPU" ? Array : + (input_args[1] == "CUDA" ? CuArray : ROCArray) + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +@test DistributedUtils.initialized(backend_type) + +# Should always hold true +rank = DistributedUtils.local_rank(backend) +nworkers = DistributedUtils.total_workers(backend) +@test rank < nworkers + +# Test the communication primitives +## broacast! +for arrType in (Array, aType) + sendbuf = (rank == 0) ? arrType(ones(512)) : arrType(zeros(512)) + recvbuf = arrType(zeros(512)) + + DistributedUtils.bcast!(backend, sendbuf, recvbuf; root=0) + + rank != 0 && @test all(recvbuf .== 1) + + sendrecvbuf = (rank == 0) ? arrType(ones(512)) : arrType(zeros(512)) + DistributedUtils.bcast!(backend, sendrecvbuf; root=0) + + @test all(sendrecvbuf .== 1) +end + +## reduce! +for arrType in (Array, aType) + sendbuf = arrType(fill(Float64(rank + 1), 512)) + recvbuf = arrType(zeros(512)) + + DistributedUtils.reduce!(backend, sendbuf, recvbuf, +; root=0) + + rank == 0 && @test all(recvbuf .≈ sum(1:nworkers)) + + sendbuf .= rank + 1 + + DistributedUtils.reduce!(backend, sendbuf, recvbuf, DistributedUtils.avg; root=0) + + rank == 0 && @test all(recvbuf .≈ sum(1:nworkers) / nworkers) + + sendrecvbuf = arrType(fill(Float64(rank + 1), 512)) + + DistributedUtils.reduce!(backend, sendrecvbuf, +; root=0) + + rank == 0 && @test all(sendrecvbuf .≈ sum(1:nworkers)) + + sendrecvbuf .= rank + 1 + + DistributedUtils.reduce!(backend, sendrecvbuf, DistributedUtils.avg; root=0) + + rank == 0 && @test all(sendrecvbuf .≈ sum(1:nworkers) / nworkers) +end + +## allreduce! +for arrType in (Array, aType) + sendbuf = arrType(fill(Float64(rank + 1), 512)) + recvbuf = arrType(zeros(512)) + + DistributedUtils.allreduce!(backend, sendbuf, recvbuf, +) + + @test all(recvbuf .≈ sum(1:nworkers)) + + sendbuf .= rank + 1 + + DistributedUtils.allreduce!(backend, sendbuf, recvbuf, DistributedUtils.avg) + + @test all(recvbuf .≈ sum(1:nworkers) / nworkers) + + sendrecvbuf = arrType(fill(Float64(rank + 1), 512)) + + DistributedUtils.allreduce!(backend, sendrecvbuf, +) + + @test all(sendrecvbuf .≈ sum(1:nworkers)) + + sendrecvbuf .= rank + 1 + + DistributedUtils.allreduce!(backend, sendrecvbuf, DistributedUtils.avg) + + @test all(sendrecvbuf .≈ sum(1:nworkers) / nworkers) +end \ No newline at end of file diff --git a/test/ext_distributed/data.jl b/test/ext_distributed/data.jl new file mode 100644 index 0000000000..ddcac5b9a3 --- /dev/null +++ b/test/ext_distributed/data.jl @@ -0,0 +1,26 @@ +using Flux, MLUtils, MPI, NCCL, Random, Test, CUDA + +const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi") +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? Flux.cpu : Flux.gpu + +rng = Xoshiro(1234) + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +data = randn(rng, Float32, 10) +dcontainer = DistributedUtils.DistributedDataContainer(backend, data) + +rank = DistributedUtils.local_rank(backend) +tworkers = DistributedUtils.total_workers(backend) + +if rank != tworkers - 1 + @test length(dcontainer) == ceil(length(data) / tworkers) +else + @test length(dcontainer) == + length(data) - (tworkers - 1) * ceil(length(data) / tworkers) +end + +dsum = sum(Base.Fix1(MLUtils.getobs, dcontainer), 1:MLUtils.numobs(dcontainer)) +@test DistributedUtils.allreduce!(backend, [dsum], +)[1] ≈ sum(data) \ No newline at end of file diff --git a/test/ext_distributed/optimizer.jl b/test/ext_distributed/optimizer.jl new file mode 100644 index 0000000000..dce2d8638e --- /dev/null +++ b/test/ext_distributed/optimizer.jl @@ -0,0 +1,28 @@ +using Flux, MPI, NCCL, Optimisers, Random, Test, CUDA + +const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi") +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? Flux.cpu : Flux.gpu + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +opt = Optimisers.Adam(0.001f0) +ps = (a=zeros(4), b=zeros(4)) |> dev +st_opt = Optimisers.setup(opt, ps) + +dopt = DistributedUtils.DistributedOptimizer(backend, opt) +st_dopt = Optimisers.setup(dopt, ps) + +@test st_dopt.a.state == st_opt.a.state +@test st_dopt.b.state == st_opt.b.state + +@test_nowarn DistributedUtils.synchronize!!(backend, st_dopt) + +gs = (a=ones(4), b=ones(4)) |> dev + +_, ps_dopt = Optimisers.update(st_dopt, ps, gs) +_, ps_opt = Optimisers.update(st_opt, ps, gs) + +@test ps_dopt.a≈ps_opt.a atol=1.0e-5 rtol=1.0e-5 +@test ps_dopt.b≈ps_opt.b atol=1.0e-5 rtol=1.0e-5 \ No newline at end of file diff --git a/test/ext_distributed/runtests.jl b/test/ext_distributed/runtests.jl new file mode 100644 index 0000000000..686ce04048 --- /dev/null +++ b/test/ext_distributed/runtests.jl @@ -0,0 +1,35 @@ +# Distributed Tests +using MPI, Pkg, Test + +nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "") +nprocs = nprocs_str == "" ? clamp(Sys.CPU_THREADS, 2, 4) : parse(Int, nprocs_str) +testdir = @__DIR__ +isdistributedtest(f) = endswith(f, "_distributedtest.jl") +distributedtestfiles = String[] +for (root, dirs, files) in walkdir(testdir) + for file in files + if isdistributedtest(file) + push!(distributedtestfiles, joinpath(root, file)) + end + end +end + +@info "Running Distributed Tests with $nprocs processes" + +cur_proj = dirname(Pkg.project().path) + +@testset "Distributed" begin + backends = get(ENV, "FLUX_TEST_DISTRIBUTED_NCCL", "false") == "true" ? ("mpi", "nccl") : ("mpi",) + for backend_type in backends + np = backend_type == "nccl" ? min(nprocs, length(CUDA.devices())) : nprocs + @testset "Backend: $(backend_type)" begin + @testset "$(basename(file))" for file in distributedtestfiles + @info "Running $file with $backend_type backend" + run(`$(MPI.mpiexec()) -n $(np) $(Base.julia_cmd()) --color=yes \ + --code-coverage=user --project=$(cur_proj) --startup-file=no $(file) \ + $(backend_type)`) + Test.@test true + end + end + end +end \ No newline at end of file diff --git a/test/ext_distributed/synchronized.jl b/test/ext_distributed/synchronized.jl new file mode 100644 index 0000000000..f2e6d6a431 --- /dev/null +++ b/test/ext_distributed/synchronized.jl @@ -0,0 +1,91 @@ +using Flux, MPI, NCCL, Optimisers, Random, Test, CUDA + +const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi") +const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend +const dev = input_args[1] == "CPU" ? Flux.cpu : Flux.gpu + +function __get_array_based_on_rank(backend, dims; root) + DistributedUtils.local_rank(backend) == root && return ones(dims...) + return zeros(dims...) +end + +root = 0 + +DistributedUtils.initialize(backend_type) +backend = DistributedUtils.get_distributed_backend(backend_type) + +# Named Tuple +gs = ( + a=(b=__get_array_based_on_rank(backend, (2, 3); root), + c=__get_array_based_on_rank(backend, (2, 3); root)), + d=__get_array_based_on_rank(backend, (2, 3); root)) |> dev + +gs_ = DistributedUtils.synchronize!!(backend, gs; root) + +@test all(gs_.a.b .== 1) +@test all(gs_.a.c .== 1) +@test all(gs_.d .== 1) + +## optimisers +opt = Optimisers.Adam(0.001f0) +st_opt = Optimisers.setup(opt, gs) + +if DistributedUtils.local_rank(backend) == root + st_opt.a.b.state[1] .= 1 + st_opt.a.b.state[2] .= 1 + st_opt.a.c.state[1] .= 1 + st_opt.a.c.state[2] .= 1 + st_opt.d.state[1] .= 1 + st_opt.d.state[2] .= 1 +end + +st_opt = DistributedUtils.synchronize!!(backend, st_opt; root) + +@test all(st_opt.a.b.state[1] .== 1) +@test all(st_opt.a.b.state[2] .== 1) +@test all(st_opt.a.c.state[1] .== 1) +@test all(st_opt.a.c.state[2] .== 1) +@test all(st_opt.d.state[1] .== 1) +@test all(st_opt.d.state[2] .== 1) + +# Has no state +opt = Optimisers.Descent(0.001f0) +st_opt = Optimisers.setup(opt, gs) + +@test_nowarn DistributedUtils.synchronize!!(backend, st_opt; root) + +## ComponentArrays +gs = ( + a=(b=__get_array_based_on_rank(backend, (2, 3); root), + c=__get_array_based_on_rank(backend, (2, 3); root)), + d=__get_array_based_on_rank(backend, (2, 3); root)) +cgs_ = DistributedUtils.synchronize!!(backend, gs; root) + +@test all(cgs_.a.b .== 1) +@test all(cgs_.a.c .== 1) +@test all(cgs_.d .== 1) + +# Tuple +gs = ( + (__get_array_based_on_rank(backend, (2, 3); root), + __get_array_based_on_rank(backend, (2, 3); root)), + __get_array_based_on_rank(backend, (2, 3); root)) |> dev + +gs = DistributedUtils.synchronize!!(backend, gs; root) + +@test all(gs[1][1] .== 1) +@test all(gs[1][2] .== 1) +@test all(gs[2] .== 1) + +# Miscelleneous +x = nothing +x = DistributedUtils.synchronize!!(backend, x; root) +@test x === nothing + +x = ifelse(root == DistributedUtils.local_rank(backend), :x, :y) +x_ = DistributedUtils.synchronize!!(backend, x; root) +# Symbol should not change +@test x_ == x + +x = DistributedUtils.synchronize!!(backend, DistributedUtils.local_rank(backend); root) +@test x == root \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9243e59f0d..7a4e46fd38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,8 @@ using Pkg # ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" # ENV["FLUX_TEST_CPU"] = "false" +# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" +# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" include("test_utils.jl") @@ -121,6 +123,23 @@ Random.seed!(0) @info "Skipping Metal tests, set FLUX_TEST_METAL=true to run them." end + if get(ENV, "FLUX_TEST_DISTRIBUTED_MPI", "false") == "true" || get(ENV, "FLUX_TEST_DISTRIBUTED_NCCL", "false") == true + Pkg.add(["MPI"]) + using MPI + + if get(ENV, "FLUX_TEST_DISTRIBUTED_NCCL", "false") == "true" + Pkg.add(["NCCL"]) + using NCCL + end + + @testset "Distributed" begin + include("ext_distributed/runtests.jl") + end + + else + @info "Skipping Distributed tests, set FLUX_TEST_DISTRIBUTED_MPI or FLUX_TEST_DISTRIBUTED_NCCL=true to run them." + end + @testset "Enzyme" begin Pkg.add(["CUDA", "cuDNN"]) import Enzyme