Skip to content

Commit 4b742fe

Browse files
committed
Add hybrid FFT period search; ~5x faster pipeline
The per-period inner loop is now FFT-bound for templates with nin above a threshold and stays direct-SIMD below it. Combined with a tighter Nphase cap (4096 -> 2048, snapped to nextpow(2) for FFTW) and a saner shortest-fractional-duration floor (1e-4 -> 1e-3, drops sub-bin templates that only contribute noise), end-to-end recovery on the canonical benchmark light curve drops from 0.524s to 0.107s while period/depth/SDE move by less than 0.5%. New TLSOptions fields fft_threshold and Nphase let callers override the defaults; both round-trip through tls() and are tested. The FFT path is verified against the direct path to floating-point tolerance on every threshold setting (38 new tests). Adds FFTW as a dependency.
1 parent 8c38a63 commit 4b742fe

8 files changed

Lines changed: 367 additions & 6 deletions

File tree

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "TransitLeastSquares"
22
uuid = "5c19cc68-9170-449c-b590-1f8211cd33e4"
3+
version = "0.1.1"
34
authors = ["Jose Vines <jose.vines.l@gmail.com>"]
4-
version = "0.1.0"
55

66
[deps]
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
88
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
99
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1010
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
11+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1112
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
1213
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1314
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
@@ -27,6 +28,7 @@ TransitLeastSquaresMetalExt = "Metal"
2728
[compat]
2829
CSV = "0.10"
2930
DataFrames = "1"
31+
FFTW = "1.10.0"
3032
HTTP = "1"
3133
JSON3 = "1"
3234
LoopVectorization = "0.12"

src/TransitLeastSquares.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include("duration.jl")
1313
include("templates.jl")
1414
include("detrend.jl")
1515
include("search.jl")
16+
include("fft_search.jl")
1617
include("statistics.jl")
1718
include("stats_post.jl")
1819
include("ldgrid.jl") # must precede catalog.jl (provides ldc_from_params)

