Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Move CUDA support to a package extension #2132

Closed
wants to merge 13 commits into from

Conversation

IanButterworth
Copy link
Contributor

@IanButterworth IanButterworth commented Dec 9, 2022

Uses new package extensions (due to be in julia 1.9), to make CUDA functionality optional.
On julia versions that don't have package extensions Flux will work as before.

I've only done the obvious things here to get loading working and haven't run tests locally.
I took the approach of maintaining file names of where functions came from to make it easier to cross-reference. It could all be consolidated if preferred.

updated:

julia> @time using Flux
  5.154103 seconds (5.45 M allocations: 381.721 MiB, 2.06% gc time, 99.50% compilation time: 37% of which was recompilation)

julia> x = rand(2)
2-element Vector{Float64}:
 0.8415303775713626
 0.26837175189895346

julia> gpu(x)
[ Info: The GPU functionality is being called via `Flux.gpu` but `NNlibCUDA` must be loaded to access GPU functionality
2-element Vector{Float64}:
 0.8415303775713626
 0.26837175189895346

julia> @time using NNlibCUDA
  7.160284 seconds (20.20 M allocations: 1.142 GiB, 3.41% gc time, 42.59% compilation time: 82% of which was recompilation)

julia> gpu(x)
┌ Info: The GPU function is being called but the GPU is not accessible.
└ Defaulting back to the CPU. (No action is required if you want to run on the CPU).
2-element Vector{Float64}:
 0.8415303775713626
 0.26837175189895346

For reference, Flux master on julia master

julia> @time using Flux
 12.092398 seconds (25.81 M allocations: 1.526 GiB, 3.02% gc time, 65.40% compilation time: 57% of which was recompilation)

Helps with #1961

cc. @mcabbott

[Edit: closes #2155]

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@DhairyaLGandhi
Copy link
Member

Excellent! We should still maintain gpu to default back to its behaviour and instead add the message to the produced warning. Really excited to see this come through!

@IanButterworth
Copy link
Contributor Author

One question, I structured this so that you need to load both CUDA and NNlibCUDA, which means you can avoid downloading & Pkg.precompiling both of those.

Alternatively NNlibCUDA could be made a strong dep so you just have to using CUDA, and it could be loaded in the extension

@DhairyaLGandhi
Copy link
Member

I prefer the latter since it enables a nicer user experience. Does the added precompilation cost a tonne of time?

@IanButterworth
Copy link
Contributor Author

IanButterworth commented Dec 9, 2022

Ah.. CUDA is a dep of NNlibCUDA.. it would have to be NNlibCUDA if only one trigger package. Updated above now

@ToucheSir
Copy link
Member

Is there a way we could have CUDA.jl as the only trigger dep? Up to this point we've been acting on the model that users shouldn't have to know NNlibCUDA even exists, so if possible it would be great to keep that.

@IanButterworth
Copy link
Contributor Author

@KristofferC would that be possible? NNlibCUDA deps on CUDA so AFAIK it's not possible

@ToucheSir
Copy link
Member

If needed, any and all changes required on the NNlibCUDA side should be on the table. Would making its CUDA dependency also weak be enough?

@IanButterworth
Copy link
Contributor Author

Maybe absorb it into NNlib and make CUDA a weak dep of NNlib?

@KristofferC
Copy link
Contributor

Note that Registrator is not updated yet to handle weak deps.

@CarloLucibello
Copy link
Member

Two suggestions, although the first one is likely unrealizable:

  1. Can gpu automatically load the extension if CUDA is available in the environment (although not loaded)?

  2. If the answer to 1) is not, I would prefer we have a registered (sub)package named FluxCUDA which loads the extension, and the warning points to that. One should also be able to do using FluxCUDA directly instead of using Flux, FluxCUDA and get all the functionality.

@IanButterworth
Copy link
Contributor Author

IanButterworth commented Dec 9, 2022

  1. Can gpu automatically load the extension if CUDA is available in the environment (although not loaded)?

I haven't thought through the implications of this design, but yes it should be possible via something like

Base.locate_package(Base.PkgId("NNlibCUDA", {uuid})) !== nothing && @eval Main using NNlibCUDA

There's one test failure remaining that I don't understand Fixed. I forgot to import Flux.dropout_mask so it wasn't being extended

Dropout RNGs: Test Failed at /var/lib/buildkite-agent/builds/gpuci-17/julialang/flux-dot-jl/test/cuda/layers.jl:285
--
  | Expression: Flux.dropout(MersenneTwister(), CUDA.rand(Float32, 2, 3), 0.1)
  | Expected: ArgumentError
  | No exception thrown

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 9, 2022

I haven't thought through the implications of this design, but yes it should be possible via something like

Wouldn't that incur in world age issues? That is the limitation of the approach (not the same) in LazyModules.jl

@KristofferC
Copy link
Contributor

It would be good to have a concrete usage example here. Like maybe a toy session in a REPL to see how a user would use it. And then based on that it might be easier to come up with a good design.

@CarloLucibello
Copy link
Member

Ideally, the following could happen, which is also non-breaking:

julia> using Flux  # only cpu code imported, CUDA.jl and NNlibCUDA.jl are not loaded yet

julia> x = [1.0,2.0,3.0];
3-element Vector{Float64}:
 1.0
 2.0
 3.0

julia> gpu(x) # here the gpu code and libraries are loaded if the device has cuda capabilities            
3-element CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 1.0
 2.0
 3.0

# successive calls to `gpu` are fast

@ToucheSir
Copy link
Member

ToucheSir commented Dec 9, 2022

Even something like the current behaviour of warning and passing the value through when CUDA isn't loaded could work, IMO.

Maybe absorb it into NNlib and make CUDA a weak dep of NNlib?

This is theoretically possible and something we always planned to do, but nobody expected weak deps to land this quickly. The main considerations would be whether a) base NNlib and Flux load times are affected (with or without CUDA loaded), and b) whether TTFX for CUDA functionality in NNlib is impacted. If neither regresses, then it should be an easy sell.

