diff --git a/nuance/periodic_search.py b/nuance/periodic_search.py index 3b2546a..08e88f6 100644 --- a/nuance/periodic_search.py +++ b/nuance/periodic_search.py @@ -1,21 +1,20 @@ """ -The periodic search module provides functions to compute the probability of -a periodic signal to be present in the data, using quantities computed from single +The periodic search module provides functions to compute the probability of +a periodic signal to be present in the data, using quantities computed from single events statistics. """ -import os - import multiprocess as mp +from functools import partial +import jax import numpy as np from tqdm.auto import tqdm -from nuance import core +from nuance import DEVICES_COUNT, core from nuance.utils import interp_split_times -def periodic_search(epochs, durations, ls, snr_f, progress=True, - processes=os.cpu_count()): +def periodic_search(epochs, durations, ls, snr_f, progress=True): """Returns a function that performs the periodic search given an array of periods. Parameters @@ -30,30 +29,40 @@ def periodic_search(epochs, durations, ls, snr_f, progress=True, Function that computes the SNR given the epoch, duration and period. progress : bool, optional wether to show progress bar, by default True - processes : int, optional - Number of processes to use, by default mp.cpu_count() Returns ------- callable Function that computes the SNR and parameters for each period. """ - global fold_f + fold_f = _fold_ll(epochs, *ls) def _progress(x, **kwargs): return tqdm(x, **kwargs) if progress else x - def function(periods): + def function(periods, processes=DEVICES_COUNT, batch_size=DEVICES_COUNT): snr = np.zeros(len(periods)) params = np.zeros((len(periods), 3)) - with mp.Pool(processes) as pool: - for p, (epoch, duration_i, period) in enumerate( - _progress(pool.imap(_solve, periods), total=len(periods)) - ): - Dj = durations[duration_i] - snr[p], params[p] = float(snr_f(epoch, Dj, period)), (epoch, Dj, period) + # Use multiprocessing to get the optimal epoch and duration at each period. + solve_f = partial(_solve, fold_f) + ctx = mp.get_context('spawn') # Can't use fork with jax. + with ctx.Pool(processes=processes) as pool: + period_chunks = [periods[i::processes] for i in range(processes)] + + for i, result in enumerate(_progress(pool.imap(solve_f, period_chunks), total=processes)): + epochs_chunk, duration_idx_chunk, periods_chunk = result + params[i::processes, 0] = epochs_chunk + params[i::processes, 1] = durations[duration_idx_chunk] + params[i::processes, 2] = periods_chunk + + # Use jax.vmap to get the SNR at each period. + snr_vmap = jax.pmap(snr_f, in_axes=(0, 0, 0)) + for i in _progress(range(0, len(periods), batch_size), unit_scale=batch_size): + imin = i + imax = i + batch_size + snr[imin:imax] = snr_vmap(params[imin:imax, 0], params[imin:imax, 1], params[imin:imax, 2]) return snr, params @@ -66,11 +75,10 @@ def _fold_ll(epochs, lls, z, vz): f_dz2 = core.nearest_neighbors(epochs, vz) def _fold(times): - lls = np.array([f_ll(time) for time in times]) - zs = np.array([f_z(time) for time in times]) - vzs = np.array([f_dz2(time) for time in times]) + lls = f_ll(times) + zs = f_z(times) + vzs = f_dz2(times) - P1 = np.sum(lls, 0) vZ = 1 / np.sum(1 / vzs, 0) Z = vZ * np.sum(zs / vzs, 0) P1 = np.sum(lls, 0) @@ -89,8 +97,25 @@ def fun(period): return fun -def _solve(period): - phase, lls = fold_f(period) - epoch_i, duration_i = np.unravel_index(np.argmax(lls), lls.shape) - epoch = phase[epoch_i] * period - return epoch, duration_i, period +def _solve(fold_f, periods): + + epochs = np.zeros_like(periods) + duration_idx = np.zeros_like(periods, dtype='int') + + for i, period in enumerate(periods): + phase, lls = fold_f(period) + epoch_i, duration_i = np.unravel_index(np.argmax(lls), lls.shape) + epoch = phase[epoch_i] * period + + epochs[i] = epoch + duration_idx[i] = duration_i + + return epochs, duration_idx, periods + + +def main(): + return + + +if __name__ == '__main__': + main()