Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions nuance/periodic_search.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
"""
The periodic search module provides functions to compute the probability of
a periodic signal to be present in the data, using quantities computed from single
The periodic search module provides functions to compute the probability of
a periodic signal to be present in the data, using quantities computed from single
events statistics.
"""

import os

import multiprocess as mp
from functools import partial
import jax
import numpy as np
from tqdm.auto import tqdm

from nuance import core
from nuance import DEVICES_COUNT, core
from nuance.utils import interp_split_times


def periodic_search(epochs, durations, ls, snr_f, progress=True,
processes=os.cpu_count()):
def periodic_search(epochs, durations, ls, snr_f, progress=True):
"""Returns a function that performs the periodic search given an array of periods.

Parameters
Expand All @@ -30,30 +29,40 @@ def periodic_search(epochs, durations, ls, snr_f, progress=True,
Function that computes the SNR given the epoch, duration and period.
progress : bool, optional
wether to show progress bar, by default True
processes : int, optional
Number of processes to use, by default mp.cpu_count()

Returns
-------
callable
Function that computes the SNR and parameters for each period.
"""
global fold_f

fold_f = _fold_ll(epochs, *ls)

def _progress(x, **kwargs):
return tqdm(x, **kwargs) if progress else x

def function(periods):
def function(periods, processes=DEVICES_COUNT, batch_size=DEVICES_COUNT):
snr = np.zeros(len(periods))
params = np.zeros((len(periods), 3))

with mp.Pool(processes) as pool:
for p, (epoch, duration_i, period) in enumerate(
_progress(pool.imap(_solve, periods), total=len(periods))
):
Dj = durations[duration_i]
snr[p], params[p] = float(snr_f(epoch, Dj, period)), (epoch, Dj, period)
# Use multiprocessing to get the optimal epoch and duration at each period.
solve_f = partial(_solve, fold_f)
ctx = mp.get_context('spawn') # Can't use fork with jax.
with ctx.Pool(processes=processes) as pool:
period_chunks = [periods[i::processes] for i in range(processes)]

for i, result in enumerate(_progress(pool.imap(solve_f, period_chunks), total=processes)):
epochs_chunk, duration_idx_chunk, periods_chunk = result
params[i::processes, 0] = epochs_chunk
params[i::processes, 1] = durations[duration_idx_chunk]
params[i::processes, 2] = periods_chunk

# Use jax.vmap to get the SNR at each period.
snr_vmap = jax.pmap(snr_f, in_axes=(0, 0, 0))
for i in _progress(range(0, len(periods), batch_size), unit_scale=batch_size):
imin = i
imax = i + batch_size
snr[imin:imax] = snr_vmap(params[imin:imax, 0], params[imin:imax, 1], params[imin:imax, 2])

return snr, params

Expand All @@ -66,11 +75,10 @@ def _fold_ll(epochs, lls, z, vz):
f_dz2 = core.nearest_neighbors(epochs, vz)

def _fold(times):
lls = np.array([f_ll(time) for time in times])
zs = np.array([f_z(time) for time in times])
vzs = np.array([f_dz2(time) for time in times])
lls = f_ll(times)
zs = f_z(times)
vzs = f_dz2(times)

P1 = np.sum(lls, 0)
vZ = 1 / np.sum(1 / vzs, 0)
Z = vZ * np.sum(zs / vzs, 0)
P1 = np.sum(lls, 0)
Expand All @@ -89,8 +97,25 @@ def fun(period):
return fun


def _solve(period):
phase, lls = fold_f(period)
epoch_i, duration_i = np.unravel_index(np.argmax(lls), lls.shape)
epoch = phase[epoch_i] * period
return epoch, duration_i, period
def _solve(fold_f, periods):

epochs = np.zeros_like(periods)
duration_idx = np.zeros_like(periods, dtype='int')

for i, period in enumerate(periods):
phase, lls = fold_f(period)
epoch_i, duration_i = np.unravel_index(np.argmax(lls), lls.shape)
epoch = phase[epoch_i] * period

epochs[i] = epoch
duration_idx[i] = duration_i

return epochs, duration_idx, periods


def main():
return


if __name__ == '__main__':
main()