@mcabbott
Copy link
Member

mcabbott commented Dec 9, 2022

Even something like the current behaviour of warning and passing the value through when CUDA isn't loaded could work, IMO.

This seems like the obvious goal. Right now, with no graphics card, |> gpu prints a warning once & then does nothing. With the weak dep, it would ideally print a different warning if you have not loaded CUDA at all. But that might be what you wanted which is fine.

I don't like the idea of making a function call magically load the package CUDA. Making behaviour depend on your environment outside of what you've loaded seems weird.

@IanButterworth
Copy link
Contributor Author

Tests are passing now, btw

@mcabbott
Copy link
Member

Locally, it looks like I can have just CUDAExt = "CUDA", and CUDAExt is able to import NNlibCUDA and call functions from it, or other packages which I load only within the extension. But I didn't test very carefully, does this fail for some reason?

@IanButterworth
Copy link
Contributor Author

Did you remove NNlibCUDA from [weakdeps]?

@mcabbott
Copy link
Member

Yes, but I don't know if it matters. It must remain in [deps] for now I think.

@IanButterworth
Copy link
Contributor Author

Well that means it's not a weak dep, so it and CUDA will be downloaded, installed and precompiled. You do avoid loading it until the extension loads, but personally I'm motivated by both avoiding installing CUDA etc and loading it.

@IanButterworth
Copy link
Contributor Author

For more context, using Flux on a resource constrained system like a small embedded SBC without a gpu is painful because of all of the above, not just load time

@mcabbott
Copy link
Member

So a package listed under both deps and weakdeps (as in the present state of this PR), is effectively deleted from deps.

I'm a bit surprised that the design doesn't allow for loading 3rd packages, but maybe this is hard, didn't follow closely.

Not downloading would be nice too. But requiring anyone to know about this obscure thing called NNlibCUDA seems like a step backwards. Moving it to be an extension of NNlib is probably the way to go then.

@CarloLucibello
Copy link
Member

Having NNlibCUDA as an extension of NNlib seems the way to go. Also, we finally move back that code to the original repo as we wanted to do for some time. Since NNlibCUDA was not part of NNlib I guess we should have

# In NNlib.jl
if !isdefined(Base, :get_extension)
  # do nothing
end

and continue to import NNlibCUDA on old julia versions of Flux.

@mcabbott
Copy link
Member

mcabbott commented Dec 10, 2022

So the proposal is to have NNlibCUDA under both [deps] and [weakdeps] here, load it unconditionally on Julia 1.6, but don't try to load it at all on 1.10.

Then give NNlib an extension for CUDA, which has another copy of the exact same code as the package NNlibCUDA? If the registered package is moved to the same repository as NNlib, then both the package and NNlib's CUDAext can include the same source file.

One issue is that any non-Flux project which loads both NNlib and NNlibCUDA will I think get two copies of all definitions.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 10, 2022

FluxML/NNlib.jl#445 is now up for anyone who wants to kick the tires with this. I haven't tested import times or TTFX, so any data would be appreciated.

@mcabbott
Copy link
Member

BTW, timing this (on a slow machine which has a GPU, Julia nightly):

