Skip to content

Commit 8f3182b

Browse files
authored
Debias wn (#1238)
* debias wn * draft of test_psd.py * update test * making changes to be non-breaking * fix * address comments
1 parent 18cdb49 commit 8f3182b

File tree

4 files changed

+223
-68
lines changed

4 files changed

+223
-68
lines changed

sotodlib/preprocess/processes.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,41 +356,55 @@ class PSDCalc(_Preprocess):
356356
""" Calculate the PSD of the data and add it to the Preprocessing AxisManager under the
357357
"psd" field.
358358
359+
Note: noverlap = 0 amd full_output = True are recommended to get unbiased
360+
median white noise estimation by Noise.
361+
359362
Example config block::
360363
361364
- "name : "psd"
362365
"signal: "signal" # optional
363366
"wrap": "psd" # optional
364367
"calc":
365-
"psd_cfgs": # optional, kwargs to scipy.welch
366-
"nperseg": 1024
368+
"nperseg": 1024 # optional
369+
"noverlap": 0 # optional
367370
"wrap_name": "psd" # optional
368-
"subscan": False
371+
"subscan": False # optional
372+
"full_output": True # optional
369373
"save": True
370374
371375
.. autofunction:: sotodlib.tod_ops.fft_ops.calc_psd
372376
"""
373377
name = "psd"
374-
378+
375379
def __init__(self, step_cfgs):
376380
self.signal = step_cfgs.get('signal', 'signal')
377381
self.wrap = step_cfgs.get('wrap', 'psd')
378382

379383
super().__init__(step_cfgs)
380384

381385
def calc_and_save(self, aman, proc_aman):
382-
freqs, Pxx = tod_ops.fft_ops.calc_psd(aman, signal=aman[self.signal],
383-
**self.calc_cfgs)
386+
full_output = self.calc_cfgs.get('full_output')
387+
if full_output:
388+
freqs, Pxx, nseg = tod_ops.fft_ops.calc_psd(aman, signal=aman[self.signal],
389+
**self.calc_cfgs)
390+
else:
391+
freqs, Pxx = tod_ops.fft_ops.calc_psd(aman, signal=aman[self.signal],
392+
**self.calc_cfgs)
384393

385394
fft_aman = core.AxisManager(aman.dets,
386395
core.OffsetAxis("nusamps", len(freqs)))
387396
pxx_axis_map = [(0, "dets"), (1, "nusamps")]
388397
if self.calc_cfgs.get('subscan', False):
389398
fft_aman.wrap("Pxx_ss", Pxx, pxx_axis_map+[(2, aman.subscans)])
390399
Pxx = np.nanmean(Pxx, axis=-1) # Mean of subscans
400+
if full_output:
401+
fft_aman.wrap("nseg_ss", nseg, [(0, aman.subscans)])
402+
nseg = np.nansum(nseg)
391403

392404
fft_aman.wrap("freqs", freqs, [(0,"nusamps")])
393405
fft_aman.wrap("Pxx", Pxx, pxx_axis_map)
406+
if full_output:
407+
fft_aman.wrap("nseg", nseg)
394408

395409
self.save(proc_aman, fft_aman)
396410

@@ -559,6 +573,7 @@ def calc_and_save(self, aman, proc_aman):
559573
wn_f_low, wn_f_high = self.calc_cfgs.get('fwhite', (5, 10))
560574
self.calc_cfgs['wn_est'] = tod_ops.fft_ops.calc_wn(aman, pxx=pxx,
561575
freqs=psd.freqs,
576+
nseg=psd.get('nseg'),
562577
low_f=wn_f_low,
563578
high_f=wn_f_high)
564579
if self.calc_cfgs.get('subscan') is None:
@@ -578,6 +593,7 @@ def calc_and_save(self, aman, proc_aman):
578593
wn_f_high = self.calc_cfgs.get("high_f", 10)
579594
wn = tod_ops.fft_ops.calc_wn(aman, pxx=pxx,
580595
freqs=psd.freqs,
596+
nseg=psd.get('nseg'),
581597
low_f=wn_f_low,
582598
high_f=wn_f_high)
583599
if not self.subscan:

sotodlib/tod_ops/fft_ops.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from functools import lru_cache, partial
55
from typing_extensions import Callable
66
from numpy.typing import NDArray
7-
import sys
7+
import warnings
88
import numdifftools as ndt
99
import numpy as np
1010
import pyfftw
1111
import so3g
12-
from so3g.proj import Ranges, RangesMatrix
12+
from so3g.proj import Ranges
1313
from scipy.optimize import minimize
1414
from scipy.signal import welch
15+
from scipy.stats import chi2
1516
from sotodlib import core, hwp
1617
from sotodlib.tod_ops import detrend_tod
1718

@@ -269,10 +270,11 @@ def calc_psd(
269270
max_samples=2**18,
270271
prefer='center',
271272
freq_spacing=None,
272-
merge=False,
273+
merge=False,
273274
merge_suffix=None,
274-
overwrite=True,
275+
overwrite=True,
275276
subscan=False,
277+
full_output=False,
276278
**kwargs
277279
):
278280
"""Calculates the power spectrum density of an input signal using signal.welch().
@@ -295,17 +297,38 @@ def calc_psd(
295297
merge_suffix (str, optional): Suffix to append to the Pxx field name in aman. Defaults to None (merged as Pxx).
296298
overwrite (bool): if true will overwrite f, Pxx axes.
297299
subscan (bool): if True, compute psd on subscans.
300+
full_output: if True this also outputs nseg, the number of segments used for
301+
welch, for correcting bias of median white noise estimation by calc_wn.
298302
**kwargs: keyword args to be passed to signal.welch().
299303
300304
Returns:
301305
freqs: array of frequencies corresponding to PSD calculated from welch.
302306
Pxx: array of PSD values.
307+
nseg: number of segments used for welch. this is returned if full_output is True.
303308
"""
304309
if signal is None:
305310
signal = aman.signal
311+
312+
if ("noverlap" not in kwargs) or \
313+
("noverlap" in kwargs and kwargs["noverlap"] != 0):
314+
warnings.warn('calc_wn will be biased. noverlap argument of welch '
315+
'needs to be 0 to get unbiased median white noise estimate.')
316+
if not full_output:
317+
warnings.warn('calc_wn will be biased. full_output argument of calc_psd '
318+
'needs to be True to get unbiased median white noise estimate.')
319+
306320
if subscan:
307-
freqs, Pxx = _calc_psd_subscan(aman, signal=signal, freq_spacing=freq_spacing, **kwargs)
321+
if full_output:
322+
freqs, Pxx, nseg = _calc_psd_subscan(aman, signal=signal,
323+
freq_spacing=freq_spacing,
324+
full_output=True,
325+
**kwargs)
326+
else:
327+
freqs, Pxx = _calc_psd_subscan(aman, signal=signal,
328+
freq_spacing=freq_spacing,
329+
**kwargs)
308330
axis_map_pxx = [(0, "dets"), (1, "nusamps"), (2, "subscans")]
331+
axis_map_nseg = [(0, "subscans")]
309332
else:
310333
if timestamps is None:
311334
timestamps = aman.timestamps
@@ -334,8 +357,14 @@ def calc_psd(
334357
nperseg = int(2 ** (np.around(np.log2((stop - start) / 50.0))))
335358
kwargs["nperseg"] = nperseg
336359

360+
if kwargs["nperseg"] > max_samples:
361+
nseg = 1
362+
else:
363+
nseg = int(max_samples / kwargs["nperseg"])
364+
337365
freqs, Pxx = welch(signal[:, start:stop], fs, **kwargs)
338366
axis_map_pxx = [(0, aman.dets), (1, "nusamps")]
367+
axis_map_nseg = None
339368

340369
if merge:
341370
if 'nusamps' not in aman:
@@ -345,19 +374,29 @@ def calc_psd(
345374
if len(freqs) != aman.nusamps.count:
346375
raise ValueError('New freqs does not match the shape of nusamps\
347376
To avoid this, use the same value for nperseg')
348-
377+
349378
if merge_suffix is None:
350379
Pxx_name = 'Pxx'
351380
else:
352381
Pxx_name = f'Pxx_{merge_suffix}'
353-
382+
354383
if overwrite:
355384
if Pxx_name in aman._fields:
356385
aman.move("Pxx", None)
357386
aman.wrap(Pxx_name, Pxx, axis_map_pxx)
358-
return freqs, Pxx
359387

360-
def _calc_psd_subscan(aman, signal=None, freq_spacing=None, **kwargs):
388+
if full_output:
389+
if overwrite and "nseg" in aman._fields:
390+
aman.move("nseg", None)
391+
aman.wrap("nseg", nseg, axis_map_nseg)
392+
393+
if full_output:
394+
return freqs, Pxx, nseg
395+
else:
396+
return freqs, Pxx
397+
398+
399+
def _calc_psd_subscan(aman, signal=None, freq_spacing=None, full_output=False, **kwargs):
361400
"""
362401
Calculate the power spectrum density of subscans using signal.welch().
363402
Data defaults to aman.signal. aman.timestamps is used for times.
@@ -378,20 +417,27 @@ def _calc_psd_subscan(aman, signal=None, freq_spacing=None, **kwargs):
378417
nperseg = int(2 ** (np.around(np.log2(np.median(duration_samps) / 4))))
379418
kwargs["nperseg"] = nperseg
380419

381-
Pxx = []
420+
Pxx, nseg = [], []
382421
for iss in range(aman.subscan_info.subscans.count):
383422
signal_ss = get_subscan_signal(aman, signal, iss)
384423
axis = -1 if "axis" not in kwargs else kwargs["axis"]
385-
if signal_ss.shape[axis] >= kwargs["nperseg"]:
424+
nsamps = signal_ss.shape[axis]
425+
if nsamps >= kwargs["nperseg"]:
386426
freqs, pxx_sub = welch(signal_ss, fs, **kwargs)
387427
Pxx.append(pxx_sub)
428+
nseg.append(int(nsamps / kwargs["nperseg"]))
388429
else:
389430
Pxx.append(np.full((signal.shape[0], kwargs["nperseg"]//2+1), np.nan)) # Add nans if subscan is too short
431+
nseg.append(np.nan)
432+
nseg = np.array(nseg)
390433
Pxx = np.array(Pxx)
391434
Pxx = Pxx.transpose(1, 2, 0) # Dets, nusamps, subscans
392-
return freqs, Pxx
435+
if full_output:
436+
return freqs, Pxx, nseg
437+
else:
438+
return freqs, Pxx
393439

394-
def calc_wn(aman, pxx=None, freqs=None, low_f=5, high_f=10):
440+
def calc_wn(aman, pxx=None, freqs=None, nseg=None, low_f=5, high_f=10):
395441
"""
396442
Function that calculates the white noise level as a median PSD value between
397443
two frequencies. Defaults to calculation of white noise between 5 and 10Hz.
@@ -408,6 +454,13 @@ def calc_wn(aman, pxx=None, freqs=None, low_f=5, high_f=10):
408454
freqs (1d Float array):
409455
frequency information related to the psd. Defaults to aman.freqs
410456
457+
nseg (Int or 1d Int array):
458+
number of segmnents used for welch. Defaults to aman.nseg. This is
459+
necessary for debiasing median white noise estimation. welch PSD with
460+
non-overlapping n segments follows chi square distribution with
461+
2 * nseg degrees of freedom. The median of chi square distribution is
462+
biased from its average.
463+
411464
low_f (Float):
412465
low frequency cutoff to calculate median psd value. Defaults to 5Hz
413466
@@ -424,12 +477,28 @@ def calc_wn(aman, pxx=None, freqs=None, low_f=5, high_f=10):
424477
if pxx is None:
425478
pxx = aman.Pxx
426479

480+
if nseg is None:
481+
nseg = aman.get('nseg')
482+
483+
if nseg is None:
484+
warnings.warn('white noise level estimated by median PSD is biased. '
485+
'nseg is necessary to debias. Need to use following '
486+
'arguments in calc_psd to get correct nseg. '
487+
'`noverlap=0, full_output=True`')
488+
debias = None
489+
else:
490+
debias = 2 * nseg / chi2.ppf(0.5, 2 * nseg)
491+
427492
fmsk = np.all([freqs >= low_f, freqs <= high_f], axis=0)
428493
if pxx.ndim == 1:
429494
wn2 = np.median(pxx[fmsk])
430495
else:
431496
wn2 = np.median(pxx[:, fmsk], axis=1)
432-
497+
if debias is not None:
498+
if pxx.ndim == 3:
499+
wn2 *= debias[None, :]
500+
else:
501+
wn2 *= debias
433502
wn = np.sqrt(wn2)
434503
return wn
435504

tests/test_psd.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
""" Test psd calculation
2+
"""
3+
4+
5+
import unittest
6+
import numpy as np
7+
from numpy.fft import rfftfreq, irfft
8+
9+
from sotodlib import core
10+
from sotodlib.tod_ops import detrend_tod
11+
from sotodlib.tod_ops.flags import get_turnaround_flags
12+
from sotodlib.tod_ops.fft_ops import (
13+
calc_psd, calc_wn, fit_noise_model, noise_model)
14+
15+
from .test_azss import get_scan
16+
17+
TOL_BIAS = 0.005
18+
19+
20+
class PSDTest(unittest.TestCase):
21+
def test_psd_fit(self):
22+
fs = 200.
23+
dets = core.LabelAxis('dets', [f'det{di:003}' for di in range(20)])
24+
nsamps = 200 * 3600
25+
26+
aman = core.AxisManager(dets)
27+
ndets = aman.dets.count
28+
29+
white_noise_amp_input = 50 + np.random.randn(ndets) # W/sqrt{Hz}
30+
fknee_input = 1 + 0.1 * np.random.randn(ndets)
31+
alpha_input = 3 + 0.2 * np.random.randn(ndets)
32+
33+
freqs = rfftfreq(nsamps, d=1/fs)
34+
params = [white_noise_amp_input[:, np.newaxis],
35+
fknee_input[:, np.newaxis],
36+
alpha_input[:, np.newaxis]]
37+
pxx_input = noise_model(freqs, params)
38+
39+
pxx_input[:, 0] = 0
40+
41+
T = nsamps/fs
42+
ft_amps = np.sqrt(pxx_input * T * fs**2 / 2)
43+
44+
ft_phases = np.random.uniform(0, 2 * np.pi, size=ft_amps.shape)
45+
ft_coefs = ft_amps * np.exp(1.0j * ft_phases)
46+
realized_noise = irfft(ft_coefs)
47+
timestamps = 1700000000 + np.arange(0, realized_noise.shape[1])/fs
48+
aman.add_axis(core.OffsetAxis('samps', len(timestamps)))
49+
aman.wrap('timestamps', timestamps, [(0, 'samps')])
50+
aman.wrap('signal', realized_noise, [(0, 'dets'), (1, 'samps')])
51+
52+
detrend_tod(aman)
53+
freqs_output, Pxx_output = calc_psd(aman, nperseg=200*100)
54+
fit_result = fit_noise_model(aman, wn_est=50, fknee_est=1.0,
55+
alpha_est=3.3, lowf=0.05,
56+
f_max=5, binning=True,
57+
psdargs={'nperseg': 200*1000})
58+
wnl_fit = fit_result.fit[:, 0]
59+
fk_fit = fit_result.fit[:, 1]
60+
alpha_fit = fit_result.fit[:, 2]
61+
62+
self.assertTrue(np.abs(np.median(white_noise_amp_input - wnl_fit)) < 1)
63+
self.assertTrue(np.abs(np.median(fknee_input - fk_fit)) < 0.1)
64+
self.assertTrue(np.abs(np.median(alpha_input - alpha_fit)) < 0.1)
65+
66+
def test_wn_debias(self):
67+
# prep
68+
timestamps, az = get_scan(
69+
n_scans=20, scan_accel=0.25, scanrate=0.5, az0=0, az1=40)
70+
71+
nsamps = len(timestamps)
72+
ndets = 100
73+
np.random.seed(0)
74+
signal = np.random.normal(0, 1, size=(ndets, nsamps))
75+
76+
dets = [f"det{i}" for i in range(ndets)]
77+
aman = core.AxisManager(
78+
core.LabelAxis("dets", dets),
79+
core.IndexAxis("samps", nsamps)
80+
)
81+
aman.wrap("timestamps", timestamps, [(0, "samps")])
82+
aman.wrap("signal", signal, [(0, "dets"), (1, "samps")])
83+
boresight = core.AxisManager(aman.samps)
84+
boresight.wrap("az", az, [(0, "samps")])
85+
aman.wrap('boresight', boresight)
86+
aman.wrap('flags', core.AxisManager(aman.dets, aman.samps))
87+
get_turnaround_flags(aman)
88+
89+
# test default arguments, this is biased
90+
calc_psd(aman, merge=True, nperseg=2**18)
91+
wn = calc_wn(aman)
92+
ratio = np.average(wn) / np.sqrt(np.average(aman.Pxx))
93+
self.assertTrue(abs(ratio - 1) > TOL_BIAS)
94+
# test debias, full_output=True, noverlap=0
95+
freqs, Pxx, nseg = calc_psd(aman, merge=False, full_output=True,
96+
noverlap=0, nperseg=2**18)
97+
wn = calc_wn(aman, Pxx, freqs, nseg)
98+
ratio = np.average(wn) / np.sqrt(np.average(Pxx))
99+
self.assertAlmostEqual(ratio, 1, delta=TOL_BIAS)
100+
# test quarter nperseg
101+
freqs, Pxx, nseg = calc_psd(aman, merge=False, full_output=True,
102+
noverlap=0, nperseg=2**16)
103+
wn = calc_wn(aman, Pxx, freqs, nseg)
104+
ratio = np.average(wn) / np.sqrt(np.average(Pxx))
105+
self.assertAlmostEqual(ratio, 1, delta=TOL_BIAS)
106+
# test defulat nperseg
107+
freqs, Pxx, nseg = calc_psd(aman, merge=False, full_output=True,
108+
noverlap=0)
109+
wn = calc_wn(aman, Pxx, freqs, nseg)
110+
ratio = np.average(wn) / np.sqrt(np.average(Pxx))
111+
self.assertAlmostEqual(ratio, 1, delta=TOL_BIAS)
112+
# test subscan
113+
freqs, Pxx, nseg = calc_psd(aman, merge=False, full_output=True,
114+
noverlap=0, subscan=True)
115+
wn = calc_wn(aman, Pxx, freqs, nseg)
116+
ratio = np.average(wn) / np.sqrt(np.average(Pxx))
117+
self.assertAlmostEqual(ratio, 1, delta=TOL_BIAS)

0 commit comments

Comments
 (0)