Skip to content

Commit d9669a2

Browse files
committed
fix(P3): Manually unroll LogExpFunctions.logsumexp
- To improve compilation of some GPU kernels, provide an unrolled implementation of logsumexp
1 parent 4d2e2ef commit d9669a2

File tree

5 files changed

+294
-151
lines changed

5 files changed

+294
-151
lines changed

src/P3.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import SpecialFunctions as SF
1111
import RootSolvers as RS
1212
import LogExpFunctions
1313
import StaticArrays as SA
14+
import UnrolledUtilities as UU
1415

1516
import ClimaParams as CP
1617

@@ -20,6 +21,7 @@ import CloudMicrophysics.Common as CO
2021
import CloudMicrophysics.DistributionTools as DT
2122
import CloudMicrophysics.HetIceNucleation as CM_HetIce
2223
import CloudMicrophysics.Microphysics2M as CM2
24+
import CloudMicrophysics.Utilities as UT
2325
import CloudMicrophysics: ShowMethods
2426

2527
include("P3_particle_properties.jl")

src/P3_size_distribution.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,12 @@ Compute `log(∫_0^∞ Dⁿ m(D) N′(D) dD)` given the `state` and `logλ`.
139139
"""
140140
function logmass_gamma_moment(state::P3State, μ, logλ; n = 0)
141141
segments = get_segments(state)
142-
return LogExpFunctions.logsumexp(
143-
let (D_min, D_max) = segment, (a, b) = ice_mass_coeffs(state, (D_min + D_max) / 2)
144-
loggamma_inc_moment(D_min, D_max, μ, logλ, b + n, a)
145-
end for segment in segments
146-
)
142+
moments = UU.unrolled_map(segments) do segment
143+
(D_min, D_max) = segment
144+
(a, b) = ice_mass_coeffs(state, (D_min + D_max) / 2)
145+
loggamma_inc_moment(D_min, D_max, μ, logλ, b + n, a)
146+
end
147+
return UT.unrolled_logsumexp(moments)
147148
end
148149

149150
"""

src/Utilities.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ Contains pure numerical operations with no physics dependencies.
66
"""
77
module Utilities
88

9+
import UnrolledUtilities as UU
10+
911
export clamp_to_nonneg, ϵ_numerics, ϵ_numerics_2M_M, ϵ_numerics_2M_N
12+
export unrolled_logsumexp
1013

1114
"""
1215
clamp_to_nonneg(x)
@@ -44,4 +47,77 @@ Numerical epsilon for 2-moment number calculations.
4447
"""
4548
@inline ϵ_numerics_2M_N(FT) = eps(FT)
4649

50+
51+
"""
52+
unrolled_logsumexp(x)
53+
54+
Compute `log(sum(exp, x))` in a statically unrolled fashion.
55+
56+
This method uses [`UnrolledUtilities`](https://github.com/CliMA/UnrolledUtilities.jl)
57+
to produce fully unrolled code with no dynamic dispatch or reductions,
58+
making it transparent to GPU compilers.
59+
60+
The standard shift-by-max trick is used for numerical stability.
61+
62+
Note: This code is two-pass (find max, then sum shifted exponentials).
63+
LogExpFunctions.jl implements a one-pass version, but is not unrolled,
64+
so may result in more complicated GPU code.
65+
66+
## Extended help
67+
68+
Other implementation options were considered, detailed below, and may be revisited in the future.
69+
For now, this implementation is sufficient.
70+
71+
### Naive implementation
72+
73+
This is the most straightforward implementation, but it is not numerically stable.
74+
75+
```julia
76+
log(UU.unrolled_sum(exp, x))
77+
```
78+
79+
### One-pass unrolled implementation
80+
81+
This is reaches into LogExpFunctions.jl internals,
82+
83+
```julia
84+
FT = eltype(x)
85+
return LogExpFunctions._logsumexp_onepass_result(
86+
UU.unrolled_reduce(LogExpFunctions._logsumexp_onepass_op, x, (FT(-Inf), zero(FT)))
87+
)
88+
```
89+
90+
### Dispatch-wrapper for reduce
91+
92+
Pass a wrapper to compile to unrolled reduce
93+
```julia
94+
# Note: This is a sketch, not tested
95+
struct UnrolledWrapper{T}
96+
x::T
97+
end
98+
Base.iterate(w::UnrolledWrapper) = iterate(w.x)
99+
Base.iterate(w::UnrolledWrapper, state) = iterate(w.x, state)
100+
Base.length(w::UnrolledWrapper) = length(w.x)
101+
Base.eltype(w::UnrolledWrapper) = eltype(w.x)
102+
Base.reduce(op, w::UnrolledWrapper) = UU.unrolled_reduce(op, w.x) # use unrolled reduce
103+
# ... then call:
104+
LogExpFunctions.logsumexp(UnrolledWrapper(x))
105+
```
106+
"""
107+
function unrolled_logsumexp(x)
108+
# Find the maximum (ps: if any element is NaN, then xmax = NaN)
109+
xmax = UU.unrolled_maximum(x)
110+
111+
# Handle non-finite values: if xmax is +Inf or -Inf or NaN, return it directly
112+
# (avoids Inf - Inf = NaN and x - NaN = NaN in the shifted exponentials below)
113+
isfinite(xmax) || return xmax
114+
115+
# Sum shifted exponentials
116+
shifted_exp(xi) = exp(xi - xmax)
117+
s = UU.unrolled_sum(shifted_exp, x)
118+
119+
return xmax + log(s)
120+
end
121+
122+
47123
end # module

0 commit comments

Comments
 (0)