julia> @time using Flux
 21.362717 seconds (24.70 M allocations: 1.449 GiB, 5.16% gc time, 46.46% compilation time: 61% of which was recompilation)  # before
  9.096614 seconds (5.34 M allocations: 373.462 MiB, 4.08% gc time, 81.04% compilation time: 56% of which was recompilation)  # after

or, loading everything:

julia> @time using Flux, NNlibCUDA, CUDA
 21.134435 seconds (24.76 M allocations: 1.453 GiB, 3.64% gc time, 47.75% compilation time: 62% of which was recompilation)  # after

After this, the biggest offenders in @time_imports are FoldsThreads & StaticArrays, details below.

Before:

julia> @time_imports using Flux
      1.2 ms  Statistics
     16.7 ms  MacroTools
      0.2 ms  Reexport
      6.8 ms  ProgressLogging
      8.7 ms  IrrationalConstants
      0.2 ms  Compat
    138.7 ms  ChainRulesCore
      8.1 ms  DocStringExtensions 80.65% compilation time
      1.1 ms  ChangesOfVariables
      1.4 ms  InverseFunctions
      1.2 ms  LogExpFunctions
      0.2 ms  OpenLibm_jll
     31.0 ms  Preferences
      0.3 ms  JLLWrappers
    461.8 ms  OpenSpecFun_jll 115.01% compilation time (86% recompilation)
     44.9 ms  SpecialFunctions
      0.3 ms  Requires
      0.3 ms  Adapt
     56.2 ms  NNlib 62.75% compilation time (12% recompilation)
     13.8 ms  ShowCases
      1.8 ms  ConstructionBase
    180.8 ms  InitialValues
      0.2 ms  DataValueInterfaces
      1.3 ms  DataAPI
      0.1 ms  IteratorInterfaceExtensions
      0.1 ms  TableTraits
     13.4 ms  OrderedCollections
     25.2 ms  Tables
      0.2 ms  ZygoteRules
      4.7 ms  StaticArraysCore
     19.4 ms  Setfield
     87.5 ms  BangBang 58.56% compilation time
      1.6 ms  ContextVariablesX
      0.1 ms  FLoopsBase
      1.3 ms  PrettyPrint
      0.5 ms  NameResolution
     31.8 ms  MLStyle
      2.4 ms  JuliaVariables
      0.5 ms  ArgCheck
     22.5 ms  Baselet
      0.1 ms  CompositionsBase
      0.1 ms  DefineSingletons
     17.1 ms  MicroCollections
     16.4 ms  SplittablesBase
    116.1 ms  Transducers 36.62% compilation time
      7.6 ms  FLoops
     41.7 ms  Accessors 39.28% compilation time
     22.2 ms  FunctionWrappers
   1017.0 ms  FoldsThreads 193.39% compilation time
    161.1 ms  DataStructures
      0.5 ms  SortingAlgorithms
     73.7 ms  Missings
      0.4 ms  StatsAPI
     51.0 ms  StatsBase
      5.5 ms  SimpleTraits
      1.0 ms  DelimitedFiles
     11.3 ms  MLUtils
      5.0 ms  Functors
     16.4 ms  Optimisers
      2.4 ms  GPUArraysCore
      0.1 ms  RealDot
     46.6 ms  StructArrays
     37.5 ms  ChainRules
     20.1 ms  IRTools
      0.8 ms  DiffRules
      0.3 ms  NaNMath
    420.6 ms  FillArrays
     20.7 ms  AbstractFFTs
      4.2 ms  DiffResults
   1271.6 ms  StaticArrays
      0.5 ms  CommonSubexpressions
    198.6 ms  ForwardDiff
    174.3 ms  Zygote 52.62% compilation time
      8.9 ms  CEnum
    474.3 ms  LLVMExtra_jll 117.69% compilation time (84% recompilation)
    406.2 ms  LLVM 35.54% compilation time
      0.4 ms  ExprTools
    180.9 ms  TimerOutputs 19.49% compilation time
    823.8 ms  GPUCompiler 22.62% compilation time (68% recompilation)
   1797.9 ms  GPUArrays
     13.6 ms  BFloat16s
     47.7 ms  RandomNumbers 37.30% compilation time
     12.4 ms  Random123
   6794.8 ms  CUDA 2.02% compilation time
     84.9 ms  OneHotArrays
    136.9 ms  NNlibCUDA
    100.9 ms  Flux

After:

