-
Notifications
You must be signed in to change notification settings - Fork 3
Description
One major difficulty of using CUDA.jl + anything Bayesian is that you immediately need to define all the (at very least, univariate) distributions all over again. But for starters we don't need all of the functionality of a Distribution to do something like Bayesian inference, e.g. for AdvancedHMC.jl we really only need logpdf and its adjoint as the only call to rand is going to be for the momentum (which can be sampled directly on the GPU using CUDA.randn).
But even trying to redefine the logpdf of a Distribution to work on the GPU is often non-trivial.
Issue #1
Before going into the issue, it's important to know the following:
- All functions that you want to vectorize/broadcast needs to be replaced with the corresponding CUDA function, e.g.
Base.logandCUDA.logare actually different methods. - AFAIK, CUDA.jl achieves this by effectively converting all expressions of the form
f.(args...)toCUDA.cufunc(f).(args...)whenever any of theargsis aCUDA.CuArray. by overloading theBroadcast.broadcasted. E.g.CUDA.cufunc(::typeof(log)) = CUDA.log. - Therefore, in the case where a function
fdoes not have acufuncalready defined and you dof.(args...)you'll, if you're lucky, get an error but sometimes the entire Julia session will crash.
So what do we do?
- Well, we define a
cufunc(::typeof(f)) = ...which will allow you to broadcast over- simple custom-defined functions (e.g.
digammafrom this PR by @xukai92: Gamma family function support JuliaGPU/CuArrays.jl#321 (comment)) - other native CUDA-functions, e.g.
loggammacan be replaced byCUDA.lgamma.
- simple custom-defined functions (e.g.
- Great stuff! Then you go and try to take the gradient of this thing and everything breaks again. So you need to also define define rules for this
f.- AFAIK, for
fon GPU Zygote.jl uses ForwardDiff.jl to obtain the adjoints for broadcasting and so we gotta define these rules usingDiffRulesand evaluate usingForwardDiff. - You then try to figure out where the current
DiffRulesdefinition is forfand you copy-paste, replacing methods withcufuncmethods.
- AFAIK, for
- Okay, so at this point we have all these lower-level functions with their corresponding
cufuncdefinitions and their corresponding@define_diffrule, so we're good to go right? Now we can just callStatsFuns.gammalogpdf.(α, θ, x), right?- Answer: No. AFAIK, there are several things that can fail:
A lot of the functions used within the Distributions.jl ecosystem is not pure Julia under the hood, but there are often pure-Julia versions for more generic number types so that one for example can just AD through them without issues, e.g. https://github.com/JuliaStats/StatsFuns.jl/blob/8dfda2c0ee33d5f85eca5c039d31d85c90f363f2/src/distrs/gamma.jl#L19. BUT this doesn't help compat with CUDA.jl because elementypes of aEDIT: this is only an issue for eltypeCUDA.CuArrayaren't special, i.e. it's just aFloat32. And so the function we dispatch on when broadcasting over aCUDA.CuArraywill be some function outside of the Julia ecosystem, and so things starts blowing up.Float64, notFloat32as pointed out by @devmotion!- Observing this, we specialize by overloading the method for
Float32and so on to use the pure-Julia implementation, e.g.StatsFuns.gammalogpdf(k::Float32, θ::Float32, x::Float32) = -SpecialFunctions.loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ
- BUT it still fails because the overloading of
broadcastedwon't be properly nested, socufuncwill only be called onStatsFuns.gammalogpdf, not on the methods used within! So, we instead dowhich is really not fun.cugammalogpdf(k, θ, x) = -CUDA.cufunc(SpecialFunctions.loggamma)(k) - k * CUDA.cufunc(log)(θ) + (k - 1) * CUDA.cufunc(log)(x) - x / θ CUDA.cufunc(::typeof(StatsFuns.gammalogpdf)) = cugammalogpdf
- I.e. it's not enough to define
cufuncfor the leaves of the method hierarchy! This sucks. - Upshot: for most AD-frameworks (Zygote.jl and ForwardDiff.jl because Zygote.jl uses ForwardDiff.jl for broadcasting) we get the AD-rules for all non-leaves in the method hierarchy this way.
- Answer: No. AFAIK, there are several things that can fail:
- Of course there is work on doing automatic method substitution: Method overrides using a method overlay table. JuliaGPU/GPUCompiler.jl#122, but in the mean-time we need to do the above process.
Potential solutions
- Do it by hand, but with the help of an "improved" version of the
CUDA.@cufuncmacro that I've implemented (https://github.com/JuliaGPU/CUDA.jl/blob/c011ffc0971ab1089f9d56dd338ef4b31e24ecc7/src/broadcast.jl#L101-L112) which has the following additional features:- Replaces all functions
fin the body withcufunc(f), with the default impl ofcufunc(f) = f. I.e. do nothing to almost all methods, but those which have acufuncimpl we replace. - Correctly handles namespaces, e.g.
@cufunc SpecialFunctions.gamma(x) = ...is converted intocugamma(x) = ...; cufunc(::typeof(SpecialFunctions.gamma)) = cugamma. - [OPTIONAL] If
fis present inDiffRule.diffrules(), then we extract the corresponding diffrule and replaces all functionsgwithin the diffrule withcufunc(g). I.e. IF there is a scalar-rule forf, then we make it CUDA compatible (assuming the methods in the rule has acufuncimplementation), otherwise we leave it to ForwardDiff.jl to figure it out.
- Replaces all functions
- Wait until the work on method-substitution is done.
Personally, I'm in favour of solution (1).
Issue #2
Now, there's also an additional "annoyance" even after solving the above issue. We cannot do something like logpdf.(Gamma.(α, θ), x) because this will first to do map(Gamma, ...) before calling logpdf. There's the possibility that this could have been inlined into completely removing the call to Gamma once it's sufficiently lowered, but GPUCompiler.jl will complain before it reaches that stage (as this is not always a guarantee + I believe it will try to fuse all the broadcasts together into a single operation for efficiency). Therefore we either need to:
- Use the underlying methods directly, e.g.
gammalogpdf.(α, θ, x). - Define a
Vectorize(D, args), e.g.Vectorize(Gamma, (α, θ)), which has alogpdfthat lazily calls the underlying method, e.g.logpdf(v::Vectorize{Gamma}, x) = gammalogpdf.(v.args..., x). Equipped with this, we can speed up implementation quite a bit by potentially doing something like:- Overload
broadcastedso that if we're using theCUDA.CuArrayStyleandf <: UnivariateDistributionwe canmateralizeargsearlier and then wrap it inVectorize, i.e.Vectorize(f, args). - Define a macro similar to
Distributions.@__delegate_statsfunsor whatever to more easily definelogpdf(v::Vectorize{D}, x)for differentD.
Worth mentioning that this requires a small redef of this method in Zygote (https://github.com/FluxML/Zygote.jl/blob/2b17256e79b2eca9a6512207284219d279398fc9/src/lib/broadcast.jl#L225-L228), though it should def. be possible to make it work even though we're overloadingbroadcasted.
- Overload