4
4
from functools import lru_cache , partial
5
5
from typing_extensions import Callable
6
6
from numpy .typing import NDArray
7
- import sys
7
+ import warnings
8
8
import numdifftools as ndt
9
9
import numpy as np
10
10
import pyfftw
11
11
import so3g
12
- from so3g .proj import Ranges , RangesMatrix
12
+ from so3g .proj import Ranges
13
13
from scipy .optimize import minimize
14
14
from scipy .signal import welch
15
+ from scipy .stats import chi2
15
16
from sotodlib import core , hwp
16
17
from sotodlib .tod_ops import detrend_tod
17
18
@@ -269,10 +270,11 @@ def calc_psd(
269
270
max_samples = 2 ** 18 ,
270
271
prefer = 'center' ,
271
272
freq_spacing = None ,
272
- merge = False ,
273
+ merge = False ,
273
274
merge_suffix = None ,
274
- overwrite = True ,
275
+ overwrite = True ,
275
276
subscan = False ,
277
+ full_output = False ,
276
278
** kwargs
277
279
):
278
280
"""Calculates the power spectrum density of an input signal using signal.welch().
@@ -295,17 +297,38 @@ def calc_psd(
295
297
merge_suffix (str, optional): Suffix to append to the Pxx field name in aman. Defaults to None (merged as Pxx).
296
298
overwrite (bool): if true will overwrite f, Pxx axes.
297
299
subscan (bool): if True, compute psd on subscans.
300
+ full_output: if True this also outputs nseg, the number of segments used for
301
+ welch, for correcting bias of median white noise estimation by calc_wn.
298
302
**kwargs: keyword args to be passed to signal.welch().
299
303
300
304
Returns:
301
305
freqs: array of frequencies corresponding to PSD calculated from welch.
302
306
Pxx: array of PSD values.
307
+ nseg: number of segments used for welch. this is returned if full_output is True.
303
308
"""
304
309
if signal is None :
305
310
signal = aman .signal
311
+
312
+ if ("noverlap" not in kwargs ) or \
313
+ ("noverlap" in kwargs and kwargs ["noverlap" ] != 0 ):
314
+ warnings .warn ('calc_wn will be biased. noverlap argument of welch '
315
+ 'needs to be 0 to get unbiased median white noise estimate.' )
316
+ if not full_output :
317
+ warnings .warn ('calc_wn will be biased. full_output argument of calc_psd '
318
+ 'needs to be True to get unbiased median white noise estimate.' )
319
+
306
320
if subscan :
307
- freqs , Pxx = _calc_psd_subscan (aman , signal = signal , freq_spacing = freq_spacing , ** kwargs )
321
+ if full_output :
322
+ freqs , Pxx , nseg = _calc_psd_subscan (aman , signal = signal ,
323
+ freq_spacing = freq_spacing ,
324
+ full_output = True ,
325
+ ** kwargs )
326
+ else :
327
+ freqs , Pxx = _calc_psd_subscan (aman , signal = signal ,
328
+ freq_spacing = freq_spacing ,
329
+ ** kwargs )
308
330
axis_map_pxx = [(0 , "dets" ), (1 , "nusamps" ), (2 , "subscans" )]
331
+ axis_map_nseg = [(0 , "subscans" )]
309
332
else :
310
333
if timestamps is None :
311
334
timestamps = aman .timestamps
@@ -334,8 +357,14 @@ def calc_psd(
334
357
nperseg = int (2 ** (np .around (np .log2 ((stop - start ) / 50.0 ))))
335
358
kwargs ["nperseg" ] = nperseg
336
359
360
+ if kwargs ["nperseg" ] > max_samples :
361
+ nseg = 1
362
+ else :
363
+ nseg = int (max_samples / kwargs ["nperseg" ])
364
+
337
365
freqs , Pxx = welch (signal [:, start :stop ], fs , ** kwargs )
338
366
axis_map_pxx = [(0 , aman .dets ), (1 , "nusamps" )]
367
+ axis_map_nseg = None
339
368
340
369
if merge :
341
370
if 'nusamps' not in aman :
@@ -345,19 +374,29 @@ def calc_psd(
345
374
if len (freqs ) != aman .nusamps .count :
346
375
raise ValueError ('New freqs does not match the shape of nusamps\
347
376
To avoid this, use the same value for nperseg' )
348
-
377
+
349
378
if merge_suffix is None :
350
379
Pxx_name = 'Pxx'
351
380
else :
352
381
Pxx_name = f'Pxx_{ merge_suffix } '
353
-
382
+
354
383
if overwrite :
355
384
if Pxx_name in aman ._fields :
356
385
aman .move ("Pxx" , None )
357
386
aman .wrap (Pxx_name , Pxx , axis_map_pxx )
358
- return freqs , Pxx
359
387
360
- def _calc_psd_subscan (aman , signal = None , freq_spacing = None , ** kwargs ):
388
+ if full_output :
389
+ if overwrite and "nseg" in aman ._fields :
390
+ aman .move ("nseg" , None )
391
+ aman .wrap ("nseg" , nseg , axis_map_nseg )
392
+
393
+ if full_output :
394
+ return freqs , Pxx , nseg
395
+ else :
396
+ return freqs , Pxx
397
+
398
+
399
+ def _calc_psd_subscan (aman , signal = None , freq_spacing = None , full_output = False , ** kwargs ):
361
400
"""
362
401
Calculate the power spectrum density of subscans using signal.welch().
363
402
Data defaults to aman.signal. aman.timestamps is used for times.
@@ -378,20 +417,27 @@ def _calc_psd_subscan(aman, signal=None, freq_spacing=None, **kwargs):
378
417
nperseg = int (2 ** (np .around (np .log2 (np .median (duration_samps ) / 4 ))))
379
418
kwargs ["nperseg" ] = nperseg
380
419
381
- Pxx = []
420
+ Pxx , nseg = [], []
382
421
for iss in range (aman .subscan_info .subscans .count ):
383
422
signal_ss = get_subscan_signal (aman , signal , iss )
384
423
axis = - 1 if "axis" not in kwargs else kwargs ["axis" ]
385
- if signal_ss .shape [axis ] >= kwargs ["nperseg" ]:
424
+ nsamps = signal_ss .shape [axis ]
425
+ if nsamps >= kwargs ["nperseg" ]:
386
426
freqs , pxx_sub = welch (signal_ss , fs , ** kwargs )
387
427
Pxx .append (pxx_sub )
428
+ nseg .append (int (nsamps / kwargs ["nperseg" ]))
388
429
else :
389
430
Pxx .append (np .full ((signal .shape [0 ], kwargs ["nperseg" ]// 2 + 1 ), np .nan )) # Add nans if subscan is too short
431
+ nseg .append (np .nan )
432
+ nseg = np .array (nseg )
390
433
Pxx = np .array (Pxx )
391
434
Pxx = Pxx .transpose (1 , 2 , 0 ) # Dets, nusamps, subscans
392
- return freqs , Pxx
435
+ if full_output :
436
+ return freqs , Pxx , nseg
437
+ else :
438
+ return freqs , Pxx
393
439
394
- def calc_wn (aman , pxx = None , freqs = None , low_f = 5 , high_f = 10 ):
440
+ def calc_wn (aman , pxx = None , freqs = None , nseg = None , low_f = 5 , high_f = 10 ):
395
441
"""
396
442
Function that calculates the white noise level as a median PSD value between
397
443
two frequencies. Defaults to calculation of white noise between 5 and 10Hz.
@@ -408,6 +454,13 @@ def calc_wn(aman, pxx=None, freqs=None, low_f=5, high_f=10):
408
454
freqs (1d Float array):
409
455
frequency information related to the psd. Defaults to aman.freqs
410
456
457
+ nseg (Int or 1d Int array):
458
+ number of segmnents used for welch. Defaults to aman.nseg. This is
459
+ necessary for debiasing median white noise estimation. welch PSD with
460
+ non-overlapping n segments follows chi square distribution with
461
+ 2 * nseg degrees of freedom. The median of chi square distribution is
462
+ biased from its average.
463
+
411
464
low_f (Float):
412
465
low frequency cutoff to calculate median psd value. Defaults to 5Hz
413
466
@@ -424,12 +477,28 @@ def calc_wn(aman, pxx=None, freqs=None, low_f=5, high_f=10):
424
477
if pxx is None :
425
478
pxx = aman .Pxx
426
479
480
+ if nseg is None :
481
+ nseg = aman .get ('nseg' )
482
+
483
+ if nseg is None :
484
+ warnings .warn ('white noise level estimated by median PSD is biased. '
485
+ 'nseg is necessary to debias. Need to use following '
486
+ 'arguments in calc_psd to get correct nseg. '
487
+ '`noverlap=0, full_output=True`' )
488
+ debias = None
489
+ else :
490
+ debias = 2 * nseg / chi2 .ppf (0.5 , 2 * nseg )
491
+
427
492
fmsk = np .all ([freqs >= low_f , freqs <= high_f ], axis = 0 )
428
493
if pxx .ndim == 1 :
429
494
wn2 = np .median (pxx [fmsk ])
430
495
else :
431
496
wn2 = np .median (pxx [:, fmsk ], axis = 1 )
432
-
497
+ if debias is not None :
498
+ if pxx .ndim == 3 :
499
+ wn2 *= debias [None , :]
500
+ else :
501
+ wn2 *= debias
433
502
wn = np .sqrt (wn2 )
434
503
return wn
435
504
0 commit comments