julia> @time_imports using Flux
      1.2 ms  Statistics
     41.7 ms  MacroTools
      0.2 ms  Reexport
      6.7 ms  ProgressLogging
      8.5 ms  IrrationalConstants
      0.1 ms  Compat
    136.0 ms  ChainRulesCore
      8.2 ms  DocStringExtensions 82.53% compilation time
      1.1 ms  ChangesOfVariables
      1.3 ms  InverseFunctions
      1.1 ms  LogExpFunctions
      0.2 ms  OpenLibm_jll
     29.5 ms  Preferences
      0.3 ms  JLLWrappers
    453.6 ms  OpenSpecFun_jll 113.71% compilation time (87% recompilation)
     31.0 ms  SpecialFunctions
      0.3 ms  Requires
      0.3 ms  Adapt
     54.2 ms  NNlib 66.10% compilation time (12% recompilation)
     13.3 ms  ShowCases
      1.8 ms  ConstructionBase
     31.1 ms  InitialValues
      0.1 ms  DataValueInterfaces
      1.3 ms  DataAPI
      0.1 ms  IteratorInterfaceExtensions
      0.1 ms  TableTraits
     14.2 ms  OrderedCollections
     25.8 ms  Tables
      0.2 ms  ZygoteRules
      4.9 ms  StaticArraysCore
     20.1 ms  Setfield
     89.9 ms  BangBang 57.72% compilation time
      1.6 ms  ContextVariablesX
      0.1 ms  FLoopsBase
      1.3 ms  PrettyPrint
      0.4 ms  NameResolution
     33.0 ms  MLStyle
      2.5 ms  JuliaVariables
      0.5 ms  ArgCheck
     23.0 ms  Baselet
      0.1 ms  CompositionsBase
      0.1 ms  DefineSingletons
     17.6 ms  MicroCollections
     17.3 ms  SplittablesBase
    117.3 ms  Transducers 36.25% compilation time
      7.8 ms  FLoops
     42.9 ms  Accessors 40.04% compilation time
     22.9 ms  FunctionWrappers
   1079.6 ms  FoldsThreads 178.99% compilation time
    158.7 ms  DataStructures
      0.5 ms  SortingAlgorithms
     17.1 ms  Missings
      0.4 ms  StatsAPI
     53.0 ms  StatsBase
      5.7 ms  SimpleTraits
      1.1 ms  DelimitedFiles
     11.5 ms  MLUtils
      4.9 ms  Functors
     16.3 ms  Optimisers
      2.4 ms  GPUArraysCore
      0.1 ms  RealDot
     48.5 ms  StructArrays
     44.9 ms  ChainRules
     29.6 ms  IRTools
      0.9 ms  DiffRules
      0.3 ms  NaNMath
    375.0 ms  FillArrays
     19.0 ms  AbstractFFTs
      4.5 ms  DiffResults
   1331.8 ms  StaticArrays
      0.6 ms  CommonSubexpressions
    217.7 ms  ForwardDiff
    222.6 ms  Zygote 42.27% compilation time
     78.4 ms  OneHotArrays
     58.7 ms  Flux

@ToucheSir
Copy link
Member

In the scenario where we're unable to resolve the import time issues in FluxML/NNlib.jl#445, I feel the incremental solution would be to keep NNlibCUDA as a normal dep and not have it be a weak dep. Precompilation of CUDA.jl will still be unavoidable, but at least SBC users can still benefit from the reduced import times (alongside the much larger group of non-SBC users on non-GPU machines).

@IanButterworth
Copy link
Contributor Author

I was just wondering if the plan here has settled? It would be great to not have to install the GPU stack on SBCs that I use!

@ToucheSir
Copy link
Member

ToucheSir commented Apr 20, 2023

Then as now, I don't foresee any path for us to save users from installing the CUDA.jl stack before dropping support for Julia versions <1.9. That said, NNlibCUDA is likely to become an extension in the near future. The current hold-up there is figuring out what the fallback path would look like on the Flux side. I've seen some people using trigger packages, but word from above suggests those are verboten...

For this particular PR, my recommendation would be to remove anything related to NNlibCUDA and convert the Flux-specific CUDA bits to an extension like we have for AMDGPU. It won't stop the GPU stack from being installed, but it'll get us one step closer.

@IanButterworth
Copy link
Contributor Author

I just tried to rebase this and it has become pretty messy unfortunately. Is CUDA lazy loaded now?

Also, my main motivation was to avoid CUDA install and Pkg.precompile-ation so if that's not in the near future, I think it best if someone else take this over afresh.

@IanButterworth IanButterworth deleted the ib/cudaext branch June 13, 2023 19:04
@ToucheSir
Copy link
Member

Thanks for pushing this, @IanButterworth! I think now that FluxML/NNlib.jl#492 is merged and #2265 has some active effort behind it, we might be able to release something sooner than later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Hoping to offer a version without cuda
6 participants