@@ -6,7 +6,10 @@ Contains pure numerical operations with no physics dependencies.
66"""
77module Utilities
88
9+ import UnrolledUtilities as UU
10+
911export clamp_to_nonneg, ϵ_numerics, ϵ_numerics_2M_M, ϵ_numerics_2M_N
12+ export unrolled_logsumexp
1013
1114"""
1215 clamp_to_nonneg(x)
@@ -44,4 +47,77 @@ Numerical epsilon for 2-moment number calculations.
4447"""
4548@inline ϵ_numerics_2M_N (FT) = eps (FT)
4649
50+
51+ """
52+ unrolled_logsumexp(x)
53+
54+ Compute `log(sum(exp, x))` in a statically unrolled fashion.
55+
56+ This method uses [`UnrolledUtilities`](https://github.com/CliMA/UnrolledUtilities.jl)
57+ to produce fully unrolled code with no dynamic dispatch or reductions,
58+ making it transparent to GPU compilers.
59+
60+ The standard shift-by-max trick is used for numerical stability.
61+
62+ Note: This code is two-pass (find max, then sum shifted exponentials).
63+ LogExpFunctions.jl implements a one-pass version, but is not unrolled,
64+ so may result in more complicated GPU code.
65+
66+ ## Extended help
67+
68+ Other implementation options were considered, detailed below, and may be revisited in the future.
69+ For now, this implementation is sufficient.
70+
71+ ### Naive implementation
72+
73+ This is the most straightforward implementation, but it is not numerically stable.
74+
75+ ```julia
76+ log(UU.unrolled_sum(exp, x))
77+ ```
78+
79+ ### One-pass unrolled implementation
80+
81+ This is reaches into LogExpFunctions.jl internals,
82+
83+ ```julia
84+ FT = eltype(x)
85+ return LogExpFunctions._logsumexp_onepass_result(
86+ UU.unrolled_reduce(LogExpFunctions._logsumexp_onepass_op, x, (FT(-Inf), zero(FT)))
87+ )
88+ ```
89+
90+ ### Dispatch-wrapper for reduce
91+
92+ Pass a wrapper to compile to unrolled reduce
93+ ```julia
94+ # Note: This is a sketch, not tested
95+ struct UnrolledWrapper{T}
96+ x::T
97+ end
98+ Base.iterate(w::UnrolledWrapper) = iterate(w.x)
99+ Base.iterate(w::UnrolledWrapper, state) = iterate(w.x, state)
100+ Base.length(w::UnrolledWrapper) = length(w.x)
101+ Base.eltype(w::UnrolledWrapper) = eltype(w.x)
102+ Base.reduce(op, w::UnrolledWrapper) = UU.unrolled_reduce(op, w.x) # use unrolled reduce
103+ # ... then call:
104+ LogExpFunctions.logsumexp(UnrolledWrapper(x))
105+ ```
106+ """
107+ function unrolled_logsumexp (x)
108+ # Find the maximum (ps: if any element is NaN, then xmax = NaN)
109+ xmax = UU. unrolled_maximum (x)
110+
111+ # Handle non-finite values: if xmax is +Inf or -Inf or NaN, return it directly
112+ # (avoids Inf - Inf = NaN and x - NaN = NaN in the shifted exponentials below)
113+ isfinite (xmax) || return xmax
114+
115+ # Sum shifted exponentials
116+ shifted_exp (xi) = exp (xi - xmax)
117+ s = UU. unrolled_sum (shifted_exp, x)
118+
119+ return xmax + log (s)
120+ end
121+
122+
47123end # module
0 commit comments