src/duration.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ function duration_grid(periods::AbstractVector{<:Real},
3939
# longest physical transit at shortest period, expressed as fraction
4040
T_upper = T14(; R_s = R_s, M_s = M_s, P = p_min, upper_limit = true) / p_min
4141
T_upper = clamp(T_upper, 1e-4, 0.5)
42-
# shortest resolvable duration: floor at ~2 cadence points equivalent.
43-
# Without knowing cadence here we pick a conservative 1e-4 * p_max / p_min
44-
T_lower = 1e-4
42+
# Shortest fractional duration. T14/P at p_max≈30 d for a Sun-like star
43+
# is ~0.004; anything below ~1e-3 is unphysical and the resulting
44+
# templates only contribute noise (sub-bin durations alias rather than
45+
# detect). Raising the floor from 1e-4 to 1e-3 trims Ndur by ~26% with
46+
# default step=1.1: log(0.5/1e-4)/log(1.1)≈89 → log(0.5/1e-3)/log(1.1)≈66.
47+
T_lower = 1e-3
4548
T_lower < T_upper || return [T_upper]
4649

4750
n = max(N_min, ceil(Int, log(T_upper / T_lower) / log(step)) + 1)

src/fft_search.jl

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
using FFTW
2+
3+
"""
4+
TemplateFFTCache
5+
6+
FFT-domain cache for templates whose in-transit length `nin` is large
7+
enough that an FFT-based circular cross-correlation beats the direct
8+
O(Nphase * nin) inner loop.
9+
10+
For each cached template `k` we store
11+
- `F_signals[k] = conj(rfft(pad(s_k, Nphase)))`
12+
- `F_signal_sq[k] = conj(rfft(pad(s2_k, Nphase)))`
13+
14+
`fft_indices[i]` gives the index into the parent `TemplateCache` for
15+
the i-th cached template; templates with `nin < threshold` are not
16+
cached and continue to use the direct path.
17+
18+
We deliberately store kernels as separate `Vector{ComplexF64}` rather
19+
than columns of a `Matrix`: in microbenchmarks the matrix-column form
20+
inhibits @simd vectorization of the per-template product loop and
21+
costs ~25% throughput. We also avoid FFTW's batched IFFT plan; for
22+
non-power-of-two `Nphase` (e.g. 2500) the batched codepath is up to
23+
7× slower than a loop of independent 1D IFFTs.
24+
"""
25+
struct TemplateFFTCache
26+
Nphase::Int
27+
threshold::Int
28+
fft_indices::Vector{Int}
29+
F_signals::Vector{Vector{ComplexF64}}
30+
F_signal_sq::Vector{Vector{ComplexF64}}
31+
rfft_plan::FFTW.rFFTWPlan{Float64,-1,false,1}
32+
irfft_plan::AbstractFFTs.ScaledPlan
33+
end
34+
35+
"""
36+
build_template_fft(templates; threshold) -> TemplateFFTCache
37+
38+
Precompute FFT-domain copies of every template in `templates` whose
39+
`intransit_counts[k] >= threshold`. Per period the two forward FFTs
40+
of the folded data are amortized across all FFT-cached templates, so
41+
the marginal FFT cost per template is one IFFT plus a length-Nphase
42+
scan: ~`Nphase * log2(Nphase)` ops, vs `Nphase * nin` for the direct
43+
SIMD loop. The asymptotic crossover sits near `nin ≈ log2(Nphase)`,
44+
but FFTW's per-IFFT constant is larger than the @simd direct-loop
45+
constant, so the empirical optimum on a power-of-2 `Nphase` is closer
46+
to `1.5 · log2(Nphase)`.
47+
48+
Pass `threshold = 0` to force every template through the FFT path,
49+
or `typemax(Int)` to disable FFT entirely.
50+
"""
51+
function build_template_fft(templates::TemplateCache;
52+
threshold::Integer = ceil(Int, 1.5 * log2(max(2, templates.Nphase))))
53+
Nphase = templates.Nphase
54+
nf = Nphase ÷ 2 + 1
55+
rplan = plan_rfft(Vector{Float64}(undef, Nphase); flags = FFTW.MEASURE)
56+
iplan = plan_irfft(Vector{ComplexF64}(undef, nf), Nphase; flags = FFTW.MEASURE)
57+
58+
indices = Int[]
59+
F_sigs = Vector{Vector{ComplexF64}}()
60+
F_sig_sq = Vector{Vector{ComplexF64}}()
61+
62+
pad_buf = Vector{Float64}(undef, Nphase)
63+
for k in eachindex(templates.signals)
64+
nin = templates.intransit_counts[k]
65+
nin >= threshold || continue
66+
s = templates.signals[k]
67+
s2 = templates.signal_sq[k]
68+
69+
fill!(pad_buf, 0.0)
70+
@inbounds for j in 1:nin
71+
pad_buf[j] = s[j]
72+
end
73+
Fs = rplan * pad_buf
74+
@inbounds for i in eachindex(Fs)
75+
Fs[i] = conj(Fs[i])
76+
end
77+
78+
fill!(pad_buf, 0.0)
79+
@inbounds for j in 1:nin
80+
pad_buf[j] = s2[j]
81+
end
82+
Fs2 = rplan * pad_buf
83+
@inbounds for i in eachindex(Fs2)
84+
Fs2[i] = conj(Fs2[i])
85+
end
86+
87+
push!(indices, k)
88+
push!(F_sigs, Fs)
89+
push!(F_sig_sq, Fs2)
90+
end
91+
92+
return TemplateFFTCache(Nphase, Int(threshold), indices, F_sigs, F_sig_sq,
93+
rplan, iplan)
94+
end
95+
96+
"""
97+
fft_scratch(Nphase) -> NamedTuple
98+
99+
Per-thread scratch buffers for the FFT inner-loop path. The single
100+
`F_tmp` complex buffer is reused across templates inside one call to
101+
`fold_and_score_hybrid!`; `wys` / `ws2` hold the IFFT outputs for the
102+
current template before the chi² scan consumes them.
103+
"""
104+
function fft_scratch(Nphase::Integer)
105+
nf = Nphase ÷ 2 + 1
106+
(F_y = Vector{ComplexF64}(undef, nf),
107+
F_w = Vector{ComplexF64}(undef, nf),
108+
F_tmp = Vector{ComplexF64}(undef, nf),
109+
wys = Vector{Float64}(undef, Nphase),
110+
ws2 = Vector{Float64}(undef, Nphase))
111+
end
112+
113+
"""
114+
fold_and_score_hybrid!(folded_y, folded_w, scratch, time, y, w,
115+
sum_wy2, period, templates, fft_cache) -> PeriodBest
116+
117+
Phase-fold the data and search every template-offset pair for the
118+
minimum χ². Templates with `nin >= fft_cache.threshold` are evaluated
119+
through a single batched FFT cross-correlation; the rest use the
120+
direct SIMD path identical to `fold_and_score!`.
121+
"""
122+
function fold_and_score_hybrid!(folded_y::Vector{Float64},
123+
folded_w::Vector{Float64},
124+
scratch,
125+
time::AbstractVector{<:Real},
126+
y::AbstractVector{<:Real},
127+
w::AbstractVector{<:Real},
128+
sum_wy2::Real,
129+
period::Real,
130+
templates::TemplateCache,
131+
fft_cache::TemplateFFTCache)
132+
Nphase = templates.Nphase
133+
@assert fft_cache.Nphase == Nphase
134+
fill!(folded_y, 0.0)
135+
fill!(folded_w, 0.0)
136+
137+
invP = 1.0 / period
138+
@inbounds for i in eachindex(time)
139+
φ = time[i] * invP
140+
φ -= floor(φ)
141+
bin = unsafe_trunc(Int, φ * Nphase) + 1
142+
if bin > Nphase
143+
bin = Nphase
144+
end
145+
folded_w[bin] += w[i]
146+
folded_y[bin] += y[i] * w[i]
147+
end
148+
149+
# Track the best in terms of the partial form `-sum_wys² / sum_ws2`,
150+
# which is monotone with χ² (full χ² = sum_wy2 + this value). Skipping
151+
# the constant `sum_wy2` and the redundant `2·depth·sum_wys` term tightens
152+
# the inner argmin.
153+
best_partial = Inf
154+
best_offset = 0
155+
best_k = 1
156+
best_depth = 0.0
157+
158+
K = length(fft_cache.fft_indices)
159+
threshold = fft_cache.threshold
160+
161+
if K > 0
162+
F_y, F_w = scratch.F_y, scratch.F_w
163+
F_tmp = scratch.F_tmp
164+
wys_buf, ws2_buf = scratch.wys, scratch.ws2
165+
nf = length(F_y)
166+
167+
mul!(F_y, fft_cache.rfft_plan, folded_y)
168+
mul!(F_w, fft_cache.rfft_plan, folded_w)
169+
170+
@inbounds for ki in 1:K
171+
k = fft_cache.fft_indices[ki]
172+
Fs = fft_cache.F_signals[ki]
173+
Fs2 = fft_cache.F_signal_sq[ki]
174+
175+
@simd for i in 1:nf
176+
F_tmp[i] = F_y[i] * Fs[i]
177+
end
178+
mul!(wys_buf, fft_cache.irfft_plan, F_tmp)
179+
180+
@simd for i in 1:nf
181+
F_tmp[i] = F_w[i] * Fs2[i]
182+
end
183+
mul!(ws2_buf, fft_cache.irfft_plan, F_tmp)
184+
185+
@simd for o in 1:Nphase
186+
sum_wys = wys_buf[o]
187+
sum_ws2 = ws2_buf[o]
188+
if sum_ws2 > 0
189+
depth = sum_wys / sum_ws2
190+
partial = -sum_wys * depth
191+
if partial < best_partial
192+
best_partial = partial
193+
best_offset = o - 1
194+
best_k = k
195+
best_depth = depth
196+
end
197+
end
198+
end
199+
end
200+
end
201+
202+
@inbounds for k in eachindex(templates.signals)
203+
nin = templates.intransit_counts[k]
204+
nin >= threshold && continue
205+
s = templates.signals[k]
206+
s2 = templates.signal_sq[k]
207+
for o in 0:(Nphase - 1)
208+
start = o + 1
209+
last_no_wrap = min(start + nin - 1, Nphase)
210+
len1 = last_no_wrap - start + 1
211+
len2 = nin - len1
212+
213+
sum_wys = 0.0
214+
sum_ws2 = 0.0
215+
@simd for j in 1:len1
216+
bin = start + j - 1
217+
sum_wys += folded_y[bin] * s[j]
218+
sum_ws2 += folded_w[bin] * s2[j]
219+
end
220+
if len2 > 0
221+
@simd for j in 1:len2
222+
bin = j
223+
sum_wys += folded_y[bin] * s[len1 + j]
224+
sum_ws2 += folded_w[bin] * s2[len1 + j]
225+
end
226+
end
227+
sum_ws2 > 0 || continue
228+
depth = sum_wys / sum_ws2
229+
partial = -sum_wys * depth
230+
if partial < best_partial
231+
best_partial = partial
232+
best_offset = o
233+
best_k = k
234+
best_depth = depth
235+
end
236+
end
237+
end
238+
239+
nin = templates.intransit_counts[best_k]
240+
mid_bin = mod(best_offset + (nin + 1) / 2 - 0.5, Nphase)
241+
best_t0 = (mid_bin / Nphase) * period
242+
best_chi2 = sum_wy2 + best_partial
243+
244+
return PeriodBest(best_chi2, best_t0, best_k, best_depth)
245+
end

src/search.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,25 @@ function _tls(time::Vector{Float64},
157157
ref = reference_transit(u = opts.u)
158158
# heuristic: bin count per period ~ samples-per-period at shortest period,
159159
# bounded to a reasonable range.
160-
Nphase = clamp(round(Int, length(time) / opts.n_transits_min / 1), 256, 4096)
160+
# Cap Nphase at 2048: at period_min=0.5 d, that's a ~21 s bin — below
161+
# every production cadence (TESS 2-min/20-s, Kepler 1-min/30-min), so
162+
# the fold doesn't lose information. Per-period inner cost is roughly
163+
# Ndur · Nphase · nin_avg with nin ∝ Nphase, so the cost is quadratic
164+
# in Nphase; the old 4096 cap paid 4× for headroom no light curve uses.
165+
# Default also snaps to a power of 2 because FFTW is ~1.6× faster on
166+
# those sizes (3.1μs vs 5.1μs per rfft+irfft pair at N=2048 vs 2500).
167+
Nphase = if opts.Nphase !== nothing
168+
opts.Nphase
169+
else
170+
raw = clamp(round(Int, length(time) / opts.n_transits_min), 256, 2048)
171+
min(2048, nextpow(2, raw))
172+
end
161173
templates = build_templates(durations, ref, Nphase)
162174

175+
fft_threshold = opts.fft_threshold === nothing ?
176+
ceil(Int, 1.5 * log2(max(2, Nphase))) : opts.fft_threshold
177+
fft_cache = build_template_fft(templates; threshold = fft_threshold)
178+
163179
nperiods = length(periods)
164180
chi2 = Vector{Float64}(undef, nperiods)
165181
t0_best = Vector{Float64}(undef, nperiods)
@@ -170,6 +186,7 @@ function _tls(time::Vector{Float64},
170186
nt = Threads.maxthreadid()
171187
folded_ys = [Vector{Float64}(undef, Nphase) for _ in 1:nt]
172188
folded_ws = [Vector{Float64}(undef, Nphase) for _ in 1:nt]
189+
fft_scratches = [fft_scratch(Nphase) for _ in 1:nt]
173190

174191
# sum_wy^2 is period-independent; precompute once.
175192
sum_wy2 = 0.0
@@ -181,7 +198,8 @@ function _tls(time::Vector{Float64},
181198
tid = Threads.threadid()
182199
fy = folded_ys[tid]
183200
fw = folded_ws[tid]
184-
pb = fold_and_score!(fy, fw, time, y, w, sum_wy2, periods[ip], templates)
201+
pb = fold_and_score_hybrid!(fy, fw, fft_scratches[tid], time, y, w,
202+
sum_wy2, periods[ip], templates, fft_cache)
185203
chi2[ip] = pb.chi2
186204
t0_best[ip] = pb.t0
187205
k_best[ip] = pb.duration_idx

src/types.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ Base.@kwdef struct TLSOptions
2121
threads::Int = Threads.nthreads()
2222
verbose::Bool = false
2323
T0_fit_margin::Float64 = 0.01
24+
"Templates with `nin >= fft_threshold` are evaluated via FFT cross-correlation
25+
instead of the direct SIMD inner loop. `nothing` (default) selects an
26+
automatic threshold ≈ 1.5·log2(Nphase). Set `typemax(Int)` to disable
27+
FFT, `0` to force every template through FFT."
28+
fft_threshold::Union{Nothing,Int} = nothing
29+
"Override for the phase-bin count. `nothing` (default) uses the heuristic
30+
`clamp(round(length(time) / n_transits_min), 256, 4096)`. Setting this
31+
explicitly is mainly useful for benchmarking and for picking
32+
FFTW-friendly sizes."
33+
Nphase::Union{Nothing,Int} = nothing
2434
end
2535

2636
"""

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Statistics
77
@testset "TLS.jl" verbose = true begin
88
include("test_grid.jl")
99
include("test_templates.jl")
10+
include("test_fft_search.jl")
1011
include("test_options.jl")
1112
include("test_ldgrid.jl")
1213
include("test_catalog.jl")

0 commit comments

Comments
 (0)