|
| 1 | +using FFTW |
| 2 | + |
| 3 | +""" |
| 4 | + TemplateFFTCache |
| 5 | +
|
| 6 | +FFT-domain cache for templates whose in-transit length `nin` is large |
| 7 | +enough that an FFT-based circular cross-correlation beats the direct |
| 8 | +O(Nphase * nin) inner loop. |
| 9 | +
|
| 10 | +For each cached template `k` we store |
| 11 | +- `F_signals[k] = conj(rfft(pad(s_k, Nphase)))` |
| 12 | +- `F_signal_sq[k] = conj(rfft(pad(s2_k, Nphase)))` |
| 13 | +
|
| 14 | +`fft_indices[i]` gives the index into the parent `TemplateCache` for |
| 15 | +the i-th cached template; templates with `nin < threshold` are not |
| 16 | +cached and continue to use the direct path. |
| 17 | +
|
| 18 | +We deliberately store kernels as separate `Vector{ComplexF64}` rather |
| 19 | +than columns of a `Matrix`: in microbenchmarks the matrix-column form |
| 20 | +inhibits @simd vectorization of the per-template product loop and |
| 21 | +costs ~25% throughput. We also avoid FFTW's batched IFFT plan; for |
| 22 | +non-power-of-two `Nphase` (e.g. 2500) the batched codepath is up to |
| 23 | +7× slower than a loop of independent 1D IFFTs. |
| 24 | +""" |
| 25 | +struct TemplateFFTCache |
| 26 | + Nphase::Int |
| 27 | + threshold::Int |
| 28 | + fft_indices::Vector{Int} |
| 29 | + F_signals::Vector{Vector{ComplexF64}} |
| 30 | + F_signal_sq::Vector{Vector{ComplexF64}} |
| 31 | + rfft_plan::FFTW.rFFTWPlan{Float64,-1,false,1} |
| 32 | + irfft_plan::AbstractFFTs.ScaledPlan |
| 33 | +end |
| 34 | + |
| 35 | +""" |
| 36 | + build_template_fft(templates; threshold) -> TemplateFFTCache |
| 37 | +
|
| 38 | +Precompute FFT-domain copies of every template in `templates` whose |
| 39 | +`intransit_counts[k] >= threshold`. Per period the two forward FFTs |
| 40 | +of the folded data are amortized across all FFT-cached templates, so |
| 41 | +the marginal FFT cost per template is one IFFT plus a length-Nphase |
| 42 | +scan: ~`Nphase * log2(Nphase)` ops, vs `Nphase * nin` for the direct |
| 43 | +SIMD loop. The asymptotic crossover sits near `nin ≈ log2(Nphase)`, |
| 44 | +but FFTW's per-IFFT constant is larger than the @simd direct-loop |
| 45 | +constant, so the empirical optimum on a power-of-2 `Nphase` is closer |
| 46 | +to `1.5 · log2(Nphase)`. |
| 47 | +
|
| 48 | +Pass `threshold = 0` to force every template through the FFT path, |
| 49 | +or `typemax(Int)` to disable FFT entirely. |
| 50 | +""" |
| 51 | +function build_template_fft(templates::TemplateCache; |
| 52 | + threshold::Integer = ceil(Int, 1.5 * log2(max(2, templates.Nphase)))) |
| 53 | + Nphase = templates.Nphase |
| 54 | + nf = Nphase ÷ 2 + 1 |
| 55 | + rplan = plan_rfft(Vector{Float64}(undef, Nphase); flags = FFTW.MEASURE) |
| 56 | + iplan = plan_irfft(Vector{ComplexF64}(undef, nf), Nphase; flags = FFTW.MEASURE) |
| 57 | + |
| 58 | + indices = Int[] |
| 59 | + F_sigs = Vector{Vector{ComplexF64}}() |
| 60 | + F_sig_sq = Vector{Vector{ComplexF64}}() |
| 61 | + |
| 62 | + pad_buf = Vector{Float64}(undef, Nphase) |
| 63 | + for k in eachindex(templates.signals) |
| 64 | + nin = templates.intransit_counts[k] |
| 65 | + nin >= threshold || continue |
| 66 | + s = templates.signals[k] |
| 67 | + s2 = templates.signal_sq[k] |
| 68 | + |
| 69 | + fill!(pad_buf, 0.0) |
| 70 | + @inbounds for j in 1:nin |
| 71 | + pad_buf[j] = s[j] |
| 72 | + end |
| 73 | + Fs = rplan * pad_buf |
| 74 | + @inbounds for i in eachindex(Fs) |
| 75 | + Fs[i] = conj(Fs[i]) |
| 76 | + end |
| 77 | + |
| 78 | + fill!(pad_buf, 0.0) |
| 79 | + @inbounds for j in 1:nin |
| 80 | + pad_buf[j] = s2[j] |
| 81 | + end |
| 82 | + Fs2 = rplan * pad_buf |
| 83 | + @inbounds for i in eachindex(Fs2) |
| 84 | + Fs2[i] = conj(Fs2[i]) |
| 85 | + end |
| 86 | + |
| 87 | + push!(indices, k) |
| 88 | + push!(F_sigs, Fs) |
| 89 | + push!(F_sig_sq, Fs2) |
| 90 | + end |
| 91 | + |
| 92 | + return TemplateFFTCache(Nphase, Int(threshold), indices, F_sigs, F_sig_sq, |
| 93 | + rplan, iplan) |
| 94 | +end |
| 95 | + |
| 96 | +""" |
| 97 | + fft_scratch(Nphase) -> NamedTuple |
| 98 | +
|
| 99 | +Per-thread scratch buffers for the FFT inner-loop path. The single |
| 100 | +`F_tmp` complex buffer is reused across templates inside one call to |
| 101 | +`fold_and_score_hybrid!`; `wys` / `ws2` hold the IFFT outputs for the |
| 102 | +current template before the chi² scan consumes them. |
| 103 | +""" |
| 104 | +function fft_scratch(Nphase::Integer) |
| 105 | + nf = Nphase ÷ 2 + 1 |
| 106 | + (F_y = Vector{ComplexF64}(undef, nf), |
| 107 | + F_w = Vector{ComplexF64}(undef, nf), |
| 108 | + F_tmp = Vector{ComplexF64}(undef, nf), |
| 109 | + wys = Vector{Float64}(undef, Nphase), |
| 110 | + ws2 = Vector{Float64}(undef, Nphase)) |
| 111 | +end |
| 112 | + |
| 113 | +""" |
| 114 | + fold_and_score_hybrid!(folded_y, folded_w, scratch, time, y, w, |
| 115 | + sum_wy2, period, templates, fft_cache) -> PeriodBest |
| 116 | +
|
| 117 | +Phase-fold the data and search every template-offset pair for the |
| 118 | +minimum χ². Templates with `nin >= fft_cache.threshold` are evaluated |
| 119 | +through a single batched FFT cross-correlation; the rest use the |
| 120 | +direct SIMD path identical to `fold_and_score!`. |
| 121 | +""" |
| 122 | +function fold_and_score_hybrid!(folded_y::Vector{Float64}, |
| 123 | + folded_w::Vector{Float64}, |
| 124 | + scratch, |
| 125 | + time::AbstractVector{<:Real}, |
| 126 | + y::AbstractVector{<:Real}, |
| 127 | + w::AbstractVector{<:Real}, |
| 128 | + sum_wy2::Real, |
| 129 | + period::Real, |
| 130 | + templates::TemplateCache, |
| 131 | + fft_cache::TemplateFFTCache) |
| 132 | + Nphase = templates.Nphase |
| 133 | + @assert fft_cache.Nphase == Nphase |
| 134 | + fill!(folded_y, 0.0) |
| 135 | + fill!(folded_w, 0.0) |
| 136 | + |
| 137 | + invP = 1.0 / period |
| 138 | + @inbounds for i in eachindex(time) |
| 139 | + φ = time[i] * invP |
| 140 | + φ -= floor(φ) |
| 141 | + bin = unsafe_trunc(Int, φ * Nphase) + 1 |
| 142 | + if bin > Nphase |
| 143 | + bin = Nphase |
| 144 | + end |
| 145 | + folded_w[bin] += w[i] |
| 146 | + folded_y[bin] += y[i] * w[i] |
| 147 | + end |
| 148 | + |
| 149 | + # Track the best in terms of the partial form `-sum_wys² / sum_ws2`, |
| 150 | + # which is monotone with χ² (full χ² = sum_wy2 + this value). Skipping |
| 151 | + # the constant `sum_wy2` and the redundant `2·depth·sum_wys` term tightens |
| 152 | + # the inner argmin. |
| 153 | + best_partial = Inf |
| 154 | + best_offset = 0 |
| 155 | + best_k = 1 |
| 156 | + best_depth = 0.0 |
| 157 | + |
| 158 | + K = length(fft_cache.fft_indices) |
| 159 | + threshold = fft_cache.threshold |
| 160 | + |
| 161 | + if K > 0 |
| 162 | + F_y, F_w = scratch.F_y, scratch.F_w |
| 163 | + F_tmp = scratch.F_tmp |
| 164 | + wys_buf, ws2_buf = scratch.wys, scratch.ws2 |
| 165 | + nf = length(F_y) |
| 166 | + |
| 167 | + mul!(F_y, fft_cache.rfft_plan, folded_y) |
| 168 | + mul!(F_w, fft_cache.rfft_plan, folded_w) |
| 169 | + |
| 170 | + @inbounds for ki in 1:K |
| 171 | + k = fft_cache.fft_indices[ki] |
| 172 | + Fs = fft_cache.F_signals[ki] |
| 173 | + Fs2 = fft_cache.F_signal_sq[ki] |
| 174 | + |
| 175 | + @simd for i in 1:nf |
| 176 | + F_tmp[i] = F_y[i] * Fs[i] |
| 177 | + end |
| 178 | + mul!(wys_buf, fft_cache.irfft_plan, F_tmp) |
| 179 | + |
| 180 | + @simd for i in 1:nf |
| 181 | + F_tmp[i] = F_w[i] * Fs2[i] |
| 182 | + end |
| 183 | + mul!(ws2_buf, fft_cache.irfft_plan, F_tmp) |
| 184 | + |
| 185 | + @simd for o in 1:Nphase |
| 186 | + sum_wys = wys_buf[o] |
| 187 | + sum_ws2 = ws2_buf[o] |
| 188 | + if sum_ws2 > 0 |
| 189 | + depth = sum_wys / sum_ws2 |
| 190 | + partial = -sum_wys * depth |
| 191 | + if partial < best_partial |
| 192 | + best_partial = partial |
| 193 | + best_offset = o - 1 |
| 194 | + best_k = k |
| 195 | + best_depth = depth |
| 196 | + end |
| 197 | + end |
| 198 | + end |
| 199 | + end |
| 200 | + end |
| 201 | + |
| 202 | + @inbounds for k in eachindex(templates.signals) |
| 203 | + nin = templates.intransit_counts[k] |
| 204 | + nin >= threshold && continue |
| 205 | + s = templates.signals[k] |
| 206 | + s2 = templates.signal_sq[k] |
| 207 | + for o in 0:(Nphase - 1) |
| 208 | + start = o + 1 |
| 209 | + last_no_wrap = min(start + nin - 1, Nphase) |
| 210 | + len1 = last_no_wrap - start + 1 |
| 211 | + len2 = nin - len1 |
| 212 | + |
| 213 | + sum_wys = 0.0 |
| 214 | + sum_ws2 = 0.0 |
| 215 | + @simd for j in 1:len1 |
| 216 | + bin = start + j - 1 |
| 217 | + sum_wys += folded_y[bin] * s[j] |
| 218 | + sum_ws2 += folded_w[bin] * s2[j] |
| 219 | + end |
| 220 | + if len2 > 0 |
| 221 | + @simd for j in 1:len2 |
| 222 | + bin = j |
| 223 | + sum_wys += folded_y[bin] * s[len1 + j] |
| 224 | + sum_ws2 += folded_w[bin] * s2[len1 + j] |
| 225 | + end |
| 226 | + end |
| 227 | + sum_ws2 > 0 || continue |
| 228 | + depth = sum_wys / sum_ws2 |
| 229 | + partial = -sum_wys * depth |
| 230 | + if partial < best_partial |
| 231 | + best_partial = partial |
| 232 | + best_offset = o |
| 233 | + best_k = k |
| 234 | + best_depth = depth |
| 235 | + end |
| 236 | + end |
| 237 | + end |
| 238 | + |
| 239 | + nin = templates.intransit_counts[best_k] |
| 240 | + mid_bin = mod(best_offset + (nin + 1) / 2 - 0.5, Nphase) |
| 241 | + best_t0 = (mid_bin / Nphase) * period |
| 242 | + best_chi2 = sum_wy2 + best_partial |
| 243 | + |
| 244 | + return PeriodBest(best_chi2, best_t0, best_k, best_depth) |
| 245 | +end |
0 commit comments